# ConvNEXT for DR detection

## Setup

### Google Drive Access

In [121]:
from google.colab import drive
import os

# Parameters
DATASET_PATH = '/content/drive/My Drive/University Of Stirling/Dissertation/ConvNEXT/APTOS2019'
PREP_PATH = DATASET_PATH + "/preprocessed/"

# Load Dataset From Drive
drive.mount('/content/drive')

os.chdir(DATASET_PATH)
print("CWD:",os.getcwd())

if not os.path.exists(PREP_PATH):
  os.mkdir(PREP_PATH)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
CWD: /content/drive/My Drive/University Of Stirling/Dissertation/ConvNEXT/APTOS2019


### Environment Setup

In [122]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [123]:
!pip install torchinfo

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
'''%%bash

pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
git clone https://github.com/facebookresearch/ConvNeXt
pip install timm==0.3.2 tensorboardX six
pip install submitit
'''

## Imports and global parameters

In [124]:
import pandas as pd
import cv2
import numpy as np
import os
from sklearn.model_selection import train_test_split
import shutil
import torch
import torchvision
from torchvision.models import ConvNeXt_Tiny_Weights
from torch import nn
from torchvision.io import read_image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader 
from torch.utils.data import WeightedRandomSampler
from torch.utils.data import random_split
import torchvision.transforms as transforms

In [137]:
IMG_SIZE = 224
BATCH_SIZE = 10
NUM_EPOCHS = 1
MODEL_PATH="/content/drive/My Drive/University Of Stirling/Dissertation/ConvNEXT/checkpoints/checkpoint.pth"

## Dataset Preparation

In [6]:
'''
# ref: https://www.kaggle.com/code/ratthachat/aptos-eye-preprocessing-in-diabetic-retinopathy/notebook
# ref: https://www.kaggle.com/code/ratthachat/aptos-eye-preprocessing-in-diabetic-retinopathy/notebook
# ref for circle crop: https://github.com/debayanmitra1993-data/Blindness-Detection-Diabetic-Retinopathy-/blob/master/research_paper_implementation.ipynb
def crop_image_from_gray(img,tol=7):
    if img.ndim ==2:
        mask = img>tol
        return img[np.ix_(mask.any(1),mask.any(0))]
    elif img.ndim==3:
        gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        mask = gray_img>tol
        
        check_shape = img[:,:,0][np.ix_(mask.any(1),mask.any(0))].shape[0]
        if (check_shape == 0): # image is too dark so that we crop out everything,
            return img # return original image
        else:
            img1=img[:,:,0][np.ix_(mask.any(1),mask.any(0))]
            img2=img[:,:,1][np.ix_(mask.any(1),mask.any(0))]
            img3=img[:,:,2][np.ix_(mask.any(1),mask.any(0))]
            img = np.stack([img1,img2,img3],axis=-1)
        return img

def circle_crop(img, sigmaX = 30):   
    """
    Create circular crop around image centre    
    """    
    img = crop_image_from_gray(img)    
    
    height, width, depth = img.shape    
    
    x = int(width/2)
    y = int(height/2)
    r = np.amin((x,y))
    
    circle_img = np.zeros((height, width), np.uint8)
    cv2.circle(circle_img, (x,y), int(r), 1, thickness=-1)
    img = cv2.bitwise_and(img, img, mask=circle_img)
    img = crop_image_from_gray(img)
    return img 

def preprocess(id_code):
  path = DATASET_PATH + "/train_images/" + id_code + ".png"

  if(os.path.isfile(path) == False):
    print(id_code + " does not exist!")
    return

  img = cv2.imread(path)


  # Circle crop
  img = circle_crop(img)

  # Resize the image
  img = cv2.resize(img, (224, 224))

  # Extract Green Channel
  img[:,:,0] = 0
  img[:,:,2] = 0

  # Convert to Greyscale
  img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

  # Apply Gaussian Blur
  img = cv2.addWeighted(img,4, cv2.GaussianBlur( img , (0,0) , 512/10) ,-4 ,128)
  
  # Perform histogram equalization

  clahe = cv2.createCLAHE(clipLimit=5.0, tileGridSize=(8,8))
  img = clahe.apply(img)

  cv2.imwrite(PREP_PATH + id_code + ".png", img)


for id_code in dataset["id_code"]:
  preprocess(id_code) 
  '''

In [126]:
# https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

class DrDetectionDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.labels = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        path = os.path.join(self.root_dir, self.labels.iloc[idx, 0])
        img = read_image(path + ".png")
        lbl = self.labels.iloc[idx, 1]
        if self.transform:
            img = self.transform(img)
        return img, lbl

global_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Lambda(lambda image: image.convert('RGB')),
    # Data Augmentation
    #transforms.RandomHorizontalFlip(),
    #transforms.RandomRotation(20),
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    #transforms.Normalize([0.5413734, 0.5413734, 0.5413734], [0.17313044, 0.17313044, 0.17313044])
    ])


dataset = DrDetectionDataset(DATASET_PATH + "/train.csv", PREP_PATH, transform = global_transforms)

In [127]:
#Train-test-validation split

# Set manual seed for reproducible results
torch.manual_seed(42)
torch.cuda.manual_seed(42)

# 60% train, 20% test, 20% valid
train_size = int(0.8 * len(dataset))
test_val_size = (len(dataset) - train_size)

train_dataset, test_dataset = random_split(dataset, [train_size, test_val_size])

train_size = int(len(train_dataset) - test_val_size)
train_dataset, valid_dataset = random_split(train_dataset, [train_size, test_val_size])

print("Train size: ", len(train_dataset))
print("Val size: ", len(valid_dataset))
print("Test size: ", len(test_dataset))

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

Train size:  2196
Val size:  733
Test size:  733


In [None]:
'''
# Use a Weighted Sampler to fix class imbalances

class_count = np.zeros(5)
for image, label in train_dataset:
  class_count[label] +=1

class_weights = 1. / class_count
weights = np.array([])

for image, label in train_dataset:
  weights = np.append(weights, class_weights[label])
  
weights = torch.from_numpy(weights)
weighted_sampler = WeightedRandomSampler(weights=weights,num_samples=len(train_dataset),replacement=True)

train_loader = DataLoader(train_dataset, batch_size=4, sampler = weighted_sampler)
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True)
'''

# Training

## Get Model

In [128]:
#https://www.learnpytorch.io/06_pytorch_transfer_learning/
from torchinfo import summary

# Download pre-trained weights on IMAGENET
weights = ConvNeXt_Tiny_Weights.IMAGENET1K_V1

model = torchvision.models.convnext_tiny(weights = weights).to(device)
print("Base Model: ")
print(summary(model=model, 
        input_size=(BATCH_SIZE, 3, 224, 224),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]) )

Base Model: 
Layer (type (var_name))                                 Input Shape          Output Shape         Param #              Trainable
ConvNeXt (ConvNeXt)                                     [10, 3, 224, 224]    [10, 1000]           --                   True
├─Sequential (features)                                 [10, 3, 224, 224]    [10, 768, 7, 7]      --                   True
│    └─Conv2dNormActivation (0)                         [10, 3, 224, 224]    [10, 96, 56, 56]     --                   True
│    │    └─Conv2d (0)                                  [10, 3, 224, 224]    [10, 96, 56, 56]     4,704                True
│    │    └─LayerNorm2d (1)                             [10, 96, 56, 56]     [10, 96, 56, 56]     192                  True
│    └─Sequential (1)                                   [10, 96, 56, 56]     [10, 96, 56, 56]     --                   True
│    │    └─CNBlock (0)                                 [10, 96, 56, 56]     [10, 96, 56, 56]     79,296          

In [129]:
from torchinfo import summary

# Tune the model parameters to fit our domain
# Sets Feature extraction layers as not trainable
for param in model.features.parameters():
    param.requires_grad = False
    
model.classifier[2] = torch.nn.Linear(in_features=768, out_features=5, bias=True)

print(summary(model=model, 
        input_size=(32, 3, 224, 224),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]))

Layer (type (var_name))                                 Input Shape          Output Shape         Param #              Trainable
ConvNeXt (ConvNeXt)                                     [32, 3, 224, 224]    [32, 5]              --                   Partial
├─Sequential (features)                                 [32, 3, 224, 224]    [32, 768, 7, 7]      --                   False
│    └─Conv2dNormActivation (0)                         [32, 3, 224, 224]    [32, 96, 56, 56]     --                   False
│    │    └─Conv2d (0)                                  [32, 3, 224, 224]    [32, 96, 56, 56]     (4,704)              False
│    │    └─LayerNorm2d (1)                             [32, 96, 56, 56]     [32, 96, 56, 56]     (192)                False
│    └─Sequential (1)                                   [32, 96, 56, 56]     [32, 96, 56, 56]     --                   False
│    │    └─CNBlock (0)                                 [32, 96, 56, 56]     [32, 96, 56, 56]     (79,296)             

## Train

In [132]:
#https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Loop through each epoch
for epoch in range(NUM_EPOCHS):
  running_loss = 0.0

  for i, data in enumerate(train_loader, 0):
    # get the inputs; data is a list of [inputs, labels]
    inputs, labels = data 
    inputs, labels = inputs.to(device), labels.to(device)

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = model(inputs)
    loss = loss_fn(outputs, labels)
    loss.backward()
    optimizer.step()

    # print statistics
    running_loss += loss.item()
    if i % 2000 == 1999:    # print every 2000 mini-batches
        print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
        running_loss = 0.0

torch.save(model.state_dict(), MODEL_PATH)


In [144]:
dataiter = iter(test_loader)
inputs, labels = dataiter.next()
inputs, labels = inputs.to(device), labels.to(device)

print('GroundTruth: ', ' '.join(f'{labels[j]}' for j in range(4))) 
model.load_state_dict(torch.load(MODEL_PATH))
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join(f'{predicted[j]}'
                              for j in range(4)))

GroundTruth:  0 4 0 0
Predicted:  0 2 0 0


In [145]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        # calculate outputs by running images through the network
        outputs = model(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 733 test images: {100 * correct // total} %')

Accuracy of the network on the 733 test images: 73 %
