diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 38c385018d..dee29528a4 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -12,6 +12,10 @@ import os import numpy as np +import sys +import types + +from fairseq.data.data_utils_fast import batch_by_size_fast def infer_language_pair(path): @@ -196,45 +200,13 @@ def batch_by_size( required_batch_size_multiple (int, optional): require batch size to be a multiple of N (default: 1). """ - max_tokens = max_tokens if max_tokens is not None else float('Inf') - max_sentences = max_sentences if max_sentences is not None else float('Inf') + max_tokens = max_tokens if max_tokens is not None else sys.maxsize + max_sentences = max_sentences if max_sentences is not None else sys.maxsize bsz_mult = required_batch_size_multiple - batch = [] - - def is_batch_full(num_tokens): - if len(batch) == 0: - return False - if len(batch) == max_sentences: - return True - if num_tokens > max_tokens: - return True - return False - - sample_len = 0 - sample_lens = [] - for idx in indices: - sample_lens.append(num_tokens_fn(idx)) - sample_len = max(sample_len, sample_lens[-1]) - assert sample_len <= max_tokens, ( - "sentence at index {} of size {} exceeds max_tokens " - "limit of {}!".format(idx, sample_len, max_tokens) - ) - num_tokens = (len(batch) + 1) * sample_len - if is_batch_full(num_tokens): - mod_len = max( - bsz_mult * (len(batch) // bsz_mult), - len(batch) % bsz_mult, - ) - yield batch[:mod_len] - batch = batch[mod_len:] - sample_lens = sample_lens[mod_len:] - sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 - - batch.append(idx) - - if len(batch) > 0: - yield batch + if isinstance(indices, types.GeneratorType): + indices = np.fromiter(indices, dtype=np.int64, count=-1) + return batch_by_size_fast(indices, num_tokens_fn, max_tokens, max_sentences, bsz_mult) def process_bpe_symbol(sentence: str, bpe_symbol: str): diff --git a/fairseq/data/data_utils_fast.pyx b/fairseq/data/data_utils_fast.pyx new file mode 100644 index 0000000000..a9c6e57b34 --- /dev/null +++ b/fairseq/data/data_utils_fast.pyx @@ -0,0 +1,67 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +cimport cython +cimport numpy as np + +DTYPE = np.int64 +ctypedef np.int64_t DTYPE_t + + +cdef _is_batch_full(list batch, long num_tokens, long max_tokens, long max_sentences): + if len(batch) == 0: + return 0 + if len(batch) == max_sentences: + return 1 + if num_tokens > max_tokens: + return 1 + return 0 + + +@cython.cdivision(True) +cpdef list batch_by_size_fast( + np.ndarray[DTYPE_t, ndim=1] indices, + num_tokens_fn, + long max_tokens, + long max_sentences, + int bsz_mult, +): + cdef long sample_len = 0 + cdef list sample_lens = [] + cdef list batch = [] + cdef list batches = [] + cdef long mod_len + cdef long i + cdef long idx + cdef long num_tokens + cdef DTYPE_t[:] indices_view = indices + + for i in range(len(indices_view)): + idx = indices_view[i] + num_tokens = num_tokens_fn(idx) + sample_lens.append(num_tokens) + sample_len = max(sample_len, num_tokens) + + assert sample_len <= max_tokens, ( + "sentence at index {} of size {} exceeds max_tokens " + "limit of {}!".format(idx, sample_len, max_tokens) + ) + num_tokens = (len(batch) + 1) * sample_len + + if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): + mod_len = max( + bsz_mult * (len(batch) // bsz_mult), + len(batch) % bsz_mult, + ) + batches.append(batch[:mod_len]) + batch = batch[mod_len:] + sample_lens = sample_lens[mod_len:] + sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 + batch.append(idx) + if len(batch) > 0: + batches.append(batch) + return batches diff --git a/fairseq/data/token_block_dataset.py b/fairseq/data/token_block_dataset.py index 6dd2cc8615..eddbea43ba 100644 --- a/fairseq/data/token_block_dataset.py +++ b/fairseq/data/token_block_dataset.py @@ -3,11 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import math - import numpy as np import torch +from fairseq.data.token_block_utils_fast import ( + _get_slice_indices_fast, + _get_block_to_dataset_index_fast, +) + from fairseq.data import FairseqDataset, plasma_utils @@ -33,7 +36,6 @@ class TokenBlockDataset(FairseqDataset): 'complete_doc' break mode). Typically 1 if the sentences have eos and 0 otherwise. """ - def __init__( self, dataset, @@ -50,70 +52,22 @@ def __init__( self.pad = pad self.eos = eos self.include_targets = include_targets - slice_indices = [] assert len(dataset) == len(sizes) assert len(dataset) > 0 - sizes = np.array(sizes, dtype=int) - if break_mode is None or break_mode == "none": - total_size = sum(sizes) - length = math.ceil(total_size / block_size) + if isinstance(sizes, list): + sizes = np.array(sizes, dtype=np.int64) + else: + sizes = sizes.astype(np.int64) - def block_at(i): - start = i * block_size - end = min(start + block_size, total_size) - return (start, end) + break_mode = break_mode if break_mode is not None else 'none' - slice_indices = [block_at(i) for i in range(length)] - elif break_mode == "complete": - tok_idx = 0 - sz_idx = 0 - curr_size = 0 - while sz_idx < len(sizes): - if curr_size + sizes[sz_idx] <= block_size or curr_size == 0: - curr_size += sizes[sz_idx] - sz_idx += 1 - else: - slice_indices.append((tok_idx, tok_idx + curr_size)) - tok_idx += curr_size - curr_size = 0 - if curr_size > 0: - slice_indices.append((tok_idx, tok_idx + curr_size)) - elif break_mode == "complete_doc": - tok_idx = 0 - sz_idx = 0 - curr_size = 0 - while sz_idx < len(sizes): - if ( - (curr_size + sizes[sz_idx] <= block_size or curr_size == 0) - # an empty sentence indicates end-of-document: - and sizes[sz_idx] != document_sep_len - ): - curr_size += sizes[sz_idx] - sz_idx += 1 - else: - if curr_size > 1: - slice_indices.append((tok_idx, tok_idx + curr_size)) - tok_idx += curr_size - curr_size = 0 - if sizes[sz_idx] == document_sep_len: - tok_idx += sizes[sz_idx] - sz_idx += 1 - if curr_size > 1: - slice_indices.append((tok_idx, tok_idx + curr_size)) - elif break_mode == "eos": - slice_indices = np.empty((len(sizes), 2), dtype=int) - if not torch.is_tensor(sizes): - sizes = torch.tensor(sizes) - cumsum = torch.cumsum(sizes, dim=0) - slice_indices[0] = [0, sizes[0]] - if len(cumsum) > 1: - slice_indices[1:] = cumsum.unfold(0, 2, 1) - else: - raise ValueError("Invalid break_mode: " + break_mode) + # For "eos" break-mode, block_size is not required parameters. + if break_mode == "eos" and block_size is None: + block_size = 0 - slice_indices = np.array(slice_indices, dtype=int) + slice_indices = _get_slice_indices_fast(sizes, break_mode, block_size, document_sep_len) self._sizes = slice_indices[:, 1] - slice_indices[:, 0] # build index mapping block indices to the underlying dataset indices @@ -130,23 +84,10 @@ def block_at(i): 1, ) else: - ds = DatasetSearcher(sizes) - block_to_dataset_index = np.empty((len(slice_indices), 3), dtype=int) - for i, (s, e) in enumerate(slice_indices): - ds.seek(s) - start_ds_idx = ds.current_index - start_offset = ds.current_offset - if e <= s: - end_ds_idx = start_ds_idx - else: - ds.seek(e - 1) - end_ds_idx = ds.current_index - block_to_dataset_index[i] = ( - start_ds_idx, # starting index in dataset - start_offset, # starting offset within starting index - end_ds_idx, # ending index in dataset - ) - + block_to_dataset_index = _get_block_to_dataset_index_fast( + sizes, + slice_indices, + ) self._slice_indices = plasma_utils.PlasmaArray(slice_indices) self._sizes = plasma_utils.PlasmaArray(self._sizes) self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index) @@ -215,42 +156,3 @@ def prefetch(self, indices): for ds_idx in range(start_ds_idx, end_ds_idx + 1) } ) - - -class DatasetSearcher(object): - """Helper for mapping "flat" indices to indices and offsets in an - underlying dataset.""" - - def __init__(self, sizes): - self.sizes = sizes - self.reset() - - def reset(self): - self.current_index = 0 # index in underlying dataset - self.current_offset = 0 # offset within current index in underlying dataset - self.current_i = 0 # "flat" index - - def seek(self, i): - assert i >= 0 - - def step(): - if i < self.current_i: - self.reset() - if i > self.current_i: - to_consume = i - self.current_i - remaining = self.sizes[self.current_index] - self.current_offset - if remaining > to_consume: - self.current_offset += to_consume - self.current_i += to_consume - else: - assert remaining > 0 - self.current_i += remaining - self.current_index += 1 - self.current_offset = 0 - return True - return False - - not_done = True - while not_done: - not_done = step() - assert self.current_i == i diff --git a/fairseq/data/token_block_utils_fast.pyx b/fairseq/data/token_block_utils_fast.pyx new file mode 100644 index 0000000000..bf3b0ecf07 --- /dev/null +++ b/fairseq/data/token_block_utils_fast.pyx @@ -0,0 +1,184 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from itertools import chain +from libc.math cimport ceil + +cimport cython +cimport numpy as np + +DTYPE = np.int64 +ctypedef np.int64_t DTYPE_t + + +@cython.boundscheck(False) +@cython.wraparound(False) +@cython.nonecheck(False) +cdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_none_mode(np.ndarray[DTYPE_t, ndim=1] sizes, int block_size): + cdef DTYPE_t total_size = sizes.sum() + cdef DTYPE_t length = ceil(total_size / block_size) + cdef np.ndarray[DTYPE_t, ndim=2] slice_indices = np.zeros([length, 2], dtype=DTYPE) + cdef DTYPE_t[:, :] slice_indices_view = slice_indices + cdef DTYPE_t i + cdef DTYPE_t start + cdef DTYPE_t end + for i in range(length): + start = i * block_size + end = min(start + block_size, total_size) + slice_indices_view[i][0] = start + slice_indices_view[i][1] = end + return slice_indices + + +cdef np.ndarray[DTYPE_t, ndim=2] _fast_convert_to_np_array(list list_of_list): + """ + Faster function to convert DTYPE_t list of list. + Only fast when there are huge number of rows and low number of columns. + """ + cdef np.ndarray[DTYPE_t, ndim=1] flat = np.fromiter(chain.from_iterable(list_of_list), DTYPE, -1) + return flat.reshape((len(list_of_list), -1)) + + +@cython.boundscheck(False) +@cython.wraparound(False) +@cython.nonecheck(False) +cpdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_fast(np.ndarray[DTYPE_t, ndim=1] sizes, str break_mode, int block_size, int document_sep_len): + cdef DTYPE_t tok_idx = 0 + cdef DTYPE_t sz_idx = 0 + cdef DTYPE_t curr_size = 0 + cdef DTYPE_t i = 0 + cdef DTYPE_t length + cdef DTYPE_t total_size + cdef DTYPE_t[:] sizes_view = sizes + cdef np.ndarray[DTYPE_t, ndim=2] slice_indices + cdef list slice_indices_list = [] + + if break_mode is None or break_mode == 'none': + slice_indices = _get_slice_indices_none_mode(sizes, block_size) + elif break_mode == 'complete': + while sz_idx < len(sizes_view): + if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0: + curr_size += sizes_view[sz_idx] + sz_idx += 1 + else: + slice_indices_list.append((tok_idx, tok_idx + curr_size)) + tok_idx += curr_size + curr_size = 0 + if curr_size > 0: + slice_indices_list.append((tok_idx, tok_idx + curr_size)) + slice_indices = _fast_convert_to_np_array(slice_indices_list) + elif break_mode == 'complete_doc': + while sz_idx < len(sizes_view): + if ( + (curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0) + # an empty sentence indicates end-of-document: + and sizes_view[sz_idx] != document_sep_len + ): + curr_size += sizes_view[sz_idx] + sz_idx += 1 + else: + # Only keep non-empty documents. + if curr_size > 1: + slice_indices_list.append((tok_idx, tok_idx + curr_size)) + tok_idx += curr_size + curr_size = 0 + if sizes_view[sz_idx] == document_sep_len: + tok_idx += sizes_view[sz_idx] + sz_idx += 1 + if curr_size > 1: + slice_indices_list.append((tok_idx, tok_idx + curr_size)) + slice_indices = _fast_convert_to_np_array(slice_indices_list) + elif break_mode == 'eos': + slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE) + cumsum = sizes.cumsum(axis=0) + slice_indices[1:, 0] = cumsum[:-1] + slice_indices[:, 1] = cumsum + else: + raise ValueError('Invalid break_mode: ' + break_mode) + return slice_indices + + +@cython.boundscheck(False) +@cython.wraparound(False) +@cython.nonecheck(False) +cpdef np.ndarray[DTYPE_t, ndim=2] _get_block_to_dataset_index_fast(np.ndarray[DTYPE_t, ndim=1] sizes, np.ndarray[DTYPE_t, ndim=2] slice_indices): + cdef DTYPE_t start_ds_idx + cdef DTYPE_t start_offset + cdef DTYPE_t end_ds_idx + cdef DTYPE_t i + cdef DTYPE_t s + cdef DTYPE_t e + cdef DatasetSearcher ds = DatasetSearcher(sizes) + cdef np.ndarray[DTYPE_t, ndim=2] block_to_dataset_index = np.zeros([len(slice_indices), 3], dtype=DTYPE) + cdef DTYPE_t[:, :] block_to_dataset_index_view = block_to_dataset_index + cdef DTYPE_t[:, :] slice_indices_view = slice_indices + cdef Py_ssize_t x_max = slice_indices.shape[0] + + for i in range(x_max): + s = slice_indices_view[i][0] + e = slice_indices_view[i][1] + ds.seek(s) + start_ds_idx = ds.current_index + start_offset = ds.current_offset + if e <= s: + end_ds_idx = start_ds_idx + else: + ds.seek(e - 1) + end_ds_idx = ds.current_index + block_to_dataset_index_view[i][0] = start_ds_idx # starting index in dataset + block_to_dataset_index_view[i][1] = start_offset # starting offset within starting index + block_to_dataset_index_view[i][2] = end_ds_idx # ending index in dataset + return block_to_dataset_index + + +cdef class DatasetSearcher(object): + """Helper for mapping "flat" indices to indices and offsets in an + underlying dataset.""" + cdef DTYPE_t current_i + cdef DTYPE_t current_offset + cdef DTYPE_t current_index + cdef DTYPE_t[:] sizes + + def __init__(self, DTYPE_t[:] sizes): + self.sizes = sizes + self.reset() + + cdef reset(self): + self.current_offset = 0 # offset within current index in underlying dataset + self.current_i = 0 # "flat" index + self.current_index = 0 # index in underlying dataset + + @cython.boundscheck(False) + @cython.wraparound(False) + @cython.nonecheck(False) + cdef int step(self, DTYPE_t i): + cdef DTYPE_t to_consume + cdef DTYPE_t remaining + if i < self.current_i: + self.reset() + if i > self.current_i: + to_consume = i - self.current_i + remaining = self.sizes[self.current_index] - self.current_offset + if remaining > to_consume: + self.current_offset += to_consume + self.current_i += to_consume + else: + assert remaining > 0 + self.current_i += remaining + self.current_index += 1 + self.current_offset = 0 + return 1 + return 0 + + @cython.boundscheck(False) + @cython.wraparound(False) + @cython.nonecheck(False) + cdef seek(self, DTYPE_t i): + cdef int not_done = 1 + while not_done == 1: + not_done = self.step(i) + assert self.current_i == i diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 3dea071629..cb7c1a8966 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import numpy as np import torch from fairseq import tokenizer @@ -134,6 +135,7 @@ def get_batch_iterator( indices = data_utils.filter_by_size( indices, dataset.size, max_positions, raise_exception=(not ignore_invalid_inputs), ) + indices = np.fromiter(indices, dtype=np.int64, count=-1) # create mini-batches with given size constraints batch_sampler = data_utils.batch_by_size( diff --git a/setup.py b/setup.py index 59d3410af0..537898dcc6 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from setuptools import setup, find_packages, Extension +from Cython.Build import cythonize import sys @@ -27,6 +28,8 @@ extra_compile_args=extra_compile_args, ) +token_block_utils = cythonize("fairseq/data/token_block_utils_fast.pyx") +data_utils_fast = cythonize("fairseq/data/data_utils_fast.pyx", language="c++") setup( name='fairseq', @@ -52,7 +55,7 @@ 'tqdm', ], packages=find_packages(exclude=['scripts', 'tests']), - ext_modules=[bleu], + ext_modules=token_block_utils + data_utils_fast + [bleu], test_suite='tests', entry_points={ 'console_scripts': [ @@ -65,4 +68,5 @@ 'fairseq-validate = fairseq_cli.validate:cli_main', ], }, + zip_safe=False, )