In [None]:
from PIL import Image, ImageEnhance, ImageOps
import numpy as np
import os
import cv2
# Define input and output directories
input_folder = '/content/drive/MyDrive/JackFruit/DataSet/PreprocessedJackfruit'
output_folder = '/content/drive/MyDrive/JackFruit/DataSet/AugmentedJackfruit'
target_count = 1500  # Total images needed per class

# CLAHE function
def apply_clahe(image):
    image_np = np.array(image)
    if len(image_np.shape) == 3 and image_np.shape[2] == 3:  # Ensure it's RGB
        lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        cl = clahe.apply(l)
        limg = cv2.merge((cl, a, b))
        image_np = cv2.cvtColor(limg, cv2.COLOR_LAB2RGB)
    return Image.fromarray(image_np)

# Brightness and Contrast adjustments
def adjust_brightness(image, level='medium'):
    enhancer = ImageEnhance.Brightness(image)
    return enhancer.enhance(1.2 if level == 'medium' else 1.0)  # Medium (1.2) or Standard (1.0)

def adjust_contrast(image, level='high'):
    enhancer = ImageEnhance.Contrast(image)
    return enhancer.enhance(1.5 if level == 'high' else 1.0)  # High (1.5) or Standard (1.0)

# Saturation adjustment
def adjust_saturation(image, factor):
    enhancer = ImageEnhance.Color(image)
    return enhancer.enhance(factor)  # e.g., factor = 0.5 for lower, 1.5 for higher

# Skew transformation
def add_skew(image, skew_factor):
    width, height = image.size
    xshift = abs(skew_factor) * width
    new_width = width + int(round(xshift))
    if skew_factor > 0:
        new_image = image.transform((new_width, height), Image.AFFINE,
                                    (1, skew_factor, -xshift if skew_factor > 0 else 0, 0, 1, 0), Image.BICUBIC)
    else:
        new_image = image.transform((new_width, height), Image.AFFINE,
                                    (1, skew_factor, 0, 0, 1, 0), Image.BICUBIC)
    return new_image

# Shear transformation
def add_shear(image, shear_factor):
    width, height = image.size
    m = shear_factor / 100.0
    xshift = abs(m) * width
    new_width = width + int(round(xshift))
    image = image.transform((new_width, height), Image.AFFINE, (1, m, -xshift if m > 0 else 0, 0, 1, 0), Image.BICUBIC)
    return image

def augment_and_save_image(input_image, output_path, image_filename, rotation_angles, flip_horizontal, shear_factors, skew_factors, saturation_factors):
    images_saved = 0

    # Check the current number of files in the output directory
    if len(os.listdir(output_path)) >= target_count:
        return images_saved  # Skip processing if the limit is reached

    base_filename = image_filename.rsplit('.', 1)[0]

    # Rotation
    for angle in rotation_angles:
        rotated_image = input_image.rotate(angle, expand=True)
        rotated_image.save(os.path.join(output_path, f'{base_filename}_rotated_{angle}.jpg'), 'JPEG')
        images_saved += 1
        if images_saved >= target_count:
            return images_saved

    # Horizontal Flip
    if flip_horizontal:
        h_flipped_image = ImageOps.mirror(input_image)
        h_flipped_image.save(os.path.join(output_path, f'{base_filename}_hflip.jpg'), 'JPEG')
        images_saved += 1
        if images_saved >= target_count:
            return images_saved

    # Shear
    for shear_factor in shear_factors:
        sheared_image = add_shear(input_image, shear_factor)
        sheared_image.save(os.path.join(output_path, f'{base_filename}_shear_{shear_factor}.jpg'), 'JPEG')
        images_saved += 1
        if images_saved >= target_count:
            return images_saved

    # Skew
    for skew_factor in skew_factors:
        skewed_image = add_skew(input_image, skew_factor)
        skewed_image.save(os.path.join(output_path, f'{base_filename}_skew_{skew_factor}.jpg'), 'JPEG')
        images_saved += 1
        if images_saved >= target_count:
            return images_saved

    # Saturation
    for saturation_factor in saturation_factors:
        saturated_image = adjust_saturation(input_image, saturation_factor)
        saturated_image.save(os.path.join(output_path, f'{base_filename}_saturation_{saturation_factor}.jpg'), 'JPEG')
        images_saved += 1
        if images_saved >= target_count:
            return images_saved

    # CLAHE
    if images_saved < target_count:
        clahe_image = apply_clahe(input_image)
        clahe_image.save(os.path.join(output_path, f'{base_filename}_clahe.jpg'), 'JPEG')
        images_saved += 1

    # Brightness and Contrast
    if images_saved < target_count:
        bright_image = adjust_brightness(input_image, level='medium')
        bright_image.save(os.path.join(output_path, f'{base_filename}_brightness_medium.jpg'), 'JPEG')
        images_saved += 1

    if images_saved < target_count:
        contrast_image = adjust_contrast(input_image, level='high')
        contrast_image.save(os.path.join(output_path, f'{base_filename}_contrast_high.jpg'), 'JPEG')
        images_saved += 1

    return images_saved

def augment_images_in_folder(input_path, output_path, rotation_angles, flip_horizontal, shear_factors, skew_factors, saturation_factors):
    total_images_saved = 0
    os.makedirs(output_path, exist_ok=True)

    for image_filename in os.listdir(input_path):
        if image_filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
            if total_images_saved >= target_count:
                break  # Stop if the folder's limit is reached
            try:
                input_image_path = os.path.join(input_path, image_filename)
                input_image = Image.open(input_image_path)
                images_saved = augment_and_save_image(
                    input_image, output_path, image_filename,
                    rotation_angles, flip_horizontal,
                    shear_factors, skew_factors, saturation_factors
                )
                total_images_saved += images_saved
            except FileNotFoundError:
                print(f"File not found: {image_filename}")
            except Exception as e:
                print(f"An error occurred with {image_filename}: {e}")

# Process each class
for class_name in os.listdir(input_folder):
    class_folder = os.path.join(input_folder, class_name)
    output_class_folder = os.path.join(output_folder, class_name)

    if os.path.isdir(class_folder):
        augment_images_in_folder(
            input_path=class_folder,
            output_path=output_class_folder,
            rotation_angles=[45,60],
            flip_horizontal=True,
            shear_factors=[5, -5],
            skew_factors=[0.1, -0.1],
            saturation_factors=[0.5, 1.5]
        )

print("Data augmentation complete.")