In [1]:
#import libraries
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models import ResNet18_Weights

In [2]:
#Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [3]:
#Load the pre-trained VGG16 model
resnet18 = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

In [4]:
# Modify the final fully connected layer for 3 output classes
resnet18.fc = nn.Linear(in_features=resnet18.fc.in_features, out_features=3)  # 3 classes: Keratoconus, Normal, Suspect

# Move the model to the appropriate device (GPU/CPU)
resnet18 = resnet18.to(device)

In [5]:
# Unfreeze the final classifier layer to fine-tune it
for param in resnet18.fc.parameters():
    param.requires_grad = True

# Print the names of layers and whether they are trainable
for name, param in resnet18.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")

# Display the model structure
print(resnet18)

conv1.weight: requires_grad=True
bn1.weight: requires_grad=True
bn1.bias: requires_grad=True
layer1.0.conv1.weight: requires_grad=True
layer1.0.bn1.weight: requires_grad=True
layer1.0.bn1.bias: requires_grad=True
layer1.0.conv2.weight: requires_grad=True
layer1.0.bn2.weight: requires_grad=True
layer1.0.bn2.bias: requires_grad=True
layer1.1.conv1.weight: requires_grad=True
layer1.1.bn1.weight: requires_grad=True
layer1.1.bn1.bias: requires_grad=True
layer1.1.conv2.weight: requires_grad=True
layer1.1.bn2.weight: requires_grad=True
layer1.1.bn2.bias: requires_grad=True
layer2.0.conv1.weight: requires_grad=True
layer2.0.bn1.weight: requires_grad=True
layer2.0.bn1.bias: requires_grad=True
layer2.0.conv2.weight: requires_grad=True
layer2.0.bn2.weight: requires_grad=True
layer2.0.bn2.bias: requires_grad=True
layer2.0.downsample.0.weight: requires_grad=True
layer2.0.downsample.1.weight: requires_grad=True
layer2.0.downsample.1.bias: requires_grad=True
layer2.1.conv1.weight: requires_grad=True


In [6]:
# Save the model setup for use in training
# torch.save(resnet18, "resnet18_setup.pth")

torch.save(resnet18.state_dict(), "resnet18_state_dict.pth")