In [4]:
import os
import torch
import torchvision
from torchvision import transforms, datasets, models
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np


## Define Transforms

In [5]:
input_size = 224

train_transforms = transforms.Compose([
    # Resizes images to 224x224, consider using CenterCrop or pad is aspect ratio wanted to be preserved
    transforms.Resize((input_size, input_size)),
    # Randomly flips the image horizontally for data augmentation, improves generalization
    transforms.RandomHorizontalFlip(),
    # Converts image to PyTorch tensor
    transforms.ToTensor(),
    # Normalizes RGB
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])  # Imagenet mean/std
])

val_test_transforms = transforms.Compose([
    transforms.Resize((input_size, input_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])


To prepare the images for training a lightweight pretrained CNN (like MobileNetV2), we apply the following transformation pipelines using `torchvision.transforms`:

### 🔁 Training Transforms:
- Scales all images to `224x224` pixels to match the expected input size of most pretrained models.  
  *(Note: Use `CenterCrop` or `Pad` instead if you want to preserve aspect ratio.)*
- *Randomly flips images horizontally*: This is a simple data augmentation technique that improves generalization by introducing more variation into the training set.
- *Converts a PIL image (H x W x C) to a PyTorch tensor (C x H x W), and scales pixel values to the range [0.0, 1.0].
- *Applies channel-wise normalization using ImageNet's mean and standard deviation*: Ensures pixel distributions align with model expectations.

### 📏 Validation & Test Transforms:
Same as training transforms **without** data augmentation. These transformations ensure evaluation is consistent and reproducible.

## Load Dataset

In [6]:
data_dir = "../dataset/split"

# Load datasets
train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=train_transforms)
val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=val_test_transforms)
test_dataset = datasets.ImageFolder(os.path.join(data_dir, "test"), transform=val_test_transforms)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

class_names = train_dataset.classes
