# Implementation of Data Augmentation Techniques

In [211]:
import os
import cv2
import time
import torch
import psutil
import shutil
import numpy as np
import torchvision.transforms as T
import torchvision.transforms.functional as F
from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr
from PIL import Image, ImageDraw, ImageFont


## Implementation of Photometric Distortions

### OpenCV

In [212]:
def adjust_brightness(image, factor):
    image = image.astype(np.float32)
    image = image * factor
    return np.clip(image, 0, 255).astype(np.uint8)

def adjust_contrast(image, factor):
    return cv2.convertScaleAbs(image, alpha=factor, beta=1.0)

def shift_color_channels(image, r_shift, g_shift, b_shift):
    shifted_image = image.copy().astype(np.float32)
    shifted_image[:, :, 0] = np.clip(shifted_image[:, :, 0] + b_shift, 0, 255)
    shifted_image[:, :, 1] = np.clip(shifted_image[:, :, 1] + g_shift, 0, 255)
    shifted_image[:, :, 2] = np.clip(shifted_image[:, :, 2] + r_shift, 0, 255)
    return shifted_image.astype(np.uint8)

def adjust_saturation(image, factor):
    hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    (h, s, v) = cv2.split(hsv_image)
    s = s * factor
    s = np.clip(s, 0, 255).astype(np.uint8)
    print(s)
    hsv_image = cv2.merge([h, s, v])
    return cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)

def add_gaussian_noise(image, mean=0, std=25):
    noise = np.random.normal(mean, std, image.shape)
    return np.clip(image + noise, 0, 255).astype(np.uint8)


### Torch

In [213]:
def torch_adjust_brightness(image, factor):
    image = torch.from_numpy(image).permute(2, 1, 0).float() / 255
    image = F.adjust_brightness(image, brightness_factor=factor)
    image = (image.permute(2, 1, 0) * 255).byte().numpy()
    return image

def torch_adjust_contrast(image, factor):
    image = torch.from_numpy(image).permute(2, 1, 0).float() / 255
    image = F.adjust_contrast(image, contrast_factor=factor)
    image = (image.permute(2, 1, 0) * 255).byte().numpy()
    return image

def torch_shift_color_channels(image, r_shift, g_shift, b_shift):
    image = torch.from_numpy(image).permute(2, 1, 0).float() 
    r, g, b = image[0], image[1], image[2]
    r = torch.clamp(r + b_shift, 0, 255)
    g = torch.clamp(g + g_shift, 0, 255)
    b = torch.clamp(b + r_shift, 0, 255)
    shifted_image = torch.stack([r, g, b], dim=0).permute(2, 1, 0).numpy().astype(np.uint8)
    return shifted_image

def torch_adjust_saturation(image, factor):
    image = torch.from_numpy(image).permute(2, 1, 0).float() / 255
    image = F.adjust_saturation(image, saturation_factor=factor)
    image = (image.permute(2, 1, 0) * 255).byte().numpy()
    return image

def torch_add_gaussian_noise(image, mean=0, std=25):
    noise = torch.randn_like(torch.from_numpy(image).float()) * std + mean
    image = torch.from_numpy(image).float() + noise
    image = torch.clamp(image, 0, 255)
    return image.numpy()

## Implementation of Geometric Distortions

### OpenCV

In [214]:
def rotate_image(image, angle):
    rows, cols = image.shape[:2]
    M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)
    return cv2.warpAffine(image, M, (cols, rows)).astype(np.uint8)

def scale_image(image, scale_factor):
    return cv2.resize(image, None, fx=scale_factor, fy=scale_factor).astype(np.uint8)

def translate_image(image, tx, ty):
    M = np.float32([[1, 0, tx], [0, 1, ty]])
    return cv2.warpAffine(image, M, (image.shape[1], image.shape[0])).astype(np.uint8)

def shear_image(image, shear_factor):
    M = np.float32([[1, shear_factor, 0], [0, 1, 0]])
    return cv2.warpAffine(image, M, (image.shape[1], image.shape[0])).astype(np.uint8)

def flip_image(image, flip_code):
    return cv2.flip(image, flip_code).astype(np.uint8)

### Torch

In [215]:
def torch_rotate_image(image, angle):
    image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
    image = F.rotate(image, angle)
    image = (image.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    return image

def torch_scale_image(image, scale_factor):
    height, width = image.shape[:2]
    new_size = (int(height * scale_factor), int(width * scale_factor))
    transform = T.Resize(new_size)
    image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
    image = transform(image)
    image = (image.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    return image

def torch_translate_image(image, tx, ty):
    image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
    image = F.affine(image, angle=0, translate=(tx, ty), scale=1, shear=0)
    image = (image.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    return image

def torch_shear_image(image, shear_factor):
    image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
    image = F.affine(image, angle=0, translate=(0, 0), scale=1, shear=[-np.degrees(np.arctan(shear_factor)), 0])
    image = (image.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    return image

def torch_flip_image(image, flip_code):
    image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
    if flip_code == 1:
        image = F.hflip(image)
    elif flip_code == 0:
        image = F.vflip(image)
    elif flip_code == -1:
        image = F.hflip(F.vflip(image))
    else:
        raise ValueError("flip_code must be 0 (vertical) or 1 (horizontal)")
    image = (image.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    return image

In [216]:
def load_images(path: str):
    images = []
    for root, dirs, files in os.walk(path):
        for file in files:
            if file.endswith(".jpeg"):
                image = cv2.imread(os.path.join(root, file), cv2.IMREAD_COLOR)
                images.append({"image": image, "name": file.split(".")[0]})
    return images

def convert_PIL_to_cv2(pil_image):
    open_cv_image = (np.array(pil_image) * 255).astype(np.uint8)
    open_cv_image = open_cv_image[:, :, ::-1].copy()
    return open_cv_image

def generate_random_value(value1, value2, excluded_range):
    while True:
        random_value = np.random.uniform(value1, value2)
        if not (excluded_range[0] <= random_value <= excluded_range[1]):
            return random_value
        
def draw_text(draw, text, position, font, color=(0, 0, 0)):
    draw.text(position, text, font=font, fill=color)

def print_augmentation_metrics(augmented_image_info_list):
    # Print header
    print(  f"{'Name':<20} {'Type':<15} {'CV Memory (MB)':<15} {'CV Time (s)':<15}"
            f"{'Torch Memory (MB)':<18} {'Torch Time (s)':<15}"
            f"{'SSIM CV-Torch':<15} {'PSNR CV-Torch':<15}"
        )

    for augmented_image_info in augmented_image_info_list:
        name = augmented_image_info["name"]
        type = augmented_image_info["type"]
        cv_print_var = torch_print_var = " "
        if augmented_image_info["cv_memory_usage"] is not None:
            # CV metrics
            cv_memory = augmented_image_info["cv_memory_usage"] / 1024 / 1024 if augmented_image_info["cv_memory_usage"] is not None else "-"
            cv_time = augmented_image_info["cv_elapsed_time"] if augmented_image_info["cv_elapsed_time"] is not None else "-"
            cv_print_var = f"{cv_memory:<18.2f} {cv_time:<15.4f}"

        if augmented_image_info["torch_memory_usage"] is not None:
            # Torch metrics
            torch_memory = augmented_image_info["torch_memory_usage"] / 1024 / 1024 if augmented_image_info["torch_memory_usage"] is not None else "-"
            torch_time = augmented_image_info["torch_elapsed_time"] if augmented_image_info["torch_elapsed_time"] is not None else "-"
            torch_print_var = f"{torch_memory:<18.2f} {torch_time:<15.4f}"
            
        cv_image_gray = cv2.cvtColor(augmented_image_info["cv_image"], cv2.COLOR_BGR2GRAY)

        torch_image = np.array(augmented_image_info["torch_image"])
        if torch_image.shape[0] == 3:
            torch_image_bgr = np.transpose(torch_image, (1, 2, 0))
        else:
            torch_image_bgr = torch_image
        torch_image_gray  = cv2.cvtColor(torch_image_bgr.astype(np.uint8), cv2.COLOR_BGR2GRAY)
        
        if cv_image_gray.shape != torch_image_gray.shape:
            torch_image_gray = cv2.resize(torch_image_gray, (cv_image_gray.shape[1], cv_image_gray.shape[0]))
        image_ssim = ssim(cv_image_gray, torch_image_gray, data_range=1.0)
        image_psnr = psnr(cv_image_gray, torch_image_gray, data_range=1.0)

        print(f"{name:<20} {type:<15} {cv_print_var}"
              f"{torch_print_var} "
              f"{image_ssim:<15.4f} {image_psnr:<15.4f}")

def print_augmentation_metrics_on_image(augmented_image_info_list, output_image_path):
    # Define font and image size
    font_path = "arial.ttf"  # Path to a TTF font file
    font_size = 14
    font = ImageFont.truetype(font_path, font_size)
    image_width = 1200
    image_height = 30 * (len(augmented_image_info_list) + 1)
    header_height = 30
    
    # Create an image with white background
    image = Image.new('RGB', (image_width, image_height), 'white')
    draw = ImageDraw.Draw(image)
    
    # Print header
    headers = [
        "Name", "Type", "CV Memory (MB)", "CV Time (s)", 
        "Torch Memory (MB)", "Torch Time (s)", 
        "SSIM CV-Torch", "PSNR CV-Torch"
    ]
    x_positions = [10, 150, 300, 450, 600, 750, 900, 1050]
    
    for i, header in enumerate(headers):
        draw_text(draw, header, (x_positions[i], 10), font)
    
    # Print data rows
    for idx, augmented_image_info in enumerate(augmented_image_info_list):
        y_position = header_height + 10 + idx * 30
        name = augmented_image_info["name"]
        type = augmented_image_info["type"] + " " + augmented_image_info["parameters"]
        
        if augmented_image_info["cv_memory_usage"] is not None:
            # CV metrics
            cv_memory = augmented_image_info["cv_memory_usage"] / 1024 / 1024
            cv_time = augmented_image_info["cv_elapsed_time"]
            cv_memory_text = f"{cv_memory:.2f}"
            cv_time_text = f"{cv_time:.4f}"
        else:
            cv_memory_text = cv_time_text = "-"

        if augmented_image_info["torch_memory_usage"] is not None:
            # Torch metrics
            torch_memory = augmented_image_info["torch_memory_usage"] / 1024 / 1024
            torch_time = augmented_image_info["torch_elapsed_time"]
            torch_memory_text = f"{torch_memory:.2f}"
            torch_time_text = f"{torch_time:.4f}"
        else:
            torch_memory_text = torch_time_text = "-"
        
        cv_image_gray = cv2.cvtColor(augmented_image_info["cv_image"], cv2.COLOR_BGR2GRAY)
        torch_image = np.array(augmented_image_info["torch_image"])
        if torch_image.shape[0] == 3:
            torch_image_bgr = np.transpose(torch_image, (1, 2, 0))
        else:
            torch_image_bgr = torch_image
        torch_image_gray = cv2.cvtColor(torch_image_bgr.astype(np.uint8), cv2.COLOR_BGR2GRAY)
        
        if cv_image_gray.shape != torch_image_gray.shape:
            torch_image_gray = cv2.resize(torch_image_gray, (cv_image_gray.shape[1], cv_image_gray.shape[0]))
        image_ssim = ssim(cv_image_gray, torch_image_gray, data_range=1.0)
        image_psnr = psnr(cv_image_gray, torch_image_gray, data_range=1.0)
        
        draw_text(draw, name, (x_positions[0], y_position), font)
        draw_text(draw, type, (x_positions[1], y_position), font)
        draw_text(draw, cv_memory_text, (x_positions[2], y_position), font)
        draw_text(draw, cv_time_text, (x_positions[3], y_position), font)
        draw_text(draw, torch_memory_text, (x_positions[4], y_position), font)
        draw_text(draw, torch_time_text, (x_positions[5], y_position), font)
        draw_text(draw, f"{image_ssim:.4f}", (x_positions[6], y_position), font)
        draw_text(draw, f"{image_psnr:.4f}", (x_positions[7], y_position), font)
    
    # Save the image
    image.save(output_image_path)

In [217]:
from time import sleep


print("Loading images")
images = load_images("images")
print("Images loaded!")

def generate_augmented_images(image: dict[str, np.ndarray], num_augmentations: int = 5, type: str = "all"):
    image_data: np.ndarray = image["image"]
    image_name: str = image["name"]
    already_augmentations: list = []
    augmented_images: list = []
    augmentation_functions = {
        "cv": {
            "brightness": adjust_brightness,
            "contrast": adjust_contrast,
            "color_shift": shift_color_channels,
            "saturation": adjust_saturation,
            "gaussian_noise": add_gaussian_noise,
            "rotate": rotate_image,
            "scale": scale_image,
            "translate": translate_image,
            "shear": shear_image,
            "flip": flip_image
        },
        "torch": {
            "brightness": torch_adjust_brightness,
            "contrast": torch_adjust_contrast,
            "color_shift": torch_shift_color_channels,
            "saturation": torch_adjust_saturation,
            "gaussian_noise": torch_add_gaussian_noise,
            "rotate": torch_rotate_image,
            "scale": torch_scale_image,
            "translate": torch_translate_image,
            "shear": torch_shear_image,
            "flip": torch_flip_image
        }
    }

    sleep(1)

    for _ in range(num_augmentations):
        # print(f"Augmenting image {image_name} with type {type}")
        choices: list[str]
        if type == "photometric":
            choices = ["brightness", "contrast", "color_shift", "saturation"]
            if "gaussian_noise" not in already_augmentations:
                choices.append("gaussian_noise")
        elif type == "geometric":
            choices = ["rotate", "scale", "translate", "shear", "flip"]
        else:
            choices = ["brightness", "contrast", "color_shift", "saturation", "rotate", "scale", "translate", "shear", "flip"]
            if "gaussian_noise" not in already_augmentations:
                choices.append("gaussian_noise")
        _tmp_choices = [choice for choice in choices if choice not in already_augmentations]
        if len(_tmp_choices) > 0:
            choices = _tmp_choices
        choice = np.random.choice(choices)
        already_augmentations.append(choice)
        if choice == "brightness":
            brightness_factor = generate_random_value(0.1, 2, (0.8, 1.2))
        elif choice == "contrast":
            contrast_factor = generate_random_value(0.1, 2, (0.8, 1.2))
        elif choice == "color_shift":
            r_shift, g_shift, b_shift = np.random.randint(-50, 50, size=3)
        elif choice == "saturation":
            saturation_factor = generate_random_value(0.1, 2, (0.8, 1.2))
        elif choice == "rotate":
            angle = generate_random_value(-180, 180, (-10, 10))
        elif choice == "scale":
            scale_factor = generate_random_value(0.1, 2, (0.8, 1.2))
        elif choice == "translate":
            tx, ty = np.random.randint(-100, 100, size=2)
        elif choice == "shear":
            shear_factor = generate_random_value(-0.5, 0.5, (-0.1, 0.1))
        elif choice == "flip":
            flip_code = np.random.choice([-1, 0, 1])
        
        augmented_image_info = {
            "image": image_data,
            "cv_image": None,
            "cv_memory_usage": None,
            "cv_elapsed_time": None,
            "torch_image": None,
            "torch_memory_usage": None,
            "torch_elapsed_time": None,
            "name": image_name,
            "type": choice,
            "parameters": f"{brightness_factor:.2f}" if choice == "brightness" else
                        f"{contrast_factor:.2f}" if choice == "contrast" else 
                        f"{saturation_factor:.2f}" if choice == "saturation" else 
                        f"{r_shift}, {g_shift}, {b_shift}" if choice == "color_shift" else 
                        f"{angle:.2f}°" if choice == "rotate" else 
                        f"{scale_factor:.2f}" if choice == "scale" else 
                        f"{tx}px, {ty}px" if choice == "translate" else 
                        f"{shear_factor:.2f}" if choice == "shear" else 
                        f"{flip_code}" if choice == "flip" else ""
        }

        for aug_type in augmentation_functions.keys():
            func = augmentation_functions[aug_type][choice]
            _image_data = image_data
            
            start_time = time.time()
            process = psutil.Process()
            initial_memory = process.memory_info().rss

            if choice == "brightness":
                augmented_image = func(_image_data, brightness_factor)
            elif choice == "contrast":
                augmented_image = func(_image_data, contrast_factor)
            elif choice == "color_shift":
                augmented_image = func(_image_data, r_shift, g_shift, b_shift)
            elif choice == "saturation":
                augmented_image = func(_image_data, saturation_factor)
            elif choice == "gaussian_noise":
                augmented_image = func(_image_data)
            elif choice == "rotate":
                augmented_image = func(_image_data, angle)
            elif choice == "scale":
                augmented_image = func(_image_data, scale_factor)
            elif choice == "translate":
                augmented_image = func(_image_data, tx, ty)
            elif choice == "shear":
                augmented_image = func(_image_data, shear_factor)
            elif choice == "flip":
                augmented_image = func(_image_data, flip_code)

            end_time = time.time()
            final_memory = process.memory_info().rss
            elapsed_time = end_time - start_time
            memory_usage = final_memory - initial_memory

            augmented_image_info[f"{aug_type}_image"] = augmented_image
            augmented_image_info[f"{aug_type}_memory_usage"] = memory_usage
            augmented_image_info[f"{aug_type}_elapsed_time"] = elapsed_time

        augmented_images.append(augmented_image_info)

    return augmented_images

output_geometric_path = "./augmented_geometric_images/"
if os.path.exists(output_geometric_path):
    shutil.rmtree(output_geometric_path)
os.makedirs(output_geometric_path, exist_ok=True)
output_photometric_path = "./augmented_photometric_images/"
if os.path.exists(output_photometric_path):
    shutil.rmtree(output_photometric_path)
os.makedirs(output_photometric_path, exist_ok=True)

num_augmentations = 5

for i, input_image in enumerate(images):
    augmented_photometric_images = generate_augmented_images(input_image, type="photometric", num_augmentations=num_augmentations)
    augmented_geometric_images = generate_augmented_images(input_image, type="geometric", num_augmentations=num_augmentations)
    print_augmentation_metrics_on_image(augmented_photometric_images + augmented_geometric_images, f"metrics/augmented_metrics_{input_image['name']}.jpeg")
    print_augmentation_metrics(augmented_photometric_images + augmented_geometric_images)
    for j, augmented_image in enumerate(augmented_photometric_images):
        name = augmented_image["name"]
        type = augmented_image["type"]
        for aug_type in ["cv", "torch"]:
            image = augmented_image[f"{aug_type}_image"]
            filename = f"augmented_{name}_{j}_{type}_{aug_type}.jpeg"
            print(f"Saving {filename}")
            print(image)
            cv2.imwrite(os.path.join(output_photometric_path, filename), image)

    for j, augmented_image in enumerate(augmented_geometric_images):
        name = augmented_image["name"]
        type = augmented_image["type"]
        for aug_type in ["cv", "torch"]:
            filename = f"augmented_{name}_{j}_{type}_{aug_type}.jpeg"
            image = augmented_image[f"{aug_type}_image"]
            print(f"Saving {filename}")
            print(image)
            cv2.imwrite(os.path.join(output_geometric_path, filename), image)

print(f"Saved {len(images) * num_augmentations * 2} augmented images to {output_geometric_path} and {output_photometric_path}!")

Loading images
Images loaded!
[[255 255 255 ... 255 255 255]
 [255 255 255 ... 255 255 255]
 [255 255 255 ... 255 255 255]
 ...
 [217 223 227 ... 253 255 255]
 [233 228 223 ... 255 255 255]
 [246 236 223 ... 255 255 255]]


  return 10 * np.log10((data_range**2) / err)


Name                 Type            CV Memory (MB)  CV Time (s)    Torch Memory (MB)  Torch Time (s) SSIM CV-Torch   PSNR CV-Torch  
3_colour             brightness      2.64               0.0116         12.14              0.0124          1.0000          inf            
3_colour             saturation      2.64               0.0071         0.97               0.0115          0.9102          -19.9397       
3_colour             gaussian_noise  2.64               0.0735         2.81               0.0205          0.1024          -27.1211       
3_colour             color_shift     2.64               0.0182         4.09               0.0076          1.0000          inf            
3_colour             contrast        2.64               0.0010         0.00               0.0099          0.6829          -29.4200       
3_colour             flip            2.64               0.0020         -34.95             0.0122          1.0000          inf            
3_colour             rotate          2

Memory: It's the difference between memory usage before and after a calculation, revealing how much memory was consumed.

Time: It's the difference between start and end times of a calculation, indicating how long the operation took to complete.

SSIM: A value close to 1 indicates that the two images are very similar in terms of their structural content. Lower values indicate differences in structure, luminance, or contrast.

PSNR: A higher PSNR value (in dB) indicates that the reconstructed image closely matches the original in terms of pixel values. Higher PSNR values generally indicate higher image fidelity.