Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Update numpy and pytorch seeding for dataloader and multiple processes per machine. #299

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions tools/perf_measurement/benchmark_data.py
Expand Up @@ -30,13 +30,16 @@ def benchmark_data(cfg: AttrDict, split: str = "train"):
except AttributeError:
device = torch.device("cuda")

# Gives sampler same seed for entire distributed group as per pytorch documentation.
sampler_seed = cfg.SEED_VALUE
dataloader = get_loader(
dataset=dataset,
dataset_config=cfg["DATA"][split],
num_dataloader_workers=cfg.DATA.NUM_DATALOADER_WORKERS,
pin_memory=False,
multi_processing_method=cfg.MULTI_PROCESSING_METHOD,
device=device,
sampler_seed=sampler_seed,
)

# Fairstore data sampler would require setting the start iter before it can start.
Expand Down
18 changes: 12 additions & 6 deletions vissl/data/__init__.py
Expand Up @@ -23,7 +23,7 @@
from vissl.data.ssl_dataset import GenericSSLDataset
from vissl.data.synthetic_dataset import SyntheticImageDataset
from vissl.data.torchvision_dataset import TorchvisionDataset
from vissl.utils.misc import setup_multiprocessing_method
from vissl.utils.misc import setup_multiprocessing_method, set_dataloader_seeds


__all__ = [
Expand Down Expand Up @@ -98,7 +98,7 @@ def print_sampler_config(data_sampler):
logging.info("Distributed Sampler config:\n{}".format(sampler_cfg))


def get_sampler(dataset, dataset_config):
def get_sampler(dataset, dataset_config, sampler_seed=0):
"""
Given the dataset object and the dataset config, get the data sampler to use
Supports 2 types of samplers:
Expand All @@ -114,10 +114,14 @@ def get_sampler(dataset, dataset_config):
)
elif dataset_config["USE_STATEFUL_DISTRIBUTED_SAMPLER"]:
data_sampler = StatefulDistributedSampler(
dataset, batch_size=dataset_config["BATCHSIZE_PER_REPLICA"]
dataset,
batch_size=dataset_config["BATCHSIZE_PER_REPLICA"],
seed=sampler_seed,
)
else:
data_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
data_sampler = torch.utils.data.distributed.DistributedSampler(
dataset, seed=sampler_seed
)
logging.info("Created the Distributed Sampler....")
print_sampler_config(data_sampler)
else:
Expand All @@ -140,8 +144,9 @@ def get_loader(
pin_memory: bool,
multi_processing_method: str,
device: torch.device,
sampler_seed=0,
get_sampler=get_sampler,
worker_init_fn=None,
worker_init_fn=set_dataloader_seeds,
):
"""
Get the dataloader for the given satasets and data split
Expand All @@ -153,6 +158,7 @@ def get_loader(
num_dataloader_workers (int): number of workers per gpu (or cpu) training
pin_memory (bool): whether to pin memory or not
multi_processing_method (str): method to use. options: forkserver | fork | spawn
sampler_seed (int): seed for the sampler. Should be identical per process
device (torch.device): training on cuda or cpu
get_sampler (get_sampler): function that is used to get the sampler
worker_init_fn (None): any function that should be executed during
Expand All @@ -169,7 +175,7 @@ def get_loader(

# we don't need to set the rank, replicas as the Sampler already does so in
# it's init function
data_sampler = get_sampler(dataset, dataset_config)
data_sampler = get_sampler(dataset, dataset_config, sampler_seed)
collate_function = get_collator(
dataset_config["COLLATE_FUNCTION"], dataset_config["COLLATE_FUNCTION_PARAMS"]
)
Expand Down
7 changes: 4 additions & 3 deletions vissl/data/data_helper.py
Expand Up @@ -94,7 +94,7 @@ class StatefulDistributedSampler(DistributedSampler):
we want to resume the data sampler from the training iteration.
"""

def __init__(self, dataset, batch_size=None):
def __init__(self, dataset, batch_size=None, seed: int = 0):
"""
Initializes the instance of StatefulDistributedSampler. Random seed is set
for the epoch set and data is shuffled. For starting the sampling, use
Expand All @@ -104,8 +104,9 @@ def __init__(self, dataset, batch_size=None):
Args:
dataset (Dataset): Pytorch dataset that sampler will shuffle
batch_size (int): batch size we want the sampler to sample
seed (int): Seed for the torch generator.
"""
super().__init__(dataset, shuffle=False)
super().__init__(dataset, shuffle=False, seed=seed)

self.start_iter = 0
self.batch_size = batch_size
Expand All @@ -116,7 +117,7 @@ def __init__(self, dataset, batch_size=None):
def __iter__(self):
# partition data into num_replicas and optionally shuffle within a rank
g = torch.Generator()
g.manual_seed(self.epoch)
g.manual_seed(self.epoch + self.seed)
shuffling = torch.randperm(self.num_samples, generator=g).tolist()
indices = np.array(
list(
Expand Down
2 changes: 1 addition & 1 deletion vissl/engines/train.py
Expand Up @@ -76,7 +76,7 @@ def train_main(

# set seeds
logging.info("Setting seed....")
set_seeds(cfg, node_id)
set_seeds(cfg, dist_rank)

# We set the CUDA device here as well as a safe solution for all downstream
# `torch.cuda.current_device()` calls to return correct device.
Expand Down
4 changes: 4 additions & 0 deletions vissl/trainer/train_task.py
Expand Up @@ -294,6 +294,9 @@ def build_dataloaders(self, pin_memory: bool) -> torch.utils.data.DataLoader:
"""
self.datasets, self.data_and_label_keys = self.build_datasets()

# Gives sampler same seed for entire distributed group as per pytorch documentation.
sampler_seed = self.config["SEED_VALUE"]

loaders = {
split.lower(): get_loader(
dataset=self.datasets[split],
Expand All @@ -302,6 +305,7 @@ def build_dataloaders(self, pin_memory: bool) -> torch.utils.data.DataLoader:
pin_memory=pin_memory,
multi_processing_method=self.config.MULTI_PROCESSING_METHOD,
device=self.device,
sampler_seed=sampler_seed,
)
for split in self.available_splits
}
Expand Down
32 changes: 23 additions & 9 deletions vissl/utils/misc.py
Expand Up @@ -137,20 +137,34 @@ def setup_multiprocessing_method(method_name: str):
pass


def set_seeds(cfg, node_id=0):
def set_seeds(cfg, dist_rank):
"""
Set the python random, numpy and torch seed for each gpu. Also set the CUDA
seeds if the CUDA is available. This ensures deterministic nature of the training.
"""
node_seed = cfg.SEED_VALUE
if cfg.DISTRIBUTED.NUM_NODES > 1:
node_seed = node_seed * 2 * node_id
logging.info(f"MACHINE SEED: {node_seed}")
random.seed(node_seed)
np.random.seed(node_seed)
torch.manual_seed(node_seed)
# Since in the pytorch sampler, we increment the seed by 1 for every epoch.
seed_value = (cfg.SEED_VALUE + dist_rank) * cfg.OPTIMIZER.num_epochs
logging.info(f"MACHINE SEED: {seed_value}")
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)
if cfg["MACHINE"]["DEVICE"] == "gpu" and torch.cuda.is_available():
torch.cuda.manual_seed_all(node_seed)
torch.cuda.manual_seed_all(seed_value)


def set_dataloader_seeds(_worker_id: int):
"""
See: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/
When using "Fork" process spawning, the dataloader workers inherit the seeds of the
parent process for numpy. While torch seeds are handled correctly across dataloaders and
across epochs, numpy seeds are not. Therefore in order to ensure each worker has a
different and deterministic seed, we must explicitly set the numpy seed to the torch seed.
Also see https://pytorch.org/docs/stable/data.html#randomness-in-multi-process-data-loading
"""
# numpy and random seed must be between 0 and 2 ** 32 - 1.
torch_seed = torch.utils.data.get_worker_info().seed % (2 ** 32)
random.seed(torch_seed)
np.random.seed(torch_seed)


def get_indices_sparse(data):
Expand Down