In [None]:
import os
import random

import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

from PIL import Image

# reproducibility
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Using device:", device)

In [None]:
dataDir = input("Enter data folder path or press enter to use default 'data': ").strip()
if not dataDir:
    dataDir = "data"

if not os.path.isdir(dataDir):
    raise FileNotFoundError(f"Data folder not found at: {dataDir}")

print("Data folder:", dataDir)

In [None]:
from torchvision.transforms import Compose, Resize, RandomHorizontalFlip, RandomRotation, ColorJitter, RandomAffine, ToTensor, Normalize

trainTransform = Compose([
    Resize((224, 224)),
    RandomHorizontalFlip(),
    RandomRotation(10),
    ColorJitter(brightness=0.15, contrast=0.15),
    RandomAffine(degrees=0, translate=(0.05, 0.05)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406],
              std=[0.229, 0.224, 0.225])
])

valTransform = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406],
              std=[0.229, 0.224, 0.225])
])

print("Transforms defined")

In [None]:
fullDataset = datasets.ImageFolder(root=dataDir, transform=trainTransform)

datasetSize = len(fullDataset)
print("Total images found:", datasetSize)
print("Class to index mapping:", fullDataset.class_to_idx)

# train val test split 70 15 15
trainSize = int(0.7 * datasetSize)
valSize = int(0.15 * datasetSize)
testSize = datasetSize - trainSize - valSize

trainDataset, valDataset, testDataset = random_split(fullDataset, [trainSize, valSize, testSize])

# apply val transform to validation and test subsets
valDataset.dataset.transform = valTransform
testDataset.dataset.transform = valTransform

print(f"Split sizes  train {len(trainDataset)}  val {len(valDataset)}  test {len(testDataset)}")

In [None]:
batchInput = input("Enter batch size or press enter to use default 16: ").strip()
try:
    batchSize = int(batchInput) if batchInput else 16
except ValueError:
    batchSize = 16

# on personal M3 use 0 workers to avoid multiprocess issues
numWorkers = 0

trainLoader = DataLoader(trainDataset, batch_size=batchSize, shuffle=True, num_workers=numWorkers)
valLoader = DataLoader(valDataset, batch_size=batchSize, shuffle=False, num_workers=numWorkers)
testLoader = DataLoader(testDataset, batch_size=batchSize, shuffle=False, num_workers=numWorkers)

print(f"DataLoaders ready  batch size {batchSize}  workers {numWorkers}")

In [None]:
def unnormalize(tensor):
    # tensor shape C H W
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = tensor.cpu().numpy().transpose(1, 2, 0)
    img = (img * std) + mean
    img = np.clip(img, 0, 1)
    return img

def showBatch(loader, classes, n=6):
    inputs, labels = next(iter(loader))
    fig = plt.figure(figsize=(12, 6))
    for i in range(min(n, inputs.size(0))):
        ax = fig.add_subplot(1, n, i+1)
        img = unnormalize(inputs[i])
        ax.imshow(img)
        ax.set_title(classes[labels[i]])
        ax.axis("off")
    plt.show()

# show a few training images
print("Class list", fullDataset.classes)
showBatch(trainLoader, fullDataset.classes, n=6)

In [None]:
os.makedirs("models", exist_ok=True)
with open("models/class_names.txt", "w") as f:
    for cls in fullDataset.classes:
        f.write(cls + "\n")

print("Saved class list to models/class_names.txt")