diff --git a/catalyst/core/runner.py b/catalyst/core/runner.py index 4416d65586..61206e591d 100644 --- a/catalyst/core/runner.py +++ b/catalyst/core/runner.py @@ -601,13 +601,13 @@ def close_log(self, *args, **kwargs) -> None: logger.close_log(*args, **kwargs) def _setup_loaders(self) -> None: - set_global_seed(self.seed + self.engine.rank + self.global_epoch_step) + set_global_seed(self.seed + max(0, self.engine.rank) + self.global_epoch_step) loaders = self.get_loaders(stage=self.stage_key) loaders = validate_loaders(loaders) self.loaders = loaders def _setup_components(self) -> None: - set_global_seed(self.seed + self.engine.rank + self.global_epoch_step) + set_global_seed(self.seed + max(0, self.engine.rank) + self.global_epoch_step) self.model, self.criterion, self.optimizer, self.scheduler = self.engine.init_components( model_fn=self._get_model, criterion_fn=self._get_criterion, @@ -641,7 +641,7 @@ def _check_callbacks(self): ) def _setup_callbacks(self): - set_global_seed(self.seed + self.engine.rank + self.global_epoch_step) + set_global_seed(self.seed + max(0, self.engine.rank) + self.global_epoch_step) callbacks = self.get_callbacks(self.stage_key) callbacks = filter_callbacks_by_node(callbacks) callbacks = sort_callbacks_by_order(callbacks) @@ -698,7 +698,7 @@ def on_epoch_start(self, runner: "IRunner"): for loader_key, loader in self.loaders.items(): if len(loader) == 0: raise RunnerError(f"DataLoader with name {loader_key} is empty.") - set_global_seed(self.seed + self.engine.rank + self.global_epoch_step) + set_global_seed(self.seed + max(0, self.engine.rank) + self.global_epoch_step) def on_loader_start(self, runner: "IRunner"): """Event handler.""" @@ -716,7 +716,7 @@ def on_loader_start(self, runner: "IRunner"): if self.loader_batch_len == 0: raise NotImplementedError(f"DataLoader with name {self.loader_key} is empty.") - set_global_seed(self.seed + self.engine.rank + self.global_epoch_step) + set_global_seed(self.seed + max(0, self.engine.rank) + self.global_epoch_step) maybe_recursive_call(self.model, "train", mode=self.is_train_loader) if isinstance(self.loader.sampler, DistributedSampler):