In [None]:
# Retinal Disease Detection - Training Notebook
# This notebook should be run in Google Colab

# ============================================
# CELL 1: Setup and Mount Google Drive
# ============================================
"""
Run this first to mount Google Drive
"""
from google.colab import drive
drive.mount('/content/drive')

# ============================================
# CELL 2: Install Required Packages
# ============================================
"""
Install all necessary packages
"""
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install pytorch-grad-cam
!pip install albumentations
!pip install kaggle

# ============================================
# CELL 3: Setup Kaggle API (IMPORTANT!)
# ============================================
"""
Before running this:
1. Go to Kaggle.com -> Account -> Create New API Token
2. This downloads kaggle.json
3. Upload kaggle.json to Colab using the file upload button
"""
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# ============================================
# CELL 4: Download Dataset from Kaggle
# ============================================
"""
Download APTOS 2019 Blindness Detection dataset
"""
# Create directory structure in Google Drive
!mkdir -p /content/drive/MyDrive/retinal_disease/data/raw
!mkdir -p /content/drive/MyDrive/retinal_disease/models/saved_models
!mkdir -p /content/drive/MyDrive/retinal_disease/outputs/plots
!mkdir -p /content/drive/MyDrive/retinal_disease/outputs/gradcam

# Download dataset
!kaggle competitions download -c aptos2019-blindness-detection -p /content/drive/MyDrive/retinal_disease/data/raw

# Unzip dataset
import zipfile
import os

zip_path = '/content/drive/MyDrive/retinal_disease/data/raw'
for file in os.listdir(zip_path):
    if file.endswith('.zip'):
        with zipfile.ZipFile(os.path.join(zip_path, file), 'r') as zip_ref:
            zip_ref.extractall(zip_path)
        print(f"Extracted: {file}")

# ============================================
# CELL 5: Upload Project Files
# ============================================
"""
Upload your project files to Google Drive:
1. Create folder: /content/drive/MyDrive/retinal_disease/src/
2. Upload all .py files from src/ folder
OR you can clone from GitHub if you have it there
"""

# If using GitHub (recommended):
# !git clone https://github.com/yourusername/retinal-disease-detection.git
# %cd retinal-disease-detection

# If uploaded to Drive manually:
import sys
sys.path.append('/content/drive/MyDrive/retinal_disease')

# ============================================
# CELL 6: Import Libraries
# ============================================
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

# Import custom modules
from src.config import Config
from src.dataset import RetinalDataset, get_transforms
from src.model import get_model
from src.train import train_model, train_epoch, validate_epoch
from src.evaluate import (
    evaluate_model, plot_confusion_matrix, 
    plot_training_history, plot_roc_curves
)
from src.utils import set_seed, visualize_predictions, plot_gradcam_samples

# ============================================
# CELL 7: Setup Configuration
# ============================================
# Set random seed
set_seed(Config.SEED)

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# ============================================
# CELL 8: Load and Prepare Data
# ============================================
# Load CSV
train_csv_path = '/content/drive/MyDrive/retinal_disease/data/raw/train.csv'
train_df = pd.read_csv(train_csv_path)

print(f"Total images: {len(train_df)}")
print(f"\nClass distribution:")
print(train_df['diagnosis'].value_counts().sort_index())

# Visualize class distribution
plt.figure(figsize=(10, 6))
train_df['diagnosis'].value_counts().sort_index().plot(kind='bar')
plt.title('Class Distribution')
plt.xlabel('Diagnosis')
plt.ylabel('Count')
plt.xticks(range(5), Config.CLASS_NAMES, rotation=45)
plt.tight_layout()
plt.show()

# ============================================
# CELL 9: Split Dataset
# ============================================
# Split into train, val, test
train_data, temp_data = train_test_split(
    train_df, test_size=0.2, random_state=Config.SEED, 
    stratify=train_df['diagnosis']
)
val_data, test_data = train_test_split(
    temp_data, test_size=0.5, random_state=Config.SEED,
    stratify=temp_data['diagnosis']
)

print(f"Train size: {len(train_data)}")
print(f"Val size: {len(val_data)}")
print(f"Test size: {len(test_data)}")

# ============================================
# CELL 10: Create Datasets and DataLoaders
# ============================================
# Image directory
img_dir = '/content/drive/MyDrive/retinal_disease/data/raw/train_images'

# Create datasets
train_dataset = RetinalDataset(
    train_data, img_dir, 
    transform=get_transforms(Config.IMAGE_SIZE, is_training=True)
)
val_dataset = RetinalDataset(
    val_data, img_dir,
    transform=get_transforms(Config.IMAGE_SIZE, is_training=False)
)
test_dataset = RetinalDataset(
    test_data, img_dir,
    transform=get_transforms(Config.IMAGE_SIZE, is_training=False)
)

# Create dataloaders
train_loader = DataLoader(
    train_dataset, batch_size=Config.BATCH_SIZE,
    shuffle=True, num_workers=Config.NUM_WORKERS,
    pin_memory=Config.PIN_MEMORY
)
val_loader = DataLoader(
    val_dataset, batch_size=Config.BATCH_SIZE,
    shuffle=False, num_workers=Config.NUM_WORKERS,
    pin_memory=Config.PIN_MEMORY
)
test_loader = DataLoader(
    test_dataset, batch_size=Config.BATCH_SIZE,
    shuffle=False, num_workers=Config.NUM_WORKERS,
    pin_memory=Config.PIN_MEMORY
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

# ============================================
# CELL 11: Visualize Sample Images
# ============================================
# Get a batch
images, labels = next(iter(train_loader))

# Denormalize and display
fig, axes = plt.subplots(2, 4, figsize=(15, 8))
axes = axes.flatten()

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

for i in range(8):
    img = images[i].numpy().transpose(1, 2, 0)
    img = img * std + mean
    img = np.clip(img, 0, 1)
    
    axes[i].imshow(img)
    axes[i].set_title(f"{Config.CLASS_NAMES[labels[i]]}")
    axes[i].axis('off')

plt.tight_layout()
plt.show()

# ============================================
# CELL 12: Initialize Model
# ============================================
model = get_model(
    model_name=Config.MODEL_NAME,
    num_classes=Config.NUM_CLASSES,
    pretrained=Config.PRETRAINED,
    device=device
)

# ============================================
# CELL 13: Setup Training Components
# ============================================
# Loss function (with class weights for imbalanced data)
class_counts = train_df['diagnosis'].value_counts().sort_index().values
class_weights = 1.0 / class_counts
class_weights = class_weights / class_weights.sum() * len(class_counts)
class_weights = torch.FloatTensor(class_weights).to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)

# Optimizer
optimizer = optim.Adam(
    model.parameters(),
    lr=Config.LEARNING_RATE,
    weight_decay=Config.WEIGHT_DECAY
)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

print("Training setup complete!")

# ============================================
# CELL 14: Train Model
# ============================================
print("Starting training...")
print("="*50)

history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=Config.NUM_EPOCHS,
    device=device,
    save_dir=Config.MODEL_SAVE_DIR,
    patience=Config.PATIENCE
)

# ============================================
# CELL 15: Plot Training History
# ============================================
plot_training_history(
    history,
    save_path='/content/drive/MyDrive/retinal_disease/outputs/plots/training_history.png'
)

# ============================================
# CELL 16: Load Best Model and Evaluate
# ============================================
# Load best model
checkpoint = torch.load(f'{Config.MODEL_SAVE_DIR}/best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']} with Val Acc: {checkpoint['val_acc']:.2f}%")

# Evaluate on test set
results = evaluate_model(model, test_loader, device, Config.CLASS_NAMES)

# ============================================
# CELL 17: Plot Confusion Matrix
# ============================================
plot_confusion_matrix(
    results['labels'],
    results['predictions'],
    Config.CLASS_NAMES,
    save_path='/content/drive/MyDrive/retinal_disease/outputs/plots/confusion_matrix.png'
)

# ============================================
# CELL 18: Plot ROC Curves
# ============================================
roc_auc_scores = plot_roc_curves(
    results['labels'],
    results['probabilities'],
    Config.CLASS_NAMES,
    save_path='/content/drive/MyDrive/retinal_disease/outputs/plots/roc_curves.png'
)

# ============================================
# CELL 19: Visualize Predictions
# ============================================
visualize_predictions(
    model, test_loader, Config.CLASS_NAMES, device, num_images=16
)

# ============================================
# CELL 20: Generate Grad-CAM Visualizations
# ============================================
# Get target layer for Grad-CAM (last conv layer of backbone)
if Config.MODEL_NAME == 'resnet50':
    target_layer = model.backbone.layer4[-1]
elif Config.MODEL_NAME == 'efficientnet_b0':
    target_layer = model.backbone.features[-1]

plot_gradcam_samples(
    model, test_loader, Config.CLASS_NAMES, target_layer, device,
    num_samples=8,
    save_dir='/content/drive/MyDrive/retinal_disease/outputs/gradcam'
)

# ============================================
# CELL 21: Save Final Model
# ============================================
# Save complete model
torch.save({
    'model_state_dict': model.state_dict(),
    'config': {
        'model_name': Config.MODEL_NAME,
        'num_classes': Config.NUM_CLASSES,
        'image_size': Config.IMAGE_SIZE
    },
    'class_names': Config.CLASS_NAMES,
    'accuracy': results['accuracy']
}, '/content/drive/MyDrive/retinal_disease/models/final_model.pth')

print("Model saved successfully!")
print(f"Final Test Accuracy: {results['accuracy']*100:.2f}%")

# ============================================
# CELL 22: Test on Single Image (Optional)
# ============================================
"""
Test model on a single image
"""
from PIL import Image
import torchvision.transforms as transforms

def predict_single_image(image_path, model, device):
    model.eval()
    
    # Load and preprocess image
    img = Image.open(image_path).convert('RGB')
    transform = get_transforms(Config.IMAGE_SIZE, is_training=False)
    img_tensor = transform(img).unsqueeze(0).to(device)
    
    # Predict
    with torch.no_grad():
        output = model(img_tensor)
        probs = torch.softmax(output, dim=1)
        pred_class = torch.argmax(probs, dim=1).item()
        confidence = probs[0][pred_class].item()
    
    # Display
    plt.figure(figsize=(8, 6))
    plt.imshow(img)
    plt.title(f"Prediction: {Config.CLASS_NAMES[pred_class]}\nConfidence: {confidence*100:.2f}%")
    plt.axis('off')
    plt.show()
    
    return pred_class, confidence

# Example usage:
# test_image_path = '/content/drive/MyDrive/retinal_disease/data/raw/train_images/some_image.png'
# pred_class, confidence = predict_single_image(test_image_path, model, device)

print("\nâœ… Training Complete! All outputs saved to Google Drive.")