In [1]:
#import libraries
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models import VGG16_Weights

In [2]:
#Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cpu


In [3]:
#Load the pre-trained VGG16 model
vgg16 = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)

In [4]:
#Modify the classifier layer for 3 output classes
vgg16.classifier[6] = nn.Linear(in_features=4096, out_features=3)  # 3 classes: Keratoconus, Normal, Suspect

#Move the model to the appropriate device (GPU/CPU)
vgg16 = vgg16.to(device)

In [5]:
#Freeze the feature extractor layers
for param in vgg16.features.parameters():
    param.requires_grad = False

for name, param in vgg16.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")


# Display the model structure
print(vgg16)

features.0.weight: requires_grad=False
features.0.bias: requires_grad=False
features.2.weight: requires_grad=False
features.2.bias: requires_grad=False
features.5.weight: requires_grad=False
features.5.bias: requires_grad=False
features.7.weight: requires_grad=False
features.7.bias: requires_grad=False
features.10.weight: requires_grad=False
features.10.bias: requires_grad=False
features.12.weight: requires_grad=False
features.12.bias: requires_grad=False
features.14.weight: requires_grad=False
features.14.bias: requires_grad=False
features.17.weight: requires_grad=False
features.17.bias: requires_grad=False
features.19.weight: requires_grad=False
features.19.bias: requires_grad=False
features.21.weight: requires_grad=False
features.21.bias: requires_grad=False
features.24.weight: requires_grad=False
features.24.bias: requires_grad=False
features.26.weight: requires_grad=False
features.26.bias: requires_grad=False
features.28.weight: requires_grad=False
features.28.bias: requires_grad=

In [6]:
# Save the model setup for use in training
# torch.save(vgg16, "vgg16_setup.pth")

torch.save(vgg16.state_dict(), "vgg16_state_dict.pth")