In [1]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models

import warnings
warnings.filterwarnings('ignore')


In [2]:
# Load the training data
train_df = pd.read_csv('train.csv')
print(train_df.head())

# Count of each class
print(train_df['diagnosis'].value_counts())

# Add .png extension to id_code
train_df['id_code'] = train_df['id_code'] + '.png'


        id_code  diagnosis
0  000c1434d8d7          2
1  001639a390f0          4
2  0024cdab0c1e          1
3  002c21358ce6          0
4  005b95c28852          0
diagnosis
0    1805
2     999
1     370
4     295
3     193
Name: count, dtype: int64


In [3]:
# Train/Validation split
train_df, val_df = train_test_split(train_df, test_size=0.2, stratify=train_df['diagnosis'], random_state=42)

print(f"Training size: {len(train_df)}")
print(f"Validation size: {len(val_df)}")


Training size: 2929
Validation size: 733


In [4]:
class DRDataset(Dataset):
    def __init__(self, dataframe, image_dir, transform=None):
        self.dataframe = dataframe
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = self.dataframe.iloc[idx, 0]
        label = self.dataframe.iloc[idx, 1]
        image_path = os.path.join(self.image_dir, img_name)
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label


In [5]:
IMAGE_SIZE = 224

transform_train = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                         [0.229, 0.224, 0.225])
])

transform_val = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                         [0.229, 0.224, 0.225])
])

train_dataset = DRDataset(train_df, 'train_images', transform=transform_train)
val_dataset = DRDataset(val_df, 'train_images', transform=transform_val)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Load pretrained ResNet50
model = models.resnet50(pretrained=True)

# Modify final fully connected layer
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 5)

model = model.to(device)


Using device: cpu


In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)


In [8]:
def train_model(model, criterion, optimizer, train_loader, val_loader, epochs=5):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_acc = 100 * correct / total
        print(f"[Epoch {epoch+1}] Train Loss: {running_loss/len(train_loader):.4f}, Accuracy: {train_acc:.2f}%")

        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_acc = 100 * val_correct / val_total
        print(f"             Val Loss: {val_loss/len(val_loader):.4f}, Accuracy: {val_acc:.2f}%\n")


In [10]:
train_model(model, criterion, optimizer, train_loader, val_loader, epochs=5)


[Epoch 1] Train Loss: 0.5582, Accuracy: 78.80%
             Val Loss: 0.5590, Accuracy: 80.22%

[Epoch 2] Train Loss: 0.4165, Accuracy: 84.02%
             Val Loss: 0.5598, Accuracy: 79.95%

[Epoch 3] Train Loss: 0.3391, Accuracy: 87.74%
             Val Loss: 0.5592, Accuracy: 82.40%

[Epoch 4] Train Loss: 0.2486, Accuracy: 91.43%
             Val Loss: 0.6153, Accuracy: 80.49%

[Epoch 5] Train Loss: 0.2064, Accuracy: 92.97%
             Val Loss: 0.5903, Accuracy: 82.13%



In [9]:
# Save the trained model
model_path = "diabetic_retinopathy_mmodel.pth"
torch.save(model.state_dict(), model_path)
print(f"✅ Model saved to: {model_path}")


✅ Model saved to: diabetic_retinopathy_mmodel.pth


In [10]:
import torch
model = torch.load("diabetic_retinopathy_model.pth", map_location="cpu")
print(model)


OrderedDict([('conv1.weight', tensor([[[[ 1.3162e-02,  1.4670e-02, -1.5270e-02,  ..., -4.0713e-02,
           -4.2589e-02, -7.0095e-02],
          [ 4.2187e-03,  5.9796e-03,  1.5129e-02,  ...,  2.6248e-03,
           -2.0232e-02, -3.7756e-02],
          [ 2.2479e-02,  2.3977e-02,  1.6425e-02,  ...,  1.0327e-01,
            6.3328e-02,  5.2601e-02],
          ...,
          [-5.9344e-04,  2.8189e-02, -9.6294e-03,  ..., -1.2662e-01,
           -7.5557e-02,  8.8647e-03],
          [ 4.2611e-03,  4.8721e-02,  6.2681e-02,  ...,  2.4877e-02,
           -3.2557e-02, -1.4511e-02],
          [-7.9559e-02, -3.1504e-02, -1.7216e-02,  ...,  3.5891e-02,
            2.3519e-02,  2.8648e-03]],

         [[-1.8526e-02,  1.1547e-02,  2.4272e-02,  ...,  5.4313e-02,
            4.5394e-02, -7.7822e-03],
          [-7.9649e-03,  1.8895e-02,  6.8233e-02,  ...,  1.6044e-01,
            1.4761e-01,  1.2204e-01],
          [-4.6139e-02, -7.5812e-02, -8.9533e-02,  ...,  1.2198e-01,
            1.6868e-01,  1.7