From b443c979a66932c7d5975cef1f91e6c2fa270664 Mon Sep 17 00:00:00 2001 From: Yauheni Kachan <19803638+bagxi@users.noreply.github.com> Date: Sun, 28 Jun 2020 18:18:39 +0300 Subject: [PATCH] feat: add callback to apply kornia augmentations (#861) --- CHANGELOG.md | 5 +- catalyst/contrib/dl/callbacks/__init__.py | 11 ++ .../contrib/dl/callbacks/kornia_transform.py | 110 ++++++++++++++++++ catalyst/contrib/registry.py | 4 + requirements/requirements-cv.txt | 1 + setup.cfg | 2 +- 6 files changed, 130 insertions(+), 3 deletions(-) create mode 100644 catalyst/contrib/dl/callbacks/kornia_transform.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 37d2c3077d..4dbe257114 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- kornia augmentations `BatchTransformCallback` ([#862](https://github.com/catalyst-team/catalyst/issues/862)) - `log` parameter to `WandbLogger` ([#836](https://github.com/catalyst-team/catalyst/pull/836)) - hparams experiment property ([#839](https://github.com/catalyst-team/catalyst/pull/839)) - add docs build on push to master branch ([#844](https://github.com/catalyst-team/catalyst/pull/844)) @@ -34,7 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `global_*` counters in `Runner` ([#858](https://github.com/catalyst-team/catalyst/pull/858)) - EarlyStoppingCallback considers first epoch as bad ([#854](https://github.com/catalyst-team/catalyst/issues/854)) - + ## [20.06] - 2020-06-04 @@ -171,4 +172,4 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- \ No newline at end of file +- diff --git a/catalyst/contrib/dl/callbacks/__init__.py b/catalyst/contrib/dl/callbacks/__init__.py index 3a52e5ed21..e39a80faf1 100644 --- a/catalyst/contrib/dl/callbacks/__init__.py +++ b/catalyst/contrib/dl/callbacks/__init__.py @@ -24,6 +24,17 @@ ) raise ex +try: + import kornia + from .kornia_transform import BatchTransformCallback +except ImportError as ex: + if settings.cv_required: + logger.warning( + "some of catalyst-cv dependencies not available," + " to install dependencies, run `pip install catalyst[cv]`." + ) + raise ex + try: import alchemy from .alchemy_logger import AlchemyLogger diff --git a/catalyst/contrib/dl/callbacks/kornia_transform.py b/catalyst/contrib/dl/callbacks/kornia_transform.py new file mode 100644 index 0000000000..c256a004c4 --- /dev/null +++ b/catalyst/contrib/dl/callbacks/kornia_transform.py @@ -0,0 +1,110 @@ +from typing import Dict, Optional, Sequence, Tuple + +from kornia import augmentation +import torch +from torch import nn + +from catalyst.contrib.registry import TRANSFORMS +from catalyst.core.callback import Callback, CallbackNode, CallbackOrder +from catalyst.core.runner import IRunner + + +class BatchTransformCallback(Callback): + """Callback to perform data augmentations on GPU using kornia library. + + Look at `Kornia: an Open Source Differentiable Computer Vision + Library for PyTorch`_ for details. + + .. _`Kornia: an Open Source Differentiable Computer Vision Library + for PyTorch`: https://arxiv.org/pdf/1910.02190.pdf + """ + + def __init__( + self, + transform: Sequence[dict], + input_key: str, + output_key: Optional[str] = None, + additional_input_key: Optional[str] = None, + additional_output_key: Optional[str] = None, + loader: Optional[str] = None, + ) -> None: + """Constructor method for the :class:`BatchTransformCallback` callback. + + Args: + transform (Sequence[dict]): A sequence of dits with params + for each transform to apply. Must contain `transform` key + with augmentation name as a value. If augmentation is custom, + then you should add it to the TRANSFORMS registry first. + input_key (str): Key in batch dict mapping to to tranform, + e.g. `'image'`. + output_key (Optional[str]): Key to use to store the result + of transform. Defaults to `input_key` if not provided. + additional_input_key (Optional[str]): Key of additional target + in batch dict mapping to to tranform, e.g. `'mask'`. + additional_output_key (Optional[str]): Key to use to store + the result of additional target transform. + Defaults to `additional_input_key` if not provided. + loader (Optional[str]): Name of the loader on which items + transform should be applied. If `None`, transform going to be + applied for each loader. + """ + super().__init__(order=CallbackOrder.Internal, node=CallbackNode.all) + + self.input_key = input_key + self.additional_input = additional_input_key + self._process_input = ( + self._process_input_tuple + if self.additional_input is not None + else self._process_input_tensor + ) + + self.output_key = output_key or input_key + self.additional_output = additional_output_key or self.additional_input + self._process_output = ( + self._process_output_tuple + if self.additional_output is not None + else self._process_output_tensor + ) + + transforms: Sequence[augmentation.AugmentationBase] = [ + TRANSFORMS.get_from_params(**params) for params in transform + ] + assert all( + isinstance(t, augmentation.AugmentationBase) for t in transforms + ), "`kornia.AugmentationBase` should be a base class for transforms" + + self.transform = nn.Sequential(*transforms) + self.loader = loader + + def _process_input_tensor(self, input_: dict) -> torch.Tensor: + return input_[self.input_key] + + def _process_input_tuple( + self, input_: dict + ) -> Tuple[torch.Tensor, torch.Tensor]: + return input_[self.input_key], input_[self.additional_input] + + def _process_output_tensor( + self, batch: Tuple[torch.Tensor, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + return {self.output_key: batch} + + def _process_output_tuple( + self, batch: Tuple[torch.Tensor, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + out_t, additional_t = batch + return {self.output_key: out_t, self.additional_output: additional_t} + + def on_batch_start(self, runner: IRunner) -> None: + """Apply transforms. + + Args: + runner (IRunner): Current runner. + """ + if self.loader is None or runner.loader_name == self.loader: + in_batch = self._process_input(runner.input) + out_batch = self.transform(in_batch) + runner.input.update(self._process_output(out_batch)) + + +__all__ = ["BatchTransformCallback"] diff --git a/catalyst/contrib/registry.py b/catalyst/contrib/registry.py index 3fe90bb4bf..431e7bd47c 100644 --- a/catalyst/contrib/registry.py +++ b/catalyst/contrib/registry.py @@ -19,6 +19,10 @@ def _transforms_loader(r: Registry): r.add_from_module(p, prefix=["A.", "albu.", "albumentations."]) + from kornia import augmentation as k + + r.add_from_module(k, prefix=["kornia."]) + from catalyst.contrib.data.cv import transforms as t r.add_from_module(t, prefix=["catalyst.", "C."]) diff --git a/requirements/requirements-cv.txt b/requirements/requirements-cv.txt index 727b40b1ae..c02b9fd93f 100644 --- a/requirements/requirements-cv.txt +++ b/requirements/requirements-cv.txt @@ -1,5 +1,6 @@ albumentations==0.4.3 imageio +kornia==0.3.1 opencv-python-headless scikit-image>=0.14.2 segmentation-models-pytorch==0.1.0 diff --git a/setup.cfg b/setup.cfg index 678d4fe9cb..dd800eb6da 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,7 +43,7 @@ reverse_relative = true # - python libs (known_third_party) # - dl libs (known_dl) # - catalyst imports -known_dl = albumentations,gym,gym_minigrid,neptune,tensorboard,tensorboardX,tensorflow,torch,torchvision,transformers,wandb +known_dl = albumentations,gym,gym_minigrid,kornia,neptune,tensorboard,tensorboardX,tensorflow,torch,torchvision,transformers,wandb known_first_party = catalyst sections = STDLIB,THIRDPARTY,DL,FIRSTPARTY,LOCALFOLDER