Skip to content

Commit

Permalink
added trace methon for AMP FP16 (#497)
Browse files Browse the repository at this point in the history
* trace in fp16

* update dost

* fix optimizer typing

* fix makefile

* fix with check_trace

* * fix codestyle
* prettify `travis/check-style` output

* fix tracing in SupervisedRunner

* fix style

* how this even worked?

* add tutorial

* rl req

* fix SimpleNet duplications

* weird

* wtf is going on

* isort

* registry?

* data_index

* fixed infer on FP16

* distributed without apex

* fix

* added load_traced_model

* fix docs

* add properties for model and device

* registry check
  • Loading branch information
TezRomacH authored and Scitator committed Nov 18, 2019
1 parent 6745360 commit 8c73121
Show file tree
Hide file tree
Showing 22 changed files with 2,114 additions and 1,458 deletions.
16 changes: 8 additions & 8 deletions Makefile
Expand Up @@ -10,27 +10,27 @@ check-docs:
bash ./bin/_check_docs.sh

docker: ./requirements/
echo building $${REPO_NAME:=catalyst-base}:$${TAG:=latest} ...
echo building $${REPO_NAME:-catalyst-base}:$${TAG:-latest} ...
docker build \
-t $${REPO_NAME}:$${TAG} . \
-t $${REPO_NAME:-catalyst-base}:$${TAG:-latest} . \
-f ./docker/Dockerfile --no-cache

docker-fp16: ./requirements/
echo building $${REPO_NAME:=catalyst-base-fp16}:$${TAG:=latest} ...
echo building $${REPO_NAME:-catalyst-base-fp16}:$${TAG:-latest} ...
docker build \
-t $${REPO_NAME}:$${TAG} . \
-t $${REPO_NAME:-catalyst-base-fp16}:$${TAG:-latest} . \
-f ./docker/Dockerfile-fp16 --no-cache

docker-dev: ./requirements/
echo building $${REPO_NAME:=catalyst-dev}:$${TAG:=latest} ...
echo building $${REPO_NAME:-catalyst-dev}:$${TAG:-latest} ...
docker build \
-t $${REPO_NAME}:$${TAG} . \
-t $${REPO_NAME:-catalyst-dev}:$${TAG:-latest} . \
-f ./docker/Dockerfile-dev --no-cache

docker-dev-fp16: ./requirements/
echo building $${REPO_NAME:=catalyst-dev-fp16}:$${TAG:=latest} ...
echo building $${REPO_NAME:-catalyst-dev-fp16}:$${TAG:-latest} ...
docker build \
-t $${REPO_NAME}:$${TAG} . \
-t $${REPO_NAME:-catalyst-dev-fp16}:$${TAG:-latest} . \
-f ./docker/Dockerfile-dev-fp16 --no-cache

install-from-source:
Expand Down
2 changes: 2 additions & 0 deletions bin/_check_codestyle.sh
Expand Up @@ -16,10 +16,12 @@ bash ./bin/flake8.sh --count \
--show-source --statistics

# exit-zero treats all errors as warnings.
echo '~ ~ ~ ~ ~ ~ ~ flake8 warnings ~ ~ ~ ~ ~ ~ ~' 1>&2
flake8 . --count --exit-zero \
--max-complexity=10 \
--config=./setup.cfg \
--statistics
echo '~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~' 1>&2

# test to make sure the code is yapf compliant
if [[ -f ${skip_inplace} ]]; then
Expand Down
4 changes: 2 additions & 2 deletions bin/tests/check_dl.sh
Expand Up @@ -32,7 +32,7 @@ fi

python -c """
from safitty import Safict
metrics=Safict.load('$LOGFILE')
metrics = Safict.load('$LOGFILE')
assert metrics.get('stage1.3', 'loss') < metrics.get('stage1.1', 'loss')
assert metrics.get('stage1.3', 'loss') < 2.0
"""
Expand All @@ -57,7 +57,7 @@ fi

python -c """
from safitty import Safict
metrics=Safict.load('$LOGFILE')
metrics = Safict.load('$LOGFILE')
assert metrics.get('stage1.3', 'loss') < metrics.get('stage1.1', 'loss')
assert metrics.get('stage1.3', 'loss') < 2.0
"""
Expand Down
34 changes: 29 additions & 5 deletions catalyst/dl/callbacks/optimizer.py
Expand Up @@ -8,7 +8,7 @@
from catalyst.dl.core import Callback, CallbackOrder, RunnerState
from catalyst.dl.registry import GRAD_CLIPPERS
from catalyst.dl.utils import get_optimizer_momentum, maybe_recursive_call
from catalyst.dl.utils.torch import _Optimizer
from catalyst.utils.typing import Optimizer

logger = logging.getLogger(__name__)

Expand All @@ -17,7 +17,6 @@ class OptimizerCallback(Callback):
"""
Optimizer callback, abstraction over optimizer step.
"""

def __init__(
self,
grad_clip_params: Dict = None,
Expand Down Expand Up @@ -51,10 +50,19 @@ def __init__(
@staticmethod
def grad_step(
*,
optimizer: _Optimizer,
optimizer: Optimizer,
optimizer_wds: List[float] = 0,
grad_clip_fn: Callable = None
):
"""
Makes a gradient step for a given optimizer
Args:
optimizer (Optimizer): the optimizer
optimizer_wds (List[float]): list of weight decay parameters
for each param group
grad_clip_fn (Callable): function for gradient clipping
"""
for group, wd in zip(optimizer.param_groups, optimizer_wds):
if wd > 0:
for param in group["params"]:
Expand All @@ -66,6 +74,7 @@ def grad_step(
optimizer.step()

def on_stage_start(self, state: RunnerState):
"""On stage start event"""
optimizer = state.get_key(
key="optimizer", inner_key=self.optimizer_key
)
Expand All @@ -76,6 +85,7 @@ def on_stage_start(self, state: RunnerState):
state.set_key(momentum, "momentum", inner_key=self.optimizer_key)

def on_epoch_start(self, state):
"""On epoch start event"""
optimizer = state.get_key(
key="optimizer", inner_key=self.optimizer_key
)
Expand Down Expand Up @@ -111,9 +121,11 @@ def _get_loss(self, state) -> torch.Tensor:
return loss

def on_batch_start(self, state):
"""On batch start event"""
state.loss = None

def on_batch_end(self, state):
"""On batch end event"""
if not state.need_backward:
return

Expand All @@ -125,18 +137,29 @@ def on_batch_end(self, state):
key="optimizer", inner_key=self.optimizer_key
)

need_gradient_step = \
(self._accumulation_counter + 1) % self.accumulation_steps == 0

# This is very hacky check whether we have AMP optimizer and this may
# change in future.
# But alternative solution is to have AmpOptimizerCallback.
# or expose another c'tor argument.
if hasattr(optimizer, "_amp_stash"):
from apex import amp
with amp.scale_loss(loss, optimizer) as scaled_loss:
# Need to set ``delay_unscale``
# according to
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
delay_unscale = not need_gradient_step
with amp.scale_loss(
loss,
optimizer,
delay_unscale=delay_unscale
) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()

if (self._accumulation_counter + 1) % self.accumulation_steps == 0:
if need_gradient_step:
self.grad_step(
optimizer=optimizer,
optimizer_wds=self._optimizer_wd,
Expand All @@ -147,6 +170,7 @@ def on_batch_end(self, state):
self._accumulation_counter = 0

def on_epoch_end(self, state):
"""On epoch end event"""
if self.decouple_weight_decay:
optimizer = state.get_key(
key="optimizer", inner_key=self.optimizer_key
Expand Down
61 changes: 54 additions & 7 deletions catalyst/dl/core/experiment.py
@@ -1,11 +1,11 @@
from typing import Any, Dict, Iterable, Mapping, Tuple # isort:skip
from typing import Any, Dict, Iterable, Mapping, Tuple, Union # isort:skip
from abc import ABC, abstractmethod
from collections import OrderedDict

from torch import nn
from torch.utils.data import DataLoader, Dataset

from catalyst.dl.utils.torch import _Criterion, _Model, _Optimizer, _Scheduler
from catalyst.utils.typing import Criterion, Model, Optimizer, Scheduler
from .callback import Callback


Expand All @@ -19,74 +19,121 @@ class Experiment(ABC):
@property
@abstractmethod
def initial_seed(self) -> int:
"""Experiment's initial seed value"""
pass

@property
@abstractmethod
def logdir(self) -> str:
"""Path to the directory where the experiment logs"""
pass

@property
@abstractmethod
def stages(self) -> Iterable[str]:
"""Experiment's stage names"""
pass

@property
@abstractmethod
def distributed_params(self) -> Dict:
"""Dict with the parameters for distributed and FP16 methond"""
pass

@property
@abstractmethod
def monitoring_params(self) -> Dict:
"""Dict with the parameters for monitoring services"""
pass

@abstractmethod
def get_state_params(self, stage: str) -> Mapping[str, Any]:
"""Returns the state parameters for a given stage"""
pass

@abstractmethod
def get_model(self, stage: str) -> _Model:
def get_model(self, stage: str) -> Model:
"""Returns the model for a given stage"""
pass

@abstractmethod
def get_criterion(self, stage: str) -> _Criterion:
def get_criterion(self, stage: str) -> Criterion:
"""Returns the criterion for a given stage"""
pass

@abstractmethod
def get_optimizer(self, stage: str, model: nn.Module) -> _Optimizer:
def get_optimizer(self, stage: str, model: Model) -> Optimizer:
"""Returns the optimizer for a given stage"""
pass

@abstractmethod
def get_scheduler(self, stage: str, optimizer) -> _Scheduler:
def get_scheduler(self, stage: str, optimizer: Optimizer) -> Scheduler:
"""Returns the scheduler for a given stage"""
pass

def get_experiment_components(
self, model: nn.Module, stage: str
) -> Tuple[_Criterion, _Optimizer, _Scheduler]:
) -> Tuple[Criterion, Optimizer, Scheduler]:
"""
Returns the tuple containing criterion, optimizer and scheduler by
giving model and stage.
"""
criterion = self.get_criterion(stage)
optimizer = self.get_optimizer(stage, model)
scheduler = self.get_scheduler(stage, optimizer)
return criterion, optimizer, scheduler

@abstractmethod
def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":
"""Returns the callbacks for a given stage"""
pass

def get_datasets(
self,
stage: str,
**kwargs,
) -> "OrderedDict[str, Dataset]":
"""Returns the datasets for a given stage and kwargs"""
raise NotImplementedError

@abstractmethod
def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
"""Returns the loaders for a given stage"""
raise NotImplementedError

@staticmethod
def get_transforms(stage: str = None, mode: str = None):
"""Returns the data transforms for a given stage and mode"""
raise NotImplementedError

def get_native_batch(
self,
stage: str,
loader: Union[str, int] = 0,
data_index: int = 0
):
"""Returns a batch from experiment loader
Args:
stage (str): stage name
loader (Union[str, int]): loader name or its index,
default is the first loader
data_index (int): index in dataset from the loader
"""
loaders = self.get_loaders(stage)
if isinstance(loader, str):
_loader = loaders[loader]
elif isinstance(loader, int):
_loader = list(loaders.values())[loader]
else:
raise TypeError("Loader parameter must be a string or an integer")

dataset = _loader.dataset
collate_fn = _loader.collate_fn

sample = collate_fn([dataset[data_index]])

return sample


__all__ = ["Experiment"]

0 comments on commit 8c73121

Please sign in to comment.