Skip to content

Commit

Permalink
Squash Apex support into single commit
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed May 20, 2019
1 parent 9c3d932 commit aab3902
Show file tree
Hide file tree
Showing 9 changed files with 261 additions and 133 deletions.
57 changes: 19 additions & 38 deletions catalyst/dl/callbacks/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os
from typing import Dict

import safitty
import torch
from typing import Dict

from catalyst.contrib.scheduler import OneCycleLR, BatchScheduler
from catalyst.dl.fp16 import Fp16Wrap, copy_params, copy_grads
from catalyst.dl.state import RunnerState
from catalyst.dl.utils import UtilsFactory, get_optimizer_momentum
from catalyst.rl.registry import GRAD_CLIPPERS
Expand Down Expand Up @@ -144,7 +142,6 @@ class OptimizerCallback(Callback):
def __init__(
self,
grad_clip_params: Dict = None,
fp16_grad_scale: float = 128.0,
accumulation_steps: int = 1,
optimizer_key: str = None,
loss_key: str = None,
Expand All @@ -157,8 +154,6 @@ def __init__(
grad_clip_params = grad_clip_params or {}
self.grad_clip_fn = GRAD_CLIPPERS.get_from_params(**grad_clip_params)

self.fp16 = False
self.fp16_grad_scale = fp16_grad_scale
self.accumulation_steps = accumulation_steps
self.optimizer_key = optimizer_key
self.loss_key = loss_key
Expand All @@ -167,7 +162,6 @@ def __init__(
self._accumulation_counter = 0

def on_stage_start(self, state: RunnerState):
self.fp16 = isinstance(state.model, Fp16Wrap)
optimizer = state.get_key(
key="optimizer", inner_key=self.optimizer_key
)
Expand Down Expand Up @@ -201,6 +195,8 @@ def on_batch_start(self, state):

def on_batch_end(self, state):
loss = state.get_key(key="loss", inner_key=self.loss_key)
if isinstance(loss, dict):
loss = list(loss.values())
if isinstance(loss, list):
loss = torch.mean(torch.stack(loss))

Expand All @@ -213,45 +209,30 @@ def on_batch_end(self, state):
return

self._accumulation_counter += 1
if not self.fp16:
model = state.model
optimizer = state.get_key(
key="optimizer", inner_key=self.optimizer_key
)
loss.backward()
model = state.model
optimizer = state.get_key(
key="optimizer", inner_key=self.optimizer_key
)

if (self._accumulation_counter + 1) % self.accumulation_steps == 0:
self.grad_step(
optimizer=optimizer,
optimizer_wd=self._optimizer_wd,
grad_clip_fn=self.grad_clip_fn
)
model.zero_grad()
self._accumulation_counter = 0
# This is very hacky check whether we have AMP optimizer and this may
# change in future.
# But alternative solution is to have AmpOptimizerCallback.
# or expose another c'tor argument.
if hasattr(optimizer, '_amp_stash'):
from apex import amp
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
model = state.model
model.zero_grad()
optimizer = state.get_key(
key="optimizer", inner_key=self.optimizer_key
)
loss = state.get_key(key="loss", inner_key=self.optimizer_key)
scaled_loss = self.fp16_grad_scale * loss.float()
scaled_loss.backward()
loss.backward()

master_params = list(optimizer.param_groups[0]["params"])
model_params = list(
filter(lambda p: p.requires_grad, model.parameters())
)
copy_grads(source=model_params, target=master_params)
for param in master_params:
param.grad.data.mul_(1. / self.fp16_grad_scale)
if (self._accumulation_counter + 1) % self.accumulation_steps == 0:
self.grad_step(
optimizer=optimizer,
optimizer_wd=self._optimizer_wd,
grad_clip_fn=self.grad_clip_fn
)
copy_params(source=master_params, target=model_params)
torch.cuda.synchronize()
model.zero_grad()
self._accumulation_counter = 0

def on_epoch_end(self, state):
optimizer = state.get_key(
Expand Down
61 changes: 38 additions & 23 deletions catalyst/dl/experiments/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
from abc import abstractmethod, ABC
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset # noqa F401
from typing import Iterable, Any, Mapping, Dict, List
from typing import Iterable, Any, Mapping, Dict, List, Tuple

from catalyst.dl.registry import \
MODELS, CRITERIONS, OPTIMIZERS, SCHEDULERS, CALLBACKS
from catalyst.dl import utils
from catalyst.dl.callbacks import Callback # noqa F401
from catalyst.dl.callbacks import \
LossCallback, OptimizerCallback, SchedulerCallback, CheckpointCallback
from catalyst.dl.fp16 import Fp16Wrap
from catalyst.dl.utils import UtilsFactory
from catalyst.utils.misc import merge_dicts

Expand Down Expand Up @@ -53,7 +52,8 @@ def get_criterion(self, stage: str) -> _Criterion:
pass

@abstractmethod
def get_optimizer(self, stage: str, model) -> _Optimizer:
def get_optimizer_and_model(self, stage: str,
model) -> Tuple[_Optimizer, _Model]:
pass

@abstractmethod
Expand All @@ -62,18 +62,18 @@ def get_scheduler(self, stage: str, optimizer) -> _Scheduler:

def get_model_stuff(self, model, stage: str):
criterion = self.get_criterion(stage)
optimizer = self.get_optimizer(stage, model)
optimizer, model = self.get_optimizer_and_model(stage, model)
scheduler = self.get_scheduler(stage, optimizer)
return criterion, optimizer, scheduler
return criterion, optimizer, scheduler, model

@abstractmethod
def get_callbacks(self, stage: str) -> "List[Callback]":
pass

def get_datasets(
self,
stage: str,
**kwargs,
self,
stage: str,
**kwargs,
) -> "OrderedDict[str, Dataset]":
raise NotImplementedError

Expand Down Expand Up @@ -158,8 +158,9 @@ def get_model(self, stage: str) -> _Model:
def get_criterion(self, stage: str) -> _Criterion:
return self._criterion

def get_optimizer(self, stage: str, model=None) -> _Optimizer:
return self._optimizer
def get_optimizer_and_model(self, stage: str,
model=None) -> Tuple[_Optimizer, _Model]:
return self._optimizer, model

def get_scheduler(self, stage: str, optimizer=None) -> _Scheduler:
return self._scheduler
Expand Down Expand Up @@ -269,14 +270,8 @@ def _postprocess_model_for_stage(self, stage: str, model: _Model):

def get_model(self, stage: str) -> _Model:
model_params = self._config["model_params"]
fp16 = model_params.pop("fp16", False)

model = MODELS.get_from_params(**model_params)

if fp16:
utils.assert_fp16_available()
model = Fp16Wrap(model)

model = self._preprocess_model_for_stage(stage, model)
model = self._postprocess_model_for_stage(stage, model)
return model
Expand Down Expand Up @@ -328,15 +323,35 @@ def _get_optimizer(self, *, model_params, **params):

return optimizer

def get_optimizer(self, stage: str, model: nn.Module) -> _Optimizer:
fp16 = isinstance(model, Fp16Wrap)
model_params = utils.prepare_optimizable_params(
model.parameters(), fp16)
def get_optimizer_and_model(self, stage: str,
model: nn.Module) -> [_Optimizer, _Model]:

model_params = utils.prepare_optimizable_params(model.parameters())
optimizer_params = \
self.stages_config[stage].get("optimizer_params", {})

fp16 = optimizer_params.get("fp16", False)
fp16_opt_level = optimizer_params.get("fp16_opt_level", "O1")

# Prevent leaking fp16-related params to optimizer factory
optimizer_params = dict((k, v) for k, v in optimizer_params.items()
if k not in {'fp16', 'fp16_opt_level'})

optimizer = self._get_optimizer(
model_params=model_params, **optimizer_params)
return optimizer

if fp16:
utils.assert_fp16_available()
from apex import amp
if fp16_opt_level not in {"O1", "O2", "O3"}:
raise ValueError("fp16 mode must be one of O1, O2, O3")

model, optimizer = amp.initialize(model, optimizer,
opt_level=fp16_opt_level)
elif torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)

return optimizer, model

@staticmethod
def _get_scheduler(*, optimizer, **params):
Expand Down Expand Up @@ -391,8 +406,8 @@ def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
assert "dataset" in ds_, \
"You need to specify dataset for dataloader"
loader_params["shuffle"] = (
name.startswith("train")
and ds_.get("sampler") is None)
name.startswith("train")
and ds_.get("sampler") is None)
loader_params = merge_dicts(ds_, loader_params)
else:
raise NotImplementedError
Expand Down
13 changes: 10 additions & 3 deletions catalyst/dl/experiments/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ def _prepare_state(self, stage: str):
})

self._prepare_model(stage)
criterion, optimizer, scheduler = \
criterion, optimizer, scheduler, model = \
self.experiment.get_model_stuff(self.model, stage)

self.state = RunnerState(
stage=stage,
model=self.model,
model=model,
device=self.device,
criterion=criterion,
optimizer=optimizer,
Expand Down Expand Up @@ -234,7 +234,14 @@ def _batch2device(self, batch: Mapping[str, Any], device):

def predict_batch(self, batch: Mapping[str, Any]):
output = self.model(batch[self.input_key])
output = {self.output_key: output}
if isinstance(output, dict):
pass
elif isinstance(output, (list, tuple)) \
and isinstance(self.output_key, list):
output = dict((key, value) for key, value in
zip(self.output_key, output))
else:
output = {self.output_key: output}
return output

def train(
Expand Down
43 changes: 0 additions & 43 deletions catalyst/dl/fp16.py

This file was deleted.

0 comments on commit aab3902

Please sign in to comment.