# Notebook to visualize the augmentation stacks used

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as transforms
from cbmi_utils.pytorch.datasets.kather import Kather224x224, Kather96x96
from PIL import Image
from pathlib import Path

In [None]:
# plt.rcParams["savefig.bbox"] = 'tight'
# if you change the seed, make sure that the randomly-applied transforms
# properly show that the image can be both transformed and *not* transformed!
torch.manual_seed(5)
px = 2
dataset = 'kather224'
orig_img = Image.open(Path('./plots/augmentations/TUM-YVGHNMGQ.tif'))

In [None]:
def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=True, figsize=(num_cols * px, num_rows * px))
    
    
    for row_idx, row in enumerate(imgs):
        row = [orig_img] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Image')
        axs[0, 0].title.set_size(9*px)
        for i in range(num_cols-1):
            axs[0, i+1].set(title=f'View {i+1}')
            axs[0, i+1].title.set_size(9*px)

    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set_ylabel(row_title[row_idx], fontsize = 9*px)

    # plt.subplots_adjust(top=0.2, bottom=0.1, right=0.1, left=0.005)
    plt.tight_layout()

In [None]:
def get_norm_values(ds: str):
    ds = ds.lower()
    if any(substr in ds for substr in ("kather_h5_224", "kather224", "kather_if_224")):
        mean, std = Kather224x224.normalization_values()
    elif any(substr in ds for substr in ("kather_h5_96", "kather96")):
        mean, std = Kather96x96.normalization_values()
    else:
        raise NotImplementedError(f'Request of normalization constants of {ds} is not implemented!')

    return mean, std
    

def light_stack(dataset, normalize: bool = False, img_size: int = 224):

    transform_pre_norm = transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomApply([
            transforms.ColorJitter(
                brightness=[0.9, 1.1],
                contrast=0,
                saturation=[0.7, 1.8],
                hue=0
            )  # not strengthened
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
    ])

    transform_post_norm = transforms.Compose([
        transforms.RandomApply([
            transforms.RandomCrop(img_size * .9),
            transforms.Resize(img_size)
        ], p=0.5),
        transforms.RandomApply([
            transforms.Resize(int(img_size * (2**0.5) + 0.9999)),
            transforms.RandomRotation(180)
        ], p=0.8),
        transforms.CenterCrop(img_size),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.ToPILImage()
    ])

    if normalize:
        mean, std = get_norm_values(dataset)
        return transforms.Compose([transform_pre_norm, transforms.Normalize(mean, std), transform_post_norm])
    else:
        return transforms.Compose([transform_pre_norm, transform_post_norm])


def medium_stack(dataset, normalize: bool = False, img_size: int = 224):

    transform_pre_norm = transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
    ])

    transform_post_norm = transforms.Compose([
        transforms.RandomApply([
            transforms.RandomCrop(img_size * .9),
            transforms.Resize(img_size)
        ], p=0.5),
        transforms.RandomApply([
            transforms.Resize(int(img_size * (2**0.5) + 0.9999)),
            transforms.RandomRotation(180)
        ], p=0.8),
        transforms.CenterCrop(img_size),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomApply([
            transforms.GaussianBlur(kernel_size=img_size // 20 * 2 + 1, sigma=(0.1, 2.0))
        ], p=0.5),
        transforms.ToPILImage()
    ])

    if normalize:
        mean, std = get_norm_values(dataset)
        return transforms.Compose([transform_pre_norm, transforms.Normalize(mean, std), transform_post_norm])
    else:
        return transforms.Compose([transform_pre_norm, transform_post_norm])


def moco_v2(dataset, normalize: bool = False, img_size: int = 224):
    augmentations = [
        transforms.ToTensor(),
        transforms.RandomResizedCrop(img_size, scale=(0.2, 1.0)),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        # We blur the image 50% of the time using a Gaussian kernel. We randomly sample σ ∈ [0.1, 2.0], and the kernel size is set to be 10% of the image height/width.
        transforms.RandomApply([
            transforms.GaussianBlur(kernel_size=img_size // 20 * 2 + 1, sigma=(0.1, 2.0))
        ], p=0.5),
        transforms.RandomHorizontalFlip()
    ]

    if normalize:
        mean, std = get_norm_values(dataset)
        augmentations.append(transforms.Normalize(mean, std))
    
    augmentations.append(transforms.ToPILImage())
    return transforms.Compose(augmentations)

import PIL

def paws(img_size=224, size=18, scale=(0.3, 0.75), normalize=False, color_distortion=0.5):
    def get_color_distortion(s=1.0):
        # s is the strength of color distortion.
        color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
        rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)

        def Solarize(img):
            v = np.random.uniform(0, 256)
            return PIL.ImageOps.solarize(img, v)
        solarize = transforms.Lambda(Solarize)
        rnd_solarize = transforms.RandomApply([solarize], p=0.2)

        def Equalize(img):
            return PIL.ImageOps.equalize(img)
        equalize = transforms.Lambda(Equalize)
        rnd_equalize = transforms.RandomApply([equalize], p=0.2)

        color_distort = transforms.Compose([
            rnd_color_jitter,
            rnd_solarize,
            rnd_equalize])
        return color_distort
    
    transform = transforms.Compose([
            transforms.Resize(int(img_size)),
            transforms.Resize(int(img_size*(2**0.5)+0.9999)),
            transforms.RandomRotation(180),
            transforms.RandomResizedCrop(size=size, scale=scale),
            transforms.RandomHorizontalFlip(),
            get_color_distortion(s=color_distortion)
        ])

    if normalize:
        transform = transforms.Compose([
            transform,
            transforms.Normalize(
                    (0.7455422, 0.52883655, 0.70516384),
                    (0.15424888, 0.20247863, 0.14497302)
                )
        ])  
    return transform


In [None]:
import torch
# noinspection PyPackageRequirements
import torchvision.transforms.functional as functional
# noinspection PyPackageRequirements
from torchvision import transforms as torch_transforms
from typing import Sequence, Tuple, Optional
import random


# copied from: https://github.com/pytorch/vision/issues/566#issuecomment-535854734
class RotateTransform:
    def __init__(self, angles: Sequence[int], use_flip=True):
        self.angles = angles
        self.flip = None
        if use_flip:
            self.flip = torch_transforms.RandomHorizontalFlip(p=0.5)

    def __call__(self, x):
        angle = random.choice(self.angles)
        if angle:
            return functional.rotate(x, angle)
        if self.flip:
            x = self.flip(x)
        return x


class RandomGaussianBlur:
    def __init__(self, kernel_size, p=0.5, sigma=(0.1, 2.0)):
        self.blur = torch_transforms.RandomApply(
            [torch_transforms.GaussianBlur(kernel_size, sigma)],
            p=p
        )

    def __call__(self, x):
        return self.blur(x)


class Noise:
    def __init__(self, strength):
        self.strength = strength

    def __call__(self, x):
        return x + torch.randn_like(x) * self.strength


class ColorDistort:
    def __init__(self, jitter_strength, jitter_p=0.8, grayscale_p=0.2):
        self.grayscale = torch_transforms.RandomGrayscale(p=grayscale_p)
        self.color_jitter = torch_transforms.RandomApply(
            [torch_transforms.ColorJitter(
                brightness=0.8*jitter_strength,
                saturation=0.8*jitter_strength,
                contrast=0.8*jitter_strength,
                hue=0.2*jitter_strength,
            )],
            p=jitter_p
        )

    def __call__(self, x):
        return self.grayscale(self.color_jitter(x))


class Clamp:
    def __call__(self, x):
        return x.clamp(0.0, 1.0)


class Cutout:
    def __init__(self, size: Optional[Tuple[int, int]] = None, color: float = 0.5, quadratic: bool = True):
        """
        Creates a new Cutout augmentation.

        :param size: A tuple (min, max) defining the size of the cutout. The actual size of the cutout is than randomly
                     sampled for each image (min <= size <= max). If not given, it defaults to (0, img_size).
        :param color: The brightness of the gray color that is used for the cutout. 0.0 means black and 1.0 means white.
                      Defaults to 0.5.
        :param quadratic: Whether the cutout should be quadratic. Defaults to True.
        """
        self.size = size
        self.color = color
        self.quadratic = quadratic

    def __call__(self, img):
        size = self.size
        if size is None:
            size = (0, min(img.size()[1], img.size()[2]))
        size_y = random.randint(size[0], size[1])
        pos_y = random.randint(0, img.size()[1] - size_y)
        if self.quadratic:
            size_x = size_y
        else:
            size_x = random.randint(size[0], size[1])
        pos_x = random.randint(0, img.size()[2] - size_x)
        img = img.detach().clone()
        img[:, pos_y:pos_y+size_y, pos_x:pos_x+size_x] = self.color
        return img



DEFAULT_COLOR_JITTER_BRIGHTNESS = 0.2
DEFAULT_COLOR_JITTER_SATURATION = 0.2
DEFAULT_COLOR_JITTER_HUE = 0.2
DEFAULT_RANDOM_RESIZED_CROP_SCALE = (0.08, 1.0)


def _entry_to_transform(entry, slide_size):
    if entry == 'color_distort':
        return ColorDistort(
            jitter_strength=1.0,
            jitter_p=0.8,
            grayscale_p=0.2,
        )
    elif entry == 'random_crop':
        return torch_transforms.RandomResizedCrop(
                size=slide_size,
                scale=DEFAULT_RANDOM_RESIZED_CROP_SCALE
            )
    elif entry == 'cutout':
        return Cutout(size=(40, 50))
    elif entry == 'blur':
        return RandomGaussianBlur(kernel_size=23, p=0.5)
    elif entry == 'rotate':
        return RotateTransform([0, 90, 180, 270])
    elif entry == 'noise':
        return Noise(strength=0.2)
    elif entry == 'affine':
        return torch_transforms.RandomAffine(
            degrees=(-180, 180), translate=None,
            scale=(0.7, 1.3), shear=(-10, 10, -10, 10)
        )
    elif entry == 'color_jitter':
        raise ValueError('color jitter was replaced by color_distort')
    elif entry == 'color_drop':
        raise ValueError('color drop was replaced by color_distort')
    else:
        raise ValueError('Could not load transform "{}"'.format(entry))


def description_to_transform(augmentations, slide_size, use_rotation, use_clamp, image_rescale_size):
    transforms_list = [torch_transforms.ToTensor()]
    if use_rotation:
        transforms_list.append(RotateTransform([0, 90, 180, 270]))
    for entry in augmentations:
        if use_rotation and entry == 'rotate':
            raise AssertionError('got rotate as augmentation and entry')
        transforms_list.append(_entry_to_transform(entry, slide_size))

    if use_clamp:
        transforms_list.append(Clamp())

    if image_rescale_size is not None:
        transforms_list.append(torch_transforms.Resize(image_rescale_size))

    transforms_list.append(torch_transforms.ToPILImage())
    return torch_transforms.Compose(transforms_list)


def make_unique(l):
    unique = []
    for i in l:
        if i not in unique:
            unique.append(i)
    return unique


def sim_clr(augmentations, slide_size, use_rotation=False, use_clamp=True, image_rescale_size=None):
    augmentations = make_unique(augmentations)
    return description_to_transform(augmentations, slide_size, use_rotation, use_clamp, image_rescale_size)
        


In [None]:

aug_stacks = [
    light_stack(dataset), 
    moco_v2(dataset), 
    medium_stack(dataset), 
    paws(
        size=96,
        scale=(0.6,0.9),
        color_distortion=0.5
        ),
    sim_clr(
            augmentations=('color_distort', 'random_crop'),
            slide_size=(224,224),
            use_rotation=True,
            image_rescale_size=(224, 224)
        )
    ]


stack_title = ['Custom', 'Moco_v2', 'Ad. Moco_v2', 'Ad. Paws', 'Ad. SimCLR']

imgs = [[aug(orig_img) for _ in range(8)] for aug in aug_stacks]

plot(imgs, row_title=stack_title)