In [None]:
#default_exp tta
from nbdev.showdoc import show_doc

# Test Time Augmentation

> Code adapted from https://github.com/qubvel/ttach.

In [None]:
#hide
from fastcore.test import *
from fastai.torch_core import TensorImage, TensorMask

In [None]:
#export
import torch
import itertools
from functools import partial
from typing import List, Optional, Union
from fastcore.foundation import store_attr

## Functional

In [None]:
#export
def rot90(x, k=1):
    "rotate batch of images by 90 degrees k times"
    return torch.rot90(x, k, (2, 3))

def hflip(x):
    "flip batch of images horizontally"
    return x.flip(3)

def vflip(x):
    "flip batch of images vertically"
    return x.flip(2)

## Base Classes

In [None]:
#export
class BaseTransform:
    identity_param = None
    def __init__(self, pname: str, params: Union[list, tuple]): store_attr()

class Chain:
    def __init__(self, functions: List[callable]):
        self.functions = functions or []

    def __call__(self, x):
        for f in self.functions:
            x = f(x)
        return x

class Transformer:
    def __init__(self, image_pipeline: Chain, mask_pipeline: Chain):
        store_attr()
        
    def augment_image(self, image):
        return self.image_pipeline(image)

    def deaugment_mask(self, mask):
        return self.mask_pipeline(mask)

class Compose:
    def __init__(self, aug_transforms: List[BaseTransform]):
        store_attr()
        self.aug_transform_parameters = list(itertools.product(*[t.params for t in self.aug_transforms]))
        self.deaug_transforms = aug_transforms[::-1]
        self.deaug_transform_parameters = [p[::-1] for p in self.aug_transform_parameters]

    def __iter__(self) -> Transformer:
        for aug_params, deaug_params in zip(self.aug_transform_parameters, self.deaug_transform_parameters):
            image_aug_chain = Chain([partial(t.apply_aug_image, **{t.pname: p})
                                     for t, p in zip(self.aug_transforms, aug_params)])
            mask_deaug_chain = Chain([partial(t.apply_deaug_mask, **{t.pname: p})
                                      for t, p in zip(self.deaug_transforms, deaug_params)])
            yield Transformer(image_pipeline=image_aug_chain, mask_pipeline=mask_deaug_chain)

    def __len__(self) -> int:
        return len(self.aug_transform_parameters)

In [None]:
#export
class Merger:
    def __init__(self):
        self.output = []
        
    def append(self, x):
        self.output.append(torch.as_tensor(x))
            
    def result(self, type='mean'):
        s = torch.stack(self.output)
        if type == 'max':
            result = torch.max(s, dim=0)[0]
        elif type == 'mean':
            result = torch.mean(s, dim=0)
        elif type ==  'std':
            result = torch.std(s, dim=0)
        else:
            raise ValueError('Not correct merge type `{}`.'.format(self.type))
        return result

In [None]:
imgs = TensorImage(torch.randn(4, 1, 356, 356))
for t in ['mean', 'max', 'std']:
    m = Merger()
    for _ in range(10): m.append(imgs)    
    test_eq(imgs.shape, m.result(t).shape)

## Transform Classes

In [None]:
#export
class HorizontalFlip(BaseTransform):
    "Flip images horizontally (left->right)"
    identity_param = False
    def __init__(self):
        super().__init__("apply", [False, True])

    def apply_aug_image(self, image, apply=False, **kwargs):
        if apply: image = hflip(image)
        return image

    def apply_deaug_mask(self, mask, apply=False, **kwargs):
        if apply: mask = hflip(mask)
        return mask

In [None]:
t = HorizontalFlip()
aug = t.apply_aug_image(imgs)
deaug = t.apply_deaug_mask(aug)
test_eq(imgs, deaug)

In [None]:
#export
class VerticalFlip(BaseTransform):
    "Flip images vertically (up->down)"
    identity_param = False
    def __init__(self):
        super().__init__("apply", [False, True])

    def apply_aug_image(self, image, apply=False, **kwargs):
        if apply: image = vflip(image)
        return image

    def apply_deaug_mask(self, mask, apply=False, **kwargs):
        if apply: mask = vflip(mask)
        return mask

In [None]:
t = VerticalFlip()
aug = t.apply_aug_image(imgs)
deaug = t.apply_deaug_mask(aug)
test_eq(imgs, deaug)

In [None]:
#export
class Rotate90(BaseTransform):
    "Rotate images 0/90/180/270 degrees (`angles`)"
    identity_param = 0
    def __init__(self, angles: List[int]):
        if self.identity_param not in angles:
            angles = [self.identity_param] + list(angles)
        super().__init__("angle", angles)

    def apply_aug_image(self, image, angle=0, **kwargs):
        k = angle // 90 if angle >= 0 else (angle + 360) // 90
        return rot90(image, k)

    def apply_deaug_mask(self, mask, angle=0, **kwargs):
        return self.apply_aug_image(mask, -angle)

In [None]:
t = Rotate90([180])
aug = t.apply_aug_image(imgs)
deaug = t.apply_deaug_mask(aug)
test_eq(imgs, deaug)

Pipeline Test

In [None]:
tfms=[HorizontalFlip(),VerticalFlip(), Rotate90(angles=[90,180,270])]
c = Compose(tfms)
m = Merger()
for t in c:
    aug = t.augment_image(imgs)
    deaug = t.deaugment_mask(aug)
    test_eq(imgs, deaug)
    m.append(deaug)
test_close(imgs, m.result())

## Export -

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_learner.ipynb.
Converted 01_models.ipynb.
Converted 02_data.ipynb.
Converted 03_metrics.ipynb.
Converted 04_callbacks.ipynb.
Converted 05_losses.ipynb.
Converted 06_utils.ipynb.
Converted 07_tta.ipynb.
Converted 08_gui.ipynb.
Converted add_information.ipynb.
Converted gt_estimation.ipynb.
Converted index.ipynb.
Converted model_library.ipynb.
Converted predict.ipynb.
Converted train-Copy1.ipynb.
Converted train.ipynb.
