# CSE 455 WEAPON CLASSIFICATION - DATA AUGMENTATION

# Necessary Libraries

In [3]:
import tensorflow
# import numpy as np
# import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
# import cv2
import torch
import torchvision.transforms as transforms
from torchvision.transforms import ColorJitter, RandomGrayscale, Lambda, GaussianBlur, RandomPosterize
from PIL import Image
import matplotlib.pyplot as plt
import os

# Data Augmentation (Transformations)

In [5]:
# Please change the location to the dataset we are using from the "Weapon-Classification/Dataset/images"
# Image data generator for augmentation
generate_data = ImageDataGenerator(
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

# Input and output directories
image_dir = "Dataset/small augment data"
output_dir = "Dataset/transformed augmented data"

# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)

# Loop through images in the input directory
for f in os.listdir(image_dir):
    if f.lower().endswith(('jpg', 'png', 'jpeg')):
        img_path = os.path.join(image_dir, f)  # Combine directory and filename
        try:
            # Load the image and preprocess
            img = load_img(img_path)  # Load the image
            x = img_to_array(img)    # Convert image to numpy array
            x = x.reshape((1,) + x.shape)  # Add batch dimension

            i = 0
            for batch in generate_data.flow(x, batch_size=1, save_to_dir=output_dir, save_prefix='aug', save_format='jpeg'):
                i += 1
                if i > 20:  # Generate up to 20 augmented images per input image (can be altered based on what we will need)
                    break
        except Exception as e:
            print(f"Error processing file {f}: {e}")


# Random Erasing based Augmentation

In [7]:
# Please change the location to the dataset we are using from the "Weapon-Classification/Dataset/images"
image_dir = "Dataset/small augment data"
output_dir = "Dataset/random erased augmented data"
os.makedirs(output_dir, exist_ok=True)

# transformation pipeline
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomErasing(p=1, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
    transforms.ToPILImage()
])

# Augment images and save to output directory
for f in os.listdir(image_dir):
    if f.lower().endswith(('jpg', 'png', 'jpeg')):
        img_path = os.path.join(image_dir, f)
        try:
            img = Image.open(img_path)
            # Convert RGBA to RGB for saving as JPEG (PS: Don't delete this or we will lose datasets that are png)
            if img.mode == 'RGBA':
                img = img.convert('RGB')
                
            # save four different variation of images
            for i in range(4):
                random_erased_augmented_image = transform(img)
                output_path = os.path.join(output_dir, f"{os.path.splitext(f)[0]}_{i+1}.jpeg") # save in the output directory
                random_erased_augmented_image.save(output_path)
        except Exception as e:
            print(f"Error processing file {f}: {e}")

# Color Transformation based Augmentation

In [9]:
# Please change the location to the dataset we are using from the "Weapon-Classification/Dataset/images"

# For Classification Tasks (Pre-trained Models)
# transform = transforms.Compose([
# transforms.RandomHorizontalFlip(),
# transforms.PILToTensor(),
# transforms.ConvertImageDtype(torch.float),
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
# transforms.RandomErasing(),
# ])

"""
PS: The pipeline above is for flipping and also remvoing parts of images from the augmented images but for now its just normalized color transformed image.
Below are more ways to perform color transformation.
"""
# Color Jittering: It randomly changes the brightness, contrast, saturation, or hue of the image to simulate varying lighting conditions and camera settings.

# transform = transforms.Compose([
#     transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),  # Randomly adjust color properties
#     transforms.ToTensor(),
# ])

# Random Grayscale Conversion: It converts the image to grayscale with a given probability which makes the model more robust to images with reduced or missing color information.

# transform = transforms.Compose([
#     transforms.RandomGrayscale(p=0.2),  # 20% chance to convert to grayscale
#     transforms.ToTensor(),
# ])

# Gamma Correction: It adjusts the gamma of an image to make it appear brighter or darker to simulate overexposed or underexposed images.

# transform = transforms.Compose([
#     Lambda(lambda img: img.point(lambda x: x ** 0.8)),  # Apply gamma correction
#     transforms.ToTensor(),
# ])

# Hue Rotation: It rotates the hue channel of the image which simulates images taken under different light sources (e.g., daylight vs. fluorescent lighting).

# transform = transforms.Compose([
#     transforms.ColorJitter(hue=0.3),  # Rotate hue randomly
#     transforms.ToTensor(),
# ])

# Gaussian Blur: It applies a Gaussian blur filter to the image which simulates out-of-focus images or motion blur.

# transform = transforms.Compose([
#     transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 2.0)),  # Randomly blur the image
#     transforms.ToTensor(),
# ])

#Posterization: It reduces the number of bits used for each color channel, creating a "posterized" effect that simulates compression artifacts or low-quality images.

# transform = transforms.Compose([
#     transforms.RandomPosterize(bits=4, p=0.5),  # Reduce to 4 bits with a 50% chance
#     transforms.ToTensor(),
# ])


import os
from PIL import Image
from torchvision import transforms

image_dir = "Dataset/small augment data"
output_dir = "Dataset/augmented data"
os.makedirs(output_dir, exist_ok=True)

# transformation pipeline
transform = transforms.Compose([
    transforms.PILToTensor(),               # Convert image to tensor
    transforms.ConvertImageDtype(torch.float),  # Scale to [0, 1]
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # Normalize
])

# Augment images and save to output directory
for f in os.listdir(image_dir):
    if f.lower().endswith(('jpg', 'png', 'jpeg')):
        img_path = os.path.join(image_dir, f)
        try:
            img = Image.open(img_path)
            # Convert RGBA to RGB for saving as JPEG
            if img.mode == 'RGBA':
                img = img.convert('RGB')
            augmented_img_tensor = transform(img)
            augmented_img = transforms.ToPILImage()(augmented_img_tensor)
            # Save the augmented image
            output_path = os.path.join(output_dir, f"{os.path.splitext(f)[0]}_aug_{i+1}.jpeg")
            augmented_img.save(output_path)
        except Exception as e:
            print(f"Error processing file {f}: {e}")
