# Animals10 Dataset Processing
---

## 1. Download Dataset from Kaggle

In [None]:
import kagglehub

# Download latest version
src_path = kagglehub.dataset_download("alessiocorrado99/animals10")

print("Path to dataset files:", src_path)

---

## 2. Configure Data Directory

In [None]:
DATA_DIR = "" # Change this to your local path or Google Drive mount path if running in Colab !!!
              # if you keep DATA_DIR = "" , the data will be loaded in your current repo !!

dataset_path = f"{DATA_DIR}/animals10_dataset"

---

## 3. Split Data into Train/Val/Test Sets

In [None]:
import splitfolders 

# Split dataset into train, validation, and test sets
splitfolders.ratio(
    src_path,
    output=dataset_path,
    seed=42, 
    ratio=(0.7, 0.2, 0.1), # train(70%), validation(20%), test(10%)
    group_prefix=None)

---

## 4. Preview Sample Images

In [None]:
from utils import display_sample_images 

display_sample_images(dataset_path) # Display sample images from the train set

---

## 5. Analyze Class Distribution

In [None]:
import os
from utils import display_distribution

train_dir = os.path.join(dataset_path, "train")

display_distribution(train_dir) # Display the distribution of images across classes of the train set

---

## 6. Balance Classes with Data Augmentation

In [None]:
import os
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from collections import Counter
import shutil

# Paths
train_dir = os.path.join(dataset_path, "train")
train_augmented_dir = os.path.join(dataset_path, "train_augmented")

# Create augmentation generator
augmenter = ImageDataGenerator(
    rotation_range=20,
    zoom_range=0.15,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.15,
    horizontal_flip=True,
    fill_mode="nearest"
)

# Step 1: Count images per class
class_counts = {}
for cls in os.listdir(train_dir):
    cls_path = os.path.join(train_dir, cls)
    if os.path.isdir(cls_path):
        class_counts[cls] = len(os.listdir(cls_path))

max_count = max(class_counts.values())
print(f"Target count per class: {max_count}")

# Step 2: Balance classes
for cls, count in class_counts.items():
    src = os.path.join(train_dir, cls)
    dst = os.path.join(train_augmented_dir, cls)
    os.makedirs(dst, exist_ok=True)

    # Copy originals
    for img_name in os.listdir(src):
        shutil.copy(os.path.join(src, img_name), os.path.join(dst, img_name))

    # How many augmented images needed
    needed = max_count - count
    if needed <= 0:
        continue  # Class already balanced

    # Load all images in class
    img_files = os.listdir(src)

    # Generate augmented images
    gen_count = 0
    while gen_count < needed:
        img_name = np.random.choice(img_files)
        img = load_img(os.path.join(src, img_name))
        x = img_to_array(img)
        x = x.reshape((1,) + x.shape)

        for batch in augmenter.flow(
            x, batch_size=1,
            save_to_dir=dst,
            save_prefix='aug',
            save_format='jpeg'
        ):
            gen_count += 1
            if gen_count >= needed:
                break

    print(f"{cls}: augmented {gen_count} images to reach {max_count}")

---

## 7. Verify Augmented Dataset

In [None]:
import os
from utils import display_distribution

train_augmented_dir = os.path.join(dataset_path, "train_augmented")

display_distribution(train_augmented_dir) # Display the distribution of images across classes of the train_augmented set

---