In [1]:
import os
import shutil
from sklearn.model_selection import train_test_split
import torch
from torchvision import datasets, transforms
import torchvision.transforms.functional as TF
from PIL import Image
from random import choice
from collections import Counter

# Import required for DataLoader
from torch.utils.data import DataLoader
from PIL import ImageFile

# Import required for mounting Google Drive (specific to Google Colab)
from google.colab import drive

# Mount Google Drive (specific to Google Colab)
drive.mount('/content/drive')

base_path = '/content/drive/My Drive/Colab Notebooks/public-data/image/dataset'

def analyze_dataset(path):
    class_counts = Counter()
    for class_dir in os.listdir(path):
        class_path = os.path.join(path, class_dir)
        if os.path.isdir(class_path):
            count = len([img for img in os.listdir(class_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))])
            class_counts[class_dir] = count
    return class_counts

def random_transform(image):
    """Apply random transformations to an image."""
    image = image.convert("RGB")
    if torch.rand(1) > 0.5:
        image = TF.hflip(image)
    if torch.rand(1) > 0.5:
        image = TF.vflip(image)
    angle = torch.randint(-30, 30, (1,)).item()
    image = TF.rotate(image, angle)
    return image

def balance_dataset(path, class_counts, max_per_class):
    for class_dir in class_counts.keys():
        class_path = os.path.join(path, class_dir)
        images = [img for img in os.listdir(class_path) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
        while len(images) < max_per_class:
            img_to_copy = choice(images)
            img_path = os.path.join(class_path, img_to_copy)
            with Image.open(img_path) as img:
                new_img = random_transform(img)
                new_img_name = f"aug_{len(images)}_{img_to_copy}"
                new_img.save(os.path.join(class_path, new_img_name))
            images.append(new_img_name)

classes = ['0', '1', '2', '3']  # List of class names

# Analyze and balance dataset
class_counts = analyze_dataset(base_path)
max_per_class = max(class_counts.values())
balance_dataset(base_path, class_counts, max_per_class)

# Define paths for train, validation, and test sets
train_path = os.path.join(base_path, 'train')
val_path = os.path.join(base_path, 'val')
test_path = os.path.join(base_path, 'test')

# Create directories for train, validation, and test sets
for _class in classes:
    os.makedirs(os.path.join(train_path, _class), exist_ok=True)
    os.makedirs(os.path.join(val_path, _class), exist_ok=True)
    os.makedirs(os.path.join(test_path, _class), exist_ok=True)

train_size = 0.8  # 80% for training

# Split and move images
for _class in classes:
    class_dir = os.path.join(base_path, _class)
    images = [img for img in os.listdir(class_dir) if img.lower().endswith(('.png', '.jpg', '.jpeg'))]

    if not images:
        print(f"No images found in class {_class} directory.")
        continue

    train_imgs, non_train_imgs = train_test_split(images, test_size=1 - train_size, random_state=42)
    test_imgs, val_imgs = train_test_split(non_train_imgs, test_size=0.5, random_state=42)

    for img in train_imgs:
        shutil.move(os.path.join(class_dir, img), os.path.join(train_path, _class, img))
    for img in val_imgs:
        shutil.move(os.path.join(class_dir, img), os.path.join(val_path, _class, img))
    for img in test_imgs:
        shutil.move(os.path.join(class_dir, img), os.path.join(test_path, _class, img))

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Load datasets
train_dataset = datasets.ImageFolder(train_path, transform=transform)
val_dataset = datasets.ImageFolder(val_path, transform=transform)
test_dataset = datasets.ImageFolder(test_path, transform=transform)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=20, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=20, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=20, shuffle=False)


Mounted at /content/drive
