In [5]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [6]:
# Basic transform for validation (resize, normalize)
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224 (common for models)
    transforms.ToTensor(),  # To tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Standard ImageNet norms
])

# Augmentation for training (add variety)
train_transform = A.Compose([
    A.Resize(224, 224),
    A.HorizontalFlip(p=0.5),  # Flip horizontally
    A.Rotate(limit=20, p=0.5),  # Rotate slightly
    A.RandomBrightnessContrast(p=0.2),  # Adjust brightness
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

In [4]:
train_dataset = datasets.ImageFolder('data/train', transform=val_transform)  # Swap to train_transform later
val_dataset = datasets.ImageFolder('data/valid', transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

num_classes = len(train_dataset.classes)
print("Num classes:", num_classes)

Num classes: 38
