In [41]:
#import library
import torch
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import random
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

In [42]:
#Define test dataset transformations (same as train and validation)
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # VGG16 expects 224x224 input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # VGG16 normalization
])

In [43]:
# Load the test dataset
test_folder = r"C:\Users\user\Documents\!TA\!TA\cornealtopography\Independent Test Set"  # Replace with your actual test folder path
test_dataset = datasets.ImageFolder(test_folder, transform=test_transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

In [44]:
# Check dataset size
print(f"Testing dataset contains {len(test_dataset)} images.")

Testing dataset contains 1051 images.


In [45]:
# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [46]:
# Load the model architecture
vgg16 = models.vgg16(weights=None)  # No weights loaded at this point

In [47]:
vgg16.load_state_dict(torch.load("vgg16_state_dict.pth"))

  vgg16.load_state_dict(torch.load("vgg16_state_dict.pth"))


RuntimeError: Error(s) in loading state_dict for VGG:
	size mismatch for classifier.6.weight: copying a param with shape torch.Size([3, 4096]) from checkpoint, the shape in current model is torch.Size([1000, 4096]).
	size mismatch for classifier.6.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([1000]).

In [None]:
# Move the model to the appropriate device
vgg16 = vgg16.to(device)

In [12]:
# Store true labels and predictions
true_labels = []
predicted_labels = []

In [13]:
# Disable gradient computation for evaluation
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = vgg16(images)
        _, preds = torch.max(outputs, 1)

        true_labels.extend(labels.cpu().numpy())
        predicted_labels.extend(preds.cpu().numpy())

KeyboardInterrupt: 

In [None]:
# Print classification report
print("Classification Report:\n")
print(classification_report(true_labels, predicted_labels, target_names=test_dataset.classes))

In [None]:
# Confusion Matrix
conf_matrix = confusion_matrix(true_labels, predicted_labels)

In [None]:
# Plot confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", 
            xticklabels=test_dataset.classes, yticklabels=test_dataset.classes)
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Confusion Matrix")
plt.show()

In [None]:
# Helper function to display images
def show_images(images, labels, preds, class_names, title):
    plt.figure(figsize=(12, 8))
    for i in range(len(images)):
        plt.subplot(2, 5, i + 1)
        plt.imshow(images[i].permute(1, 2, 0).cpu().numpy())
        plt.title(f"True: {class_names[labels[i]]}\nPred: {class_names[preds[i]]}")
        plt.axis("off")
    plt.suptitle(title)
    plt.show()

In [None]:
# Randomly select a few test images
random_indices = random.sample(range(len(test_dataset)), 10)
images, labels = zip(*[test_dataset[i] for i in random_indices])
images = torch.stack(images).to(device)
labels = torch.tensor(labels).to(device)

In [None]:
# Get predictions
outputs = vgg16(images)
_, preds = torch.max(outputs, 1)


In [None]:
# Display the images
show_images(images, labels.cpu().numpy(), preds.cpu().numpy(), test_dataset.classes, title="Random Test Predictions")

In [None]:
# Save the classification report as a CSV file
report = classification_report(true_labels, predicted_labels, target_names=test_dataset.classes, output_dict=True)
import pandas as pd
df = pd.DataFrame(report).transpose()
df.to_csv("classification_report.csv", index=True)


In [None]:
# Save the confusion matrix plot
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", 
            xticklabels=test_dataset.classes, yticklabels=test_dataset.classes)
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Confusion Matrix")
plt.savefig("confusion_matrix.png")
plt.close()