Skip to content

Commit

Permalink
set_global_seed hotfix (#1329)
Browse files Browse the repository at this point in the history
  • Loading branch information
asteyo committed Oct 14, 2021
1 parent 93eedf0 commit b263838
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions catalyst/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,13 +601,13 @@ def close_log(self, *args, **kwargs) -> None:
logger.close_log(*args, **kwargs)

def _setup_loaders(self) -> None:
set_global_seed(self.seed + self.engine.rank + self.global_epoch_step)
set_global_seed(self.seed + max(0, self.engine.rank) + self.global_epoch_step)
loaders = self.get_loaders(stage=self.stage_key)
loaders = validate_loaders(loaders)
self.loaders = loaders

def _setup_components(self) -> None:
set_global_seed(self.seed + self.engine.rank + self.global_epoch_step)
set_global_seed(self.seed + max(0, self.engine.rank) + self.global_epoch_step)
self.model, self.criterion, self.optimizer, self.scheduler = self.engine.init_components(
model_fn=self._get_model,
criterion_fn=self._get_criterion,
Expand Down Expand Up @@ -641,7 +641,7 @@ def _check_callbacks(self):
)

def _setup_callbacks(self):
set_global_seed(self.seed + self.engine.rank + self.global_epoch_step)
set_global_seed(self.seed + max(0, self.engine.rank) + self.global_epoch_step)
callbacks = self.get_callbacks(self.stage_key)
callbacks = filter_callbacks_by_node(callbacks)
callbacks = sort_callbacks_by_order(callbacks)
Expand Down Expand Up @@ -698,7 +698,7 @@ def on_epoch_start(self, runner: "IRunner"):
for loader_key, loader in self.loaders.items():
if len(loader) == 0:
raise RunnerError(f"DataLoader with name {loader_key} is empty.")
set_global_seed(self.seed + self.engine.rank + self.global_epoch_step)
set_global_seed(self.seed + max(0, self.engine.rank) + self.global_epoch_step)

def on_loader_start(self, runner: "IRunner"):
"""Event handler."""
Expand All @@ -716,7 +716,7 @@ def on_loader_start(self, runner: "IRunner"):

if self.loader_batch_len == 0:
raise NotImplementedError(f"DataLoader with name {self.loader_key} is empty.")
set_global_seed(self.seed + self.engine.rank + self.global_epoch_step)
set_global_seed(self.seed + max(0, self.engine.rank) + self.global_epoch_step)

maybe_recursive_call(self.model, "train", mode=self.is_train_loader)
if isinstance(self.loader.sampler, DistributedSampler):
Expand Down

0 comments on commit b263838

Please sign in to comment.