In [None]:
import matplotlib.pyplot as plt
from IPython.display import clear_output
from models import CustomResNet18, device
from dataset_utils import get_dataloaders, pacs_df, pacs_classes, to_pil
from utils import show_label_distribution, plot_loss
import torch
import torch.nn as nn
import pandas as pd

In [None]:
resnet18 = CustomResNet18()
resnet18.to(device)

train_loader, test_loader = get_dataloaders(pacs_df[pacs_df["domain"] == "photo"], 0.05)

show_label_distribution(train_loader, pacs_classes)

In [None]:
resnet18.train()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet18.parameters(), lr=0.01)
num_epochs = 100
patience = 20  # Number of epochs to wait for improvement

best_loss = float('inf')
epochs_without_improvement = 0
best_model_weights = None

# Lists to store training and validation loss for plotting
train_losses = []
val_losses = []

# Training Loop
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (features, labels) in enumerate(train_loader):
        features, labels = features.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = resnet18(features)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()


    # Validate the model
    resnet18.eval()  # Switch to evaluation mode
    validation_loss = 0.0
    with torch.no_grad():
        for features, labels in test_loader:
            features, labels = features.to(device), labels.to(device)
            outputs = resnet18(features)
            loss = criterion(outputs, labels)
            validation_loss += loss.item()

    validation_loss /= len(test_loader)
    

    # Store the losses for plotting
    train_losses.append(running_loss / len(train_loader.dataset))  # Store the last batch loss for training
    val_losses.append(validation_loss / len(test_loader.dataset))

    clear_output(wait=True)
    plot_loss(train_losses, val_losses)

    # Early Stopping Check
    if validation_loss < best_loss:
        best_loss = validation_loss
        epochs_without_improvement = 0
        best_model_weights = resnet18.state_dict()  # Save the best model
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            print(f"No improvement for {patience} epochs. Stopping training early.")
            break

# Load the best model weights
resnet18.load_state_dict(best_model_weights)



print("Training complete. Returning the model with the best validation loss.")
