In [8]:
#import library
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from torchvision import models

In [9]:
#Load the DataLoader objects
with open(r"C:\Users\user\Documents\!TA\!TA\all trial\train_loader.pkl", "rb") as f:
    train_loader = pickle.load(f)
with open(r"C:\Users\user\Documents\!TA\!TA\all trial\valid_loader.pkl", "rb") as f:
    valid_loader = pickle.load(f)

In [10]:
#Confirm the DataLoaders are loaded
print(f"Train Loader: {len(train_loader)} batches")
print(f"Valid Loader: {len(valid_loader)} batches")

Train Loader: 37 batches
Valid Loader: 10 batches


In [11]:
#Check device availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [12]:
# Load the pre-trained ResNet50resnet50 model
resnet50 = models.resnet50(pretrained=True)

# Modify the final fully connected layer for 3 output classes
resnet50.fc = nn.Linear(in_features=resnet50.fc.in_features, out_features=3)  # 3 classes: Keratoconus, Normal, Suspect

# Move the model to the appropriate device (GPU/CPU)
resnet50 = resnet50.to(device)

In [13]:
# Freeze the feature extractor layers (everything except the final classifier layer)
for param in resnet50.parameters():
    param.requires_grad = False  # Freeze all layers

# Unfreeze the final classifier layer to fine-tune it
for param in resnet50.fc.parameters():
    param.requires_grad = True

# Print the names of layers and whether they are trainable
for name, param in resnet50.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")

conv1.weight: requires_grad=False
bn1.weight: requires_grad=False
bn1.bias: requires_grad=False
layer1.0.conv1.weight: requires_grad=False
layer1.0.bn1.weight: requires_grad=False
layer1.0.bn1.bias: requires_grad=False
layer1.0.conv2.weight: requires_grad=False
layer1.0.bn2.weight: requires_grad=False
layer1.0.bn2.bias: requires_grad=False
layer1.0.conv3.weight: requires_grad=False
layer1.0.bn3.weight: requires_grad=False
layer1.0.bn3.bias: requires_grad=False
layer1.0.downsample.0.weight: requires_grad=False
layer1.0.downsample.1.weight: requires_grad=False
layer1.0.downsample.1.bias: requires_grad=False
layer1.1.conv1.weight: requires_grad=False
layer1.1.bn1.weight: requires_grad=False
layer1.1.bn1.bias: requires_grad=False
layer1.1.conv2.weight: requires_grad=False
layer1.1.bn2.weight: requires_grad=False
layer1.1.bn2.bias: requires_grad=False
layer1.1.conv3.weight: requires_grad=False
layer1.1.bn3.weight: requires_grad=False
layer1.1.bn3.bias: requires_grad=False
layer1.2.conv1.wei

In [14]:
#Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()  # For multi-class classification
optimizer = optim.Adam(resnet50.fc.parameters(), lr=0.001)

In [15]:
#Learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # Reduce lr by 0.5 every 5 epochs

In [16]:
# # Early stopping criteria
# best_accuracy = 0.0
# patience = 3  # Stop after 3 epochs with no improvement
# epochs_without_improvement = 0

In [17]:
#Training function
def train_model(model, train_loader, valid_loader, criterion, optimizer, scheduler, num_epochs=100):
    # global best_accuracy, epochs_without_improvement
    best_accuracy = 0.0 

    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        train_loss = 0.0
        train_correct = 0
        total_train = 0

        # Training loop
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            train_correct += (preds == labels).sum().item()
            total_train += labels.size(0)

        # Calculate train accuracy
        train_accuracy = 100 * train_correct / total_train

        # Validation loop
        model.eval()
        valid_correct = 0
        total_valid = 0

        with torch.no_grad():
            for images, labels in valid_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                valid_correct += (preds == labels).sum().item()
                total_valid += labels.size(0)
                
        # Calculate validatin accuracy
        valid_accuracy = 100 * valid_correct / total_valid

        # Print stats for the epoch
        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {train_loss/len(train_loader):.4f}, "
              f"Train Accuracy: {train_accuracy:.2f}%, "
              f"Validation Accuracy: {valid_accuracy:.2f}%")

        # Save the model with the best validation accuracy
        if valid_accuracy > best_accuracy:
            best_accuracy = valid_accuracy
            torch.save(model.state_dict(), "best_resnet50_model.pth")
        #     epochs_without_improvement = 0
        # else:
        #     epochs_without_improvement += 1

        # # Save checkpoints after every 5 epochs
        # if (epoch + 1) % 5 == 0:
        #     torch.save(model.state_dict(), f"vgg16_epoch_{epoch+1}.pth")

        # # Stop training early if no improvement
        # if epochs_without_improvement >= patience:
        #     print("Early stopping triggered!")
        #     break
        
        # Step the scheduler
        scheduler.step()

    print(f"Best Validation Accuracy: {best_accuracy:.2f}%")

In [18]:
#Train the model
train_model(resnet50, train_loader, valid_loader, criterion, optimizer, scheduler, num_epochs=100)

Epoch [1/100], Train Loss: 0.9979, Train Accuracy: 48.31%, Validation Accuracy: 58.52%
Epoch [2/100], Train Loss: 0.8170, Train Accuracy: 61.78%, Validation Accuracy: 65.26%
Epoch [3/100], Train Loss: 0.7522, Train Accuracy: 66.05%, Validation Accuracy: 68.47%
Epoch [4/100], Train Loss: 0.7399, Train Accuracy: 65.16%, Validation Accuracy: 67.96%
Epoch [5/100], Train Loss: 0.7014, Train Accuracy: 68.54%, Validation Accuracy: 70.32%
Epoch [6/100], Train Loss: 0.6758, Train Accuracy: 70.31%, Validation Accuracy: 71.50%
Epoch [7/100], Train Loss: 0.6782, Train Accuracy: 70.48%, Validation Accuracy: 72.34%
Epoch [8/100], Train Loss: 0.6627, Train Accuracy: 70.48%, Validation Accuracy: 68.13%
Epoch [9/100], Train Loss: 0.6716, Train Accuracy: 70.27%, Validation Accuracy: 72.34%
Epoch [10/100], Train Loss: 0.6655, Train Accuracy: 70.19%, Validation Accuracy: 71.67%
Epoch [11/100], Train Loss: 0.6395, Train Accuracy: 72.13%, Validation Accuracy: 70.66%
Epoch [12/100], Train Loss: 0.6334, Train

KeyboardInterrupt: 