diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py index cb23ce4278..7fddf285e4 100644 --- a/fairseq/data/indexed_dataset.py +++ b/fairseq/data/indexed_dataset.py @@ -15,9 +15,16 @@ from . import FairseqDataset -def make_builder(out_file, impl): +def __best_fitting_dtype(vocab_size=None): + if vocab_size is not None and vocab_size < 65500: + return np.uint16 + else: + return np.int32 + + +def make_builder(out_file, impl, vocab_size=None): if impl == 'mmap': - return MMapIndexedDatasetBuilder(out_file) + return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size)) else: return IndexedDatasetBuilder(out_file) @@ -63,6 +70,7 @@ def write_longs(f, a): 5: np.int64, 6: np.float, 7: np.double, + 8: np.uint16 } @@ -143,7 +151,7 @@ def size(self, index): @staticmethod def exists(path): return ( - os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) ) @property @@ -440,11 +448,11 @@ def __len__(self): def __getitem__(self, i): ptr, size = self._index[i] - tensor = torch.from_numpy(np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)) - if tensor.dtype == torch.int64: - return tensor - else: - return tensor.long() + np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr) + if self._index.dtype != np.int64: + np_array = np_array.astype(np.int64) + + return torch.from_numpy(np_array) @property def sizes(self): @@ -457,7 +465,7 @@ def supports_prefetch(self): @staticmethod def exists(path): return ( - os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) ) diff --git a/preprocess.py b/preprocess.py index 1e1b92da9b..c4ac37cf43 100644 --- a/preprocess.py +++ b/preprocess.py @@ -129,7 +129,8 @@ def merge_result(worker_result): ) pool.close() - ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"), impl=args.dataset_impl) + ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"), + impl=args.dataset_impl, vocab_size=len(vocab)) merge_result( Binarizer.binarize( input_file, vocab, lambda t: ds.add_item(t), @@ -231,7 +232,8 @@ def make_all(lang, vocab): def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos=True): - ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"), impl=args.dataset_impl) + ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"), + impl=args.dataset_impl, vocab_size=len(vocab)) def consumer(tensor): ds.add_item(tensor)