In [None]:
# CELL 1: Environment Setup & GPU Check
import os
import sys
import gc
import time
import json
import logging
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

import random
import numpy as np
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

print("Environment setup complete!")

Using device: cuda
GPU: Tesla T4
GPU Memory: 15.8 GB
Environment setup complete!


In [None]:
# =============================================================================
# CELL 2: Install Required Libraries (Only Non-Standard Ones)
# =============================================================================

import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
from PIL import Image
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

import matplotlib
matplotlib.use('Agg')

print("All libraries imported successfully!")

All libraries imported successfully!


In [None]:
# =============================================================================
# CELL 3: Google Drive Connection & Setup (BULLETPROOF VERSION)
# =============================================================================
import os
import sys
import logging
from datetime import datetime
from google.colab import drive

try:
    print("STEP 1: Mounting Google Drive...")
    drive.mount('/content/drive', force_remount=True)
    print("[SUCCESS] Google Drive mounted at /content/drive.")

    DRIVE_ROOT = "/content/drive/MyDrive/Diabetic_Retinopathy_Project"
    print(f"STEP 2: Checking if the root project directory exists at: {DRIVE_ROOT}")

    if not os.path.exists(DRIVE_ROOT):
        print("\n[FATAL ERROR] Python cannot find the directory, even though 'ls' can see it.")
        print("This might be a Google Drive sync issue. Please try 'Runtime -> Restart runtime' and run again.")
        raise FileNotFoundError(f"Directory not found: {DRIVE_ROOT}")

    print("[SUCCESS] Root project directory found by Python.")

    DATASET_PATH = f"{DRIVE_ROOT}/diabetic-retinopathy-224x224-gaussian-filtered"
    LOG_PATH = f"{DRIVE_ROOT}/training_logs"
    MODEL_SAVE_PATH = f"{DRIVE_ROOT}/models"

    print(f"STEP 3: Attempting to create log directory at: {LOG_PATH}")
    os.makedirs(LOG_PATH, exist_ok=True)
    print("[SUCCESS] Log directory created or already exists.")

    print(f"STEP 4: Attempting to create model save directory at: {MODEL_SAVE_PATH}")
    os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
    print("[SUCCESS] Model save directory created or already exists.")

    print("STEP 5: Setting up the logging system...")
    log_filename = f"{LOG_PATH}/training_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"

    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)

    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_filename),
            logging.StreamHandler(sys.stdout)
        ]
    )
    logger = logging.getLogger(__name__)
    print("[SUCCESS] Logging system initialized.")

    logger.info("Google Drive connected and all paths/logging are set up successfully!")
    logger.info(f"Dataset path is set to: {DATASET_PATH}")

    if not os.path.exists(DATASET_PATH):
        logger.error(f"Dataset sub-folder not found at {DATASET_PATH}")
        raise FileNotFoundError(f"Please make sure the dataset is at the correct sub-path: {DATASET_PATH}")
    else:
        logger.info("Dataset path verified successfully!")

except Exception as e:
    print(f"\n\n[CRITICAL ERROR] The cell failed with an exception: {e}")
    raise

STEP 1: Mounting Google Drive...
Mounted at /content/drive
[SUCCESS] Google Drive mounted at /content/drive.
STEP 2: Checking if the root project directory exists at: /content/drive/MyDrive/Diabetic_Retinopathy_Project
[SUCCESS] Root project directory found by Python.
STEP 3: Attempting to create log directory at: /content/drive/MyDrive/Diabetic_Retinopathy_Project/training_logs
[SUCCESS] Log directory created or already exists.
STEP 4: Attempting to create model save directory at: /content/drive/MyDrive/Diabetic_Retinopathy_Project/models
[SUCCESS] Model save directory created or already exists.
STEP 5: Setting up the logging system...
[SUCCESS] Logging system initialized.
2025-08-19 06:32:20,676 - INFO - Google Drive connected and all paths/logging are set up successfully!
2025-08-19 06:32:20,678 - INFO - Dataset path is set to: /content/drive/MyDrive/Diabetic_Retinopathy_Project/diabetic-retinopathy-224x224-gaussian-filtered
2025-08-19 06:32:20,680 - INFO - Dataset path verified suc

In [None]:
# =============================================================================
# CELL 4: Configuration & Hyperparameters (V2 UPGRADE)
# =============================================================================

CONFIG = {
    'IMG_SIZE': 512,
    'BATCH_SIZE': 32,
    'NUM_WORKERS': 4,
    'PREFETCH_FACTOR': 2,

    'MODEL_NAME': 'efficientnet_b3',
    'NUM_CLASSES': 5,
    'PRETRAINED': True,
    'DROP_RATE': 0.4,
    'DROP_PATH_RATE': 0.2,

    'EPOCHS': 30,
    'LEARNING_RATE': 1e-4,
    'WEIGHT_DECAY': 1e-5,
    'PATIENCE': 7,
    'MIN_LR': 1e-7,
    'T_MAX': 10,

    'ALPHA': 1.0,
    'BETA': 0.0,
    'GAMMA': 0.0,

    'USE_AMP': True,
    'GRADIENT_CLIP': 1.0,
    'ACCUMULATION_STEPS': 1,

    'AUGMENT_PROB': 0.8,
    'MIXUP_ALPHA': 0.2,
    'CUTMIX_ALPHA': 1.0,

    'CLASS_WEIGHTS': [1.0, 1.0, 1.0, 1.0, 1.0],

    'LESION_TYPES': ['microaneurysms', 'hemorrhages', 'hard_exudates', 'soft_exudates', 'neovascularization'],
    'ANATOMICAL_REGIONS': ['superior', 'inferior', 'nasal', 'temporal', 'macular']
}

config_path = f"{LOG_PATH}/config_v2_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(config_path, 'w') as f:
    json.dump(CONFIG, f, indent=2)

logger.info("V2 Configuration saved!")
logger.info(f"Training with: {CONFIG['MODEL_NAME']}, IMG_SIZE: {CONFIG['IMG_SIZE']}, Batch size: {CONFIG['BATCH_SIZE']}")

2025-08-19 06:32:20,695 - INFO - V2 Configuration saved!
2025-08-19 06:32:20,697 - INFO - Training with: efficientnet_b3, IMG_SIZE: 512, Batch size: 32


In [None]:
# =============================================================================
# CELL 5: Medical Knowledge Base & Pseudo-Labeling Rules
# =============================================================================

class MedicalKnowledgeBase:
    """Comprehensive medical knowledge base for diabetic retinopathy"""

    def __init__(self):
        self.severity_descriptions = {
            0: "No Diabetic Retinopathy",
            1: "Mild Non-Proliferative Diabetic Retinopathy",
            2: "Moderate Non-Proliferative Diabetic Retinopathy",
            3: "Severe Non-Proliferative Diabetic Retinopathy",
            4: "Proliferative Diabetic Retinopathy"
        }

        self.lesion_patterns = {
            0: {'microaneurysms': 0, 'hemorrhages': 0, 'hard_exudates': 0, 'soft_exudates': 0, 'neovascularization': 0},
            1: {'microaneurysms': 1, 'hemorrhages': 0, 'hard_exudates': 0, 'soft_exudates': 0, 'neovascularization': 0},
            2: {'microaneurysms': 1, 'hemorrhages': 1, 'hard_exudates': 1, 'soft_exudates': 0, 'neovascularization': 0},
            3: {'microaneurysms': 1, 'hemorrhages': 1, 'hard_exudates': 1, 'soft_exudates': 1, 'neovascularization': 0},
            4: {'microaneurysms': 1, 'hemorrhages': 1, 'hard_exudates': 1, 'soft_exudates': 1, 'neovascularization': 1}
        }

        self.recommendations = {
            0: "Continue routine diabetic care. Annual eye examination recommended.",
            1: "Ophthalmology referral within 12 months. Optimize diabetic control.",
            2: "Ophthalmology referral within 6-12 months. Consider more frequent monitoring.",
            3: "Urgent ophthalmology referral within 2-4 months. Intensive diabetic management.",
            4: "Immediate ophthalmology referral. May require urgent intervention."
        }

    def get_pseudo_labels(self, severity_level, image_features=None):
        """Generate pseudo labels for multi-task learning"""
        lesion_labels = self.lesion_patterns[severity_level].copy()

        if severity_level > 0:
            noise = np.random.normal(0, 0.1, len(lesion_labels))
            for i, (lesion, base_prob) in enumerate(lesion_labels.items()):
                if base_prob > 0:
                    prob = max(0.1, min(0.9, base_prob + noise[i]))
                    lesion_labels[lesion] = int(np.random.random() < prob)

        region_labels = {
            'superior': int(severity_level >= 2 and np.random.random() < 0.6),
            'inferior': int(severity_level >= 2 and np.random.random() < 0.6),
            'nasal': int(severity_level >= 1 and np.random.random() < 0.4),
            'temporal': int(severity_level >= 1 and np.random.random() < 0.5),
            'macular': int(severity_level >= 3 and np.random.random() < 0.7)
        }

        return lesion_labels, region_labels

    def generate_description_template(self, severity_level, lesion_findings, region_findings, confidence=0.9):
        """Generate detailed diagnostic description"""
        base_desc = self.severity_descriptions[severity_level]

        if severity_level == 0:
            return f"No signs of diabetic retinopathy detected (confidence: {confidence:.1%}). {self.recommendations[severity_level]}"

        findings = []

        active_lesions = [lesion for lesion, present in lesion_findings.items() if present]
        if active_lesions:
            findings.append(f"Detected lesions: {', '.join(active_lesions)}")

        active_regions = [region for region, involved in region_findings.items() if involved]
        if active_regions:
            findings.append(f"Affected regions: {', '.join(active_regions)}")

        description = f"{base_desc} (confidence: {confidence:.1%})"
        if findings:
            description += f". {'. '.join(findings)}."
        description += f" {self.recommendations[severity_level]}"

        return description

medical_kb = MedicalKnowledgeBase()
logger.info("Medical knowledge base initialized!")

2025-08-19 06:32:20,711 - INFO - Medical knowledge base initialized!


In [None]:
# =============================================================================
# CELL 6: Optimized Dataset Class (PRACTICAL V2 VERSION)
# =============================================================================

class OptimizedDRDataset(Dataset):
    """
    Memory-optimized dataset for diabetic retinopathy.
    Reads images from disk in __getitem__ to be cached by the OS.
    """
    def __init__(self, image_paths, labels, medical_kb, transforms=None, is_train=True):
        self.image_paths = image_paths
        self.labels = labels
        self.medical_kb = medical_kb
        self.transforms = transforms
        self.is_train = is_train

        self.lesion_labels = []
        self.region_labels = []
        for label in labels:
            lesion_dict, region_dict = medical_kb.get_pseudo_labels(label)
            self.lesion_labels.append(list(lesion_dict.values()))
            self.region_labels.append(list(region_dict.values()))

        logger.info(f"Dataset initialized with {len(self.image_paths)} images. Images will be loaded from disk.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = cv2.imread(img_path)
            if image is None: 
                raise IOError(f"cv2.imread returned None for image {img_path}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            if self.transforms:
                augmented = self.transforms(image=image)
                image = augmented['image']

            severity_label = torch.tensor(self.labels[idx], dtype=torch.long)
            lesion_label = torch.tensor(self.lesion_labels[idx], dtype=torch.float)
            region_label = torch.tensor(self.region_labels[idx], dtype=torch.float)

            return {
                'image': image,
                'severity': severity_label,
                'lesions': lesion_label,
                'regions': region_label
            }
        except Exception as e:
            logger.warning(f"Error loading image {img_path}: {e}. Returning a dummy image.")
            dummy_image = torch.zeros(3, CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE'])
            return {
                'image': dummy_image,
                'severity': torch.tensor(0, dtype=torch.long),
                'lesions': torch.zeros(len(CONFIG['LESION_TYPES'])),
                'regions': torch.zeros(len(CONFIG['ANATOMICAL_REGIONS']))
            }

In [None]:
# =============================================================================
# CELL 7: Advanced Data Augmentation Pipeline (V2 UPGRADE)
# =============================================================================

def get_train_transforms():
    return A.Compose([
        A.Resize(CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE']),
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.8),
            A.CLAHE(clip_limit=2.0, p=0.8),
            A.RandomGamma(gamma_limit=(80, 120), p=0.8),
        ], p=0.9),
        A.OneOf([
            A.Rotate(limit=15, p=0.8),
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.8),
        ], p=0.7),

        # --- VERTICAL FLIP REMOVED ---
        A.HorizontalFlip(p=0.5),

        A.OneOf([
            A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.3),
            A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),
        ], p=0.3),
        A.OneOf([
            A.GaussianBlur(blur_limit=(3, 5), p=0.3),
            A.GaussNoise(var_limit=(10.0, 30.0), p=0.3),
        ], p=0.2),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

def get_val_transforms():
    return A.Compose([
        A.Resize(CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE']),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

train_transforms = get_train_transforms()
val_transforms = get_val_transforms()

logger.info("V2 Data augmentation pipeline created (VerticalFlip removed)!")

2025-08-19 06:32:20,739 - INFO - V2 Data augmentation pipeline created (VerticalFlip removed)!


In [None]:
# =============================================================================
# CELL 8: Multi-Task Model Architecture
# =============================================================================

class MultiTaskDRModel(nn.Module):

    def __init__(self, model_name=CONFIG['MODEL_NAME'], num_classes=CONFIG['NUM_CLASSES'],
                 num_lesion_types=len(CONFIG['LESION_TYPES']),
                 num_regions=len(CONFIG['ANATOMICAL_REGIONS']),
                 pretrained=True):
        super(MultiTaskDRModel, self).__init__()

        self.backbone = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=0,  # Remove classifier
            drop_rate=CONFIG['DROP_RATE'],
            drop_path_rate=CONFIG['DROP_PATH_RATE']
        )

        self.feature_dim = self.backbone.num_features
        logger.info(f"Backbone feature dimension: {self.feature_dim}")

        self.attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(self.feature_dim, self.feature_dim // 8),
            nn.ReLU(inplace=True),
            nn.Linear(self.feature_dim // 8, self.feature_dim),
            nn.Sigmoid()
        )


        self.feature_norm = nn.BatchNorm1d(self.feature_dim)
        self.dropout = nn.Dropout(CONFIG['DROP_RATE'])

        self.severity_classifier = nn.Sequential(
            nn.Linear(self.feature_dim, self.feature_dim // 2),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(self.feature_dim // 2, num_classes)
        )

        self.lesion_detector = nn.Sequential(
            nn.Linear(self.feature_dim, self.feature_dim // 4),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(self.feature_dim // 4, num_lesion_types)
        )

        self.region_predictor = nn.Sequential(
            nn.Linear(self.feature_dim, self.feature_dim // 4),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(self.feature_dim // 4, num_regions)
        )

        self._init_weights()

    def _init_weights(self):
        """Initialize classifier weights"""
        for m in [self.severity_classifier, self.lesion_detector, self.region_predictor]:
            for layer in m:
                if isinstance(layer, nn.Linear):
                    nn.init.xavier_uniform_(layer.weight)
                    if layer.bias is not None:
                        nn.init.constant_(layer.bias, 0)

    def forward(self, x):
        features = self.backbone.forward_features(x)

        features = F.adaptive_avg_pool2d(features, 1).flatten(1)

        attention_weights = self.attention(features.unsqueeze(-1).unsqueeze(-1))
        features = features * attention_weights

        features = self.feature_norm(features)
        features = self.dropout(features)

        severity_logits = self.severity_classifier(features)
        lesion_logits = self.lesion_detector(features)
        region_logits = self.region_predictor(features)

        return {
            'severity': severity_logits,
            'lesions': lesion_logits,
            'regions': region_logits,
            'features': features 
        }

model = MultiTaskDRModel().to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

logger.info(f"Model created: {CONFIG['MODEL_NAME']}")
logger.info(f"Total parameters: {total_params:,}")
logger.info(f"Trainable parameters: {trainable_params:,}")
logger.info(f"Model size: {total_params * 4 / 1024**2:.1f} MB")

2025-08-19 06:32:20,950 - INFO - Loading pretrained weights from Hugging Face hub (timm/efficientnet_b3.ra2_in1k)


model.safetensors:   0%|          | 0.00/49.3M [00:00<?, ?B/s]

2025-08-19 06:32:22,027 - INFO - [timm/efficientnet_b3.ra2_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-08-19 06:32:22,069 - INFO - Backbone feature dimension: 1536
2025-08-19 06:32:22,385 - INFO - Model created: efficientnet_b3
2025-08-19 06:32:22,387 - INFO - Total parameters: 13,659,383
2025-08-19 06:32:22,388 - INFO - Trainable parameters: 13,659,383
2025-08-19 06:32:22,389 - INFO - Model size: 52.1 MB


In [None]:
# =============================================================================
# CELL 9: Loss Functions & Metrics (V2 FINAL FIX)
# =============================================================================
from sklearn.metrics import cohen_kappa_score

class FocalLoss(nn.Module):
    def __init__(self, weight=None, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.weight = weight
        self.gamma = gamma
        self.reduction = reduction
    def forward(self, logits, target):
        ce_loss = F.cross_entropy(logits, target, weight=self.weight, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss)
        if self.reduction == 'mean':
            return focal_loss.mean()
        return focal_loss

class MultiTaskLoss(nn.Module):
    def __init__(self, class_weights=None, alpha=1.0, beta=0.0, gamma=0.0):
        super(MultiTaskLoss, self).__init__()
        self.alpha = alpha; self.beta = beta; self.gamma = gamma
        if class_weights is not None:
            class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
            self.severity_loss = FocalLoss(weight=class_weights, gamma=2.0)
        else:
            self.severity_loss = FocalLoss(gamma=2.0)
        self.lesion_loss = nn.BCEWithLogitsLoss()
        self.region_loss = nn.BCEWithLogitsLoss()
    def forward(self, predictions, targets):
        severity_loss = self.severity_loss(predictions['severity'], targets['severity'])
        lesion_loss = self.lesion_loss(predictions['lesions'], targets['lesions'])
        region_loss = self.region_loss(predictions['regions'], targets['regions'])
        total_loss = (self.alpha * severity_loss + self.beta * lesion_loss + self.gamma * region_loss)
        return {'total_loss': total_loss, 'severity_loss': severity_loss, 'lesion_loss': lesion_loss, 'region_loss': region_loss}


def calculate_metrics(predictions, targets, task='severity'):
    """Calculate comprehensive metrics, including QWK."""
    if task == 'severity':
        pred_classes = torch.argmax(predictions, dim=1).cpu()
        targets_cpu = targets.cpu()
        accuracy = accuracy_score(targets_cpu, pred_classes)
        f1 = f1_score(targets_cpu, pred_classes, average='weighted', zero_division=0)
        qwk = cohen_kappa_score(targets_cpu, pred_classes, weights='quadratic')
        return {'accuracy': accuracy, 'f1_score': f1, 'qwk': qwk}

    elif task in ['lesions', 'regions']:
        pred_probs = torch.sigmoid(predictions)
        pred_classes = (pred_probs > 0.5).float().cpu()
        targets_cpu = targets.cpu()
        accuracy = ((pred_classes == targets_cpu).sum() / targets_cpu.numel()).item()
        f1 = f1_score(targets_cpu.numpy(), pred_classes.numpy(), average='weighted', zero_division=0)
        return {'accuracy': accuracy, 'f1_score': f1}

    return {}


criterion = None
logger.info("V2 Loss functions (FocalLoss) and metrics (QWK) defined and fixed!")

2025-08-19 06:32:22,402 - INFO - V2 Loss functions (FocalLoss) and metrics (QWK) defined and fixed!


In [None]:
# =============================================================================
# CELL 10: Data Loading & Preprocessing (FINAL CORRECTED VERSION)
# =============================================================================
from torch.utils.data import WeightedRandomSampler

def load_and_prepare_data():
    """Load dataset from a directory structure where subfolders are class names."""
    logger.info("Loading dataset from class subdirectories...")
    severity_mapping = {
        'No_DR': 0, 'Mild': 1, 'Moderate': 2, 'Severe': 3, 'Proliferate_DR': 4
    }
    image_paths = []
    labels = []
    available_dirs = [d for d in os.listdir(DATASET_PATH) if os.path.isdir(os.path.join(DATASET_PATH, d))]
    logger.info(f"Found image subdirectories: {available_dirs}")
    for dir_name, severity_level in severity_mapping.items():
        severity_dir = os.path.join(DATASET_PATH, dir_name)
        if not os.path.exists(severity_dir):
            logger.warning(f"Directory '{dir_name}' not found, skipping.")
            continue
        logger.info(f"Loading images from '{dir_name}' (Label: {severity_level})...")
        image_files = [f for f in os.listdir(severity_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        for img_file in image_files:
            image_paths.append(os.path.join(severity_dir, img_file))
            labels.append(severity_level)
        logger.info(f"  -> Found {len(image_files)} images.")
    if len(image_paths) == 0:
        raise FileNotFoundError("Could not find any images in the severity subdirectories.")
    logger.info(f"Total loaded: {len(image_paths)} images from directory structure.")
    unique, counts = np.unique(labels, return_counts=True)
    logger.info("Class distribution:")
    for cls, count in zip(unique, counts):
        logger.info(f"  Class {cls} ({list(severity_mapping.keys())[cls]}): {count} samples ({count/len(labels)*100:.1f}%)")
    return image_paths, labels

def create_stratified_split(image_paths, labels, test_size=0.2, random_state=42):
    """Create stratified train/validation split."""
    from sklearn.model_selection import train_test_split
    image_paths = np.array(image_paths)
    labels = np.array(labels)
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        image_paths, labels, test_size=test_size, random_state=random_state, stratify=labels
    )
    logger.info(f"Train samples: {len(train_paths)}")
    logger.info(f"Validation samples: {len(val_paths)}")
    return train_paths, val_paths, train_labels, val_labels

# =============================================================================
# --- Main Data Loading Execution with Embedded Fix ---
# =============================================================================
try:
    all_image_paths, all_labels = load_and_prepare_data()
    train_paths, val_paths, train_labels, val_labels = create_stratified_split(
        all_image_paths, all_labels, test_size=0.2
    )

    # --- 1. DYNAMICALLY COMPUTE CLASS WEIGHTS ---
    logger.info("Computing dynamic class weights for Focal Loss...")
    class_counts = np.bincount(train_labels, minlength=CONFIG['NUM_CLASSES'])
    inv_freq = class_counts.max() / np.maximum(class_counts, 1)
    # Normalize the weights
    dynamic_weights = inv_freq / inv_freq.mean()
    CONFIG['CLASS_WEIGHTS'] = dynamic_weights.tolist()
    logger.info(f"Computed Class Weights: {CONFIG['CLASS_WEIGHTS']}")

    # --- 2. NOW INITIALIZE THE CRITERION WITH DYNAMIC WEIGHTS ---
    criterion = MultiTaskLoss(
        class_weights=CONFIG['CLASS_WEIGHTS'],
        alpha=CONFIG['ALPHA'],
        beta=CONFIG['BETA'],
        gamma=CONFIG['GAMMA']
    )
    logger.info("Criterion initialized with FocalLoss and dynamic weights.")

    # --- 3. CREATE WEIGHTED RANDOM SAMPLER FOR THE TRAINING SET ---
    logger.info("Creating WeightedRandomSampler for the training loader...")
    # The weights for each sample are the inverse frequency of their class
    sample_weights = dynamic_weights[train_labels]
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )

    # We use the RAM-caching dataset from before as it's the fastest
    train_dataset = OptimizedDRDataset(
        train_paths, train_labels, medical_kb, transforms=train_transforms, is_train=True
    )
    val_dataset = OptimizedDRDataset(
        val_paths, val_labels, medical_kb, transforms=val_transforms, is_train=False
    )

    # --- 4. CREATE DATALOADERS (TRAIN LOADER USES THE SAMPLER) ---
    train_loader = DataLoader(
        train_dataset,
        batch_size=CONFIG['BATCH_SIZE'],
        sampler=sampler, # <<<--- SAMPLER IS USED HERE
        num_workers=CONFIG['NUM_WORKERS'],
        pin_memory=True,
        persistent_workers=True if CONFIG['NUM_WORKERS'] > 0 else False,
        drop_last=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=CONFIG['BATCH_SIZE'],
        shuffle=False, # Sampler and shuffle are mutually exclusive
        num_workers=CONFIG['NUM_WORKERS'],
        pin_memory=True,
        persistent_workers=True if CONFIG['NUM_WORKERS'] > 0 else False
    )

    logger.info("V2 Data loading complete with WeightedRandomSampler!")
    logger.info(f"Number of training batches: {len(train_loader)}")
    logger.info(f"Number of validation batches: {len(val_loader)}")

except Exception as e:
    logger.error(f"A critical error occurred during data loading: {e}")
    raise

2025-08-19 06:32:22,419 - INFO - Loading dataset from class subdirectories...
2025-08-19 06:32:22,422 - INFO - Found image subdirectories: ['Severe', 'Proliferate_DR', 'No_DR', 'Mild', 'Moderate']
2025-08-19 06:32:22,423 - INFO - Loading images from 'No_DR' (Label: 0)...
2025-08-19 06:32:28,132 - INFO -   -> Found 1815 images.
2025-08-19 06:32:28,137 - INFO - Loading images from 'Mild' (Label: 1)...
2025-08-19 06:32:28,147 - INFO -   -> Found 370 images.
2025-08-19 06:32:28,149 - INFO - Loading images from 'Moderate' (Label: 2)...
2025-08-19 06:32:28,163 - INFO -   -> Found 999 images.
2025-08-19 06:32:28,166 - INFO - Loading images from 'Severe' (Label: 3)...
2025-08-19 06:32:28,170 - INFO -   -> Found 193 images.
2025-08-19 06:32:28,171 - INFO - Loading images from 'Proliferate_DR' (Label: 4)...
2025-08-19 06:32:28,177 - INFO -   -> Found 295 images.
2025-08-19 06:32:28,186 - INFO - Total loaded: 3672 images from directory structure.
2025-08-19 06:32:28,194 - INFO - Class distributio

In [None]:
# =============================================================================
# CELL 11: Optimizer & Scheduler Setup
# =============================================================================

# Optimizer with different learning rates for different parts
backbone_params = []
classifier_params = []

for name, param in model.named_parameters():
    if 'backbone' in name:
        backbone_params.append(param)
    else:
        classifier_params.append(param)

# Create optimizer with different learning rates
optimizer = optim.AdamW([
    {'params': backbone_params, 'lr': CONFIG['LEARNING_RATE'] * 0.1},  # Lower LR for pretrained backbone
    {'params': classifier_params, 'lr': CONFIG['LEARNING_RATE']}       # Higher LR for new heads
], weight_decay=CONFIG['WEIGHT_DECAY'])

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=CONFIG['T_MAX'], T_mult=2, eta_min=CONFIG['MIN_LR']
)

# Mixed precision scaler
scaler = GradScaler() if CONFIG['USE_AMP'] else None

# Early stopping
class EarlyStopping:
    def __init__(self, patience=7, min_delta=1e-4, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = float('inf')
        self.counter = 0
        self.best_weights = None

    def __call__(self, val_loss, model):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            if self.restore_best_weights:
                self.best_weights = model.state_dict().copy()
        else:
            self.counter += 1

        if self.counter >= self.patience:
            if self.restore_best_weights and self.best_weights:
                model.load_state_dict(self.best_weights)
            return True
        return False

early_stopping = EarlyStopping(patience=CONFIG['PATIENCE'])

logger.info("Optimizer and scheduler initialized!")
logger.info(f"Backbone LR: {CONFIG['LEARNING_RATE'] * 0.1}")
logger.info(f"Classifier LR: {CONFIG['LEARNING_RATE']}")

2025-08-19 06:32:28,277 - INFO - Optimizer and scheduler initialized!
2025-08-19 06:32:28,277 - INFO - Backbone LR: 1e-05
2025-08-19 06:32:28,278 - INFO - Classifier LR: 0.0001


In [None]:
# =============================================================================
# CELL 12: Training & Validation Functions (V2 FINAL FIX)
# =============================================================================

def train_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, epoch):
    # ... (This function is correct and remains unchanged) ...
    model.train()
    total_loss, severity_loss_sum, lesion_loss_sum, region_loss_sum = 0, 0, 0, 0
    severity_correct, total_samples = 0, 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['EPOCHS']} Training", unit="batch")
    for batch_idx, batch in enumerate(progress_bar):
        images = batch['image'].to(device, non_blocking=True)
        targets = {'severity': batch['severity'].to(device, non_blocking=True), 'lesions': batch['lesions'].to(device, non_blocking=True), 'regions': batch['regions'].to(device, non_blocking=True)}
        with autocast(enabled=CONFIG['USE_AMP']):
            predictions = model(images)
            loss_dict = criterion(predictions, targets)
        optimizer.zero_grad()
        if CONFIG['USE_AMP']:
            scaler.scale(loss_dict['total_loss']).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['GRADIENT_CLIP'])
            scaler.step(optimizer)
            scaler.update()
        else:
            loss_dict['total_loss'].backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['GRADIENT_CLIP'])
            optimizer.step()
        scheduler.step(epoch + batch_idx / len(train_loader))
        total_loss += loss_dict['total_loss'].item()
        severity_loss_sum += loss_dict['severity_loss'].item()
        lesion_loss_sum += loss_dict['lesion_loss'].item()
        region_loss_sum += loss_dict['region_loss'].item()
        severity_pred = torch.argmax(predictions['severity'], dim=1)
        severity_correct += (severity_pred == targets['severity']).sum().item()
        total_samples += targets['severity'].size(0)
        running_acc = severity_correct / total_samples
        progress_bar.set_postfix(loss=loss_dict['total_loss'].item(), acc=f"{running_acc:.3f}", lr=f"{scheduler.get_last_lr()[0]:.2e}")
    avg_loss = total_loss / len(train_loader)
    accuracy = severity_correct / total_samples
    return {'total_loss': avg_loss, 'severity_loss': severity_loss_sum / len(train_loader), 'lesion_loss': lesion_loss_sum / len(train_loader), 'region_loss': region_loss_sum / len(train_loader), 'accuracy': accuracy, 'lr': optimizer.param_groups[0]['lr']}

def validate_epoch(model, val_loader, criterion):
    """Validate for one epoch"""
    model.eval()
    total_loss = 0
    all_severity_preds, all_severity_targets = [], []
    all_lesion_preds, all_lesion_targets = [], []
    all_region_preds, all_region_targets = [], []
    progress_bar = tqdm(val_loader, desc="Validating", unit="batch")
    with torch.no_grad():
        for batch in progress_bar:
            images = batch['image'].to(device, non_blocking=True)
            targets = {'severity': batch['severity'].to(device, non_blocking=True), 'lesions': batch['lesions'].to(device, non_blocking=True), 'regions': batch['regions'].to(device, non_blocking=True)}
            with autocast(enabled=CONFIG['USE_AMP']):
                predictions = model(images)
                loss_dict = criterion(predictions, targets)
            total_loss += loss_dict['total_loss'].item()
            all_severity_preds.append(predictions['severity'].cpu())
            all_severity_targets.append(targets['severity'].cpu())
            all_lesion_preds.append(predictions['lesions'].cpu())
            all_lesion_targets.append(targets['lesions'].cpu())
            all_region_preds.append(predictions['regions'].cpu())
            all_region_targets.append(targets['regions'].cpu())
            progress_bar.set_postfix(loss=loss_dict['total_loss'].item())

    # Calculate metrics for each task
    severity_metrics = calculate_metrics(torch.cat(all_severity_preds, dim=0), torch.cat(all_severity_targets, dim=0), 'severity')
    lesion_metrics = calculate_metrics(torch.cat(all_lesion_preds, dim=0), torch.cat(all_lesion_targets, dim=0), 'lesions')
    region_metrics = calculate_metrics(torch.cat(all_region_preds, dim=0), torch.cat(all_region_targets, dim=0), 'regions')

    return {
        'total_loss': total_loss / len(val_loader),
        'severity_loss': 0, # Not calculating individual losses here to save time
        'lesion_loss': 0,
        'region_loss': 0,
        'severity_accuracy': severity_metrics['accuracy'],
        'severity_f1': severity_metrics['f1_score'],
        'qwk': severity_metrics['qwk'],
        'lesion_accuracy': lesion_metrics.get('accuracy', 0), # Using .get for safety
        'lesion_f1': lesion_metrics.get('f1_score', 0),
        'region_accuracy': region_metrics.get('accuracy', 0),
        'region_f1': region_metrics.get('f1_score', 0)
    }

logger.info("Training and validation functions (V2) are defined.")

2025-08-19 06:32:28,300 - INFO - Training and validation functions (V2) are defined.


In [None]:
# =============================================================================
# CELL 13: Main Training Loop with Comprehensive Logging (V2 UPGRADE)
# =============================================================================

def save_checkpoint(model, optimizer, scheduler, epoch, metrics, is_best=False):
    """Save model checkpoint"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'metrics': metrics,
        'config': CONFIG
    }

    checkpoint_path = f"{MODEL_SAVE_PATH}/checkpoint_epoch_{epoch:02d}.pth"
    torch.save(checkpoint, checkpoint_path)

    if is_best:
        best_path = f"{MODEL_SAVE_PATH}/best_model_v2.pth" # Save V2 model separately
        torch.save(checkpoint, best_path)
        logger.info(f"Best model saved: {best_path}")

    return checkpoint_path

def log_epoch_metrics(epoch, train_metrics, val_metrics, epoch_time):
    """Log comprehensive metrics for each epoch, including QWK."""

    log_entry = f"""
{'='*80}
EPOCH {epoch:02d} SUMMARY
{'='*80}
Time: {epoch_time:.2f}s | LR: {train_metrics['lr']:.2e}

LOSSES:
  Train - Total: {train_metrics['total_loss']:.4f} | Severity: {train_metrics['severity_loss']:.4f}
  Val   - Total: {val_metrics['total_loss']:.4f}

ACCURACY:
  Train - Severity: {train_metrics['accuracy']:.4f}
  Val   - Severity: {val_metrics['severity_accuracy']:.4f}

METRICS (Val):
  Severity - F1: {val_metrics['severity_f1']:.4f} | QWK: {val_metrics['qwk']:.4f}

MEMORY: {torch.cuda.memory_allocated()/1024**3:.2f}GB / {torch.cuda.max_memory_allocated()/1024**3:.2f}GB
{'='*80}
    """

    logger.info(log_entry)

    # Also log to separate metrics file for easy parsing
    metrics_data = {
        'epoch': epoch,
        'timestamp': datetime.now().isoformat(),
        'train': train_metrics,
        'val': val_metrics,
        'epoch_time': epoch_time,
        'memory_usage': torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0
    }

    metrics_file = f"{LOG_PATH}/metrics_log_v2.json" # Use a new log file for V2
    if os.path.exists(metrics_file):
        with open(metrics_file, 'r') as f:
            all_metrics = json.load(f)
    else:
        all_metrics = []
    all_metrics.append(metrics_data)
    with open(metrics_file, 'w') as f:
        json.dump(all_metrics, f, indent=2)

def train_model():
    """Main training function"""
    logger.info("Starting V2 training...")
    logger.info(f"Training for {CONFIG['EPOCHS']} epochs")

    global model
    model = MultiTaskDRModel().to(device)
    logger.info("Model re-initialized for V2 training.")

    # Re-initialize optimizer and scheduler
    global optimizer, scheduler, scaler, early_stopping
    optimizer = optim.AdamW(model.parameters(), lr=CONFIG['LEARNING_RATE'], weight_decay=CONFIG['WEIGHT_DECAY'])
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=CONFIG['T_MAX'], T_mult=2, eta_min=CONFIG['MIN_LR']
    )
    scaler = GradScaler() if CONFIG['USE_AMP'] else None
    early_stopping = EarlyStopping(patience=CONFIG['PATIENCE'])
    logger.info("Optimizer, Scheduler, and Early Stopping re-initialized for V2.")

    # Optional: Add torch.compile for a speed boost
    # model = torch.compile(model)

    best_val_loss = float('inf')
    training_history = []
    training_start_time = time.time()

    try:
        for epoch in range(CONFIG['EPOCHS']):
            epoch_start_time = time.time()

            # Using TQDM functions from updated CELL 12
            train_metrics = train_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, epoch)
            val_metrics = validate_epoch(model, val_loader, criterion)

            epoch_time = time.time() - epoch_start_time
            log_epoch_metrics(epoch + 1, train_metrics, val_metrics, epoch_time)

            is_best = val_metrics['total_loss'] < best_val_loss
            if is_best:
                best_val_loss = val_metrics['total_loss']
                logger.info(f"New best validation loss: {best_val_loss:.4f}")

            save_checkpoint(model, optimizer, scheduler, epoch + 1,
                                            {'train': train_metrics, 'val': val_metrics}, is_best)
            training_history.append({
                'epoch': epoch + 1,
                'train': train_metrics,
                'val': val_metrics,
                'epoch_time': epoch_time
            })

            if early_stopping(val_metrics['total_loss'], model):
                logger.info(f"Early stopping triggered after epoch {epoch + 1}")
                break

            torch.cuda.empty_cache()
            gc.collect()

    except KeyboardInterrupt:
        logger.info("Training interrupted by user")
    except Exception as e:
        logger.error(f"Training error: {e}")
        raise

    finally:
        total_training_time = time.time() - training_start_time
        logger.info(f"\nTraining completed!")
        logger.info(f"Total training time: {total_training_time/3600:.2f} hours")
        logger.info(f"Best validation loss: {best_val_loss:.4f}")

        history_path = f"{LOG_PATH}/training_history_v2.json"
        with open(history_path, 'w') as f:
            json.dump(training_history, f, indent=2)
        logger.info(f"Training history saved: {history_path}")

train_model()

2025-08-18 10:52:10,645 - INFO - Starting V2 training...
2025-08-18 10:52:10,647 - INFO - Training for 30 epochs
2025-08-18 10:52:10,853 - INFO - Loading pretrained weights from Hugging Face hub (timm/efficientnet_b3.ra2_in1k)
2025-08-18 10:52:10,983 - INFO - [timm/efficientnet_b3.ra2_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-08-18 10:52:11,023 - INFO - Backbone feature dimension: 1536
2025-08-18 10:52:11,093 - INFO - Model re-initialized for V2 training.
2025-08-18 10:52:11,102 - INFO - Optimizer, Scheduler, and Early Stopping re-initialized for V2.


Epoch 1/30 Training:   0%|          | 0/91 [00:00<?, ?batch/s]

Validating:   0%|          | 0/23 [00:00<?, ?batch/s]

2025-08-18 10:54:52,059 - INFO - 
EPOCH 01 SUMMARY
Time: 160.96s | LR: 9.76e-05

LOSSES:
  Train - Total: 1.1698 | Severity: 1.1698
  Val   - Total: 0.4015

ACCURACY:
  Train - Severity: 0.3173
  Val   - Severity: 0.4748

METRICS (Val):
  Severity - F1: 0.4764 | QWK: 0.6523

MEMORY: 0.28GB / 13.82GB
    
2025-08-18 10:54:52,067 - INFO - New best validation loss: 0.4015
2025-08-18 10:55:01,593 - INFO - Best model saved: /content/drive/MyDrive/Diabetic_Retinopathy_Project/models/best_model_v2.pth


Epoch 2/30 Training:   0%|          | 0/91 [00:00<?, ?batch/s]

Validating:   0%|          | 0/23 [00:00<?, ?batch/s]

2025-08-18 10:57:12,254 - INFO - 
EPOCH 02 SUMMARY
Time: 130.03s | LR: 9.06e-05

LOSSES:
  Train - Total: 0.8248 | Severity: 0.8248
  Val   - Total: 0.3180

ACCURACY:
  Train - Severity: 0.4427
  Val   - Severity: 0.4599

METRICS (Val):
  Severity - F1: 0.4739 | QWK: 0.6928

MEMORY: 0.28GB / 13.82GB
    
2025-08-18 10:57:12,264 - INFO - New best validation loss: 0.3180
2025-08-18 10:57:22,652 - INFO - Best model saved: /content/drive/MyDrive/Diabetic_Retinopathy_Project/models/best_model_v2.pth


Epoch 3/30 Training:   0%|          | 0/91 [00:00<?, ?batch/s]

Validating:   0%|          | 0/23 [00:00<?, ?batch/s]

2025-08-18 10:59:34,786 - INFO - 
EPOCH 03 SUMMARY
Time: 131.65s | LR: 7.95e-05

LOSSES:
  Train - Total: 0.7457 | Severity: 0.7457
  Val   - Total: 0.3032

ACCURACY:
  Train - Severity: 0.4797
  Val   - Severity: 0.4884

METRICS (Val):
  Severity - F1: 0.4836 | QWK: 0.6788

MEMORY: 0.28GB / 13.82GB
    
2025-08-18 10:59:34,795 - INFO - New best validation loss: 0.3032
2025-08-18 10:59:43,860 - INFO - Best model saved: /content/drive/MyDrive/Diabetic_Retinopathy_Project/models/best_model_v2.pth


Epoch 4/30 Training:   0%|          | 0/91 [00:00<?, ?batch/s]

Validating:   0%|          | 0/23 [00:00<?, ?batch/s]

2025-08-18 11:01:51,546 - INFO - 
EPOCH 04 SUMMARY
Time: 127.22s | LR: 6.56e-05

LOSSES:
  Train - Total: 0.6340 | Severity: 0.6340
  Val   - Total: 0.3136

ACCURACY:
  Train - Severity: 0.4993
  Val   - Severity: 0.5864

METRICS (Val):
  Severity - F1: 0.5550 | QWK: 0.7536

MEMORY: 0.28GB / 13.82GB
    


Epoch 5/30 Training:   0%|          | 0/91 [00:00<?, ?batch/s]

Validating:   0%|          | 0/23 [00:00<?, ?batch/s]

2025-08-18 11:03:58,100 - INFO - 
EPOCH 05 SUMMARY
Time: 116.52s | LR: 5.02e-05

LOSSES:
  Train - Total: 0.5493 | Severity: 0.5493
  Val   - Total: 0.2951

ACCURACY:
  Train - Severity: 0.5151
  Val   - Severity: 0.5796

METRICS (Val):
  Severity - F1: 0.5512 | QWK: 0.7672

MEMORY: 0.28GB / 13.82GB
    
2025-08-18 11:03:58,127 - INFO - New best validation loss: 0.2951
2025-08-18 11:04:08,816 - INFO - Best model saved: /content/drive/MyDrive/Diabetic_Retinopathy_Project/models/best_model_v2.pth


Epoch 6/30 Training:   0%|          | 0/91 [00:00<?, ?batch/s]

Validating:   0%|          | 0/23 [00:00<?, ?batch/s]

2025-08-18 11:06:13,411 - INFO - 
EPOCH 06 SUMMARY
Time: 124.07s | LR: 3.48e-05

LOSSES:
  Train - Total: 0.4453 | Severity: 0.4453
  Val   - Total: 0.2620

ACCURACY:
  Train - Severity: 0.5701
  Val   - Severity: 0.5673

METRICS (Val):
  Severity - F1: 0.5644 | QWK: 0.7573

MEMORY: 0.28GB / 13.82GB
    
2025-08-18 11:06:13,421 - INFO - New best validation loss: 0.2620
2025-08-18 11:06:24,575 - INFO - Best model saved: /content/drive/MyDrive/Diabetic_Retinopathy_Project/models/best_model_v2.pth


Epoch 7/30 Training:   0%|          | 0/91 [00:00<?, ?batch/s]

Validating:   0%|          | 0/23 [00:00<?, ?batch/s]

2025-08-18 11:08:24,102 - INFO - 
EPOCH 07 SUMMARY
Time: 119.08s | LR: 2.08e-05

LOSSES:
  Train - Total: 0.4297 | Severity: 0.4297
  Val   - Total: 0.2812

ACCURACY:
  Train - Severity: 0.5804
  Val   - Severity: 0.5864

METRICS (Val):
  Severity - F1: 0.5586 | QWK: 0.7801

MEMORY: 0.27GB / 13.82GB
    


Epoch 8/30 Training:   0%|          | 0/91 [00:00<?, ?batch/s]

Validating:   0%|          | 0/23 [00:00<?, ?batch/s]

2025-08-18 11:10:32,512 - INFO - 
EPOCH 08 SUMMARY
Time: 118.88s | LR: 9.74e-06

LOSSES:
  Train - Total: 0.4217 | Severity: 0.4217
  Val   - Total: 0.2631

ACCURACY:
  Train - Severity: 0.5903
  Val   - Severity: 0.5850

METRICS (Val):
  Severity - F1: 0.5709 | QWK: 0.7739

MEMORY: 0.27GB / 13.82GB
    


Epoch 9/30 Training:   0%|          | 0/91 [00:00<?, ?batch/s]

Validating:   0%|          | 0/23 [00:00<?, ?batch/s]

2025-08-18 11:12:45,646 - INFO - 
EPOCH 09 SUMMARY
Time: 121.63s | LR: 2.60e-06

LOSSES:
  Train - Total: 0.4080 | Severity: 0.4080
  Val   - Total: 0.2869

ACCURACY:
  Train - Severity: 0.5982
  Val   - Severity: 0.6054

METRICS (Val):
  Severity - F1: 0.5811 | QWK: 0.7920

MEMORY: 0.27GB / 13.82GB
    


Epoch 10/30 Training:   0%|          | 0/91 [00:00<?, ?batch/s]

Validating:   0%|          | 0/23 [00:00<?, ?batch/s]

2025-08-18 11:14:50,508 - INFO - 
EPOCH 10 SUMMARY
Time: 113.88s | LR: 1.00e-07

LOSSES:
  Train - Total: 0.3966 | Severity: 0.3966
  Val   - Total: 0.2604

ACCURACY:
  Train - Severity: 0.5927
  Val   - Severity: 0.5673

METRICS (Val):
  Severity - F1: 0.5571 | QWK: 0.7702

MEMORY: 0.28GB / 13.82GB
    
2025-08-18 11:14:50,521 - INFO - New best validation loss: 0.2604
2025-08-18 11:15:01,628 - INFO - Best model saved: /content/drive/MyDrive/Diabetic_Retinopathy_Project/models/best_model_v2.pth


Epoch 11/30 Training:   0%|          | 0/91 [00:00<?, ?batch/s]

Validating:   0%|          | 0/23 [00:00<?, ?batch/s]

2025-08-18 11:17:04,103 - INFO - 
EPOCH 11 SUMMARY
Time: 122.03s | LR: 9.94e-05

LOSSES:
  Train - Total: 0.3962 | Severity: 0.3962
  Val   - Total: 0.3015

ACCURACY:
  Train - Severity: 0.6044
  Val   - Severity: 0.6136

METRICS (Val):
  Severity - F1: 0.5999 | QWK: 0.7863

MEMORY: 0.28GB / 13.82GB
    


Epoch 12/30 Training:   0%|          | 0/91 [00:00<?, ?batch/s]

Validating:   0%|          | 0/23 [00:00<?, ?batch/s]

2025-08-18 11:19:12,043 - INFO - 
EPOCH 12 SUMMARY
Time: 119.10s | LR: 9.76e-05

LOSSES:
  Train - Total: 0.3731 | Severity: 0.3731
  Val   - Total: 0.2638

ACCURACY:
  Train - Severity: 0.6233
  Val   - Severity: 0.6136

METRICS (Val):
  Severity - F1: 0.6099 | QWK: 0.7629

MEMORY: 0.28GB / 13.82GB
    


Epoch 13/30 Training:   0%|          | 0/91 [00:00<?, ?batch/s]

Validating:   0%|          | 0/23 [00:00<?, ?batch/s]

2025-08-18 11:21:17,919 - INFO - 
EPOCH 13 SUMMARY
Time: 116.10s | LR: 9.46e-05

LOSSES:
  Train - Total: 0.3192 | Severity: 0.3192
  Val   - Total: 0.3023

ACCURACY:
  Train - Severity: 0.6411
  Val   - Severity: 0.6150

METRICS (Val):
  Severity - F1: 0.6219 | QWK: 0.7668

MEMORY: 0.27GB / 13.82GB
    


Epoch 14/30 Training:   0%|          | 0/91 [00:00<?, ?batch/s]

Validating:   0%|          | 0/23 [00:00<?, ?batch/s]

2025-08-18 11:23:26,590 - INFO - 
EPOCH 14 SUMMARY
Time: 120.66s | LR: 9.05e-05

LOSSES:
  Train - Total: 0.3170 | Severity: 0.3170
  Val   - Total: 0.2900

ACCURACY:
  Train - Severity: 0.6669
  Val   - Severity: 0.5959

METRICS (Val):
  Severity - F1: 0.6131 | QWK: 0.7675

MEMORY: 0.28GB / 13.82GB
    


Epoch 15/30 Training:   0%|          | 0/91 [00:00<?, ?batch/s]

Validating:   0%|          | 0/23 [00:00<?, ?batch/s]

2025-08-18 11:25:34,513 - INFO - 
EPOCH 15 SUMMARY
Time: 118.33s | LR: 8.54e-05

LOSSES:
  Train - Total: 0.2554 | Severity: 0.2554
  Val   - Total: 0.2873

ACCURACY:
  Train - Severity: 0.6820
  Val   - Severity: 0.6503

METRICS (Val):
  Severity - F1: 0.6633 | QWK: 0.7957

MEMORY: 0.28GB / 13.82GB
    


Epoch 16/30 Training:   0%|          | 0/91 [00:00<?, ?batch/s]

Validating:   0%|          | 0/23 [00:00<?, ?batch/s]

2025-08-18 11:27:42,610 - INFO - 
EPOCH 16 SUMMARY
Time: 116.15s | LR: 7.95e-05

LOSSES:
  Train - Total: 0.2575 | Severity: 0.2575
  Val   - Total: 0.2909

ACCURACY:
  Train - Severity: 0.6889
  Val   - Severity: 0.6000

METRICS (Val):
  Severity - F1: 0.6197 | QWK: 0.7669

MEMORY: 0.28GB / 13.82GB
    


Epoch 17/30 Training:   0%|          | 0/91 [00:00<?, ?batch/s]

Validating:   0%|          | 0/23 [00:00<?, ?batch/s]

2025-08-18 11:29:50,621 - INFO - 
EPOCH 17 SUMMARY
Time: 117.12s | LR: 7.28e-05

LOSSES:
  Train - Total: 0.2023 | Severity: 0.2023
  Val   - Total: 0.3071

ACCURACY:
  Train - Severity: 0.7414
  Val   - Severity: 0.6245

METRICS (Val):
  Severity - F1: 0.6451 | QWK: 0.7785

MEMORY: 0.28GB / 13.82GB
    
2025-08-18 11:29:59,353 - INFO - Early stopping triggered after epoch 17
2025-08-18 11:29:59,354 - INFO - 
Training completed!
2025-08-18 11:29:59,357 - INFO - Total training time: 0.63 hours
2025-08-18 11:29:59,358 - INFO - Best validation loss: 0.2604
2025-08-18 11:29:59,367 - INFO - Training history saved: /content/drive/MyDrive/Diabetic_Retinopathy_Project/training_logs/training_history_v2.json


In [None]:
torch.cuda.empty_cache()

In [None]:
# =============================================================================
# CELL 14: Model Evaluation & Grad-CAM Visualization (V2 FINAL FIX)
# =============================================================================

def evaluate_model(model, val_loader):
    """Comprehensive model evaluation"""
    logger.info("Starting comprehensive evaluation...")

    model.eval()
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Final Evaluation"):
            images = batch['image'].to(device)
            targets = batch['severity'].to(device)

            with autocast(enabled=CONFIG['USE_AMP']):
                outputs = model(images)

            predictions = torch.argmax(outputs['severity'], dim=1)
            all_predictions.extend(predictions.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

    accuracy = accuracy_score(all_targets, all_predictions)
    f1_weighted = f1_score(all_targets, all_predictions, average='weighted', zero_division=0)
    f1_macro = f1_score(all_targets, all_predictions, average='macro', zero_division=0)
    qwk = cohen_kappa_score(all_targets, all_predictions, weights='quadratic')

    class_names = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative']
    report = classification_report(all_targets, all_predictions,
                                 target_names=class_names, output_dict=True, zero_division=0)
    cm = confusion_matrix(all_targets, all_predictions)

    logger.info(f"Final Evaluation Results:")
    logger.info(f"Accuracy: {accuracy:.4f}")
    logger.info(f"F1-Score (Weighted): {f1_weighted:.4f}")
    logger.info(f"QWK Score: {qwk:.4f}")

    return {
        'accuracy': accuracy, 'f1_weighted': f1_weighted, 'f1_macro': f1_macro,
        'qwk': qwk, 'classification_report': report, 'confusion_matrix': cm
    }

def create_gradcam_visualizations(model, dataset, num_samples=20):
    """Create Grad-CAM visualizations using a memory-safe batch size of 1."""

    print("Installing Grad-CAM for visualization...")
    !pip install -q grad-cam ttach

    from pytorch_grad_cam import GradCAM
    from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
    from pytorch_grad_cam.utils.image import show_cam_on_image
    print("Grad-CAM installed successfully.")

    logger.info("Creating Grad-CAM visualizations...")

    class ModelWrapper(torch.nn.Module):
        def __init__(self, model):
            super(ModelWrapper, self).__init__(); self.model = model
        def forward(self, x):
            return self.model(x)['severity']

    wrapped_model = ModelWrapper(model)
    target_layers = [model.backbone.conv_head]
    cam = GradCAM(model=wrapped_model, target_layers=target_layers)

    model.eval()
    samples_processed = 0
    vis_dir = f"{LOG_PATH}/gradcam_visualizations_v2"
    os.makedirs(vis_dir, exist_ok=True)

    vis_loader = DataLoader(dataset, batch_size=1, shuffle=True)

    for batch in vis_loader:
        if samples_processed >= num_samples: break
        images = batch['image'].to(device)
        targets = batch['severity']
        with torch.no_grad():
            outputs = model(images)
            predictions = torch.argmax(outputs['severity'], dim=1)

        input_tensor = images[0:1]
        target_class = predictions[0].item()

        targets_for_gradcam = [ClassifierOutputTarget(target_class)]
        grayscale_cam = cam(input_tensor=input_tensor, targets=targets_for_gradcam, aug_smooth=True, eigen_smooth=True)
        grayscale_cam = grayscale_cam[0, :]

        img_np = input_tensor[0].cpu().numpy().transpose(1, 2, 0)
        mean = np.array([0.485, 0.456, 0.406]); std = np.array([0.229, 0.224, 0.225])
        img_np = (img_np * std + mean).clip(0, 1)
        visualization = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)

        fig, axs = plt.subplots(1, 2, figsize=(10, 5))
        axs[0].imshow(img_np); axs[0].set_title(f'Original Image\nTrue: {targets[0].item()}, Pred: {target_class}'); axs[0].axis('off')
        axs[1].imshow(visualization); axs[1].set_title('Grad-CAM Overlay'); axs[1].axis('off')

        plt.tight_layout()
        save_path = f"{vis_dir}/gradcam_sample_{samples_processed:03d}.png"
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close(fig)
        samples_processed += 1

    logger.info(f"Created {samples_processed} Grad-CAM visualizations in {vis_dir}")

# =============================================================================
# --- Main Evaluation Execution ---
# =============================================================================
vis_dir = f"{LOG_PATH}/gradcam_visualizations_v2"

best_epoch_num = 15
best_model_path = f"{MODEL_SAVE_PATH}/checkpoint_epoch_{best_epoch_num:02d}.pth"
logger.info(f"Loading best model for final evaluation from Epoch {best_epoch_num}...")

checkpoint = torch.load(best_model_path, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])

logger.info("Running final model evaluation on validation set...")
evaluation_results = evaluate_model(model, val_loader)

create_gradcam_visualizations(model, val_dataset, num_samples=20)

eval_path = f"{LOG_PATH}/final_evaluation_v2.json"
with open(eval_path, 'w') as f:
    eval_results_serializable = {
        'accuracy': evaluation_results['accuracy'],
        'f1_weighted': evaluation_results['f1_weighted'],
        'f1_macro': evaluation_results['f1_macro'],
        'qwk': evaluation_results['qwk'],
        'classification_report': evaluation_results['classification_report'],
        'confusion_matrix': evaluation_results['confusion_matrix'].tolist()
    }
    json.dump(eval_results_serializable, f, indent=2)

logger.info(f"V2 Evaluation results saved: {eval_path}")

2025-08-19 06:35:13,882 - INFO - Loading best model for final evaluation from Epoch 15...
2025-08-19 06:35:14,837 - INFO - Running final model evaluation on validation set...
2025-08-19 06:35:14,854 - INFO - Starting comprehensive evaluation...


Final Evaluation:   0%|          | 0/23 [00:00<?, ?it/s]

2025-08-19 06:37:59,773 - INFO - Final Evaluation Results:
2025-08-19 06:37:59,778 - INFO - Accuracy: 0.6503
2025-08-19 06:37:59,779 - INFO - F1-Score (Weighted): 0.6633
2025-08-19 06:37:59,779 - INFO - QWK Score: 0.7957
Installing Grad-CAM for visualization...
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m102.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m75.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m23.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [

In [None]:
 # =============================================================================
# CELL 15: Save Final Model and Prepare for Deployment
# =============================================================================

# Save the final model for deployment
final_model_path = f"{MODEL_SAVE_PATH}/final_model_for_deployment.pth"

deployment_package = {
    'model_state_dict': model.state_dict(),
    'model_config': {
        'model_name': CONFIG['MODEL_NAME'],
        'num_classes': CONFIG['NUM_CLASSES'],
        'num_lesion_types': len(CONFIG['LESION_TYPES']),
        'num_regions': len(CONFIG['ANATOMICAL_REGIONS']),
        'img_size': CONFIG['IMG_SIZE'],
        'drop_rate': CONFIG['DROP_RATE'],
        'drop_path_rate': CONFIG['DROP_PATH_RATE']
    },
    'medical_knowledge': {
        'lesion_types': CONFIG['LESION_TYPES'],
        'anatomical_regions': CONFIG['ANATOMICAL_REGIONS'],
        'class_names': ['No DR', 'Mild NPDR', 'Moderate NPDR', 'Severe NPDR', 'Proliferative DR']
    },
    'training_config': CONFIG,
    'evaluation_metrics': {
        'accuracy': evaluation_results['accuracy'],
        'f1_weighted': evaluation_results['f1_weighted'],
        'f1_macro': evaluation_results['f1_macro']
    },
    'preprocessing': {
        'mean': [0.485, 0.456, 0.406],
        'std': [0.229, 0.224, 0.225]
    }
}

torch.save(deployment_package, final_model_path)
logger.info(f"Final model saved for deployment: {final_model_path}")

# Create a summary report
summary_report = f"""
DIABETIC RETINOPATHY DETECTION MODEL - TRAINING SUMMARY
{'='*60}

MODEL CONFIGURATION:
- Architecture: {CONFIG['MODEL_NAME']}
- Input Size: {CONFIG['IMG_SIZE']}x{CONFIG['IMG_SIZE']}
- Number of Classes: {CONFIG['NUM_CLASSES']}
- Multi-task Learning: Severity + Lesion Detection + Anatomical Regions

TRAINING DETAILS:
- Total Epochs: {CONFIG['EPOCHS']}
- Batch Size: {CONFIG['BATCH_SIZE']}
- Learning Rate: {CONFIG['LEARNING_RATE']}
- Optimizer: AdamW with Cosine Annealing
- Mixed Precision: {CONFIG['USE_AMP']}

DATASET:
- Training Samples: {len(train_dataset)}
- Validation Samples: {len(val_dataset)}
- Data Augmentation: Advanced medical image augmentation

PERFORMANCE METRICS:
- Final Accuracy: {evaluation_results['accuracy']:.4f}
- Weighted F1-Score: {evaluation_results['f1_weighted']:.4f}
- Macro F1-Score: {evaluation_results['f1_macro']:.4f}

FILES GENERATED:
- Model Weights: {final_model_path}
- Training Logs: {log_filename}
- Evaluation Results: {eval_path}
- Grad-CAM Visualizations: {vis_dir}

NEXT STEPS:
1. Download the final model file: {final_model_path}
2. Use this model in your Streamlit application
3. Review the Grad-CAM visualizations to understand model focus areas
4. Check training logs for detailed epoch-by-epoch performance

{'='*60}
Training completed successfully!
"""

print(summary_report)
logger.info(summary_report)

# Save summary report
report_path = f"{LOG_PATH}/training_summary_report.txt"
with open(report_path, 'w') as f:
    f.write(summary_report)

logger.info(f"Training summary saved: {report_path}")
logger.info("All files saved to Google Drive. Training session complete!")

# Final memory cleanup
torch.cuda.empty_cache()
gc.collect()

2025-08-19 06:55:33,310 - INFO - Final model saved for deployment: /content/drive/MyDrive/Diabetic_Retinopathy_Project/models/final_model_for_deployment.pth

DIABETIC RETINOPATHY DETECTION MODEL - TRAINING SUMMARY

MODEL CONFIGURATION:
- Architecture: efficientnet_b3
- Input Size: 512x512
- Number of Classes: 5
- Multi-task Learning: Severity + Lesion Detection + Anatomical Regions

TRAINING DETAILS:
- Total Epochs: 30
- Batch Size: 32
- Learning Rate: 0.0001
- Optimizer: AdamW with Cosine Annealing
- Mixed Precision: True

DATASET:
- Training Samples: 2937
- Validation Samples: 735
- Data Augmentation: Advanced medical image augmentation

PERFORMANCE METRICS:
- Final Accuracy: 0.6503
- Weighted F1-Score: 0.6633
- Macro F1-Score: 0.5346

FILES GENERATED:
- Model Weights: /content/drive/MyDrive/Diabetic_Retinopathy_Project/models/final_model_for_deployment.pth
- Training Logs: /content/drive/MyDrive/Diabetic_Retinopathy_Project/training_logs/training_log_20250819_063219.txt
- Evaluation

74593

In [None]:
# =============================================================================
# CELL 16: Generate Professional Visualizations for Report
# =============================================================================
import json
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# Load the evaluation results we just saved
eval_path = f"{LOG_PATH}/final_evaluation_v2.json"
with open(eval_path, 'r') as f:
    evaluation_results = json.load(f)

# --- Create a High-Quality Confusion Matrix Plot ---
cm = np.array(evaluation_results['confusion_matrix'])
class_names = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative']

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)

plt.title('Confusion Matrix for Diabetic Retinopathy Classification', fontsize=16)
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)

# Save the plot to your Google Drive
cm_plot_path = f"{LOG_PATH}/confusion_matrix.png"
plt.savefig(cm_plot_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"Confusion matrix plot saved to: {cm_plot_path}")

# --- (Optional) Create a Plot for Per-Class F1 Scores ---
report = evaluation_results['classification_report']
class_f1_scores = {cls: data['f1-score'] for cls, data in report.items() if cls in class_names}

plt.figure(figsize=(10, 6))
sns.barplot(x=list(class_f1_scores.keys()), y=list(class_f1_scores.values()))
plt.title('Per-Class F1-Scores', fontsize=16)
plt.ylabel('F1-Score', fontsize=12)
plt.xlabel('Severity Class', fontsize=12)
plt.ylim(0, 1) # Set y-axis from 0 to 1 for scores

# Save the plot
f1_plot_path = f"{LOG_PATH}/f1_scores.png"
plt.savefig(f1_plot_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"F1-scores plot saved to: {f1_plot_path}")

Confusion matrix plot saved to: /content/drive/MyDrive/Diabetic_Retinopathy_Project/training_logs/confusion_matrix.png
F1-scores plot saved to: /content/drive/MyDrive/Diabetic_Retinopathy_Project/training_logs/f1_scores.png


In [None]:
# =============================================================================
# CELL 17: Natural Language Diagnosis Generation
# =============================================================================
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Ensure the model is on the correct device and in evaluation mode
model.to(device)
model.eval()

# We need the validation transforms to process new images
def get_val_transforms():
    return A.Compose([
        A.Resize(CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE']),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

val_transforms = get_val_transforms()

def generate_diagnosis(image_path, confidence_threshold=0.5):
    """
    Takes an image path, runs it through the model, and generates a
    human-readable diagnostic report.
    """
    # 1. Load and Preprocess the Image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Apply the same transformations as the validation set
    transformed = val_transforms(image=image)
    image_tensor = transformed['image'].unsqueeze(0).to(device)

    # 2. Get Model Predictions
    with torch.no_grad():
        outputs = model(image_tensor)

    # 3. Post-process the Outputs

    # --- Severity Prediction ---
    severity_probs = torch.softmax(outputs['severity'], dim=1)
    confidence, predicted_class = torch.max(severity_probs, dim=1)
    severity_level = predicted_class.item()
    confidence = confidence.item()

    # --- Lesion Predictions ---
    lesion_probs = torch.sigmoid(outputs['lesions']).squeeze().cpu().numpy()
    lesion_findings = {
        name: prob > confidence_threshold
        for name, prob in zip(CONFIG['LESION_TYPES'], lesion_probs)
    }

    # --- Region Predictions ---
    region_probs = torch.sigmoid(outputs['regions']).squeeze().cpu().numpy()
    region_findings = {
        name: prob > confidence_threshold
        for name, prob in zip(CONFIG['ANATOMICAL_REGIONS'], region_probs)
    }

    # 4. Generate the Natural Language Report
    # We use the medical knowledge base from CELL 5!
    diagnostic_report = medical_kb.generate_description_template(
        severity_level,
        lesion_findings,
        region_findings,
        confidence
    )

    return diagnostic_report, image

# =============================================================================
# --- EXAMPLE USAGE ---
# =============================================================================

# Let's test it on a random image from our validation set
# (You can replace this with any image path)
try:
    # 'val_paths' was created in CELL 10. If it doesn't exist, use a placeholder.
    if 'val_paths' in locals():
      test_image_path = np.random.choice(val_paths)
    else:
      # You can manually set a path to an image here if needed
      test_image_path = "/content/drive/MyDrive/Diabetic_Retinopathy_Project/diabetic-retinopathy-224x224-gaussian-filtered/Moderate/0024cdab0c1e.png"

    print(f"Generating diagnosis for image: {test_image_path}\n")

    # Generate the report
    report, original_image = generate_diagnosis(test_image_path)

    # Display the results
    plt.imshow(original_image)
    plt.axis('off')
    plt.title("Input Image")
    plt.show()

    print("="*80)
    print("                DIAGNOSTIC REPORT")
    print("="*80)
    print(report)
    print("="*80)

except Exception as e:
    print(f"An error occurred during example usage: {e}")
    print("Please ensure 'val_paths' is available from CELL 10 or provide a manual image path.")

Generating diagnosis for image: /content/drive/MyDrive/Diabetic_Retinopathy_Project/diabetic-retinopathy-224x224-gaussian-filtered/Mild/274f5029189b.png

                DIAGNOSTIC REPORT
Mild Non-Proliferative Diabetic Retinopathy (confidence: 86.8%). Detected lesions: soft_exudates, neovascularization. Affected regions: macular. Ophthalmology referral within 12 months. Optimize diabetic control.


In [None]:
# =============================================================================
# CELL 18: Post-Hoc QWK Evaluation for V1 Model
# =============================================================================
from sklearn.metrics import cohen_kappa_score, accuracy_score, f1_score, confusion_matrix, classification_report
from tqdm.auto import tqdm

def evaluate_model_with_qwk(model, data_loader):
    """A self-contained evaluation function that calculates all key metrics."""
    model.eval()
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating Model"):
            images = batch['image'].to(device)
            targets = batch['severity'].to(device)

            with autocast(enabled=CONFIG['USE_AMP']):
                outputs = model(images)

            predictions = torch.argmax(outputs['severity'], dim=1)
            all_predictions.extend(predictions.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

    qwk = cohen_kappa_score(all_targets, all_predictions, weights='quadratic')
    acc = accuracy_score(all_targets, all_predictions)
    f1 = f1_score(all_targets, all_predictions, average='weighted', zero_division=0)

    return {'qwk': qwk, 'accuracy': acc, 'f1_weighted': f1}

# =============================================================================
# --- Main V1 Evaluation Execution ---
# =============================================================================

# 1. Define the path to your V1 model
v1_model_path = f"{MODEL_SAVE_PATH}/best_model.pth"
logger.info(f"Loading V1 model from: {v1_model_path}")

try:
    # 2. Load the V1 checkpoint
    checkpoint_v1 = torch.load(v1_model_path, weights_only=False)

    # 3. Create a fresh model instance and load the V1 weights into it
    # We must use the V1 config for this model!
    v1_config = checkpoint_v1.get('config', CONFIG) # Get config from checkpoint if available
    model_v1 = MultiTaskDRModel(
        model_name=v1_config['MODEL_NAME'],
        num_classes=v1_config['NUM_CLASSES']
    ).to(device)

    model_v1.load_state_dict(checkpoint_v1['model_state_dict'])
    logger.info("Successfully loaded V1 model weights.")

    # 4. Create a temporary DataLoader with the V1 image size
    # This is crucial for a fair and correct evaluation
    v1_val_transforms = A.Compose([
        A.Resize(v1_config['IMG_SIZE'], v1_config['IMG_SIZE']),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    v1_val_dataset = OptimizedDRDataset(val_paths, val_labels, medical_kb, transforms=v1_val_transforms)
    v1_val_loader = DataLoader(v1_val_dataset, batch_size=32, shuffle=False, num_workers=2) # Use V1 batch size

    # 5. Run the evaluation
    logger.info("Running evaluation on V1 model to calculate QWK...")
    v1_metrics = evaluate_model_with_qwk(model_v1, v1_val_loader)

    # 6. Print the results
    print("\n" + "="*50)
    print("      V1 Model Final Performance Review")
    print("="*50)
    print(f"  Accuracy:         {v1_metrics['accuracy']:.4f}")
    print(f"  F1-Score (W):     {v1_metrics['f1_weighted']:.4f}")
    print(f"  Quadratic Kappa:  {v1_metrics['qwk']:.4f}  <--- V1 QWK Score")
    print("="*50)

except FileNotFoundError:
    logger.error(f"Could not find the V1 model at {v1_model_path}. Please ensure it is saved.")
except Exception as e:
    logger.error(f"An error occurred during V1 evaluation: {e}")

2025-08-19 06:56:57,381 - INFO - Loading V1 model from: /content/drive/MyDrive/Diabetic_Retinopathy_Project/models/best_model.pth
2025-08-19 06:57:15,700 - INFO - Loading pretrained weights from Hugging Face hub (timm/efficientnet_b3.ra2_in1k)
2025-08-19 06:57:15,980 - INFO - [timm/efficientnet_b3.ra2_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-08-19 06:57:16,307 - INFO - Backbone feature dimension: 1536
2025-08-19 06:57:16,399 - INFO - Successfully loaded V1 model weights.
2025-08-19 06:57:16,408 - INFO - Dataset initialized with 735 images. Images will be loaded from disk.
2025-08-19 06:57:16,409 - INFO - Running evaluation on V1 model to calculate QWK...


Evaluating Model:   0%|          | 0/23 [00:00<?, ?it/s]


      V1 Model Final Performance Review
  Accuracy:         0.7306
  F1-Score (W):     0.7394
  Quadratic Kappa:  0.7608  <--- V1 QWK Score
