In [2]:
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# Define the directory containing the images
train_dir = 'data/train'
val_dir = 'data/val'

# Define image transformations for augmentation
train_transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor(),
])

train_aug_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation((-90, 90)),  # 90° Rotate: Clockwise, Counter-Clockwise
    transforms.ColorJitter(saturation=0.25, brightness=0.1, hue=0.15),  # Saturation, Exposure and Hue
    transforms.RandomGrayscale(),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.5)), # Blur
    transforms.Resize((640, 640)),
    transforms.ToTensor(),
])

val_transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor(),
])

# Load the dataset without any transformations
train_dataset = datasets.ImageFolder(root=train_dir, transform=train_transform)
val_dataset = datasets.ImageFolder(root=val_dir, transform=val_transform)

# Create an oversampled augmented training dataset
oversampled_train_dataset = []
num_augmentations_per_image = 5  # Number of augmentations to apply to each image

for i in range(num_augmentations_per_image):
    augmented_dataset = datasets.ImageFolder(root=train_dir, transform=train_aug_transform)
    oversampled_train_dataset.append(augmented_dataset)

oversampled_train_dataset = torch.utils.data.ConcatDataset(oversampled_train_dataset)

# Combine the original training dataset with the oversampled augmented dataset
combined_train_dataset = torch.utils.data.ConcatDataset([train_dataset, oversampled_train_dataset])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
train_loader_aug = DataLoader(combined_train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# Create directories if they don't exist
os.makedirs('./datasets/images/train/barrel', exist_ok=True)
os.makedirs('./datasets/images/val/barrel', exist_ok=True)

# Save images from train_loader
for i, (images, _) in enumerate(train_loader):
    for j in range(images.size(0)):
        image = images[j]
        image = transforms.ToPILImage()(image)
        image.save(f'./datasets/images/train/barrel/image_{i}_{j}.jpg')

# Save images from val_loader
for i, (images, _) in enumerate(val_loader):
    for j in range(images.size(0)):
        image = images[j]
        image = transforms.ToPILImage()(image)
        image.save(f'./datasets/images/val/barrel/image_{i}_{j}.jpg')

# Create augmented training images
for i, (images, _) in enumerate(train_loader_aug):
    image = images[0]
    image = transforms.ToPILImage()(image)
    image.save(f'./datasets/images/train/barrel/aug_image_{i}.jpg')