Skip to content

Commit

Permalink
Asmekal catalyst.gan.v0.1 (#607)
Browse files Browse the repository at this point in the history
* creepy gans standalone separate examples: vanilla GAN, WGAN, WGAN-GP

* wasserstein distance callback

* conditional vanilla gans: runners & configs + dagan batch_sampler

* autoencoder runners

* somehow written BaseGANRunner + callbacks to update input batch

* minor refactoring & clean up

* generalized GANRunner to work with GAN, WGAN, WGAN-GP; + configs update

* GANRunner extended to have arbitrary conditions; updated yaml configs; removed trash

* tiny runners refactoring + documentation

* removed autoencoders; wgan & wgan-gp discriminator renamed to critic

* refactoring, beautification and docs

* todos: merged with CriterionCallback

* weighted criterion callback aggregation moved to core

* readme update

* todos upd

* isort known third-party

* partial codestyle =)

* verified batch_sampler correctness and reproducibility

* fixed gradient penalty; changed default runner to GANRunner

* removed WeightedCriterionAggregationCallback & improved CriterionAggregationCallback

* added GANRunner to catalyst.dl.runner

* paper links

* replaced batchtransform callbacks with simple transforms in config

* wgan & wgan-gp upd

* added cgan; all gans now can be run with same interface

* removed trash & too complicated image-conditioned example

* simplified GANRunner (assuming same conditioning of G and D)

* improved gp (now conditional); moved many functionality to core

* codestyle

* codestyle v2

* tiny clear

* naming

Co-authored-by: Andrey Zharkov <andreyzharkov@mail.ru>
  • Loading branch information
Scitator and asmekal committed Jan 28, 2020
1 parent e9905e9 commit b22f990
Show file tree
Hide file tree
Showing 26 changed files with 1,574 additions and 216 deletions.
1 change: 1 addition & 0 deletions catalyst/contrib/criterion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from .dice import BCEDiceLoss, DiceLoss
from .focal import FocalLossBinary, FocalLossMultiClass
from .gan import GradientPenaltyLoss, MeanOutputLoss
from .huber import HuberLoss
from .iou import BCEIoULoss, IoULoss
from .lovasz import (
Expand Down
49 changes: 49 additions & 0 deletions catalyst/contrib/criterion/gan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from torch import nn


class MeanOutputLoss(nn.Module):
"""
Criterion to compute simple mean of the output, completely ignoring target
(maybe useful e.g. for WGAN real/fake validity averaging
"""
def forward(self, output, target):
"""Compute criterion"""
return output.mean()


class GradientPenaltyLoss(nn.Module):
"""Criterion to compute gradient penalty
WARN: SHOULD NOT BE RUN WITH CriterionCallback,
use special GradientPenaltyCallback instead
"""
def forward(self, fake_data, real_data, critic, critic_condition_args):
"""Compute gradient penalty"""
device = real_data.device
# Random weight term for interpolation between real and fake samples
alpha = torch.rand((real_data.size(0), 1, 1, 1), device=device)
# Get random interpolation between real and fake samples
interpolates = (alpha * real_data + ((1 - alpha) * fake_data)).detach()
interpolates.requires_grad_(True)
with torch.set_grad_enabled(True): # to compute in validation mode
d_interpolates = critic(interpolates, *critic_condition_args)

fake = torch.ones(
(real_data.size(0), 1), device=device, requires_grad=False
)
# Get gradient w.r.t. interpolates
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
return gradient_penalty


__all__ = ["MeanOutputLoss", "GradientPenaltyLoss"]
4 changes: 4 additions & 0 deletions catalyst/dl/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
CriterionAggregatorCallback, CriterionCallback,
CriterionOutputOnlyCallback
)
from .gan import (
GradientPenaltyCallback, WassersteinDistanceCallback,
WeightClampingOptimizerCallback
)
from .inference import InferCallback, InferMaskCallback
from .logging import (
ConsoleLogger, TelegramLogger, TensorboardLogger, VerboseLogger
Expand Down
31 changes: 24 additions & 7 deletions catalyst/dl/callbacks/criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def __init__(
self._get_output = utils.get_dictkey_auto_fn(self.output_key)
kv_types = (dict, tuple, list, type(None))
# @TODO: fix to only KV usage
if isinstance(self.input_key, str) \
if hasattr(self, "_compute_loss"):
pass # overridden in descendants
elif isinstance(self.input_key, str) \
and isinstance(self.output_key, str):
self._compute_loss = self._compute_loss_value
elif isinstance(self.input_key, kv_types) \
Expand Down Expand Up @@ -186,20 +188,35 @@ def __init__(
self.loss_keys = loss_keys

self.multiplier = multiplier
if loss_aggregate_fn == "sum":
self.loss_fn = lambda x: torch.sum(torch.stack(x)) * multiplier
elif loss_aggregate_fn == "weighted_sum":

if loss_keys in ("sum", "mean"):
if loss_keys is not None and not isinstance(loss_keys, list):
raise ValueError(
"For `sum` or `mean` mode the loss_keys must be "
"None or list or str (not dict)"
)
elif loss_keys in ("weighted_sum", "weighted_mean"):
if loss_keys is None or not isinstance(loss_keys, dict):
raise ValueError(
"For `weighted_sum` mode the loss_keys must be specified "
"For `weighted_sum` or `weighted_mean` mode "
"the loss_keys must be specified "
"and must be a dict"
)

if loss_aggregate_fn in ("sum", "weighted_sum", "weighted_mean"):
self.loss_fn = lambda x: torch.sum(torch.stack(x)) * multiplier
if loss_aggregate_fn == "weighted_mean":
weights_sum = sum(loss_keys.items())
self.loss_keys = {
key: weight / weights_sum
for key, weight in loss_keys.items()
}
elif loss_aggregate_fn == "mean":
self.loss_fn = lambda x: torch.mean(torch.stack(x)) * multiplier
else:
raise ValueError(
"loss_aggregate_fn must be `sum`, `mean` or weighted_sum`"
"loss_aggregate_fn must be `sum`, `mean` "
"or `weighted_sum` or `weighted_mean`"
)

self.loss_aggregate_name = loss_aggregate_fn
Expand Down Expand Up @@ -248,5 +265,5 @@ def on_batch_end(self, state: RunnerState) -> None:
__all__ = [
"CriterionCallback",
"CriterionOutputOnlyCallback",
"CriterionAggregatorCallback",
"CriterionAggregatorCallback"
]
227 changes: 227 additions & 0 deletions catalyst/dl/callbacks/gan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
from typing import Any, Callable, Dict, List, Optional, Union # isort:skip

from catalyst.dl.core import Callback, CallbackOrder, RunnerState
from .criterion import CriterionCallback
from .optimizer import OptimizerCallback


"""
MetricCallbacks alternatives for input/output keys
"""


class MultiKeyMetricCallback(Callback):
"""
A callback that returns single metric on `state.on_batch_end`
"""

# TODO:
# merge it with MetricCallback in catalyst.core
# maybe after the changes with CriterionCallback will be finalized
# in the main repo
def __init__(
self,
prefix: str,
metric_fn: Callable,
input_key: Optional[Union[str, List[str]]] = "targets",
output_key: Optional[Union[str, List[str]]] = "logits",
**metric_params
):
"""
:param prefix:
:param metric_fn:
:param input_key:
:param output_key:
:param metric_params:
"""
super().__init__(CallbackOrder.Metric)
self.prefix = prefix
self.metric_fn = metric_fn
self.input_key = input_key
self.output_key = output_key
self.metric_params = metric_params

@staticmethod
def _get(dictionary: dict, keys: Optional[Union[str, List[str]]]) -> Any:
if keys is None:
result = dictionary
elif isinstance(keys, list):
result = {key: dictionary[key] for key in keys}
else:
result = dictionary[keys]
return result

def on_batch_end(self, state: RunnerState):
"""On batch end call"""
outputs = self._get(state.output, self.output_key)
targets = self._get(state.input, self.input_key)
metric = self.metric_fn(outputs, targets, **self.metric_params)
state.metrics.add_batch_value(name=self.prefix, value=metric)


class WassersteinDistanceCallback(MultiKeyMetricCallback):
"""
Callback to compute Wasserstein distance metric
"""
def __init__(
self,
prefix: str = "wasserstein_distance",
real_validity_output_key: str = "real_validity",
fake_validity_output_key: str = "fake_validity"
):
"""
:param prefix:
:param real_validity_output_key:
:param fake_validity_output_key:
"""
super().__init__(
prefix,
metric_fn=self.get_wasserstein_distance,
input_key=None,
output_key=[real_validity_output_key, fake_validity_output_key]
)
self.real_validity_key = real_validity_output_key
self.fake_validity_key = fake_validity_output_key

def get_wasserstein_distance(self, outputs, targets):
"""
Computes Wasserstein distance
:param outputs:
:param targets:
:return:
"""
real_validity = outputs[self.real_validity_key]
fake_validity = outputs[self.fake_validity_key]
return real_validity.mean() - fake_validity.mean()


"""
CriterionCallback extended
"""


class GradientPenaltyCallback(CriterionCallback):
"""
Criterion Callback to compute Gradient Penalty
"""

def __init__(self,
real_input_key: str = "data",
fake_output_key: str = "fake_data",
condition_keys: List[str] = None,
critic_model_key: str = "critic",
critic_criterion_key: str = "critic",
real_data_criterion_key: str = "real_data",
fake_data_criterion_key: str = "fake_data",
condition_args_criterion_key: str = "critic_condition_args",
prefix: str = "loss",
criterion_key: str = None,
multiplier: float = 1.0):
"""
:param real_input_key: real data key in state.input
:param fake_output_key: fake data key in state.output
:param condition_keys: all condition keys in state.input for critic
:param critic_model_key: key for critic model in state.model
:param critic_criterion_key: key for critic model in criterion
:param real_data_criterion_key: key for real data in criterion
:param fake_data_criterion_key: key for fake data in criterion
:param condition_args_criterion_key: key for all condition args
in criterion
:param prefix:
:param criterion_key:
:param multiplier:
"""
super().__init__(
input_key=real_input_key,
output_key=fake_output_key,
prefix=prefix,
criterion_key=criterion_key,
multiplier=multiplier
)
self.condition_keys = condition_keys or []
self.critic_model_key = critic_model_key
self.critic_criterion_key = critic_criterion_key
self.real_data_criterion_key = real_data_criterion_key
self.fake_data_criterion_key = fake_data_criterion_key
self.condition_args_criterion_key = condition_args_criterion_key

def _compute_loss(self, state: RunnerState, criterion):
criterion_kwargs = {
self.real_data_criterion_key: state.input[self.input_key],
self.fake_data_criterion_key: state.output[self.output_key],
self.critic_criterion_key: state.model[self.critic_model_key],
self.condition_args_criterion_key: [
state.input[key] for key in self.condition_keys
]
}
return criterion(**criterion_kwargs)


"""
Optimizer Callback with weights clamp after update
"""


class WeightClampingOptimizerCallback(OptimizerCallback):
"""
Optimizer callback + weights clipping after step is finished
"""
def __init__(
self,
grad_clip_params: Dict = None,
accumulation_steps: int = 1,
optimizer_key: str = None,
loss_key: str = "loss",
decouple_weight_decay: bool = True,
weight_clamp_value: float = 0.01
):
"""
:param grad_clip_params:
:param accumulation_steps:
:param optimizer_key:
:param loss_key:
:param decouple_weight_decay:
:param weight_clamp_value:
value to clamp weights after each optimization iteration
Attention: will clamp WEIGHTS, not GRADIENTS
"""
super().__init__(
grad_clip_params=grad_clip_params,
accumulation_steps=accumulation_steps,
optimizer_key=optimizer_key,
loss_key=loss_key,
decouple_weight_decay=decouple_weight_decay
)
self.weight_clamp_value = weight_clamp_value

def on_batch_end(self, state):
"""On batch end event"""
super().on_batch_end(state)
if not state.need_backward:
return

optimizer = state.get_key(
key="optimizer", inner_key=self.optimizer_key
)

need_gradient_step = \
self._accumulation_counter % self.accumulation_steps == 0

if need_gradient_step:
for group in optimizer.param_groups:
for param in group["params"]:
param.data.clamp_(
min=-self.weight_clamp_value,
max=self.weight_clamp_value
)


__all__ = [
"WassersteinDistanceCallback",
"GradientPenaltyCallback",
"WeightClampingOptimizerCallback"
]
2 changes: 1 addition & 1 deletion catalyst/dl/core/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __init__(
Make sure that they are in the same order that metrics
are outputted by the meters in `meter_list`
meter_list (list-like): List of meters.meter.Meter instances
len(meter_list) == n_classes
len(meter_list) == num_classes
input_key (str): input key to use for metric calculation
specifies our ``y_true``.
output_key (str): output key to use for metric calculation;
Expand Down
2 changes: 1 addition & 1 deletion catalyst/dl/runner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# flake8: noqa

from .gan import GanRunner
from .gan import GanRunner, MultiPhaseRunner
from .supervised import SupervisedRunner

from catalyst.contrib.runner import * # isort:skip

0 comments on commit b22f990

Please sign in to comment.