In [None]:
import sys
sys.path.insert(1,"..")
from datasets.augmentation.randaugment import RandAugment
import matplotlib.pyplot as plt
import random
from PIL import Image
import os
from torchvision import transforms
import numpy as np

## 1 - Control flags and  image paths

In [None]:
img_path="../data/EuroSAT_RGB/Highway/Highway_61.jpg" #img path
augmented_dir=".."
augmentation_type = "strong" #Supported: strong, weak
plot_info = False #If True, numerical grid and image name are saved in the final picture

For strong augmentations, it is possible to use a custom augmentation policy selecting only among the possible augmentations.
To this aim, you can use `strong_augmentation_wanted` list. To use all the possible augmentations, select `strong_augmentation_wanted=["all"]`; otherwise, fill the list through the wanted augmentations. Supported augmentations are:

<ul>
    <li>"AutoContrast"</li>
    <li>"Brightness"</li>
    <li>"Color"</li>
    <li>"Contrast"</li>
    <li>"Equalize"</li>
    <li>"Identity</li>
    <li>"Posterize"</li>
    <li>"Rotate"</li>
    <li>"Sharpness"</li>
    <li>"ShearX"</li>
    <li>"ShearY"</li>
    <li>"Solarize"</li>
    <li>"TranslateX"</li>
    <li>"TranslateY"</li>
</ul>
If you are using weak augmentation, next line can be ignored.

In [None]:
strong_augmentation_wanted=["AutoContrast", "Solarize", "Equalize"] 

## 2- Opening and plotting the original image

In [None]:
img = Image.open(img_path)

In [None]:
image_name = img_path[img_path.rfind('/')+1:img_path.rfind('.')] #finding image name
plt.title("Original image: "+ str(image_name), fontsize=14, fontweight='bold')
plt.imshow(img)

In [None]:
transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(),
            transforms.RandomAffine(0,translate=(0,0.125)),
        ]
    )
transform_to_img = transforms.ToPILImage()
img = (transform_to_img(transform(img)))
    
    
if augmentation_type == "strong": 
    randaugment = RandAugment(3, 5, False)
    #it might be useful to extract specific augmentation transforms from the list. 
    randaugment_dict = {"AutoContrast" : 0, "Brightness" : 1, "Color" : 2, "Contrast" : 3,"Equalize" : 4,"Identity" : 5,"Posterize" : 6,"Rotate" : 7,"Sharpness" : 8,"ShearX" : 9,"ShearY" : 10,"Solarize" : 11,"TranslateX" : 12,"TranslateY" : 13}

    if strong_augmentation_wanted[0] == "all":
        augment_list = randaugment.augment_list
    else:
        augment_list = []
        for augmentation in strong_augmentation_wanted:
            augment_list.append(randaugment.augment_list[randaugment_dict[augmentation]])
    
    for op, min_val, max_val in augment_list:
        val = min_val + float(max_val - min_val) * random.random()
        print(op)
        img = op(img, val)
    

 ## 3- Plotting and saving the augmented image

In [None]:
plt.close()
fig = plt.imshow(img)
if plot_info:
    plt.title("Augmented image: "+ str(image_name), fontsize=14, fontweight='bold')
else:
    plt.axis('off')
augmented_dir = os.path.join(augmented_dir,image_name+"_"+augmentation_type+".png")
print("Saving augmented image at: ",augmented_dir)
plt.savefig(augmented_dir)