Skip to content

Commit

Permalink
WIP on correct hadling dor several dataloaders
Browse files Browse the repository at this point in the history
  • Loading branch information
TezRomacH committed May 12, 2019
1 parent 709de01 commit 3acb59e
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 4 deletions.
2 changes: 1 addition & 1 deletion catalyst/contrib/scheduler/onecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def reset(self):
)
self.last_epoch = 0

def recalculate_(
def recalculate(
self,
loader_len: int
) -> None:
Expand Down
14 changes: 11 additions & 3 deletions catalyst/dl/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict

import safitty
import numpy as np
import torch

from catalyst.contrib.scheduler import OneCycleLR
Expand Down Expand Up @@ -301,10 +302,17 @@ def on_loader_start(self, state: RunnerState):
scheduler = state.get_key(
key="scheduler", inner_key=self.scheduler_key
)
if state.loader_name == "train" and \
if state.loader_name.startswith("train") and \
isinstance(scheduler, OneCycleLR) and self.mode == "batch":
if not self._onecycle_recalculated:
scheduler.recalculate_(loader_len=state.loader_len)
loaders = state.loaders

train_loaders_len = [
len(train_loader) for train_loader in loaders
if train_loader.startswith("train")
]
loader_len = np.array(train_loaders_len).sum().item()
scheduler.recalculate(loader_len=loader_len)
self._onecycle_recalculated = True

def on_batch_end(self, state):
Expand All @@ -315,8 +323,8 @@ def on_epoch_end(self, state):
if self.mode == "epoch":
self.step(state=state)

@staticmethod
def _scheduler_step(
self,
scheduler,
valid_metric=None,
):
Expand Down
1 change: 1 addition & 0 deletions catalyst/dl/experiments/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def _prepare_state(self, stage: str):
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
loaders=self.experiment.get_loaders(stage),
**self.experiment.get_state_params(stage),
**migrating_params
)
Expand Down
2 changes: 2 additions & 0 deletions catalyst/dl/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
minimize_metric=True,
valid_loader="valid",
verbose=False,
loaders=None,
**kwargs
):
# @TODO: refactor
Expand All @@ -45,6 +46,7 @@ def __init__(
self.stage = stage
self.device = device
self.loader_name = None
self.loaders = loaders

# data pipeline
self.input = None
Expand Down

0 comments on commit 3acb59e

Please sign in to comment.