# Imports

In [1]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import torch.optim as optim
import torch.nn as nn

from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPS device is available.")
else:
    device = torch.device("cpu")
    print("MPS device not available, using CPU.")

# Helper Functions

In [3]:
def crop_and_pad_to_square(image):


    if isinstance(image, Image.Image):
        image = TF.to_tensor(image) 
        
    # Convert to grayscale
    gray = torch.mean(image, dim=0, keepdim=True)
    gray_float = gray.float()
    gray_uint8 = (gray_float * 255.0).byte() 

    # Create binary mask
    binary_mask = (gray_uint8 > 25).float() 

    # Find non-zero elements
    non_zero_indices = torch.nonzero(binary_mask[0])  

    if non_zero_indices.size(0) == 0: 
        return image

    # Get the bounding box
    top_left = torch.min(non_zero_indices, dim=0)[0]
    bottom_right = torch.max(non_zero_indices, dim=0)[0]

    # Crop the image to the bounding box
    cropped_image = image[:, top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]]
    cropped_height = bottom_right[0] - top_left[0]
    cropped_width = bottom_right[1] - top_left[1]

    # Pad the image to make it square
    if cropped_width > cropped_height:
        # Pad height to match width
        padded_image = F.pad(cropped_image, (0, 0, (cropped_width - cropped_height) // 2, (cropped_width - cropped_height + 1) // 2))
    elif cropped_height > cropped_width:
        # Pad width to match height
        padded_image = F.pad(cropped_image, ((cropped_height - cropped_width) // 2, (cropped_height - cropped_width + 1) // 2, 0, 0))
    else:
        padded_image = cropped_image  

    return padded_image

class CropAndPadToSquare:
    def __call__(self, image):
        return crop_and_pad_to_square(image)

In [4]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0, verbose=False, path='checkpoint.pth'):

        self.patience = patience
        self.delta = delta
        self.verbose = verbose
        self.best_loss = None
        self.counter = 0
        self.early_stop = False
        self.path = path

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(val_loss, model)
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Save the model when validation loss decreases.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.best_loss:.6f} --> {val_loss:.6f}).  Saving model...')
        torch.save(model.state_dict(), self.path)

## Create DataSet

In [5]:
class OcularDiseaseDataset(Dataset):
    def __init__(self, img_dir, csv_file, transform=None):

        self.img_dir = img_dir
        self.df = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        
        img_id = self.df.iloc[idx]['ID']
        img_path = os.path.join(self.img_dir, f"{img_id}.png")
        
        # Load
        image = Image.open(img_path).convert("RGB")
    
        other_conditions = self.df.iloc[idx][['DR', 'MH', 'ODC', 'TSLN', 'DN', 'MYA', 'ARMD']].values
        label = torch.tensor([ *other_conditions], dtype=torch.float32)
        multi_class_label = torch.argmax(label).item()
        
        # transformations 
        if self.transform:
            image = self.transform(image)

        return image, multi_class_label

# Load DataSet

In [6]:
BATCH_SIZE = 64
transform = transforms.Compose([
    CropAndPadToSquare(),
    transforms.Resize((512, 512)), 

])

train_dataset = OcularDiseaseDataset(img_dir="../../data/RFMiD/img/Train", csv_file="../../data/RFMiD/labels/Filtered_Train.csv", transform=transform)
validation_dataset = OcularDiseaseDataset(img_dir="../../data/RFMiD/img/Validation", csv_file="../../data/RFMiD/labels/Filtered_Validation.csv", transform=transform)
test_dataset = OcularDiseaseDataset(img_dir="../../data/RFMiD/img/Test", csv_file="../../data/RFMiD/labels/Filtered_Test.csv", transform=transform)

train_loader = DataLoader(train_dataset, batch_size= BATCH_SIZE, shuffle= True)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Classifier

In [7]:
class SimpleClassifier(nn.Module):
    def __init__(self, num_classes=7):
        super(SimpleClassifier, self).__init__()
        
        # Define a simple feed-forward neural network
        self.fc1 = nn.Linear(3 * 512 * 512, 512)  
        self.relu = nn.ReLU()                    
        self.fc2 = nn.Linear(512, num_classes)   
    
    def forward(self, x):
        x = x.view(x.size(0), -1)  
        x = self.relu(self.fc1(x))
        x = self.fc2(x)           
        return x

In [8]:
model = SimpleClassifier(num_classes=7).to(device)
criterion = nn.CrossEntropyLoss()      
optimizer = optim.Adam(model.parameters(), lr=0.001)  

In [None]:
EPOCH = 100   
best_mdl_path = '../../save_models/best_model.pth'

early_stopping = EarlyStopping(patience=5, verbose=True, path=best_mdl_path)

for epoch in range(EPOCH):
    model.train()  
    running_loss = 0.0

    for images, labels in train_loader:  
        images = images.to(device)
        labels = labels.to(device)
        # forward
        outputs = model(images)           
        loss = criterion(outputs, labels) 

        # Backward pass and optim
        optimizer.zero_grad()  
        loss.backward()        
        optimizer.step()      
        
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{EPOCH}], Loss: {running_loss/len(train_loader)}")
    
    # Validation phase
    model.eval()  
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():  
        for images, labels in validation_loader:  
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_accuracy = 100 * correct / total
    avg_val_loss = val_loss / len(validation_loader)

    print(f'Validation Loss: {avg_val_loss}, Validation Accuracy: {val_accuracy:.2f}%')
    
    # Check for early stopping
    early_stopping(avg_val_loss, model)
    
    if early_stopping.early_stop:
        print("Early stopping triggered. Stopping training.")
        break

# Testing 

In [None]:
model.load_state_dict(torch.load(best_mdl_path))

model.eval() 
test_loss = 0.0
correct = 0
total = 0


with torch.no_grad():
    for images, labels in test_loader:  # Test batch
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        test_loss += loss.item()

        # Calculate accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = 100 * correct / total
print(f'Test Loss: {test_loss / len(test_loader)}, Test Accuracy: {test_accuracy:.2f}%')