In [1]:
!pip install -U albumentations

Collecting albumentations
  Downloading albumentations-2.0.2-py3-none-any.whl.metadata (38 kB)
Collecting albucore==0.0.23 (from albumentations)
  Downloading albucore-0.0.23-py3-none-any.whl.metadata (5.3 kB)
Collecting simsimd>=5.9.2 (from albucore==0.0.23->albumentations)
  Downloading simsimd-6.2.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (66 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.0/66.0 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
Downloading albumentations-2.0.2-py3-none-any.whl (278 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m278.2/278.2 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading albucore-0.0.23-py3-none-any.whl (14 kB)
Downloading simsimd-6.2.1-cp310-cp310-manylinux_2_28_x86_64.whl (632 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m632.7/632.7 kB[0m [31m22.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: simsimd, albucore, albumentations
  Att

In [2]:
import os
import shutil
from albumentations import (
    RandomBrightnessContrast,
    GaussianBlur,
    Affine,
    HueSaturationValue,
    Compose
)
from PIL import Image, ImageOps
import numpy as np
from zipfile import ZipFile
import logging

# Setup logging
log_file = "/kaggle/working/dataset_preparation.log"
logging.basicConfig(
    filename=log_file,
    filemode='w',
    format='%(asctime)s - %(levelname)s - %(message)s',
    level=logging.INFO
)

def log_and_print(message, level=logging.INFO):
    """Logs and prints a message."""
    logging.log(level, message)
    print(message)

# Augmentation Pipeline
augmentation_pipeline = Compose([
    RandomBrightnessContrast(p=0.5),
    GaussianBlur(blur_limit=(3, 5), p=0.3),
    Affine(scale=(0.95, 1.05), translate_percent=(0.05, 0.05), rotate=0, p=0.7),
    HueSaturationValue(hue_shift_limit=10, sat_shift_limit=15, val_shift_limit=10, p=0.5)
])

# Analyze dataset to find class sizes and the largest class
def analyze_dataset(dataset_dir):
    class_counts = {}
    for class_name in os.listdir(dataset_dir):
        class_dir = os.path.join(dataset_dir, class_name)
        if os.path.isdir(class_dir):
            class_counts[class_name] = len([f for f in os.listdir(class_dir) if f.endswith(('.jpg', '.png', '.jpeg'))])
    max_class_size = 7215  # Target size for all classes
    return class_counts, max_class_size

# Perform image augmentation
def augment_image(image_path):
    image = Image.open(image_path)
    image = ImageOps.exif_transpose(image)  # Correct orientation
    image_np = np.array(image)
    augmented = augmentation_pipeline(image=image_np)
    return Image.fromarray(augmented['image'])

# Balance dataset by augmenting images
def balance_dataset(input_dir, output_dir, max_class_size):
    for class_name in os.listdir(input_dir):
        class_dir = os.path.join(input_dir, class_name)
        output_class_dir = os.path.join(output_dir, class_name)
        os.makedirs(output_class_dir, exist_ok=True)

        if os.path.isdir(class_dir):
            images = [f for f in os.listdir(class_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]
            num_images = len(images)

            # Copy existing images
            for img_name in images:
                shutil.copy(os.path.join(class_dir, img_name), os.path.join(output_class_dir, img_name))

            # Augment images to reach the max_class_size
            if num_images < max_class_size:
                log_and_print(f"Augmenting class '{class_name}' from {num_images} to {max_class_size} images.")
                while len(os.listdir(output_class_dir)) < max_class_size:
                    for img_name in images:
                        if len(os.listdir(output_class_dir)) >= max_class_size:
                            break
                        img_path = os.path.join(class_dir, img_name)
                        augmented_image = augment_image(img_path)
                        augmented_img_name = f"aug_{len(os.listdir(output_class_dir))}_{img_name}"
                        augmented_image.save(os.path.join(output_class_dir, augmented_img_name))

            log_and_print(f"Class '{class_name}' balanced with {len(os.listdir(output_class_dir))} images.")

# Compress final dataset
def create_zip(output_dir, zip_path):
    log_and_print(f"Creating ZIP file at {zip_path}...")
    with ZipFile(zip_path, 'w') as zipf:
        for root, _, files in os.walk(output_dir):
            for file in files:
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, output_dir)
                zipf.write(file_path, arcname)
    log_and_print(f"ZIP file created: {zip_path}")

# Main workflow
def main():
    INPUT_DIR = "/kaggle/input/fer-data"
    OUTPUT_DIR = "/kaggle/working/final_balanced_dataset"
    ZIP_FILE = "/kaggle/working/final_balanced_dataset.zip"

    # Train dataset
    log_and_print("Analyzing Train Dataset...")
    train_counts, train_max_size = analyze_dataset(os.path.join(INPUT_DIR, "train"))
    log_and_print(f"Train class counts: {train_counts}, Max size: {train_max_size}")

    log_and_print("Balancing Train Dataset...")
    balance_dataset(os.path.join(INPUT_DIR, "train"), os.path.join(OUTPUT_DIR, "train"), train_max_size)

    # Test dataset
    log_and_print("Analyzing Test Dataset...")
    test_counts, test_max_size = analyze_dataset(os.path.join(INPUT_DIR, "test"))
    log_and_print(f"Test class counts: {test_counts}, Max size: {test_max_size}")

    log_and_print("Balancing Test Dataset...")
    balance_dataset(os.path.join(INPUT_DIR, "test"), os.path.join(OUTPUT_DIR, "test"), test_max_size)

    # Create ZIP file
    log_and_print("Zipping final dataset...")
    create_zip(OUTPUT_DIR, ZIP_FILE)

    log_and_print("Dataset preparation complete.")

if __name__ == "__main__":
    main()


Analyzing Train Dataset...
Train class counts: {'fearful': 2664, 'disgusted': 284, 'angry': 2728, 'neutral': 3597, 'sad': 3079, 'surprised': 2258, 'happy': 5058}, Max size: 7215
Balancing Train Dataset...
Augmenting class 'fearful' from 2664 to 7215 images.
Class 'fearful' balanced with 7215 images.
Augmenting class 'disgusted' from 284 to 7215 images.
Class 'disgusted' balanced with 7215 images.
Augmenting class 'angry' from 2728 to 7215 images.
Class 'angry' balanced with 7215 images.
Augmenting class 'neutral' from 3597 to 7215 images.
Class 'neutral' balanced with 7215 images.
Augmenting class 'sad' from 3079 to 7215 images.
Class 'sad' balanced with 7215 images.
Augmenting class 'surprised' from 2258 to 7215 images.
Class 'surprised' balanced with 7215 images.
Augmenting class 'happy' from 5058 to 7215 images.
Class 'happy' balanced with 7215 images.
Analyzing Test Dataset...
Test class counts: {'fearful': 667, 'disgusted': 71, 'angry': 683, 'neutral': 900, 'sad': 770, 'surprised'