In [33]:
import os
import json
import random
from typing import List, Tuple, Dict, Union

import numpy as np
from PIL import Image, ImageDraw, ImageFont
from glob import glob
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import albumentations as A

In [34]:
test_img_path = sorted(glob('../data/train/*/*/*.png'))[0]

In [35]:
test_img = Image.open(test_img_path).convert("RGB")

In [44]:
augmentation = {
    "Original": A.Compose([A.NoOp()]),
    "ChannelShuffle": A.ChannelShuffle(p=1),
    "CLAHE": A.CLAHE(p=1),
    "ColorJitter": A.ColorJitter(p=1),
    "Emboss": A.Emboss(p=1),
    "GaussNoise": A.GaussNoise(p=1),
    "RandomBrightness": A.RandomBrightness(p=1),
    "RandomContrast": A.RandomContrast(p=1),
    "RandomBrightnessContrast": A.RandomBrightnessContrast(p=1),
    "UnsharpMask": A.UnsharpMask(p=1),
    "Equalize": A.Equalize(p=1),
    "Rotate": A.Rotate(limit=30, p=1),
    "SafeRotate": A.SafeRotate(limit=30, p=1),
    "Affine": A.Affine(p=1),
    "ElasticTransform": A.ElasticTransform(p=1),
    "GridDistortion": A.GridDistortion(p=1),
    "OpticalDistortion": A.OpticalDistortion(p=1),
    "Perspective": A.Perspective(p=1),
    "PiecewiseAffine": A.PiecewiseAffine(p=1),
    "ShiftScaleRotate": A.ShiftScaleRotate(p=1),
    "AdvancedBlur": A.AdvancedBlur(p=1),
    "Blur": A.Blur(p=1),
    "MedianBlur": A.MedianBlur(p=1),
    "MotionBlur": A.MotionBlur(p=1),
    "GaussianBlur": A.GaussianBlur(p=1),
    "GlassBlur": A.GlassBlur(p=1),
    "Superpixels": A.Superpixels(p=1),
    "ZoomBlur": A.ZoomBlur(p=1),
    "Defocus": A.Defocus(p=1),
    "ChannelDropout": A.ChannelDropout(p=1),
    "CoarseDropout": A.CoarseDropout(p=1),
    "GridDropout": A.GridDropout(p=1),
}

In [45]:
class CustomAug:
    def __init__(self, augments: list[Union[A.ImageOnlyTransform, A.DualTransform]]):
        self.augments = augments
        self.set_transform()
            
    def set_transform(self):
        _transform = []
        for augment in self.augments:
            _transform.append(augment)
        self.transform = A.Compose(_transform)
        
    def __call__(self, image: Image):
        image = np.array(image)
        image = self.transform(image=image)['image']
        return image

In [None]:
length = len(augmentation)
rows, cols = round(length/5), 5
fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(20, 30))

for i, (name, augment) in enumerate(augmentation.items()):
    r, c = i//cols, i%cols
    axes[r][c].imshow(CustomAug([augment])(test_img))
    axes[r][c].set_title(name)
    axes[r][c].axis('off')
