# Cell 1 - Imports + Config

In [14]:
import os
import glob
import random


import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter


# ---- CONFIG ----

data_root = "/Users/dominicschlegel/PycharmProjects/LearningByDoing/data/ImageNetSubset"  # folder with binder_0000041.JPEG etc.

batch_size    = 32
num_epochs    = 5
learning_rate = 1e-3

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print("Using device:", device)
print("Data root:", data_root)



Using device: mps
Data root: /Users/dominicschlegel/PycharmProjects/LearningByDoing/data/ImageNetSubset


# Cell 2 - Transformations, Dataset, DataLoader

In [15]:
# Cell 2 (clean version with train/val split)

import glob
import random

# 1) Collect paths
all_image_paths = sorted(glob.glob(os.path.join(data_root, "*.JPEG")))
if len(all_image_paths) == 0:
    raise RuntimeError("No .JPEG images found in your data folder.")

# 2) Extract class names (classname_index.JPEG)
def extract_classname(path):
    filename = os.path.basename(path)
    name_no_ext = filename.split(".")[0]
    classname = name_no_ext.split("_")[0]
    return classname

# 3) Build class maps
classnames = sorted({extract_classname(p) for p in all_image_paths})
class_to_idx = {c: i for i, c in enumerate(classnames)}

num_classes = len(classnames)

# 4) Train/val split
random.shuffle(all_image_paths)
val_split = 0.2
num_val = int(len(all_image_paths) * val_split)

val_paths = all_image_paths[:num_val]
train_paths = all_image_paths[num_val:]

print("Train:", len(train_paths))
print("Val:", len(val_paths))

# 5) Transforms
train_transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

train_transform_augmentation = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(p=0.5),   # flip left-right
    transforms.RandomRotation(10),            # rotate +/- 10 degrees
    transforms.ColorJitter(
        brightness=0.2,
        contrast=0.2,
        saturation=0.2
    ),
    transforms.ToTensor(),
])


val_transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor(),
])

# 6) Dataset class
class SimpleImageDataset(Dataset):
    def __init__(self, paths, transform):
        self.paths = paths
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p).convert("RGB")
        img = self.transform(img)

        classname = extract_classname(p)
        label = class_to_idx[classname]
        return img, label

# 7) Create datasets + loaders
train_dataset = SimpleImageDataset(train_paths, transform=train_transform_augmentation)
val_dataset   = SimpleImageDataset(val_paths,   transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=0)



Train: 10800
Val: 2700


# Define Model

In [16]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        # 1) Convolutional "feature extractor"
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  # [B, 3, 128,128] -> [B,32,128,128]
            nn.ReLU(),
            nn.MaxPool2d(2),                             # [B,32,128,128] -> [B,32,64,64]

            nn.Conv2d(32, 64, kernel_size=3, padding=1), # [B,32,64,64]  -> [B,64,64,64]
            nn.ReLU(),
            nn.MaxPool2d(2),                             # [B,64,64,64]  -> [B,64,32,32]

            nn.Conv2d(64, 128, kernel_size=3, padding=1),# [B,64,32,32]  -> [B,128,32,32]
            nn.ReLU(),
            nn.MaxPool2d(2),                             # [B,128,32,32] -> [B,128,16,16]
        )

        # 2) Linear "classifier" head
        self.classifier = nn.Sequential(
            nn.Flatten(),                                # [B,128,16,16] -> [B, 128*16*16]
            nn.Linear(128 * 16 * 16, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# create model instance and move it to the device (M1 / CPU / CUDA)
num_classes = len(class_to_idx)  # use discovered classes
model = SimpleCNN(num_classes).to(device)
print(model)

SimpleCNN(
  (features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=32768, out_features=256, bias=True)
    (2): ReLU()
    (3): Linear(in_features=256, out_features=10, bias=True)
  )
)


# Loss function + optimizer


In [17]:
# Cell 4: loss function + optimizer

# For multi-class classification, this is the standard choice:
criterion = nn.CrossEntropyLoss()

# Adam is a common optimizer that usually works well out-of-the-box
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

print("Loss function:", criterion)
print("Optimizer:", optimizer)


Loss function: CrossEntropyLoss()
Optimizer: Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    decoupled_weight_decay: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0
)


# Debug cell

In [18]:
# Debug cell: take one batch and run it through the model

imgs, labels = next(iter(train_loader))
print("Batch imgs shape:", imgs.shape)     # expected: [batch_size, 3, 128, 128]
print("Batch labels shape:", labels.shape) # expected: [batch_size]

imgs = imgs.to(device)
labels = labels.to(device)

outputs = model(imgs)
print("Outputs shape:", outputs.shape)     # expected: [batch_size, num_classes]

loss = criterion(outputs, labels)
print("Loss:", loss.item())


Batch imgs shape: torch.Size([32, 3, 128, 128])
Batch labels shape: torch.Size([32])
Outputs shape: torch.Size([32, 10])
Loss: 2.29620361328125


# Validation Loop


In [19]:
# Cell 5: validation loop

def validate(model, val_loader):
    model.eval()  # turn off dropout/batchnorm
    total = 0
    correct = 0
    running_loss = 0.0

    with torch.no_grad():  # disable gradients
        for imgs, labels in val_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)

            outputs = model(imgs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * imgs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    val_loss = running_loss / total
    val_acc = correct / total
    return val_loss, val_acc


# Training Loop

In [20]:
# Cell 6: training loop with validation + best-model saving

best_val_acc = 0.0
writer = SummaryWriter()

for epoch in range(num_epochs):
    model.train()  # training mode
    running_loss = 0.0
    running_correct = 0
    total = 0

    for imgs, labels in train_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        outputs = model(imgs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        _, preds = torch.max(outputs, 1)
        running_correct += (preds == labels).sum().item()
        total += labels.size(0)

    train_loss = running_loss / total
    train_acc = running_correct / total

    # ---- VALIDATION ----
    val_loss, val_acc = validate(model, val_loader)

    #writer.add...

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"  Train Loss: {train_loss:.4f}  |  Train Acc: {train_acc:.4f}")
    print(f"  Val   Loss: {val_loss:.4f}  |  Val   Acc: {val_acc:.4f}")

    # ---- SAVE BEST MODEL ----
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")
        print(f"  ðŸ”¥ New best model saved! (Val Acc: {best_val_acc:.4f})")


Epoch 1/5
  Train Loss: 2.1406  |  Train Acc: 0.2101
  Val   Loss: 2.0414  |  Val   Acc: 0.2622
  ðŸ”¥ New best model saved! (Val Acc: 0.2622)
Epoch 2/5
  Train Loss: 2.0233  |  Train Acc: 0.2798
  Val   Loss: 2.0117  |  Val   Acc: 0.2741
  ðŸ”¥ New best model saved! (Val Acc: 0.2741)
Epoch 3/5
  Train Loss: 1.8962  |  Train Acc: 0.3319
  Val   Loss: 1.8938  |  Val   Acc: 0.3381
  ðŸ”¥ New best model saved! (Val Acc: 0.3381)
Epoch 4/5
  Train Loss: 1.7982  |  Train Acc: 0.3750
  Val   Loss: 1.7952  |  Val   Acc: 0.3678
  ðŸ”¥ New best model saved! (Val Acc: 0.3678)
Epoch 5/5
  Train Loss: 1.6988  |  Train Acc: 0.4176
  Val   Loss: 1.8251  |  Val   Acc: 0.3796
  ðŸ”¥ New best model saved! (Val Acc: 0.3796)


#