# Imports

In [1]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import torch.optim as optim
import torch.nn as nn

from PIL import Image,ImageOps
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA device is available.")
else:
    device = torch.device("cpu")
    print("CUDA device not available, using CPU.")


CUDA device is available.


# Helper Functions

## crop and pad to square

In [None]:
def crop_and_pad_to_square(image):
        
    # Convert to grayscale
    gray = torch.mean(image, dim=0, keepdim=True).to(device)

    # Use a threshold to create binary mask
    binary_mask = (gray > 0.1).float().to(device)


    # Find the non-zero elements
    non_zero_indices = torch.nonzero(binary_mask[0]).to(device)


    if non_zero_indices.size(0) == 0: 
        return image

    # Get the bounding box
    top_left = torch.min(non_zero_indices, dim=0)[0]
    bottom_right = torch.max(non_zero_indices, dim=0)[0]

    cropped_image = image[:, top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]]

    # Calculate padding to square the image
    delta_w = bottom_right[1] - top_left[1]
    delta_h = bottom_right[0] - top_left[0]
    padding = (delta_h - delta_w) // 2

    # Pad and return the square image
    square_image = F.pad(cropped_image, (padding, padding, padding, padding), mode='constant', value=0)
    return square_image


class CropAndPadToSquare:
    def __call__(self, image):
        return crop_and_pad_to_square(image)
    




## early stopping

In [None]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0, verbose=False, path='checkpoint.pth'):

        self.patience = patience
        self.delta = delta
        self.verbose = verbose
        self.best_loss = None
        self.counter = 0
        self.early_stop = False
        self.path = path

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(val_loss, model)
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Save the model when validation loss decreases.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.best_loss:.6f} --> {val_loss:.6f}).  Saving model...')
        torch.save(model.state_dict(), self.path)

## Create DataSet

In [None]:
class OcularDiseaseDataset(Dataset):
    def __init__(self, img_dir, csv_file, transform=None, device= device):

        self.img_dir = img_dir
        self.df = pd.read_csv(csv_file)
        self.transform = transform
        self.device = device

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        
        img_id = self.df.iloc[idx]['ID']
        img_path = os.path.join(self.img_dir, f"{img_id}.png")
        
        # Load
        image = Image.open(img_path).convert("RGB")
    
        other_conditions = self.df.iloc[idx][['DR', 'MH', 'ODC', 'TSLN', 'DN', 'MYA', 'ARMD']].values
        label = torch.tensor([ *other_conditions], dtype=torch.float32)
        multi_class_label = torch.argmax(label).item()
        
        # transformations 
        if self.transform:
            image = self.transform(image)


        return image, multi_class_label

# Load DataSet

In [None]:
BATCH_SIZE = 64
transform = transforms.Compose([
    transforms.ToTensor(),
    CropAndPadToSquare(),
    transforms.Resize((512, 512)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  
])

train_dataset = OcularDiseaseDataset(img_dir="../../data/RFMiD/img/Train", csv_file="../../data/RFMiD/labels/Filtered_Train.csv", transform=transform)
validation_dataset = OcularDiseaseDataset(img_dir="../../data/RFMiD/img/Validation", csv_file="../../data/RFMiD/labels/Filtered_Validation.csv", transform=transform)
test_dataset = OcularDiseaseDataset(img_dir="../../data/RFMiD/img/Test", csv_file="../../data/RFMiD/labels/Filtered_Test.csv", transform=transform)

train_loader = DataLoader(train_dataset, batch_size= BATCH_SIZE, shuffle= True)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Classifier

In [None]:
class SimpleConvClassifier(nn.Module):
    def __init__(self, num_classes=7):
        super(SimpleConvClassifier, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 128 * 128, 512)
        self.fc2 = nn.Linear(512, num_classes)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
model = SimpleConvClassifier(num_classes=7).to(device)
criterion = nn.CrossEntropyLoss()      
optimizer = optim.Adam(model.parameters(), lr=0.001)  

In [None]:
EPOCH = 100   
best_mdl_path = '../../save_models/best_model.pth'

early_stopping = EarlyStopping(patience=5, verbose=True, path=best_mdl_path)

for epoch in range(EPOCH):
    model.train()  
    running_loss = 0.0

    for images, labels in train_loader:  
        
        images = images.to(device)
        labels = labels.to(device)

        # forward
        outputs = model(images)           
        loss = criterion(outputs, labels) 

        # Backward pass and optim
        optimizer.zero_grad()  
        loss.backward()        
        optimizer.step()      
        
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{EPOCH}], Loss: {running_loss/len(train_loader)}")
    
    # Validation phase
    model.eval()  
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():  
        for images, labels in validation_loader:  
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_accuracy = 100 * correct / total
    avg_val_loss = val_loss / len(validation_loader)

    print(f'Validation Loss: {avg_val_loss}, Validation Accuracy: {val_accuracy:.2f}%')
    
    # Check for early stopping
    early_stopping(avg_val_loss, model)
    
    if early_stopping.early_stop:
        print("Early stopping triggered. Stopping training.")
        break

# Testing 

In [None]:
model.load_state_dict(torch.load(best_mdl_path))

model.eval() 
test_loss = 0.0
correct = 0
total = 0


with torch.no_grad():
    for images, labels in test_loader:  # Test batch
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        test_loss += loss.item()

        # Calculate accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = 100 * correct / total
print(f'Test Loss: {test_loss / len(test_loader)}, Test Accuracy: {test_accuracy:.2f}%')