In [14]:
#import library
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from torchvision import models

In [15]:
#Load the DataLoader objects
with open(r"C:\Users\user\Documents\!TA\!TA\all trial\train_loader.pkl", "rb") as f:
    train_loader = pickle.load(f)
with open(r"C:\Users\user\Documents\!TA\!TA\all trial\valid_loader.pkl", "rb") as f:
    valid_loader = pickle.load(f)

In [16]:
#Confirm the DataLoaders are loaded
print(f"Train Loader: {len(train_loader)} batches")
print(f"Valid Loader: {len(valid_loader)} batches")

Train Loader: 74 batches
Valid Loader: 19 batches


In [17]:
#Check device availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [18]:
#load the pre-trained VGG16 model architecture
vgg16 = models.vgg16(weights=None).to(device)

#Modify the classifier layer for 3 output classes
vgg16.classifier[6] = nn.Linear(in_features=4096, out_features=3).to(device)

#Load the saved state dictionary
vgg16.load_state_dict(torch.load("vgg16_state_dict.pth", map_location=device))

  vgg16.load_state_dict(torch.load("vgg16_state_dict.pth", map_location=device))


<All keys matched successfully>

In [19]:
#Ensure the classifier has the correct number of output classes (3 in this case)
vgg16.classifier[6] = nn.Linear(in_features=4096, out_features=3).to(device)

In [20]:
#Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()  # For multi-class classification
optimizer = optim.Adam(vgg16.classifier.parameters(), lr=0.001)

In [21]:
#Learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # Reduce lr by 0.5 every 5 epochs

In [22]:
# # Early stopping criteria
# best_accuracy = 0.0
# patience = 3  # Stop after 3 epochs with no improvement
# epochs_without_improvement = 0

In [23]:
#Training function
def train_model(model, train_loader, valid_loader, criterion, optimizer, scheduler, num_epochs=10):
    # global best_accuracy, epochs_without_improvement
    best_accuracy = 0.0 

    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        train_loss = 0.0
        train_correct = 0
        total_train = 0

        # Training loop
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            train_correct += (preds == labels).sum().item()
            total_train += labels.size(0)

        # Calculate train accuracy
        train_accuracy = 100 * train_correct / total_train

        # Validation loop
        model.eval()
        valid_correct = 0
        total_valid = 0

        with torch.no_grad():
            for images, labels in valid_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                valid_correct += (preds == labels).sum().item()
                total_valid += labels.size(0)
                
        # Calculate validatin accuracy
        valid_accuracy = 100 * valid_correct / total_valid

        # Print stats for the epoch
        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {train_loss/len(train_loader):.4f}, "
              f"Train Accuracy: {train_accuracy:.2f}%, "
              f"Validation Accuracy: {valid_accuracy:.2f}%")

        # Save the model with the best validation accuracy
        if valid_accuracy > best_accuracy:
            best_accuracy = valid_accuracy
            torch.save(model.state_dict(), "best_vgg16_model.pth")
        #     epochs_without_improvement = 0
        # else:
        #     epochs_without_improvement += 1

        # # Save checkpoints after every 5 epochs
        # if (epoch + 1) % 5 == 0:
        #     torch.save(model.state_dict(), f"vgg16_epoch_{epoch+1}.pth")

        # # Stop training early if no improvement
        # if epochs_without_improvement >= patience:
        #     print("Early stopping triggered!")
        #     break
        
        # Step the scheduler
        scheduler.step()

    print(f"Best Validation Accuracy: {best_accuracy:.2f}%")

In [24]:
#Train the model
train_model(vgg16, train_loader, valid_loader, criterion, optimizer, scheduler, num_epochs=15)

Epoch [1/15], Train Loss: 1.2791, Train Accuracy: 50.55%, Validation Accuracy: 66.27%
Epoch [2/15], Train Loss: 0.8921, Train Accuracy: 61.91%, Validation Accuracy: 70.49%
Epoch [3/15], Train Loss: 0.8497, Train Accuracy: 64.48%, Validation Accuracy: 69.31%
Epoch [4/15], Train Loss: 0.8465, Train Accuracy: 63.98%, Validation Accuracy: 66.95%
Epoch [5/15], Train Loss: 0.7829, Train Accuracy: 65.71%, Validation Accuracy: 68.97%
Epoch [6/15], Train Loss: 0.6009, Train Accuracy: 72.42%, Validation Accuracy: 72.18%
Epoch [7/15], Train Loss: 0.5206, Train Accuracy: 74.92%, Validation Accuracy: 76.22%
Epoch [8/15], Train Loss: 0.4810, Train Accuracy: 76.86%, Validation Accuracy: 74.20%
Epoch [9/15], Train Loss: 0.4959, Train Accuracy: 75.63%, Validation Accuracy: 73.86%
Epoch [10/15], Train Loss: 0.4652, Train Accuracy: 78.12%, Validation Accuracy: 75.72%
Epoch [11/15], Train Loss: 0.3909, Train Accuracy: 81.25%, Validation Accuracy: 75.89%
Epoch [12/15], Train Loss: 0.3488, Train Accuracy: 8