diff --git a/fairseq/data/replace_dataset.py b/fairseq/data/replace_dataset.py index 670b812f45..3bc52f0fb5 100644 --- a/fairseq/data/replace_dataset.py +++ b/fairseq/data/replace_dataset.py @@ -7,20 +7,30 @@ class ReplaceDataset(BaseWrapperDataset): - def __init__(self, dataset, replace_map, offset=0): + """Replaces tokens found in the dataset by a specified replacement token + + Args: + dataset (~torch.utils.data.Dataset): dataset to replace tokens in + replace_map(Dictionary[int,int]): map of token to replace -> replacement token + offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be + as many as the number of objects returned by the underlying dataset __getitem__ method. + """ + + def __init__(self, dataset, replace_map, offsets): super().__init__(dataset) assert len(replace_map) > 0 self.replace_map = replace_map - self.offset = offset + self.offsets = offsets def __getitem__(self, index): item = self.dataset[index] is_tuple = isinstance(item, tuple) - src = item[0] if is_tuple else item + srcs = item if is_tuple else [item] - for k, v in self.replace_map.items(): - src_off = src[self.offset:] - src_off.masked_fill_(src_off == k, v) + for offset, src in zip(self.offsets, srcs): + for k, v in self.replace_map.items(): + src_off = src[offset:] if offset >= 0 else src[:offset] + src_off.masked_fill_(src_off == k, v) - item = tuple((src,) + item[1:]) if is_tuple else src + item = srcs if is_tuple else srcs[0] return item diff --git a/fairseq/data/subsample_dataset.py b/fairseq/data/subsample_dataset.py index 983a611393..f1c2942e52 100644 --- a/fairseq/data/subsample_dataset.py +++ b/fairseq/data/subsample_dataset.py @@ -9,15 +9,24 @@ class SubsampleDataset(BaseWrapperDataset): + """Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples + + Args: + dataset (~torch.utils.data.Dataset): dataset to subsample + size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive) + """ + def __init__(self, dataset, size_ratio): super().__init__(dataset) assert size_ratio < 1 self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int) self.indices = np.random.choice( - range(len(self.dataset)), self.actual_size, replace=False + list(range(len(self.dataset))), self.actual_size, replace=False ) print( - "subsampled dataset from {} to {} (ratio={})".format(len(self.dataset), self.actual_size, size_ratio) + "subsampled dataset from {} to {} (ratio={})".format( + len(self.dataset), self.actual_size, size_ratio + ) ) def __getitem__(self, index): diff --git a/train.py b/train.py index afe9c10232..e4f0f7a5d2 100644 --- a/train.py +++ b/train.py @@ -9,6 +9,7 @@ import collections import math +import numpy as np import random import torch @@ -28,6 +29,7 @@ def main(args, init_distributed=False): # Initialize CUDA and distributed training if torch.cuda.is_available() and not args.cpu: torch.cuda.set_device(args.device_id) + np.random.seed(args.seed) torch.manual_seed(args.seed) if init_distributed: args.distributed_rank = distributed_utils.distributed_init(args)