Skip to content

Commit

Permalink
Cythonize token block dataset (facebookresearch#834)
Browse files Browse the repository at this point in the history
Summary:
Cythonized token block dataset code, it's `> 100x` faster. Token block for entire `bookwiki+CC+stories+openweb` is just ~`39.9` seconds.

TODO:
1) I think, I can make it 2x more faster.
2) cleanup.

EDIT History:
~~First pass at parellelizing `token_block_dataset`. The code feels somewhat complicated and cluttered.
This is 2-3x faster though on my tests on `bookwiki` dataset with both `complete` and `complete_doc` modes.
myleott Can you take a look for correctness as I am still not 100% sure that I am not missing corner cases.~~
Pull Request resolved: fairinternal/fairseq-py#834

Test Plan:
Imported from GitHub, without a `Test Plan:` line.

Test workflow: f133816198

Reviewed By: myleott

Differential Revision: D16970257

Pulled By: myleott

fbshipit-source-id: ec45a308193c9e9f3e7075336c15df4723228d6f
  • Loading branch information
Naman Goyal authored and facebook-github-bot committed Aug 23, 2019
1 parent 6e2bd79 commit 4fc3953
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 154 deletions.
46 changes: 9 additions & 37 deletions fairseq/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
import os

import numpy as np
import sys
import types

from fairseq.data.data_utils_fast import batch_by_size_fast


def infer_language_pair(path):
Expand Down Expand Up @@ -196,45 +200,13 @@ def batch_by_size(
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N (default: 1).
"""
max_tokens = max_tokens if max_tokens is not None else float('Inf')
max_sentences = max_sentences if max_sentences is not None else float('Inf')
max_tokens = max_tokens if max_tokens is not None else sys.maxsize
max_sentences = max_sentences if max_sentences is not None else sys.maxsize
bsz_mult = required_batch_size_multiple

batch = []

def is_batch_full(num_tokens):
if len(batch) == 0:
return False
if len(batch) == max_sentences:
return True
if num_tokens > max_tokens:
return True
return False

sample_len = 0
sample_lens = []
for idx in indices:
sample_lens.append(num_tokens_fn(idx))
sample_len = max(sample_len, sample_lens[-1])
assert sample_len <= max_tokens, (
"sentence at index {} of size {} exceeds max_tokens "
"limit of {}!".format(idx, sample_len, max_tokens)
)
num_tokens = (len(batch) + 1) * sample_len
if is_batch_full(num_tokens):
mod_len = max(
bsz_mult * (len(batch) // bsz_mult),
len(batch) % bsz_mult,
)
yield batch[:mod_len]
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0

batch.append(idx)

if len(batch) > 0:
yield batch
if isinstance(indices, types.GeneratorType):
indices = np.fromiter(indices, dtype=np.int64, count=-1)
return batch_by_size_fast(indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult)


def process_bpe_symbol(sentence: str, bpe_symbol: str):
Expand Down
67 changes: 67 additions & 0 deletions fairseq/data/data_utils_fast.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np

cimport cython
cimport numpy as np

DTYPE = np.int64
ctypedef np.int64_t DTYPE_t


cdef _is_batch_full(list batch, long num_tokens, long max_tokens, long max_sentences):
if len(batch) == 0:
return 0
if len(batch) == max_sentences:
return 1
if num_tokens > max_tokens:
return 1
return 0


@cython.cdivision(True)
cpdef list batch_by_size_fast(
np.ndarray[DTYPE_t, ndim=1] indices,
num_tokens_fn,
long max_tokens,
long max_sentences,
int bsz_mult,
):
cdef long sample_len = 0
cdef list sample_lens = []
cdef list batch = []
cdef list batches = []
cdef long mod_len
cdef long i
cdef long idx
cdef long num_tokens
cdef DTYPE_t[:] indices_view = indices

for i in range(len(indices_view)):
idx = indices_view[i]
num_tokens = num_tokens_fn(idx)
sample_lens.append(num_tokens)
sample_len = max(sample_len, num_tokens)

assert sample_len <= max_tokens, (
"sentence at index {} of size {} exceeds max_tokens "
"limit of {}!".format(idx, sample_len, max_tokens)
)
num_tokens = (len(batch) + 1) * sample_len

if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
mod_len = max(
bsz_mult * (len(batch) // bsz_mult),
len(batch) % bsz_mult,
)
batches.append(batch[:mod_len])
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
if len(batch) > 0:
batches.append(batch)
return batches
134 changes: 18 additions & 116 deletions fairseq/data/token_block_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math

import numpy as np
import torch

from fairseq.data.token_block_utils_fast import (
_get_slice_indices_fast,
_get_block_to_dataset_index_fast,
)

from fairseq.data import FairseqDataset, plasma_utils


Expand All @@ -33,7 +36,6 @@ class TokenBlockDataset(FairseqDataset):
'complete_doc' break mode). Typically 1 if the sentences have eos
and 0 otherwise.
"""

def __init__(
self,
dataset,
Expand All @@ -50,70 +52,22 @@ def __init__(
self.pad = pad
self.eos = eos
self.include_targets = include_targets
slice_indices = []

assert len(dataset) == len(sizes)
assert len(dataset) > 0
sizes = np.array(sizes, dtype=int)

if break_mode is None or break_mode == "none":
total_size = sum(sizes)
length = math.ceil(total_size / block_size)
if isinstance(sizes, list):
sizes = np.array(sizes, dtype=np.int64)
else:
sizes = sizes.astype(np.int64)

def block_at(i):
start = i * block_size
end = min(start + block_size, total_size)
return (start, end)
break_mode = break_mode if break_mode is not None else 'none'

slice_indices = [block_at(i) for i in range(length)]
elif break_mode == "complete":
tok_idx = 0
sz_idx = 0
curr_size = 0
while sz_idx < len(sizes):
if curr_size + sizes[sz_idx] <= block_size or curr_size == 0:
curr_size += sizes[sz_idx]
sz_idx += 1
else:
slice_indices.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if curr_size > 0:
slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == "complete_doc":
tok_idx = 0
sz_idx = 0
curr_size = 0
while sz_idx < len(sizes):
if (
(curr_size + sizes[sz_idx] <= block_size or curr_size == 0)
# an empty sentence indicates end-of-document:
and sizes[sz_idx] != document_sep_len
):
curr_size += sizes[sz_idx]
sz_idx += 1
else:
if curr_size > 1:
slice_indices.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if sizes[sz_idx] == document_sep_len:
tok_idx += sizes[sz_idx]
sz_idx += 1
if curr_size > 1:
slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == "eos":
slice_indices = np.empty((len(sizes), 2), dtype=int)
if not torch.is_tensor(sizes):
sizes = torch.tensor(sizes)
cumsum = torch.cumsum(sizes, dim=0)
slice_indices[0] = [0, sizes[0]]
if len(cumsum) > 1:
slice_indices[1:] = cumsum.unfold(0, 2, 1)
else:
raise ValueError("Invalid break_mode: " + break_mode)
# For "eos" break-mode, block_size is not required parameters.
if break_mode == "eos" and block_size is None:
block_size = 0

slice_indices = np.array(slice_indices, dtype=int)
slice_indices = _get_slice_indices_fast(sizes, break_mode, block_size, document_sep_len)
self._sizes = slice_indices[:, 1] - slice_indices[:, 0]

# build index mapping block indices to the underlying dataset indices
Expand All @@ -130,23 +84,10 @@ def block_at(i):
1,
)
else:
ds = DatasetSearcher(sizes)
block_to_dataset_index = np.empty((len(slice_indices), 3), dtype=int)
for i, (s, e) in enumerate(slice_indices):
ds.seek(s)
start_ds_idx = ds.current_index
start_offset = ds.current_offset
if e <= s:
end_ds_idx = start_ds_idx
else:
ds.seek(e - 1)
end_ds_idx = ds.current_index
block_to_dataset_index[i] = (
start_ds_idx, # starting index in dataset
start_offset, # starting offset within starting index
end_ds_idx, # ending index in dataset
)

block_to_dataset_index = _get_block_to_dataset_index_fast(
sizes,
slice_indices,
)
self._slice_indices = plasma_utils.PlasmaArray(slice_indices)
self._sizes = plasma_utils.PlasmaArray(self._sizes)
self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index)
Expand Down Expand Up @@ -215,42 +156,3 @@ def prefetch(self, indices):
for ds_idx in range(start_ds_idx, end_ds_idx + 1)
}
)


class DatasetSearcher(object):
"""Helper for mapping "flat" indices to indices and offsets in an
underlying dataset."""

def __init__(self, sizes):
self.sizes = sizes
self.reset()

def reset(self):
self.current_index = 0 # index in underlying dataset
self.current_offset = 0 # offset within current index in underlying dataset
self.current_i = 0 # "flat" index

def seek(self, i):
assert i >= 0

def step():
if i < self.current_i:
self.reset()
if i > self.current_i:
to_consume = i - self.current_i
remaining = self.sizes[self.current_index] - self.current_offset
if remaining > to_consume:
self.current_offset += to_consume
self.current_i += to_consume
else:
assert remaining > 0
self.current_i += remaining
self.current_index += 1
self.current_offset = 0
return True
return False

not_done = True
while not_done:
not_done = step()
assert self.current_i == i
Loading

0 comments on commit 4fc3953

Please sign in to comment.