Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utils callbacks #1127

Merged
merged 18 commits into from Mar 28, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
51 changes: 51 additions & 0 deletions catalyst/callbacks/quantization.py
@@ -1,3 +1,54 @@
from typing import Dict, Optional, TYPE_CHECKING, Union
from pathlib import Path

import torch

from catalyst.core import Callback, CallbackNode, CallbackOrder
from catalyst.utils import BestModel, quantize_model

if TYPE_CHECKING:
from catalyst.core import IRunner


class QuantizationCallback(Callback):
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
metric: str,
minimize_metric: bool = True,
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
qconfig_spec: Dict = None,
dtype: Union[str, Optional[torch.dtype]] = "qint8",
logdir: Union[str, Path] = None,
filename: str = "quantized.pth",
):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
WPS355 Found an unnecessary blank line before a bracket

"""
@TODO
Args:
metric:
minimize_metric:
qconfig_spec:
dtype:
logdir:
filename:
"""
super().__init__(order=CallbackOrder.External, node=CallbackNode.master)
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
self.best_model = BestModel(metric=metric, minimize_metric=minimize_metric)
self.qconfig_spec = qconfig_spec
self.dtype = dtype
if logdir is not None:
self.filename = Path(logdir) / filename
else:
self.filename = filename

def on_epoch_end(self, runner: "IRunner") -> None:
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
self.best_model.add_result(epoch_metrics=runner.epoch_metrics, model=runner.model)

def on_stage_end(self, runner: "IRunner") -> None:
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
model = runner.model.cpu()
model.load_state_dict(self.best_model.get_best_model_sd())
q_model = quantize_model(model.cpu(), qconfig_spec=self.qconfig_spec, dtype=self.dtype)
torch.save(q_model.state_dict(), self.filename)
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved


# # @TODO: make the same API for tracing/onnx/pruning/quantization
# from typing import Dict, Optional, Set, TYPE_CHECKING, Union
# from pathlib import Path
Expand Down
2 changes: 2 additions & 0 deletions catalyst/utils/__init__.py
Expand Up @@ -14,6 +14,8 @@
all_gather,
)

from catalyst.utils.best_model import BestModel

from catalyst.utils.misc import (
get_fn_default_params,
get_fn_argsnames,
Expand Down
55 changes: 55 additions & 0 deletions catalyst/utils/best_model.py
@@ -0,0 +1,55 @@
from typing import Any, Dict, Mapping

from catalyst.typing import Model


class BestModel:
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, metric: str, minimize_metric: bool = False):
"""
Class for store best model state dict.

Args:
metric: metric to choose best model.
minimize_metric: minimize/maximize metric.
"""
self.metric = metric
self.minimize_metric = minimize_metric
self._best_model_sd = None
self._best_metric = None

def add_result(self, epoch_metrics: Mapping[str, Any], model: Model) -> None:
"""
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pep8] reported by reviewdog 🐶
DAR101 Missing parameter(s) in Docstring: - model

Adds result for current epoch and saves state dict if current epoch is best.
Args:
epoch_metrics: dict of metrics for epoch
model: current model

Raises:
Exception: if specified metric not in epoch metrics dict.
"""
if self.metric not in epoch_metrics.keys():
raise Exception(f"Metric {self.metric} not in runner.epoch_metrics.")
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
current_metric = epoch_metrics[self.metric]
if self._best_metric is None:
self._best_metric = current_metric
self._best_model_sd = model.state_dict()
else:
is_best_model = (self.minimize_metric and self._best_metric > current_metric) or (
not self.minimize_metric and self._best_metric < current_metric
)
if is_best_model:
self._best_metric = current_metric
self._best_model_sd = model.state_dict()

def get_best_model_sd(self) -> Dict[str, Any]:
"""
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
Gets best model state dict.

Returns: state dict.
"""
if self._best_model_sd is None:
raise Exception(f"There is no best model.")
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
elephantmipt marked this conversation as resolved.
Show resolved Hide resolved
return self._best_model_sd


__all__ = ["BestModel"]