-
-
Notifications
You must be signed in to change notification settings - Fork 385
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* creepy gans standalone separate examples: vanilla GAN, WGAN, WGAN-GP * wasserstein distance callback * conditional vanilla gans: runners & configs + dagan batch_sampler * autoencoder runners * somehow written BaseGANRunner + callbacks to update input batch * minor refactoring & clean up * generalized GANRunner to work with GAN, WGAN, WGAN-GP; + configs update * GANRunner extended to have arbitrary conditions; updated yaml configs; removed trash * tiny runners refactoring + documentation * removed autoencoders; wgan & wgan-gp discriminator renamed to critic * refactoring, beautification and docs * todos: merged with CriterionCallback * weighted criterion callback aggregation moved to core * readme update * todos upd * isort known third-party * partial codestyle =) * verified batch_sampler correctness and reproducibility * fixed gradient penalty; changed default runner to GANRunner * removed WeightedCriterionAggregationCallback & improved CriterionAggregationCallback * added GANRunner to catalyst.dl.runner * paper links * replaced batchtransform callbacks with simple transforms in config * wgan & wgan-gp upd * added cgan; all gans now can be run with same interface * removed trash & too complicated image-conditioned example * simplified GANRunner (assuming same conditioning of G and D) * improved gp (now conditional); moved many functionality to core * codestyle * codestyle v2 * tiny clear * naming Co-authored-by: Andrey Zharkov <andreyzharkov@mail.ru>
- Loading branch information
Showing
26 changed files
with
1,574 additions
and
216 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import torch | ||
from torch import nn | ||
|
||
|
||
class MeanOutputLoss(nn.Module): | ||
""" | ||
Criterion to compute simple mean of the output, completely ignoring target | ||
(maybe useful e.g. for WGAN real/fake validity averaging | ||
""" | ||
def forward(self, output, target): | ||
"""Compute criterion""" | ||
return output.mean() | ||
|
||
|
||
class GradientPenaltyLoss(nn.Module): | ||
"""Criterion to compute gradient penalty | ||
WARN: SHOULD NOT BE RUN WITH CriterionCallback, | ||
use special GradientPenaltyCallback instead | ||
""" | ||
def forward(self, fake_data, real_data, critic, critic_condition_args): | ||
"""Compute gradient penalty""" | ||
device = real_data.device | ||
# Random weight term for interpolation between real and fake samples | ||
alpha = torch.rand((real_data.size(0), 1, 1, 1), device=device) | ||
# Get random interpolation between real and fake samples | ||
interpolates = (alpha * real_data + ((1 - alpha) * fake_data)).detach() | ||
interpolates.requires_grad_(True) | ||
with torch.set_grad_enabled(True): # to compute in validation mode | ||
d_interpolates = critic(interpolates, *critic_condition_args) | ||
|
||
fake = torch.ones( | ||
(real_data.size(0), 1), device=device, requires_grad=False | ||
) | ||
# Get gradient w.r.t. interpolates | ||
gradients = torch.autograd.grad( | ||
outputs=d_interpolates, | ||
inputs=interpolates, | ||
grad_outputs=fake, | ||
create_graph=True, | ||
retain_graph=True, | ||
only_inputs=True, | ||
)[0] | ||
gradients = gradients.view(gradients.size(0), -1) | ||
gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() | ||
return gradient_penalty | ||
|
||
|
||
__all__ = ["MeanOutputLoss", "GradientPenaltyLoss"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
from typing import Any, Callable, Dict, List, Optional, Union # isort:skip | ||
|
||
from catalyst.dl.core import Callback, CallbackOrder, RunnerState | ||
from .criterion import CriterionCallback | ||
from .optimizer import OptimizerCallback | ||
|
||
|
||
""" | ||
MetricCallbacks alternatives for input/output keys | ||
""" | ||
|
||
|
||
class MultiKeyMetricCallback(Callback): | ||
""" | ||
A callback that returns single metric on `state.on_batch_end` | ||
""" | ||
|
||
# TODO: | ||
# merge it with MetricCallback in catalyst.core | ||
# maybe after the changes with CriterionCallback will be finalized | ||
# in the main repo | ||
def __init__( | ||
self, | ||
prefix: str, | ||
metric_fn: Callable, | ||
input_key: Optional[Union[str, List[str]]] = "targets", | ||
output_key: Optional[Union[str, List[str]]] = "logits", | ||
**metric_params | ||
): | ||
""" | ||
:param prefix: | ||
:param metric_fn: | ||
:param input_key: | ||
:param output_key: | ||
:param metric_params: | ||
""" | ||
super().__init__(CallbackOrder.Metric) | ||
self.prefix = prefix | ||
self.metric_fn = metric_fn | ||
self.input_key = input_key | ||
self.output_key = output_key | ||
self.metric_params = metric_params | ||
|
||
@staticmethod | ||
def _get(dictionary: dict, keys: Optional[Union[str, List[str]]]) -> Any: | ||
if keys is None: | ||
result = dictionary | ||
elif isinstance(keys, list): | ||
result = {key: dictionary[key] for key in keys} | ||
else: | ||
result = dictionary[keys] | ||
return result | ||
|
||
def on_batch_end(self, state: RunnerState): | ||
"""On batch end call""" | ||
outputs = self._get(state.output, self.output_key) | ||
targets = self._get(state.input, self.input_key) | ||
metric = self.metric_fn(outputs, targets, **self.metric_params) | ||
state.metrics.add_batch_value(name=self.prefix, value=metric) | ||
|
||
|
||
class WassersteinDistanceCallback(MultiKeyMetricCallback): | ||
""" | ||
Callback to compute Wasserstein distance metric | ||
""" | ||
def __init__( | ||
self, | ||
prefix: str = "wasserstein_distance", | ||
real_validity_output_key: str = "real_validity", | ||
fake_validity_output_key: str = "fake_validity" | ||
): | ||
""" | ||
:param prefix: | ||
:param real_validity_output_key: | ||
:param fake_validity_output_key: | ||
""" | ||
super().__init__( | ||
prefix, | ||
metric_fn=self.get_wasserstein_distance, | ||
input_key=None, | ||
output_key=[real_validity_output_key, fake_validity_output_key] | ||
) | ||
self.real_validity_key = real_validity_output_key | ||
self.fake_validity_key = fake_validity_output_key | ||
|
||
def get_wasserstein_distance(self, outputs, targets): | ||
""" | ||
Computes Wasserstein distance | ||
:param outputs: | ||
:param targets: | ||
:return: | ||
""" | ||
real_validity = outputs[self.real_validity_key] | ||
fake_validity = outputs[self.fake_validity_key] | ||
return real_validity.mean() - fake_validity.mean() | ||
|
||
|
||
""" | ||
CriterionCallback extended | ||
""" | ||
|
||
|
||
class GradientPenaltyCallback(CriterionCallback): | ||
""" | ||
Criterion Callback to compute Gradient Penalty | ||
""" | ||
|
||
def __init__(self, | ||
real_input_key: str = "data", | ||
fake_output_key: str = "fake_data", | ||
condition_keys: List[str] = None, | ||
critic_model_key: str = "critic", | ||
critic_criterion_key: str = "critic", | ||
real_data_criterion_key: str = "real_data", | ||
fake_data_criterion_key: str = "fake_data", | ||
condition_args_criterion_key: str = "critic_condition_args", | ||
prefix: str = "loss", | ||
criterion_key: str = None, | ||
multiplier: float = 1.0): | ||
""" | ||
:param real_input_key: real data key in state.input | ||
:param fake_output_key: fake data key in state.output | ||
:param condition_keys: all condition keys in state.input for critic | ||
:param critic_model_key: key for critic model in state.model | ||
:param critic_criterion_key: key for critic model in criterion | ||
:param real_data_criterion_key: key for real data in criterion | ||
:param fake_data_criterion_key: key for fake data in criterion | ||
:param condition_args_criterion_key: key for all condition args | ||
in criterion | ||
:param prefix: | ||
:param criterion_key: | ||
:param multiplier: | ||
""" | ||
super().__init__( | ||
input_key=real_input_key, | ||
output_key=fake_output_key, | ||
prefix=prefix, | ||
criterion_key=criterion_key, | ||
multiplier=multiplier | ||
) | ||
self.condition_keys = condition_keys or [] | ||
self.critic_model_key = critic_model_key | ||
self.critic_criterion_key = critic_criterion_key | ||
self.real_data_criterion_key = real_data_criterion_key | ||
self.fake_data_criterion_key = fake_data_criterion_key | ||
self.condition_args_criterion_key = condition_args_criterion_key | ||
|
||
def _compute_loss(self, state: RunnerState, criterion): | ||
criterion_kwargs = { | ||
self.real_data_criterion_key: state.input[self.input_key], | ||
self.fake_data_criterion_key: state.output[self.output_key], | ||
self.critic_criterion_key: state.model[self.critic_model_key], | ||
self.condition_args_criterion_key: [ | ||
state.input[key] for key in self.condition_keys | ||
] | ||
} | ||
return criterion(**criterion_kwargs) | ||
|
||
|
||
""" | ||
Optimizer Callback with weights clamp after update | ||
""" | ||
|
||
|
||
class WeightClampingOptimizerCallback(OptimizerCallback): | ||
""" | ||
Optimizer callback + weights clipping after step is finished | ||
""" | ||
def __init__( | ||
self, | ||
grad_clip_params: Dict = None, | ||
accumulation_steps: int = 1, | ||
optimizer_key: str = None, | ||
loss_key: str = "loss", | ||
decouple_weight_decay: bool = True, | ||
weight_clamp_value: float = 0.01 | ||
): | ||
""" | ||
:param grad_clip_params: | ||
:param accumulation_steps: | ||
:param optimizer_key: | ||
:param loss_key: | ||
:param decouple_weight_decay: | ||
:param weight_clamp_value: | ||
value to clamp weights after each optimization iteration | ||
Attention: will clamp WEIGHTS, not GRADIENTS | ||
""" | ||
super().__init__( | ||
grad_clip_params=grad_clip_params, | ||
accumulation_steps=accumulation_steps, | ||
optimizer_key=optimizer_key, | ||
loss_key=loss_key, | ||
decouple_weight_decay=decouple_weight_decay | ||
) | ||
self.weight_clamp_value = weight_clamp_value | ||
|
||
def on_batch_end(self, state): | ||
"""On batch end event""" | ||
super().on_batch_end(state) | ||
if not state.need_backward: | ||
return | ||
|
||
optimizer = state.get_key( | ||
key="optimizer", inner_key=self.optimizer_key | ||
) | ||
|
||
need_gradient_step = \ | ||
self._accumulation_counter % self.accumulation_steps == 0 | ||
|
||
if need_gradient_step: | ||
for group in optimizer.param_groups: | ||
for param in group["params"]: | ||
param.data.clamp_( | ||
min=-self.weight_clamp_value, | ||
max=self.weight_clamp_value | ||
) | ||
|
||
|
||
__all__ = [ | ||
"WassersteinDistanceCallback", | ||
"GradientPenaltyCallback", | ||
"WeightClampingOptimizerCallback" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
# flake8: noqa | ||
|
||
from .gan import GanRunner | ||
from .gan import GanRunner, MultiPhaseRunner | ||
from .supervised import SupervisedRunner | ||
|
||
from catalyst.contrib.runner import * # isort:skip |
Oops, something went wrong.