Skip to content

Commit

Permalink
Prerelease/20.04.2.1 (#761)
Browse files Browse the repository at this point in the history
* 20.04.2 version refactoring: plotly to contrib; gans to Catalyst.GAN

* extra methods for inference and tracing

* codestyle

* load_on_stage_end feature, tracing update

* refactoring and hotfixes
  • Loading branch information
Scitator committed Apr 20, 2020
1 parent fbd7d0e commit e0369b8
Show file tree
Hide file tree
Showing 15 changed files with 284 additions and 214 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,13 @@ runner.train(
logdir="./logdir",
num_epochs=8,
verbose=True,
load_best_on_end=True,
)
# model inference
loader_logits = runner.predict_loader(model=model, loader=loader, verbose=True)
for prediction in runner.predict_loader(loader=loader):
do_something()
# model tracing
traced_model = runner.trace(loader=loader)
```

### Minimal Examples
Expand Down
68 changes: 48 additions & 20 deletions catalyst/core/callbacks/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,35 +27,42 @@ def _pack_state(state: State):
return checkpoint


def _load_checkpoint(*, filename, state: State):
def _load_checkpoint(*, filename, state: State, load_full: bool = True):
if not os.path.isfile(filename):
raise Exception(f"No checkpoint found at {filename}")

print(f"=> loading checkpoint {filename}")
checkpoint = utils.load_checkpoint(filename)

if not state.stage_name.startswith("infer"):
if not state.stage_name.startswith("infer") and load_full:
state.stage_name = checkpoint["stage_name"]
state.epoch = checkpoint["epoch"]
state.global_epoch = checkpoint["global_epoch"]
# @TODO: should we also load,
# checkpoint_data, main_metric, minimize_metric, valid_loader ?
# epoch_metrics, valid_metrics ?

utils.unpack_checkpoint(
checkpoint,
model=state.model,
criterion=state.criterion,
optimizer=state.optimizer,
scheduler=state.scheduler,
)
if load_full:
utils.unpack_checkpoint(
checkpoint,
model=state.model,
criterion=state.criterion,
optimizer=state.optimizer,
scheduler=state.scheduler,
)

print(
f"loaded checkpoint {filename} "
f"(global epoch {checkpoint['global_epoch']}, "
f"epoch {checkpoint['epoch']}, "
f"stage {checkpoint['stage_name']})"
)
print(
f"loaded state checkpoint {filename} "
f"(global epoch {checkpoint['global_epoch']}, "
f"epoch {checkpoint['epoch']}, "
f"stage {checkpoint['stage_name']})"
)
else:
utils.unpack_checkpoint(
checkpoint, model=state.model,
)

print(f"loaded model checkpoint {filename}")


class BaseCheckpointCallback(Callback):
Expand All @@ -67,9 +74,7 @@ def __init__(self, metrics_filename: str = "_metrics.json"):
metrics_filename (str): filename to save metrics
in checkpoint folder. Must ends on ``.json`` or ``.yml``
"""
super().__init__(
order=CallbackOrder.External, node=CallbackNode.Master
)
super().__init__(order=CallbackOrder.External, node=CallbackNode.All)
self.metrics_filename = metrics_filename
self.metrics: dict = {}

Expand Down Expand Up @@ -115,6 +120,7 @@ def __init__(
resume: str = None,
resume_dir: str = None,
metrics_filename: str = "_metrics.json",
load_on_stage_end: str = None,
):
"""
Args:
Expand All @@ -123,11 +129,24 @@ def __init__(
and initialize runner state
metrics_filename (str): filename to save metrics
in checkpoint folder. Must ends on ``.json`` or ``.yml``
load_on_stage_end (str): name of the model to load
at the end of the stage.
You can use ``best`` to load the best model according
to validation metrics, or ``last`` to use just the last one
(default behaviour).
"""
super().__init__(metrics_filename)
assert load_on_stage_end in [
None,
"best",
"last",
"best_full",
"last_full",
]
self.save_n_best = save_n_best
self.resume = resume
self.resume_dir = resume_dir
self.load_on_stage_end = load_on_stage_end

self.top_best_metrics = []
self.metrics_history = []
Expand Down Expand Up @@ -234,7 +253,7 @@ def on_stage_start(self, state: State):

def on_epoch_end(self, state: State):
"""@TODO: Docs. Contribution is welcome."""
if state.stage_name.startswith("infer"):
if state.stage_name.startswith("infer") or state.is_distributed_worker:
return

checkpoint = _pack_state(state)
Expand All @@ -248,7 +267,7 @@ def on_epoch_end(self, state: State):

def on_stage_end(self, state: State):
"""@TODO: Docs. Contribution is welcome."""
if state.stage_name.startswith("infer"):
if state.stage_name.startswith("infer") or state.is_distributed_worker:
return

print("Top best models:")
Expand All @@ -262,6 +281,15 @@ def on_stage_end(self, state: State):
)
print(top_best_metrics_str)

if self.load_on_stage_end in ["best", "best_full"]:
resume = f"{state.logdir}/checkpoints/{self.load_on_stage_end}.pth"
print(f"Loading {self.load_on_stage_end} model from {resume}")
_load_checkpoint(
filename=resume,
state=state,
load_full=self.load_on_stage_end.endswith("full"),
)


class IterationCheckpointCallback(BaseCheckpointCallback):
"""Iteration checkpoint callback to save your model/criterion/optimizer."""
Expand Down
2 changes: 1 addition & 1 deletion catalyst/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def _batch2device(
def _handle_batch(self, batch: Mapping[str, Any]) -> None:
"""
Inner method to handle specified data batch.
Used to make a train/valid/infer step during Experiment run.
Used to make a train/valid/infer stage during Experiment run.
Args:
batch (Mapping[str, Any]): dictionary with data batches
Expand Down
8 changes: 7 additions & 1 deletion catalyst/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,15 @@ class State(FrozenClass):
**state.distributed_rank** - distributed rank of current worker
**state.is_distributed_master** - bool, indicator flag
- ``True`` if is master node (state.distributed_rank == 0)
- ``False`` if is worker node (state.distributed_rank != 0)
**state.is_distributed_worker** - bool, indicator flag
- ``True`` if is worker node (state.distributed_rank > 0)
- ``False`` if is master node (state.distributed_rank == 0)
- ``False`` if is master node (state.distributed_rank <= 0)
**state.stage_name** - string, current stage name,\
Expand Down Expand Up @@ -335,6 +340,7 @@ def __init__(

# pipeline info
self.distributed_rank = utils.get_rank()
self.is_distributed_master = ~(self.distributed_rank > 0)
self.is_distributed_worker = self.distributed_rank > 0

self.stage_name: str = stage
Expand Down
1 change: 1 addition & 0 deletions catalyst/dl/experiment/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def get_state_params(self, stage: str) -> Mapping[str, Any]:

def _preprocess_model_for_stage(self, stage: str, model: Model):
stage_index = self.stages.index(stage)
# @TODO: remove to callbacks
if stage_index > 0:
checkpoint_path = f"{self.logdir}/checkpoints/best.pth"
checkpoint = utils.load_checkpoint(checkpoint_path)
Expand Down

0 comments on commit e0369b8

Please sign in to comment.