# PyTorch Data Augmentation 

# <u>Authors:</u>
## 1. Matthias Bartolo ID: 0436103L
## 2. Luke Cardona ID: 0011803H
## 3. Jerome Agius ID: 0353803L
## 4. Isaac Muscat ID: 0265203L

## <u>Installed Packages</u>

In [1]:
#!pip install torchvision 

## <u>Packages</u>

In [2]:
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torch
import torch.nn as nn
from torchvision.utils import make_grid
from torchvision.utils import save_image
from IPython.display import Image
import matplotlib.pyplot as plt
import numpy as np
import random
%matplotlib inline
from PIL import Image
import os 

## <u>Loading required images</u>

In [3]:
#Specifying the image directory
directory = "./images/original_images/"

#Function to load the .jpeg images in the specified directory
def loadImages(directory):
    #List to store the original images
    original_imgs = []
    
    #Retrieving the list of all file names
    files = os.listdir(directory);
    for file in files: 
        if(file[len(file)-4:len(file)] == "jpeg"):
            
            #Adding the image to the list of original images
            original_imgs.append(Image.open(directory+file))
    
    #Returning the list of original images
    return original_imgs

#Retrieving the list of original images from the specified directory
original_imgs = loadImages(directory)

## <u>Data Augmentation Functions</u> 

**RandomRotation** - This function rotates the passed image randomly within a set range. 

In [4]:
def RandomRotation(img, RotationValue):
    rotation_img = T.functional.rotate(img, RotationValue) 
    return rotation_img, "Rotation Image"

**Brightness** - This function adjusts the brightness of the passed image within a set range.

In [5]:
def Brightness(img, BrightnessValue):
    brightness_img = T.functional.adjust_brightness(img, BrightnessValue)
    return brightness_img, "Brightness Image"

**Contrast** - This function adjusts the contrast of the passed image in between a certain range.

In [6]:
def Contrast(img, ContrastValue):
    contrast_img = T.functional.adjust_contrast(img, ContrastValue)
    return contrast_img, "Contrast Image"

**Saturation** - This function adjusts the saturation of the passed image between a certain range.

In [7]:
def Saturation(img, SaturationValue):
    saturation_img = T.functional.adjust_saturation(img, SaturationValue)
    return saturation_img, "Saturation Image"

**Hue** - This function adjusts the hue of the passed image between a certain range.

In [8]:
def Hue(img, HueValue):
    hue_img = T.functional.adjust_hue(img, HueValue)
    return hue_img, "Hue Image"

**CenterCrop** - This function crops the center of the passed image and returns the cropped out section.

In [9]:
def CenterCrop(img):
    centre_crop = T.CenterCrop([img.size[1]/2,img.size[0]/2])
    centre_crop_img = centre_crop(img)
    return centre_crop_img, "Centre Crop Image"

**HorizontalFlip** - This function flips the passed image horizontally in relation to its center point.

In [10]:
def HorizontalFlip(img):
    horizontal_flip_img = T.functional.hflip(img)
    return horizontal_flip_img, "Horizontal Flip Image"

**VerticalFlip** - This function flips the passed image vertically in relation to its center point.

In [11]:
def VerticalFlip(img):
    vertical_flip_img = T.functional.vflip(img)
    return vertical_flip_img, "Vertical Flip Image"

**Shear** - This function shears the passed image wihtin a set range.

In [12]:
def Shear(img, ShearValue):
    shear_img = T.functional.affine(img, 0, [0,0], 1, [ShearValue,ShearValue])
    return shear_img, "Shear Image"

**Gamma** - This function alters the gamma values of the passed image within a set range.

In [13]:
def Gamma(img, GammaValue):
    gamma_img = T.functional.adjust_gamma(img, GammaValue, 1)
    return gamma_img, "Gamma Image" 

**GaussianBlur** - This function applies the GaussianBlur kernal onto the passed image

In [14]:
def GaussianBlur(img, GBlurValue):
    gaussian_blur_img = T.functional.gaussian_blur(img, GBlurValue)
    return gaussian_blur_img, "Gaussian Blur Image"

**Translation** - This function translates the passed image along both axis within a set range. 

In [15]:
def Translation(img, TranslationValue):
    translation_img = T.functional.affine(img, 0, [TranslationValue,TranslationValue], 1, 0, 0, 0)
    return translation_img, "Translation Image"

**AugmentImg** - This function is used to apply all the image augmentation functions discussed above on the passed image and return a list of augemented images. 

In [16]:
def AugmentImg(img):
    RotationValue = 30
    BrightnessValue = 0.6
    ContrastValue = 1
    SaturationValue = 1
    ShearValue = 30
    HueValue = -0.5
    GammaValue = 1
    GBlurValue = 31
    
    Augmented_Img = {"Rotations":[], "Brightness":[],"Contrast":[],"Saturation":[],
                     "Hue":[],"Flipped":[],"Shearing":[],
                    "CenterCrop":[],"Gamma":[],"GaussianBlur":[],"Translation": []}
    
    #Executing all augmentation function fives times per image and storing the results
    for x in range(5):
        TranslationValue = int((((x)-2.5)+1)*100)
        
        #Photometric Augmentation
        Augmented_Img["Brightness"].append(Brightness(img, BrightnessValue))
        Augmented_Img["Contrast"].append(Contrast(img, ContrastValue))
        Augmented_Img["Saturation"].append(Saturation(img, SaturationValue))
        Augmented_Img["Hue"].append(Hue(img, HueValue))        
        Augmented_Img["Gamma"].append(Gamma(img, GammaValue))
        Augmented_Img["GaussianBlur"].append(GaussianBlur(img, GBlurValue))

        #Geometric Augmentation
        Augmented_Img["Rotations"].append(RandomRotation(img, RotationValue))
        Augmented_Img["CenterCrop"].append(CenterCrop(img))
        Augmented_Img["Flipped"].append(HorizontalFlip(img))
        Augmented_Img["Flipped"].append(VerticalFlip(img))
        Augmented_Img["Shearing"].append(Shear(img, ShearValue))
        Augmented_Img["Translation"].append(Translation(img, TranslationValue))
        
        #Updating the parameter values
        RotationValue += 30
        BrightnessValue += 0.6
        ContrastValue += 1
        SaturationValue += 1
        HueValue += 0.2
        ShearValue += 30
        GammaValue += 1
        GBlurValue += 30
        
        
    return Augmented_Img

In [17]:
AugmentedImgs = []

for img in original_imgs: 
    AugmentedImgs.append(AugmentImg(img))

# <u>Output Display</u>

This method displays all the altered images in an ordarly manner for presentation purposes.

In [18]:
rows, cols = 11,5

#Looping through all the augmented images
for imgSet in AugmentedImgs:
    fig, ax = plt.subplots(11, 5, figsize=(10,20))
    row = 0
    for key in imgSet:
        for cnt in range(5):
            col = cnt % 5
            
            #Setting the image
            ax[row][col].imshow(imgSet[key][cnt][0])
            
            #Setting the title
            ax[row][col].set_title(imgSet[key][cnt][1]+" "+str(cnt+1)) 
            
            #Removing the axes
            ax[row][col].axis('off') 
            fig.tight_layout()
        row += 1 

**SaveImages** - This function saves the passed list of images to the specified directory

In [19]:
def SaveImages(listOfImages, directory):
    if not os.path.exists(directory):
       # Create a new directory if it does not exist
       os.makedirs(directory)
    
    count = 1
    for imgSet in listOfImages:
        
        for key in imgSet:
            if(key == "Brightness" or key == "Contrast" or key == "Saturation" or key == "Hue" or key == "Gamma" or key == "GaussianBlur"):                
                subDir = "PhotometricAugmentation"
            else:
                subDir = "GeometricAugmentation"
            
            if not os.path.exists(directory+"/"+subDir+"/"+key+"/"):
               # Create a new directory if it does not exist
               os.makedirs(directory+"/"+subDir+"/"+key+"/")

            for x in range(5):
                #Saving the images
                imgSet[key][x][0].save(directory+"/"+subDir+"/"+key+"/Item_"+str(count)+"_"+imgSet[key][x][1]+"_"+str(x+1)+".jpeg", 'JPEG')
        count += 1

In [20]:
SaveImages(AugmentedImgs, "./images/AugmentedImagesPyTorch/")