In [45]:
#import library
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from torchvision import models

In [46]:
#Load the DataLoader objects
with open(r"C:\Users\user\Documents\!TA\!TA\all trial\train_loader.pkl", "rb") as f:
    train_loader = pickle.load(f)
with open(r"C:\Users\user\Documents\!TA\!TA\all trial\valid_loader.pkl", "rb") as f:
    valid_loader = pickle.load(f)

In [47]:
#Confirm the DataLoaders are loaded
print(f"Train Loader: {len(train_loader)} batches")
print(f"Valid Loader: {len(valid_loader)} batches")

Train Loader: 74 batches
Valid Loader: 19 batches


In [48]:
#Check device availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [49]:
# Load the pre-trained ResNet18 model
resnet18 = models.resnet18(pretrained=True)

# Modify the final fully connected layer for 3 output classes
resnet18.fc = nn.Linear(in_features=resnet18.fc.in_features, out_features=3)  # 3 classes: Keratoconus, Normal, Suspect

# Move the model to the appropriate device (GPU/CPU)
resnet18 = resnet18.to(device)

In [50]:
# Freeze the feature extractor layers (everything except the final classifier layer)
for param in resnet18.parameters():
    param.requires_grad = False  # Freeze all layers

# Unfreeze the final classifier layer to fine-tune it
for param in resnet18.fc.parameters():
    param.requires_grad = True

# Print the names of layers and whether they are trainable
for name, param in resnet18.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")

conv1.weight: requires_grad=False
bn1.weight: requires_grad=False
bn1.bias: requires_grad=False
layer1.0.conv1.weight: requires_grad=False
layer1.0.bn1.weight: requires_grad=False
layer1.0.bn1.bias: requires_grad=False
layer1.0.conv2.weight: requires_grad=False
layer1.0.bn2.weight: requires_grad=False
layer1.0.bn2.bias: requires_grad=False
layer1.1.conv1.weight: requires_grad=False
layer1.1.bn1.weight: requires_grad=False
layer1.1.bn1.bias: requires_grad=False
layer1.1.conv2.weight: requires_grad=False
layer1.1.bn2.weight: requires_grad=False
layer1.1.bn2.bias: requires_grad=False
layer2.0.conv1.weight: requires_grad=False
layer2.0.bn1.weight: requires_grad=False
layer2.0.bn1.bias: requires_grad=False
layer2.0.conv2.weight: requires_grad=False
layer2.0.bn2.weight: requires_grad=False
layer2.0.bn2.bias: requires_grad=False
layer2.0.downsample.0.weight: requires_grad=False
layer2.0.downsample.1.weight: requires_grad=False
layer2.0.downsample.1.bias: requires_grad=False
layer2.1.conv1.wei

In [51]:
#Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()  # For multi-class classification
optimizer = optim.Adam(resnet18.fc.parameters(), lr=0.001)

In [52]:
#Learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # Reduce lr by 0.5 every 5 epochs

In [53]:
# # Early stopping criteria
# best_accuracy = 0.0
# patience = 3  # Stop after 3 epochs with no improvement
# epochs_without_improvement = 0

In [54]:
#Training function
def train_model(model, train_loader, valid_loader, criterion, optimizer, scheduler, num_epochs=30):
    # global best_accuracy, epochs_without_improvement
    best_accuracy = 0.0 

    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        train_loss = 0.0
        train_correct = 0
        total_train = 0

        # Training loop
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            train_correct += (preds == labels).sum().item()
            total_train += labels.size(0)

        # Calculate train accuracy
        train_accuracy = 100 * train_correct / total_train

        # Validation loop
        model.eval()
        valid_correct = 0
        total_valid = 0

        with torch.no_grad():
            for images, labels in valid_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                valid_correct += (preds == labels).sum().item()
                total_valid += labels.size(0)
                
        # Calculate validatin accuracy
        valid_accuracy = 100 * valid_correct / total_valid

        # Print stats for the epoch
        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {train_loss/len(train_loader):.4f}, "
              f"Train Accuracy: {train_accuracy:.2f}%, "
              f"Validation Accuracy: {valid_accuracy:.2f}%")

        # Save the model with the best validation accuracy
        if valid_accuracy > best_accuracy:
            best_accuracy = valid_accuracy
            torch.save(model.state_dict(), "best_resnet18_model.pth")
        #     epochs_without_improvement = 0
        # else:
        #     epochs_without_improvement += 1

        # # Save checkpoints after every 5 epochs
        # if (epoch + 1) % 5 == 0:
        #     torch.save(model.state_dict(), f"vgg16_epoch_{epoch+1}.pth")

        # # Stop training early if no improvement
        # if epochs_without_improvement >= patience:
        #     print("Early stopping triggered!")
        #     break
        
        # Step the scheduler
        scheduler.step()

    print(f"Best Validation Accuracy: {best_accuracy:.2f}%")

In [55]:
#Train the model
train_model(resnet18, train_loader, valid_loader, criterion, optimizer, scheduler, num_epochs=30)

Epoch [1/30], Train Loss: 0.8863, Train Accuracy: 57.47%, Validation Accuracy: 64.59%
Epoch [2/30], Train Loss: 0.7577, Train Accuracy: 65.79%, Validation Accuracy: 63.07%
Epoch [3/30], Train Loss: 0.7196, Train Accuracy: 67.02%, Validation Accuracy: 67.96%
Epoch [4/30], Train Loss: 0.6859, Train Accuracy: 68.03%, Validation Accuracy: 68.47%
Epoch [5/30], Train Loss: 0.6710, Train Accuracy: 69.72%, Validation Accuracy: 68.80%
Epoch [6/30], Train Loss: 0.6464, Train Accuracy: 71.37%, Validation Accuracy: 69.65%
Epoch [7/30], Train Loss: 0.6448, Train Accuracy: 71.07%, Validation Accuracy: 69.98%
Epoch [8/30], Train Loss: 0.6219, Train Accuracy: 72.97%, Validation Accuracy: 70.83%
Epoch [9/30], Train Loss: 0.6420, Train Accuracy: 70.86%, Validation Accuracy: 71.16%
Epoch [10/30], Train Loss: 0.6249, Train Accuracy: 71.62%, Validation Accuracy: 72.18%
Epoch [11/30], Train Loss: 0.6148, Train Accuracy: 72.42%, Validation Accuracy: 71.67%
Epoch [12/30], Train Loss: 0.6192, Train Accuracy: 7