Skip to content

Commit

Permalink
set numpy seed explicitly + other minor fixes (#850)
Browse files Browse the repository at this point in the history
Summary:
not setting the numpy seed explicitly at the beginning was an extremely annoying bug to find. it it caused different gpus to have a different view of data if some randomization was used in the dataset (e.g. subsample dataset)
Pull Request resolved: fairinternal/fairseq-py#850

Differential Revision: D17085006

Pulled By: alexeib

fbshipit-source-id: 62bb2116369fb703df878e6bc24c06f1ea4e75a0
  • Loading branch information
alexeib authored and facebook-github-bot committed Aug 30, 2019
1 parent 8777465 commit 4a7cd58
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 9 deletions.
24 changes: 17 additions & 7 deletions fairseq/data/replace_dataset.py
Expand Up @@ -7,20 +7,30 @@


class ReplaceDataset(BaseWrapperDataset):
def __init__(self, dataset, replace_map, offset=0):
"""Replaces tokens found in the dataset by a specified replacement token
Args:
dataset (~torch.utils.data.Dataset): dataset to replace tokens in
replace_map(Dictionary[int,int]): map of token to replace -> replacement token
offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be
as many as the number of objects returned by the underlying dataset __getitem__ method.
"""

def __init__(self, dataset, replace_map, offsets):
super().__init__(dataset)
assert len(replace_map) > 0
self.replace_map = replace_map
self.offset = offset
self.offsets = offsets

def __getitem__(self, index):
item = self.dataset[index]
is_tuple = isinstance(item, tuple)
src = item[0] if is_tuple else item
srcs = item if is_tuple else [item]

for k, v in self.replace_map.items():
src_off = src[self.offset:]
src_off.masked_fill_(src_off == k, v)
for offset, src in zip(self.offsets, srcs):
for k, v in self.replace_map.items():
src_off = src[offset:] if offset >= 0 else src[:offset]
src_off.masked_fill_(src_off == k, v)

item = tuple((src,) + item[1:]) if is_tuple else src
item = srcs if is_tuple else srcs[0]
return item
13 changes: 11 additions & 2 deletions fairseq/data/subsample_dataset.py
Expand Up @@ -9,15 +9,24 @@


class SubsampleDataset(BaseWrapperDataset):
"""Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples
Args:
dataset (~torch.utils.data.Dataset): dataset to subsample
size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive)
"""

def __init__(self, dataset, size_ratio):
super().__init__(dataset)
assert size_ratio < 1
self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int)
self.indices = np.random.choice(
range(len(self.dataset)), self.actual_size, replace=False
list(range(len(self.dataset))), self.actual_size, replace=False
)
print(
"subsampled dataset from {} to {} (ratio={})".format(len(self.dataset), self.actual_size, size_ratio)
"subsampled dataset from {} to {} (ratio={})".format(
len(self.dataset), self.actual_size, size_ratio
)
)

def __getitem__(self, index):
Expand Down
2 changes: 2 additions & 0 deletions train.py
Expand Up @@ -9,6 +9,7 @@

import collections
import math
import numpy as np
import random

import torch
Expand All @@ -28,6 +29,7 @@ def main(args, init_distributed=False):
# Initialize CUDA and distributed training
if torch.cuda.is_available() and not args.cpu:
torch.cuda.set_device(args.device_id)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if init_distributed:
args.distributed_rank = distributed_utils.distributed_init(args)
Expand Down

0 comments on commit 4a7cd58

Please sign in to comment.