# **Data Augmentation**

In [21]:
import os
import cv2
import numpy as np
import albumentations as A
import random
import shutil
import matplotlib as plt

In [22]:
original_dir = "dataset/original"
train_dir = "dataset/train"
test_dir = "dataset/test"

os.makedirs("dataset", exist_ok=True)
os.makedirs(original_dir, exist_ok=True)
os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)


original_files = [f for f in os.listdir(original_dir) if f.lower().endswith(".png")]
random.shuffle(original_files)

split_idx = int(0.8 * len(original_files))
train_files = original_files[:split_idx]
test_files = original_files[split_idx:]

#Divide the original images in training and test set
for f in train_files:
    shutil.copy(os.path.join(original_dir, f), os.path.join(train_dir, f))
for f in test_files:
    shutil.copy(os.path.join(original_dir, f), os.path.join(test_dir, f))

In [24]:

train_dir = "dataset/train"  # or "heatmaps/dataset/train"
valid_exts = ('.jpg', '.jpeg', '.png')
image_files = [f for f in os.listdir(train_dir) if f.lower().endswith(valid_exts)]

#Transformations
transformations = {
    #Basic
    "rotated": A.Rotate(limit=60, p=1.0),
    "shift_scale_rotate": A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=20, p=1.0),
    "flipped": A.HorizontalFlip(p=1.0),
    "brightness_contrast": A.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, p=1.0),
    "hue_shift": A.HueSaturationValue(hue_shift_limit=40, sat_shift_limit=50, val_shift_limit=40, p=1.0),
    "gamma": A.RandomGamma(gamma_limit=(60, 140), p=1.0),
    "rgb_shift": A.RGBShift(r_shift_limit=40, g_shift_limit=40, b_shift_limit=40, p=1.0),
    "clahe": A.CLAHE(clip_limit=8.0, tile_grid_size=(8,8), p=1.0),
    
    #Combination
    "flip_contrast": A.Compose([
        A.HorizontalFlip(p=1.0),
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1.0)
    ]),
    "rotate_hue": A.Compose([
        A.Rotate(limit=25, p=1.0),
        A.HueSaturationValue(hue_shift_limit=15, sat_shift_limit=25, val_shift_limit=15, p=1.0)
    ]),
    "shift_gamma_rgb": A.Compose([
        A.ShiftScaleRotate(shift_limit=0.03, scale_limit=0.05, rotate_limit=10, p=1.0),
        A.RandomGamma(gamma_limit=(80, 120), p=1.0),
        A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=1.0)
    ])
}


for filename in image_files:
    #print(filename)
    img_path = os.path.join(train_dir, filename)
    image = cv2.imread(img_path)
    if image is None:
        print(f"Skipping unreadable file: {filename}")
        continue

    name, ext = os.path.splitext(filename)

    for aug_name, aug in transformations.items():
        augmented = aug(image=image)["image"]
        output_name = f"{name}_{aug_name}{ext}"
        output_path = os.path.join(train_dir, output_name)
        cv2.imwrite(output_path, augmented)