# Data Augmentation

**Imports and Setup**

In [2]:
# SECTION 1: Imports and Setup

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import cv2
import matplotlib.pyplot as plt
from os import listdir
import os
import time

%matplotlib inline


In [3]:
# SECTION 2: Time Formatter Utility

def hms_string(sec_elapsed):
    h = int(sec_elapsed / (60 * 60))
    m = int((sec_elapsed % (60 * 60)) / 60)
    s = sec_elapsed % 60
    return f"{h}:{m}:{round(s,1)}"


In [4]:
# SECTION 3: Data Augmentation Function

def augment_data(file_dir, n_generated_samples, save_to_dir):
    data_gen = ImageDataGenerator(
        rotation_range=10, 
        width_shift_range=0.1, 
        height_shift_range=0.1, 
        shear_range=0.1, 
        brightness_range=(0.3, 1.0),
        horizontal_flip=True, 
        vertical_flip=True, 
        fill_mode='nearest'
    )
    
    for filename in listdir(file_dir):
        image_path = os.path.join(file_dir, filename)
        image = cv2.imread(image_path)
        
        if image is None:
            print(f"Warning: Unable to read {filename}")
            continue
        
        image = image.reshape((1,)+image.shape)
        save_prefix = 'aug_' + filename[:-4]
        
        i = 0
        for batch in data_gen.flow(
            x=image, batch_size=1, save_to_dir=save_to_dir, 
            save_prefix=save_prefix, save_format='jpg'):
            i += 1
            if i > n_generated_samples:
                break


In [5]:
# SECTION 4: Dataset Summary Function

def data_summary(main_path):
    classes = os.listdir(main_path)
    total_images = 0

    print(f"Classes found: {len(classes)}")
    
    for class_name in classes:
        class_path = os.path.join(main_path, class_name)
        if os.path.isdir(class_path):
            num_images = len(listdir(class_path))
            total_images += num_images
            print(f"Class: {class_name}, Number of examples: {num_images}")
    
    print(f"Total number of images: {total_images}")
    
    for class_name in classes:
        class_path = os.path.join(main_path, class_name)
        if os.path.isdir(class_path):
            num_images = len(listdir(class_path))
            percentage = (num_images * 100.0) / total_images
            print(f"Percentage of {class_name} examples: {percentage:.2f}%")


In [9]:
# SECTION 5: Paths and Class Augmentation Mapping

original_dataset_path = "./brain_tumor_dataset_multiclass"
augmented_dataset_path = "./Augmented_Dataset_Multiclass"

# Ensure output folder exists
os.makedirs(augmented_dataset_path, exist_ok=True)

# Augmentation factors based on class imbalance
classes = {
    "glioma": 0,
    "meningioma": 0,
    "pituitary": 0,
    "no_tumor": 0
}


In [10]:
# SECTION 6: Dataset Augmentation Process

start_time = time.time()

for class_name, aug_factor in classes.items():
    print(f"\nProcessing {class_name} class...")
    
    class_input_dir = os.path.join(original_dataset_path, class_name)
    class_output_dir = os.path.join(augmented_dataset_path, class_name)
    os.makedirs(class_output_dir, exist_ok=True)
    
    if os.path.exists(class_input_dir) and os.path.isdir(class_input_dir):
        # Copy original files to output dir
        for img_file in listdir(class_input_dir):
            img_path = os.path.join(class_input_dir, img_file)
            if os.path.isfile(img_path):
                img = cv2.imread(img_path)
                if img is not None:
                    cv2.imwrite(os.path.join(class_output_dir, img_file), img)
        
        # Perform augmentation
        augment_data(file_dir=class_input_dir, 
                     n_generated_samples=aug_factor, 
                     save_to_dir=class_output_dir)
    else:
        print(f"Warning: {class_input_dir} not found or not a directory")

end_time = time.time()
execution_time = (end_time - start_time)
print(f"\nElapsed time: {hms_string(execution_time)}")



Processing glioma class...

Processing meningioma class...

Processing pituitary class...

Processing no_tumor class...

Elapsed time: 0:15:13.8
