In [2]:
!pip install medmnist



In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
import medmnist
from medmnist import INFO
from torchvision.utils import save_image

In [4]:
data_flag = 'octmnist'
download = True

info = INFO[data_flag]
DataClass = getattr(medmnist, info['python_class'])
n_classes = len(info['label'])

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = DataClass(split='train', transform=transform, download=download)
val_dataset = DataClass(split='val', transform=transform, download=download)
test_dataset = DataClass(split='test', transform=transform, download=download)

# DataLoaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Using downloaded and verified file: /root/.medmnist/octmnist.npz
Using downloaded and verified file: /root/.medmnist/octmnist.npz
Using downloaded and verified file: /root/.medmnist/octmnist.npz


In [5]:
# Class labels (for reference)
class_names = list(info['label'].values())
print(class_names)

['choroidal neovascularization', 'diabetic macular edema', 'drusen', 'normal']


In [6]:
# Define CNN Model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)
        
        self.fc1 = nn.Linear(128 * 3 * 3, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, n_classes)

    def forward(self, x):
        x = self.pool(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool(torch.relu(self.bn3(self.conv3(x))))
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cnn = CNN().to(device)
print(cnn)

CNN(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=1152, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=4, bias=True)
)


In [15]:
loss_fn = nn.CrossEntropyLoss()  
optimizer = optim.Adam(cnn.parameters(), lr=0.001)

num_epochs = 20
for epoch in range(num_epochs):
    cnn.train()
    total_loss = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.squeeze().long().to(device)

        optimizer.zero_grad()
        outputs = cnn(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}")

Epoch 1/20, Loss: 0.0823
Epoch 2/20, Loss: 0.0723
Epoch 3/20, Loss: 0.0715
Epoch 4/20, Loss: 0.0684
Epoch 5/20, Loss: 0.0656
Epoch 6/20, Loss: 0.0649
Epoch 7/20, Loss: 0.0619
Epoch 8/20, Loss: 0.0588
Epoch 9/20, Loss: 0.0577
Epoch 10/20, Loss: 0.0565
Epoch 11/20, Loss: 0.0545
Epoch 12/20, Loss: 0.0547
Epoch 13/20, Loss: 0.0522
Epoch 14/20, Loss: 0.0523
Epoch 15/20, Loss: 0.0482
Epoch 16/20, Loss: 0.0489
Epoch 17/20, Loss: 0.0490
Epoch 18/20, Loss: 0.0449
Epoch 19/20, Loss: 0.0462
Epoch 20/20, Loss: 0.0447


In [16]:
torch.save(cnn.state_dict(), "octmnist.pth")
print("Model weights saved successfully!")

Model weights saved successfully!


In [17]:
cnn.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.squeeze().long().to(device)
        outputs = cnn(images)
        _, predicted = torch.max(outputs, 1)

        correct += (predicted == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")

Test Accuracy: 67.90%
