In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, ConcatDataset, Subset
from torchvision import datasets, models
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import f1_score
from sklearn.utils.class_weight import compute_class_weight

# Configuration
train_csv_path = '/kaggle/input/aptos2019/train_1.csv'
val_csv_path = '/kaggle/input/aptos2019/test.csv'
train_images_path = '/kaggle/input/aptos2019/train_images/train_images'
val_images_path = '/kaggle/input/aptos2019/test_images/test_images'
additional_dataset_1 = '/kaggle/input/diabetic-retinopathy-resized-arranged'
additional_dataset_2 = '/kaggle/input/diabetic-retinopathy-dataset'
num_classes = 5

# Class distribution constraints
class_limits = {
    0: 1000,  # Healthy (No DR)
    1: 370,   # Mild DR
    2: 900,   # Moderate DR
    3: 190,   # Severe DR
    4: 290    # Proliferative DR
}

def get_data_loaders():
    class DRDataset(Dataset):
        def __init__(self, csv_file, root_dir, transform=None):
            self.df = pd.read_csv(csv_file)
            self.root_dir = root_dir
            self.transform = transform

        def __len__(self):
            return len(self.df)

        def __getitem__(self, idx):
            img_id = self.df.iloc[idx]['id_code']
            img_path = os.path.join(self.root_dir, f'{img_id}.png')
            image = Image.open(img_path).convert('RGB')
            label = int(self.df.iloc[idx]['diagnosis'])
            if self.transform:
                image = self.transform(image)
            return image, label

    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor()
    ])
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    train_dataset_csv = DRDataset(train_csv_path, train_images_path, train_transform)
    train_dataset_1 = datasets.ImageFolder(root=additional_dataset_1, transform=train_transform)
    train_dataset_2 = datasets.ImageFolder(root=additional_dataset_2, transform=train_transform)
    val_dataset = DRDataset(val_csv_path, val_images_path, val_transform)

    # Filtering dataset based on class limits
    def filter_dataset(dataset):
        class_counts = {k: 0 for k in class_limits.keys()}
        indices = []
        for idx in range(len(dataset)):
            _, label = dataset[idx]
            if class_counts[label] < class_limits[label]:
                indices.append(idx)
                class_counts[label] += 1
        return Subset(dataset, indices)
    
    train_dataset_csv = filter_dataset(train_dataset_csv)
    train_dataset_1 = filter_dataset(train_dataset_1)
    train_dataset_2 = filter_dataset(train_dataset_2)
    
    train_dataset = ConcatDataset([train_dataset_csv, train_dataset_1, train_dataset_2])

    labels = [sample[1] for sample in train_dataset]
    class_weights = compute_class_weight('balanced', classes=np.unique(labels), y=labels)
    class_weights = torch.tensor(class_weights, dtype=torch.float)

    weights = [class_weights[label] for label in labels]
    sampler = WeightedRandomSampler(weights, len(weights), replacement=True)

    train_loader = DataLoader(train_dataset, batch_size=32, sampler=sampler, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
    return train_loader, val_loader, class_weights

train_loader, val_loader, class_weights = get_data_loaders()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model Selection
def get_model(model_name):
    if model_name == "resnet":
        model = models.resnet50(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif model_name == "mobilenet":
        model = models.mobilenet_v2(pretrained=True)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    else:
        raise ValueError("Model not supported")
    return model.to(device)

# Training & Evaluation Functions
def train_one_epoch(model, dataloader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    all_preds, all_labels = [], []
    for images, labels in tqdm(dataloader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    return running_loss / len(dataloader), f1_score(all_labels, all_preds, average='weighted')

def evaluate(model, dataloader, criterion):
    model.eval()
    running_loss = 0.0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validation"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return running_loss / len(dataloader), f1_score(all_labels, all_preds, average='weighted')

# Training Loop
def train_model(model_name):
    model = get_model(model_name)
    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
    optimizer = optim.AdamW(model.parameters(), lr=2e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)
    
    best_val_loss = float('inf')
    early_stop_counter = 0
    early_stop_patience = 20
    
    for epoch in range(1, 21):
        print(f"\nEpoch {epoch}/20 - {model_name}")
        train_loss, train_f1 = train_one_epoch(model, train_loader, criterion, optimizer)
        val_loss, val_f1 = evaluate(model, val_loader, criterion)
        print(f"Train Loss: {train_loss:.4f} | Train F1: {train_f1:.4f}")
        print(f"Val Loss: {val_loss:.4f} | Val F1: {val_f1:.4f}")
        scheduler.step(val_loss)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), f"best_{model_name}_model.pth")
            print("✅ Saved new best model.")
        else:
            early_stop_counter += 1
        
        if early_stop_counter >= early_stop_patience:
            print("🔥 Early stopping triggered.")
            break
    print(f"🎉 Training Complete for {model_name}.")

# Train both models
train_model("resnet")
train_model("mobilenet")
