In [27]:
import os
import glob
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import numpy as np
import torch
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
DATA_ROOT = '/kaggle/input/retinal-disease-classification'

train_dir = os.path.join(DATA_ROOT, 'Training_Set', 'Training_Set')
val_dir = os.path.join(DATA_ROOT, 'Evaluation_Set', 'Evaluation_Set')
test_dir = os.path.join(DATA_ROOT, 'Test_Set', 'Test_Set')

train_img_paths = glob.glob(os.path.join(train_dir, 'Training', '*.png'))
train_label_path = os.path.join(train_dir, 'RFMiD_Training_Labels.csv')

val_img_paths = glob.glob(os.path.join(val_dir, 'Validation', '*.png'))
val_label_path = os.path.join(val_dir, 'RFMiD_Validation_Labels.csv')

test_img_paths = glob.glob(os.path.join(test_dir, 'Test', '*.png'))
test_label_path = os.path.join(test_dir, 'RFMiD_Testing_Labels.csv')

len(train_img_paths), len(val_img_paths), len(test_img_paths)

(1920, 640, 640)

In [3]:
train_label_df = pd.read_csv(train_label_path)
val_label_df = pd.read_csv(val_label_path)
test_label_df = pd.read_csv(test_label_path)

print(f'Num. train labels: {len(train_label_df)}')
print(f'Num. val labels: {len(val_label_df)}')
print(f'Num. test labels: {len(test_label_df)}')

Num. train labels: 1920
Num. val labels: 640
Num. test labels: 640


In [11]:
from torch.utils.data import Dataset
from PIL import Image

def path2id(img_path):
    return int(os.path.splitext(os.path.basename(img_path))[0])

class RetinaDataset(Dataset):
    
    def __init__(self, img_paths, label_csv_path, transform=None):
        self.img_paths = sorted(img_paths)
        self.label_df = pd.read_csv(label_csv_path).drop(columns=['ID', 'Disease_Risk', 'HR', 'ODPM'])
        self.transform = transform
    
    def __len__(self): 
        return len(self.img_paths)
    
    def __getitem__(self, id):
        img_path = self.img_paths[id]
        img_id = path2id(img_path)
        img = Image.open(img_path)
        label = np.array(self.label_df.iloc[img_id - 1]) # Index to frame is image ID - 1, see the note earlier
        if self.transform is not None:
            img = self.transform(img)
            label = torch.from_numpy(label).float()
        else:
            img = np.array(img)
        return img, label

In [12]:
# test the dataset correctness, without transform
ID = 0

data = RetinaDataset(train_img_paths, train_label_path)

img, label = data[ID]
img.shape, img.dtype, label.shape, label.dtype

((1424, 2144, 3), dtype('uint8'), (43,), dtype('int64'))

In [15]:
import torchvision.transforms.v2 as transforms

img_path = train_img_paths[8]
img = Image.open(img_path)

tf1 = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
])

tf2 = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomAdjustSharpness(2, 1),
    transforms.RandomRotation(15),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.CenterCrop(224),
])

img1, img2 = tf1(img), tf2(img)

In [16]:
import torchvision.transforms.v2 as transforms

PRET_MEANS = [0.485, 0.456, 0.406]
PRET_STDS = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomAdjustSharpness(2, 0.8),
    transforms.RandomRotation(180),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.5),
    transforms.CenterCrop(224),
    transforms.ToImageTensor(),
    transforms.ConvertImageDtype(torch.float32),
    transforms.Normalize(mean=PRET_MEANS, std=PRET_STDS)
])

test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToImageTensor(),
    transforms.ConvertImageDtype(torch.float32),
    transforms.Normalize(mean=PRET_MEANS, std=PRET_STDS)
])

In [17]:
train_data = RetinaDataset(train_img_paths, train_label_path, train_transform)
val_data = RetinaDataset(val_img_paths, val_label_path, test_transform)
test_data = RetinaDataset(test_img_paths, test_label_path, test_transform)

img, label = train_data[ID]
img.shape, img.dtype, label.shape, label.dtype

(torch.Size([3, 224, 224]), torch.float32, torch.Size([43]), torch.float32)

In [19]:
from torch.utils.data import DataLoader

BATCH_SIZE = 64
# WEIGHT_DECAY = 1e-5
N_WORKERS = os.cpu_count()

train_loader = DataLoader(train_data,
                           batch_size=BATCH_SIZE,
                           shuffle=True,
                           num_workers=N_WORKERS,
                           pin_memory=True)

val_loader = DataLoader(val_data,
                         batch_size=BATCH_SIZE,
                         shuffle=False,
                         num_workers=N_WORKERS,
                         pin_memory=True)

test_loader = DataLoader(test_data,
                         batch_size=BATCH_SIZE,
                         shuffle=False,
                         num_workers=N_WORKERS,
                         pin_memory=True)

In [20]:
from torchvision import models
from torch import nn

model = models.convnext_tiny(weights='IMAGENET1K_V1')

# replace final layer in classifier
in_final = model.classifier[-1].in_features
OUT_FINAL = 43
model.classifier[-1] = nn.Linear(in_final, OUT_FINAL)
model.classifier

Downloading: "https://download.pytorch.org/models/convnext_tiny-983f1562.pth" to /root/.cache/torch/hub/checkpoints/convnext_tiny-983f1562.pth
100%|██████████| 109M/109M [00:00<00:00, 210MB/s]  


Sequential(
  (0): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=768, out_features=43, bias=True)
)

# Training model

In [23]:
import torch.optim as optim

LR_FOUND = 2e-3

loss_fn = nn.BCEWithLogitsLoss().to(device)
model = model.to(device)

# using smaller learning rate for feature extractor compared to classifier
lr_params = [
    {'params': model.features.parameters(), 'lr': LR_FOUND / 10},
    {'params': model.classifier.parameters(), 'lr': LR_FOUND}
]

optimizer = optim.AdamW(lr_params)

In [25]:
def accuracy(y_h, y):
    y_h = (y_h > 0.5).float()
    return (y_h == y).float().mean(dim=1).mean()

def train_epoch(model, loader, loss_fn, optimizer, device, epoch):
    model.train()
    lossi, acci = [], []

    for x, y in tqdm(loader, desc=f'Epoch {epoch}'):
        x, y = x.to(device), y.to(device)
        y_h = model(x)
        loss = loss_fn(y_h, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lossi.append(loss.item())
        
        y_h, y = y_h.detach().cpu(), y.detach().cpu()
        acc = accuracy(y_h, y)
        acci.append(acc)

    loss = torch.tensor(lossi).mean().item()
    acc = torch.tensor(acci).mean().item()
    return {'loss': loss, 'acc': acc}

def eval(model, loader, loss_fn, device):
    model.eval()
    lossi, acci = [], []

    with torch.no_grad():
        for x, y in tqdm(loader, desc='Evaluation'):
            x, y = x.to(device), y.to(device)
            y_h = model(x)
            loss = loss_fn(y_h, y)
            lossi.append(loss.item())
            
            y_h, y = y_h.detach().cpu(), y.detach().cpu()
            acc = accuracy(y_h, y)
            acci.append(acc)

    loss = torch.tensor(lossi).mean().item()
    acc = torch.tensor(acci).mean().item()
    return {'loss': loss, 'acc': acc}

In [29]:
N_EPOCHS = 10
best_loss = float('inf')

train_resi, val_resi = [], []
BEST_MODEL = ''

no_improvement_count = 0
max_no_improvement = 3  # Set the threshold for consecutive epochs without improvement

for epoch in range(N_EPOCHS):
    train_res = train_epoch(model, train_loader, loss_fn, optimizer, device, epoch+1)
    val_res = eval(model, val_loader, loss_fn, device)
    
    print('- Train')
    print(f"  Loss: {train_res['loss']: 3.4f} | "
          f"Accuracy: {train_res['acc']: .3f}")

    print('- Validation')
    print(f"  Loss: {val_res['loss']: 3.4f} | "
          f"Accuracy: {val_res['acc']: .3f}")

    val_loss = val_res['loss']
    
    if val_loss < best_loss:
        best_loss = val_loss
        BEST_MODEL = f'convnext_cp_{epoch+1}.pth'
        torch.save(model.state_dict(), BEST_MODEL)
        print('* Current best loss. Saved model!')
        no_improvement_count = 0  # Reset the counter since there's an improvement
    else:
        no_improvement_count += 1

    if no_improvement_count >= max_no_improvement:
        print(f'No improvement in validation loss for {max_no_improvement} consecutive epochs. Stopping training.')
        break

    train_resi.append(train_res)
    val_resi.append(val_res)

Epoch 1: 100%|██████████| 30/30 [01:58<00:00,  3.93s/it]
Evaluation: 100%|██████████| 10/10 [00:38<00:00,  3.88s/it]


- Train
  Loss:  0.0830 | Accuracy:  0.976
- Validation
  Loss:  0.0750 | Accuracy:  0.979
* Current best loss. Saved model!


Epoch 2: 100%|██████████| 30/30 [01:52<00:00,  3.74s/it]
Evaluation: 100%|██████████| 10/10 [00:34<00:00,  3.42s/it]


- Train
  Loss:  0.0653 | Accuracy:  0.979
- Validation
  Loss:  0.0656 | Accuracy:  0.980
* Current best loss. Saved model!


Epoch 3: 100%|██████████| 30/30 [01:52<00:00,  3.75s/it]
Evaluation: 100%|██████████| 10/10 [00:35<00:00,  3.55s/it]


- Train
  Loss:  0.0533 | Accuracy:  0.982
- Validation
  Loss:  0.0590 | Accuracy:  0.981
* Current best loss. Saved model!


Epoch 4: 100%|██████████| 30/30 [01:53<00:00,  3.79s/it]
Evaluation: 100%|██████████| 10/10 [00:33<00:00,  3.39s/it]


- Train
  Loss:  0.0448 | Accuracy:  0.984
- Validation
  Loss:  0.0578 | Accuracy:  0.981
* Current best loss. Saved model!


Epoch 5: 100%|██████████| 30/30 [01:51<00:00,  3.73s/it]
Evaluation: 100%|██████████| 10/10 [00:34<00:00,  3.49s/it]


- Train
  Loss:  0.0393 | Accuracy:  0.986
- Validation
  Loss:  0.0509 | Accuracy:  0.984
* Current best loss. Saved model!


Epoch 6: 100%|██████████| 30/30 [01:52<00:00,  3.75s/it]
Evaluation: 100%|██████████| 10/10 [00:34<00:00,  3.49s/it]


- Train
  Loss:  0.0342 | Accuracy:  0.988
- Validation
  Loss:  0.0501 | Accuracy:  0.984
* Current best loss. Saved model!


Epoch 7: 100%|██████████| 30/30 [01:51<00:00,  3.71s/it]
Evaluation: 100%|██████████| 10/10 [00:34<00:00,  3.44s/it]


- Train
  Loss:  0.0287 | Accuracy:  0.990
- Validation
  Loss:  0.0538 | Accuracy:  0.983


Epoch 8: 100%|██████████| 30/30 [01:52<00:00,  3.74s/it]
Evaluation: 100%|██████████| 10/10 [00:34<00:00,  3.46s/it]


- Train
  Loss:  0.0239 | Accuracy:  0.991
- Validation
  Loss:  0.0555 | Accuracy:  0.984


Epoch 9: 100%|██████████| 30/30 [01:50<00:00,  3.68s/it]
Evaluation: 100%|██████████| 10/10 [00:35<00:00,  3.54s/it]

- Train
  Loss:  0.0209 | Accuracy:  0.992
- Validation
  Loss:  0.0561 | Accuracy:  0.983
No improvement in validation loss for 3 consecutive epochs. Stopping training.





In [37]:
test_res = eval(model, test_loader, loss_fn, device)

Evaluation: 100%|██████████| 10/10 [00:39<00:00,  3.95s/it]


In [38]:
print('- Test')
print(f"  Loss: {test_res['loss']: 3.4f} | "
      f"Accuracy: {test_res['acc']: .3f}")

- Test
  Loss:  0.0462 | Accuracy:  0.985
