In [35]:
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

### Loading the dataset

In [36]:
# Define the transformation (in this case, just convert to tensor)
transform = transforms.Compose([
    transforms.ToTensor()
])

# Path to your dataset folder (with subfolders for each class)
data_dir = '/Users/jeff/code/dataset/dr/dr-2019-3662/filtered'

# Load the dataset with ImageFolder
full_dataset = datasets.ImageFolder(root=data_dir, transform=transform)

# Check the classes (order is determined by alphabetical order of folder names)
print("Classes:", full_dataset.classes)

# Calculate the sizes for the training and validation sets (80:20 split)
dataset_size = len(full_dataset)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - train_size

# Split the dataset
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Create DataLoaders
train_dl = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
valid_dl = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

print(f"Total images: {dataset_size}, Training: {train_size}, Validation: {val_size}")


Classes: ['Mild', 'Moderate', 'No_DR', 'Proliferate_DR', 'Severe']
Total images: 965, Training: 772, Validation: 193


### Building The CNN Architecture

#### Four Convolutions

In [37]:
import torch.nn as nn

model = nn.Sequential()

model.add_module('conv1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1))
model.add_module('relu1', nn.ReLU())        
model.add_module('pool1', nn.MaxPool2d(kernel_size=2))  
model.add_module('dropout1', nn.Dropout(p=0.5)) 

model.add_module('conv2', nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1))
model.add_module('relu2', nn.ReLU())        
model.add_module('pool2', nn.MaxPool2d(kernel_size=2))   
model.add_module('dropout2', nn.Dropout(p=0.5)) 

model.add_module('conv3', nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1))
model.add_module('relu3', nn.ReLU())        
model.add_module('pool3', nn.MaxPool2d(kernel_size=2))   

model.add_module('conv4', nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1))
model.add_module('relu4', nn.ReLU())  

#### Pooling and Flattening

In [38]:
model.add_module('pool4', nn.AvgPool2d(kernel_size=8)) 
model.add_module('flatten', nn.Flatten()) 

#### Fully Connected (FC) and Softmax

In [39]:
model.add_module('fc1', nn.Linear(2304, 256))
model.add_module('relu_fc1', nn.ReLU())
model.add_module('fc2', nn.Linear(256, 5))
model.add_module('softmax', nn.Softmax(dim=1))

In [40]:
model

Sequential(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout1): Dropout(p=0.5, inplace=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu3): ReLU()
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu4): ReLU()
  (pool4): AvgPool2d(kernel_size=8, stride=8, padding=0)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=2304, out_features=256, bias=True)
  (relu_fc1): ReLU()
  (fc2): Linear(in_features=256, out_features=5, bias=True)
  (softmax): 

### Setting the model to MPS device

In [41]:
device = torch.device("mps:0")
# device = torch.device("cpu")
model = model.to(device) 

### Training the Model With Adam optimizer

In [42]:
loss_fn = nn.CrossEntropyLoss()  # for multi-class
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train(model, num_epochs, train_dl, valid_dl):
    loss_hist_train = [0] * num_epochs
    accuracy_hist_train = [0] * num_epochs
    loss_hist_valid = [0] * num_epochs
    accuracy_hist_valid = [0] * num_epochs

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_corrects = 0
        for x_batch, y_batch in train_dl:
            x_batch = x_batch.to(device) 
            y_batch = y_batch.to(device)  # y_batch should be integer labels in {0,1,2,3,4}

            outputs = model(x_batch)  # Forward pass (output shape: (batch, 5))
            loss = loss_fn(outputs, y_batch)  # Compute loss
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            running_loss += loss.item() * y_batch.size(0)
            preds = torch.argmax(outputs, dim=1)
            running_corrects += (preds == y_batch).float().sum().cpu().item()

        loss_hist_train[epoch] = running_loss / len(train_dl.dataset)
        accuracy_hist_train[epoch] = running_corrects / len(train_dl.dataset)
        
        model.eval()
        running_loss_val = 0.0
        running_corrects_val = 0
        with torch.no_grad():
            for x_batch, y_batch in valid_dl:
                x_batch = x_batch.to(device) 
                y_batch = y_batch.to(device)
                outputs = model(x_batch)
                loss = loss_fn(outputs, y_batch)
                running_loss_val += loss.item() * y_batch.size(0)
                preds = torch.argmax(outputs, dim=1)
                running_corrects_val += (preds == y_batch).float().sum().cpu().item()

        loss_hist_valid[epoch] = running_loss_val / len(valid_dl.dataset)
        accuracy_hist_valid[epoch] = running_corrects_val / len(valid_dl.dataset)
        
        print(f'Epoch {epoch+1} - Train accuracy: {accuracy_hist_train[epoch]:.4f}, Val accuracy: {accuracy_hist_valid[epoch]:.4f}')

        # Early stopping: if after 5 epochs training accuracy is still below 80%, stop training.
        if epoch >= 4 and accuracy_hist_train[epoch] < 0.8:
            print(f"Early stopping at epoch {epoch+1} as training accuracy did not reach 80%.")
            break

    return loss_hist_train, loss_hist_valid, accuracy_hist_train, accuracy_hist_valid

torch.manual_seed(1)
num_epochs = 20
hist = train(model, num_epochs, train_dl, valid_dl)

Epoch 1 - Train accuracy: 0.1904, Val accuracy: 0.2176
Epoch 2 - Train accuracy: 0.2021, Val accuracy: 0.2176
Epoch 3 - Train accuracy: 0.3070, Val accuracy: 0.3472
Epoch 4 - Train accuracy: 0.3782, Val accuracy: 0.3472
Epoch 5 - Train accuracy: 0.4184, Val accuracy: 0.3990
Early stopping at epoch 5 as training accuracy did not reach 80%.
