In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, RocCurveDisplay
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Define data directories
data_dir = '/home/jovyan/.cache/kagglehub/datasets/nandanp6/cataract-image-dataset/versions/3/processed_images/'

In [None]:
# Data transformations with augmentation
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

transform_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [None]:
# Load dataset
full_dataset = datasets.ImageFolder(os.path.join(data_dir,'train'), transform=transform_train)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])


# Update validation transform
val_dataset.dataset.transform = transform_val

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [None]:
# Define a Vanilla CNN
def build_cnn():
    return nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Conv2d(32, 64, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2, 2),
        nn.Flatten(),
        nn.Linear(64 * 56 * 56, 128),
        nn.ReLU(),
        nn.Linear(128, 1),
        nn.Sigmoid()
    )

def get_pretrained_model(model_name):
    if model_name == "vgg16":
        model = models.vgg16(pretrained=True)
        model.classifier[6] = nn.Linear(4096, 1)
    elif model_name == "mobilenetv2":
        model = models.mobilenet_v2(pretrained=True)
        model.classifier[1] = nn.Linear(model.last_channel, 1)
    elif model_name == "resnet50":
        model = models.resnet50(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, 1)
    model = model.to(device)
    return model

In [None]:
def train_model(model, criterion, optimizer, train_loader, val_loader, epochs=10, patience=3):
    train_losses, val_losses, train_accuracies, val_accuracies = [], [], [], []
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device).float()
            optimizer.zero_grad()
            outputs = model(images).squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            preds = torch.sigmoid(outputs) > 0.5
            correct_train += (preds == labels).sum().item()
            total_train += labels.size(0)

            progress_bar.set_postfix(loss=running_loss / (total_train // labels.size(0)))

        train_losses.append(running_loss / len(train_loader))
        train_accuracies.append(correct_train / total_train)

        val_loss, val_accuracy = evaluate_model(model, val_loader, return_metrics=True)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)

        print(f"Epoch {epoch + 1}, Train Loss: {train_losses[-1]:.4f}, Train Accuracy: {train_accuracies[-1]:.4f}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break

    # Plot training and validation loss
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.legend()
    plt.title('Loss per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.show()

    # Plot training and validation accuracy
    plt.figure(figsize=(10, 5))
    plt.plot(train_accuracies, label='Training Accuracy')
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.legend()
    plt.title('Accuracy per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.show()

In [None]:
# Evaluation function
def evaluate_model(model, val_loader, return_metrics=False):
    model.eval()
    preds, true_labels = [], []
    val_loss = 0.0
    criterion = nn.BCEWithLogitsLoss()
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device).float()
            outputs = model(images).squeeze()
            val_loss += criterion(outputs, labels).item()
            preds.extend(torch.sigmoid(outputs).cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    preds_binary = [1 if p > 0.5 else 0 for p in preds]
    val_loss /= len(val_loader)
    accuracy = sum(p == t for p, t in zip(preds_binary, true_labels)) / len(true_labels)

    if return_metrics:
        return val_loss, accuracy

    print("Accuracy:")
    print(accuracy)
    
    print("Classification Report:")
    print(classification_report(true_labels, preds_binary))

    cm = confusion_matrix(true_labels, preds_binary)
    print("Confusion Matrix:")
    print(cm)

    auc = roc_auc_score(true_labels, preds)
    print(f"AUC Score: {auc:.4f}")

    RocCurveDisplay.from_predictions(true_labels, preds)
    plt.title('ROC Curve')
    plt.show()

In [None]:


# Instantiate and train
model_choice = "mobilenetv2"  # Choose "cnn", "vgg16", or "mobilenetv2"
if model_choice == "cnn":
    model = build_cnn().to(device)
elif model_choice in ["vgg16", "mobilenetv2"]:
    model = get_pretrained_model(model_choice)



In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
train_model(model, criterion, optimizer, train_loader, val_loader, epochs=10)

In [None]:
test_data = datasets.ImageFolder(os.path.join(data_dir,'test'),transform = transform_val)

In [None]:
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

In [None]:
evaluate_model(model, test_loader)  