In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, random_split
import numpy as np
import time
import copy
import os

In [2]:
train_dir = "/Users/hwarden/Documents/PhD/GOFMLPipeline/image_patches/train"
num_classes = 2
batch_size = 32
lr = 1e-4
num_epochs = 30
patience = 5
val_ratio = 0.2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
right_angle_rotation = transforms.RandomChoice([
    transforms.RandomRotation((0, 0)),
    transforms.RandomRotation((90, 90)),
    transforms.RandomRotation((180, 180)),
    transforms.RandomRotation((270, 270)),
])

train_transforms = transforms.Compose([
    right_angle_rotation,
    transforms.ToTensor()
])

val_transforms = transforms.Compose([
    right_angle_rotation,
    transforms.ToTensor()
])

In [4]:
full_dataset = datasets.ImageFolder(train_dir, transform=train_transforms)

train_size = int((1 - val_ratio) * len(full_dataset))
val_size   = len(full_dataset) - train_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Override transform for validation subset
val_dataset.dataset.transform = val_transforms

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

dataloaders = {"train": train_loader, "val": val_loader}

In [5]:
model = models.resnet101(weights=None)  # <--- random init

# Replace classifier head
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)

model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [6]:
best_loss = np.inf
best_model_wts = copy.deepcopy(model.state_dict())
epochs_no_improve = 0

In [7]:
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    print("-" * 20)

    for phase in ["train", "val"]:
        if phase == "train":
            model.train()
        else:
            model.eval()

        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == "train"):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)

                if phase == "train":
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(dataloaders[phase].dataset)
        epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

        print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

        # ----- EARLY STOPPING CHECK -----
        if phase == "val":
            if epoch_loss < best_loss:
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())
                epochs_no_improve = 0
                print("→ Validation loss improved, saving model.")
            else:
                epochs_no_improve += 1
                print(f"→ No improvement for {epochs_no_improve} epochs.")

    if epochs_no_improve >= patience:
        print("Early stopping triggered!")
        break

print("Training complete.")
model.load_state_dict(best_model_wts)

Epoch 1/30
--------------------
train Loss: 0.6516 Acc: 0.6431
val Loss: 0.7925 Acc: 0.5137
→ Validation loss improved, saving model.
Epoch 2/30
--------------------
train Loss: 0.5555 Acc: 0.7328
val Loss: 1.1129 Acc: 0.5137
→ No improvement for 1 epochs.
Epoch 3/30
--------------------
train Loss: 0.4489 Acc: 0.7793
val Loss: 1.0834 Acc: 0.5137
→ No improvement for 2 epochs.
Epoch 4/30
--------------------
train Loss: 0.4112 Acc: 0.8121
val Loss: 0.7238 Acc: 0.6781
→ Validation loss improved, saving model.
Epoch 5/30
--------------------
train Loss: 0.3619 Acc: 0.8345
val Loss: 0.3947 Acc: 0.8288
→ Validation loss improved, saving model.
Epoch 6/30
--------------------
train Loss: 0.4174 Acc: 0.8259
val Loss: 1.2426 Acc: 0.6918
→ No improvement for 1 epochs.
Epoch 7/30
--------------------
train Loss: 0.3748 Acc: 0.8276
val Loss: 0.4827 Acc: 0.8288
→ No improvement for 2 epochs.
Epoch 8/30
--------------------
train Loss: 0.3698 Acc: 0.8534
val Loss: 0.8109 Acc: 0.7260
→ No improveme

<All keys matched successfully>

In [8]:
os.makedirs("saved_models", exist_ok=True)
torch.save(model.state_dict(), "saved_models/resnet101_best.pth")
print("Best model saved as resnet101_best.pth")

Best model saved as resnet101_best.pth
