Skip to content

Commit

Permalink
feat: add callback to apply kornia augmentations (#861)
Browse files Browse the repository at this point in the history
  • Loading branch information
bagxi committed Jun 28, 2020
1 parent c3fcb48 commit b443c97
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 3 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- kornia augmentations `BatchTransformCallback` ([#862](https://github.com/catalyst-team/catalyst/issues/862))
- `log` parameter to `WandbLogger` ([#836](https://github.com/catalyst-team/catalyst/pull/836))
- hparams experiment property ([#839](https://github.com/catalyst-team/catalyst/pull/839))
- add docs build on push to master branch ([#844](https://github.com/catalyst-team/catalyst/pull/844))
Expand All @@ -34,7 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `global_*` counters in `Runner` ([#858](https://github.com/catalyst-team/catalyst/pull/858))
- EarlyStoppingCallback considers first epoch as bad
([#854](https://github.com/catalyst-team/catalyst/issues/854))


## [20.06] - 2020-06-04

Expand Down Expand Up @@ -171,4 +172,4 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
-
11 changes: 11 additions & 0 deletions catalyst/contrib/dl/callbacks/__init__.py
Expand Up @@ -24,6 +24,17 @@
)
raise ex

try:
import kornia
from .kornia_transform import BatchTransformCallback
except ImportError as ex:
if settings.cv_required:
logger.warning(
"some of catalyst-cv dependencies not available,"
" to install dependencies, run `pip install catalyst[cv]`."
)
raise ex

try:
import alchemy
from .alchemy_logger import AlchemyLogger
Expand Down
110 changes: 110 additions & 0 deletions catalyst/contrib/dl/callbacks/kornia_transform.py
@@ -0,0 +1,110 @@
from typing import Dict, Optional, Sequence, Tuple

from kornia import augmentation
import torch
from torch import nn

from catalyst.contrib.registry import TRANSFORMS
from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.core.runner import IRunner


class BatchTransformCallback(Callback):
"""Callback to perform data augmentations on GPU using kornia library.
Look at `Kornia: an Open Source Differentiable Computer Vision
Library for PyTorch`_ for details.
.. _`Kornia: an Open Source Differentiable Computer Vision Library
for PyTorch`: https://arxiv.org/pdf/1910.02190.pdf
"""

def __init__(
self,
transform: Sequence[dict],
input_key: str,
output_key: Optional[str] = None,
additional_input_key: Optional[str] = None,
additional_output_key: Optional[str] = None,
loader: Optional[str] = None,
) -> None:
"""Constructor method for the :class:`BatchTransformCallback` callback.
Args:
transform (Sequence[dict]): A sequence of dits with params
for each transform to apply. Must contain `transform` key
with augmentation name as a value. If augmentation is custom,
then you should add it to the TRANSFORMS registry first.
input_key (str): Key in batch dict mapping to to tranform,
e.g. `'image'`.
output_key (Optional[str]): Key to use to store the result
of transform. Defaults to `input_key` if not provided.
additional_input_key (Optional[str]): Key of additional target
in batch dict mapping to to tranform, e.g. `'mask'`.
additional_output_key (Optional[str]): Key to use to store
the result of additional target transform.
Defaults to `additional_input_key` if not provided.
loader (Optional[str]): Name of the loader on which items
transform should be applied. If `None`, transform going to be
applied for each loader.
"""
super().__init__(order=CallbackOrder.Internal, node=CallbackNode.all)

self.input_key = input_key
self.additional_input = additional_input_key
self._process_input = (
self._process_input_tuple
if self.additional_input is not None
else self._process_input_tensor
)

self.output_key = output_key or input_key
self.additional_output = additional_output_key or self.additional_input
self._process_output = (
self._process_output_tuple
if self.additional_output is not None
else self._process_output_tensor
)

transforms: Sequence[augmentation.AugmentationBase] = [
TRANSFORMS.get_from_params(**params) for params in transform
]
assert all(
isinstance(t, augmentation.AugmentationBase) for t in transforms
), "`kornia.AugmentationBase` should be a base class for transforms"

self.transform = nn.Sequential(*transforms)
self.loader = loader

def _process_input_tensor(self, input_: dict) -> torch.Tensor:
return input_[self.input_key]

def _process_input_tuple(
self, input_: dict
) -> Tuple[torch.Tensor, torch.Tensor]:
return input_[self.input_key], input_[self.additional_input]

def _process_output_tensor(
self, batch: Tuple[torch.Tensor, torch.Tensor]
) -> Dict[str, torch.Tensor]:
return {self.output_key: batch}

def _process_output_tuple(
self, batch: Tuple[torch.Tensor, torch.Tensor]
) -> Dict[str, torch.Tensor]:
out_t, additional_t = batch
return {self.output_key: out_t, self.additional_output: additional_t}

def on_batch_start(self, runner: IRunner) -> None:
"""Apply transforms.
Args:
runner (IRunner): Current runner.
"""
if self.loader is None or runner.loader_name == self.loader:
in_batch = self._process_input(runner.input)
out_batch = self.transform(in_batch)
runner.input.update(self._process_output(out_batch))


__all__ = ["BatchTransformCallback"]
4 changes: 4 additions & 0 deletions catalyst/contrib/registry.py
Expand Up @@ -19,6 +19,10 @@ def _transforms_loader(r: Registry):

r.add_from_module(p, prefix=["A.", "albu.", "albumentations."])

from kornia import augmentation as k

r.add_from_module(k, prefix=["kornia."])

from catalyst.contrib.data.cv import transforms as t

r.add_from_module(t, prefix=["catalyst.", "C."])
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-cv.txt
@@ -1,5 +1,6 @@
albumentations==0.4.3
imageio
kornia==0.3.1
opencv-python-headless
scikit-image>=0.14.2
segmentation-models-pytorch==0.1.0
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Expand Up @@ -43,7 +43,7 @@ reverse_relative = true
# - python libs (known_third_party)
# - dl libs (known_dl)
# - catalyst imports
known_dl = albumentations,gym,gym_minigrid,neptune,tensorboard,tensorboardX,tensorflow,torch,torchvision,transformers,wandb
known_dl = albumentations,gym,gym_minigrid,kornia,neptune,tensorboard,tensorboardX,tensorflow,torch,torchvision,transformers,wandb
known_first_party = catalyst
sections = STDLIB,THIRDPARTY,DL,FIRSTPARTY,LOCALFOLDER

Expand Down

0 comments on commit b443c97

Please sign in to comment.