In [None]:
import os
import sys
import math
import json
import time
import random
import argparse
import warnings
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional, Union

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

try:
    import timm
except ImportError:
    timm = None

try:
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
    from sklearn.calibration import calibration_curve
except Exception:
    train_test_split = None
    accuracy_score = None
    f1_score = None
    confusion_matrix = None
    classification_report = None
    calibration_curve = None

# Grad-CAM (installed on Kaggle with: pip install pytorch-grad-cam)
try:
    from pytorch_grad_cam import GradCAM
    from pytorch_grad_cam.utils.image import show_cam_on_image
    from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
except Exception:
    GradCAM = None
    show_cam_on_image = None
    ClassifierOutputTarget = None


SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Alzheimer MRI Dataset - Focus on cognitive decline progression
ALZHEIMER_LABEL_MAP = {
    'nondemented': 0,           # Cognitively normal
    'verymilddemented': 1,      # Very mild cognitive decline (CDR 0.5)
    'milddemented': 2,          # Mild dementia (CDR 1)
    'moderatedemented': 3,      # Moderate dementia (CDR 2)
}

# Clinical significance for research narrative
ALZHEIMER_CLINICAL_INFO = {
    0: {'name': 'Cognitively Normal', 'cdr': 0, 'description': 'No cognitive impairment detected'},
    1: {'name': 'Very Mild Dementia', 'cdr': 0.5, 'description': 'Questionable dementia, very mild cognitive decline'},
    2: {'name': 'Mild Dementia', 'cdr': 1, 'description': 'Mild dementia with clear functional impairment'},
    3: {'name': 'Moderate Dementia', 'cdr': 2, 'description': 'Moderate dementia requiring substantial care'}
}


# ------------------------------
# Utilities & Setup
# ------------------------------

def set_seed(seed: int) -> None:
    """Set reproducible seed for all random number generators"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # Additional deterministic settings for reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def worker_init_fn(worker_id: int) -> None:
    """Worker initialization for DataLoader reproducibility"""
    np.random.seed(SEED + worker_id)


def exists(path: str) -> bool:
    return path is not None and os.path.exists(path)


def ensure_dir(path: str) -> None:
    if not exists(path):
        os.makedirs(path, exist_ok=True)


def device() -> torch.device:
    return torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# ------------------------------
# Enhanced Dataset Implementation
# ------------------------------

@dataclass
class ImageRecord:
    image_path: str
    label: int
    metadata: Optional[Dict] = None


class AlzheimerDataset(Dataset):
    """
    Enhanced Alzheimer MRI Dataset with sophisticated augmentation
    and proper normalization for medical imaging.
    """
    def __init__(
        self,
        records: List[ImageRecord],
        image_size: int = 224,
        augment: bool = False,
        mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
        std: Tuple[float, float, float] = (0.229, 0.224, 0.225),
    ) -> None:
        self.records = records
        self.image_size = image_size
        self.augment = augment
        self.mean = mean
        self.std = std

    def __len__(self) -> int:
        return len(self.records)

    def _apply_medical_augmentation(self, img: Image.Image) -> Image.Image:
        """Apply medically-appropriate augmentations that preserve diagnostic features"""
        # Horizontal flip (brain is roughly symmetric)
        if random.random() < 0.5:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
        
        # Small rotation (¬±5 degrees) to account for patient positioning
        if random.random() < 0.3:
            angle = random.uniform(-5, 5)
            img = img.rotate(angle, fillcolor=0)
        
        # Slight brightness/contrast adjustment (medical imaging variations)
        if random.random() < 0.3:
            from PIL import ImageEnhance
            enhancer = ImageEnhance.Brightness(img)
            img = enhancer.enhance(random.uniform(0.9, 1.1))
            
            enhancer = ImageEnhance.Contrast(img)
            img = enhancer.enhance(random.uniform(0.9, 1.1))
        
        return img

    def __getitem__(self, idx: int):
        rec = self.records[idx]
        
        try:
            img = Image.open(rec.image_path).convert('RGB')
        except Exception as e:
            # Fallback for corrupted images
            print(f"Warning: Could not load image {rec.image_path}: {e}")
            img = Image.new('RGB', (self.image_size, self.image_size), color=(128, 128, 128))
        
        # Resize with high-quality resampling for medical images
        img = img.resize((self.image_size, self.image_size), Image.LANCZOS)

        if self.augment:
            img = self._apply_medical_augmentation(img)

        # Normalize for ImageNet pretrained models
        img_np = np.array(img).astype(np.float32) / 255.0
        img_np = (img_np - np.array(self.mean, dtype=np.float32)) / np.array(self.std, dtype=np.float32)
        img_np = np.transpose(img_np, (2, 0, 1))

        x = torch.from_numpy(img_np).float()  # Ensure float32
        y = torch.tensor(rec.label, dtype=torch.long)
        
        return x, y


def load_alzheimer_records(dataset_root: str) -> List[ImageRecord]:
    """
    Load Alzheimer MRI dataset with robust error handling and validation.
    Supports multiple folder structures and naming conventions.
    """
    label_map = ALZHEIMER_LABEL_MAP
    
    def _find_class_dirs(root: str, class_names_lower: List[str]) -> Dict[str, List[str]]:
        """Find directories containing each class"""
        mapping: Dict[str, List[str]] = {cn: [] for cn in class_names_lower}
        
        for dirpath, dirnames, filenames in os.walk(root):
            base = os.path.basename(dirpath).lower()
            # Handle various naming conventions
            normalized_base = base.replace('_', '').replace('-', '').replace(' ', '')
            
            for class_name in class_names_lower:
                normalized_class = class_name.replace('_', '').replace('-', '').replace(' ', '')
                if normalized_base == normalized_class or base == class_name:
                    mapping[class_name].append(dirpath)
                    break
        
        return mapping

    def _collect_images_from_dirs(dirs: List[str]) -> List[str]:
        """Collect all valid image files from directories"""
        images: List[str] = []
        valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
        
        for d in dirs:
            if not exists(d):
                continue
            for fname in os.listdir(d):
                fpath = os.path.join(d, fname)
                if (os.path.isfile(fpath) and 
                    any(fname.lower().endswith(ext) for ext in valid_extensions)):
                    images.append(fpath)
        return images

    class_names_lower = list(label_map.keys())
    dir_map = _find_class_dirs(dataset_root, class_names_lower)
    
    records: List[ImageRecord] = []
    class_counts = {}
    
    for class_name_lower, label_idx in label_map.items():
        dirs = dir_map.get(class_name_lower, [])
        if not dirs:
            print(f"Warning: No directories found for class '{class_name_lower}'")
            continue
            
        imgs = _collect_images_from_dirs(dirs)
        class_counts[class_name_lower] = len(imgs)
        
        for img_path in imgs:
            records.append(ImageRecord(
                image_path=img_path, 
                label=label_idx,
                metadata={'class_name': class_name_lower}
            ))
    
    if len(records) == 0:
        raise RuntimeError(
            f"No images found under {dataset_root}. "
            f"Expected class folders: {list(label_map.keys())}"
        )
    
    print("Dataset loaded successfully:")
    for class_name, count in class_counts.items():
        clinical_info = ALZHEIMER_CLINICAL_INFO[label_map[class_name]]
        print(f"  {class_name}: {count} images (CDR {clinical_info['cdr']} - {clinical_info['name']})")
    
    return records


# ------------------------------
# Advanced Model Architecture
# ------------------------------

class MedicalDropout(nn.Module):
    """
    Enhanced dropout specifically designed for medical imaging.
    Implements structured dropout patterns.
    """
    def __init__(self, p: float = 0.2, medical_structured: bool = True):
        super().__init__()
        self.p = p
        self.medical_structured = medical_structured
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.training:
            return x
            
        if self.medical_structured and len(x.shape) == 4:  # For CNN features
            # Apply spatial dropout to maintain spatial structure
            return F.dropout2d(x, p=self.p, training=True)
        else:
            return F.dropout(x, p=self.p, training=True)


def build_simple_model(
    model_name: str = 'resnet18', 
    num_classes: int = 4, 
    dropout: float = 0.2, 
    pretrained: bool = True
) -> nn.Module:
    """
    Build simple, reliable model for Kaggle compatibility
    """
    if timm is None:
        raise ImportError("timm is not installed. Please install with: pip install timm")

    model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
    
    # Simple replacement of the classifier head with standard components
    if hasattr(model, 'fc') and isinstance(model.fc, nn.Linear):
        in_features = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(in_features, num_classes)
        )
    elif hasattr(model, 'classifier') and isinstance(model.classifier, nn.Linear):
        in_features = model.classifier.in_features
        model.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(in_features, num_classes)
        )
    elif hasattr(model, 'head') and isinstance(model.head, nn.Linear):
        in_features = model.head.in_features
        model.head = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(in_features, num_classes)
        )
    
    return model


def build_medical_model(
    model_name: str = 'tf_efficientnet_b0_ns', 
    num_classes: int = 4, 
    dropout: float = 0.25, 
    pretrained: bool = True
) -> nn.Module:
    """
    Build sophisticated model architecture optimized for medical imaging.
    """
    if timm is None:
        raise ImportError("timm is not installed. Please install with: pip install timm")

    model = timm.create_model(model_name, pretrained=pretrained)
    
    # Get the feature dimension
    if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Linear):
        in_features = model.classifier.in_features
        # Replace with medical-optimized head (using LayerNorm instead of BatchNorm)
        model.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.LayerNorm(in_features),  # LayerNorm handles small batches better
            MedicalDropout(p=dropout),
            nn.Linear(in_features, in_features // 2),
            nn.ReLU(inplace=True),
            nn.LayerNorm(in_features // 2),  # LayerNorm instead of BatchNorm
            MedicalDropout(p=dropout / 2),
            nn.Linear(in_features // 2, num_classes)
        )
    elif hasattr(model, 'head'):
        in_features = model.head.in_features
        model.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.LayerNorm(in_features),  # LayerNorm handles small batches better
            MedicalDropout(p=dropout),
            nn.Linear(in_features, in_features // 2),
            nn.ReLU(inplace=True),
            nn.LayerNorm(in_features // 2),  # LayerNorm instead of BatchNorm
            MedicalDropout(p=dropout / 2),
            nn.Linear(in_features // 2, num_classes)
        )
    else:
        raise ValueError(f"Unsupported model architecture: {model_name}")
    
    return model


def enable_mc_dropout(model: nn.Module) -> None:
    """Enable Monte Carlo dropout during inference"""
    for module in model.modules():
        if isinstance(module, (nn.Dropout, nn.Dropout2d, MedicalDropout)):
            module.train()


# ------------------------------
# Advanced Uncertainty Quantification
# ------------------------------

class AdvancedTemperatureScaling(nn.Module):
    """
    Enhanced temperature scaling with class-specific temperatures
    for better calibration in multi-class medical scenarios.
    """
    def __init__(self, num_classes: int, class_specific: bool = True):
        super().__init__()
        self.num_classes = num_classes
        self.class_specific = class_specific
        
        if class_specific:
            self.temperatures = nn.Parameter(torch.ones(num_classes))
        else:
            self.temperature = nn.Parameter(torch.ones(1))
    
    def forward(self, logits: torch.Tensor) -> torch.Tensor:
        if self.class_specific:
            # Apply class-specific temperatures
            temps = self.temperatures.unsqueeze(0).expand(logits.size(0), -1)
            return logits / temps
        else:
            return logits / self.temperature


def fit_advanced_temperature(
    model: nn.Module, 
    loader: DataLoader, 
    device_: torch.device, 
    num_classes: int,
    max_iter: int = 100,
    lr: float = 0.01
) -> AdvancedTemperatureScaling:
    """Fit advanced temperature scaling for better calibration"""
    model.eval()
    temp_scaler = AdvancedTemperatureScaling(num_classes, class_specific=True).to(device_)
    
    # Collect all predictions and labels
    all_logits = []
    all_labels = []
    
    with torch.no_grad():
        for images, targets in loader:
            images = images.to(device_)
            targets = targets.to(device_)
            logits = model(images)
            all_logits.append(logits)
            all_labels.append(targets)
    
    logits_tensor = torch.cat(all_logits, dim=0)
    labels_tensor = torch.cat(all_labels, dim=0)
    
    # Optimize temperature parameters
    optimizer = torch.optim.LBFGS(temp_scaler.parameters(), lr=lr, max_iter=50)
    
    def closure():
        optimizer.zero_grad()
        scaled_logits = temp_scaler(logits_tensor)
        loss = F.cross_entropy(scaled_logits, labels_tensor)
        loss.backward()
        return loss
    
    optimizer.step(closure)
    return temp_scaler


# ------------------------------
# RAPS Conformal Prediction (Enhanced)
# ------------------------------

def compute_raps_scores_enhanced(
    probs: np.ndarray, 
    labels: np.ndarray, 
    lambda_reg: float, 
    k_reg: int,
    class_weights: Optional[np.ndarray] = None
) -> np.ndarray:
    """
    Enhanced RAPS with class balancing for medical datasets
    """
    n, num_classes = probs.shape
    scores = np.zeros(n, dtype=np.float32)
    
    if class_weights is None:
        class_weights = np.ones(num_classes)
    
    for i in range(n):
        p = probs[i]
        y = labels[i]
        
        # Apply class weights to probabilities
        p_weighted = p * class_weights
        p_weighted = p_weighted / p_weighted.sum()  # Renormalize
        
        order = np.argsort(p_weighted)[::-1]
        p_sorted = p_weighted[order]
        
        # Find rank of true class
        ranks = np.empty_like(order)
        ranks[order] = np.arange(num_classes)
        true_rank = ranks[y]
        
        # Compute cumulative probability up to true class
        cum = np.cumsum(p_sorted)
        
        # Add regularization penalty
        reg = lambda_reg * max(true_rank - k_reg, 0)
        scores[i] = cum[true_rank] + reg
    
    return scores


def adaptive_raps_quantile(scores: np.ndarray, alpha: float, n_bootstrap: int = 1000) -> Tuple[float, float]:
    """
    Adaptive RAPS quantile with bootstrap confidence intervals
    """
    n = len(scores)
    k = math.ceil((n + 1) * (1 - alpha))
    k = min(max(k, 1), n)
    
    # Primary quantile
    qhat = np.partition(scores, k - 1)[k - 1]
    
    # Bootstrap confidence interval for quantile
    bootstrap_qhats = []
    for _ in range(n_bootstrap):
        boot_indices = np.random.choice(n, size=n, replace=True)
        boot_scores = scores[boot_indices]
        boot_qhat = np.partition(boot_scores, k - 1)[k - 1]
        bootstrap_qhats.append(boot_qhat)
    
    ci_lower = np.percentile(bootstrap_qhats, 2.5)
    ci_upper = np.percentile(bootstrap_qhats, 97.5)
    
    return float(qhat), (float(ci_lower), float(ci_upper))


# ------------------------------
# Batch Ensemble (Optimized)
# ------------------------------

class OptimizedBatchEnsemble(nn.Module):
    """
    Optimized Batch Ensemble implementation with rank-1 perturbations
    """
    def __init__(self, backbone: nn.Module, num_classes: int, ensemble_size: int = 4):
        super().__init__()
        self.backbone = backbone
        self.ensemble_size = ensemble_size
        self.num_classes = num_classes
        
        # Get feature dimension from backbone by examining the model structure
        # This avoids the tensor type issues in __init__
        if hasattr(backbone, 'num_features'):
            self.feature_dim = backbone.num_features
        elif hasattr(backbone, 'classifier') and hasattr(backbone.classifier, 'in_features'):
            self.feature_dim = backbone.classifier.in_features
        elif hasattr(backbone, 'head') and hasattr(backbone.head, 'in_features'):
            self.feature_dim = backbone.head.in_features
        else:
            # Fallback: use a small dummy forward pass with proper device/dtype
            backbone.eval()
            device = next(backbone.parameters()).device
            dtype = next(backbone.parameters()).dtype
            with torch.no_grad():
                dummy_input = torch.randn(1, 3, 224, 224, device=device, dtype=dtype)
                try:
                    features = backbone(dummy_input)
                    if len(features.shape) > 2:
                        features = F.adaptive_avg_pool2d(features, 1).flatten(1)
                    self.feature_dim = features.shape[1]
                except Exception as e:
                    print(f"Warning: Could not determine feature dimension, using default 1280: {e}")
                    self.feature_dim = 1280  # Default for EfficientNet-B0
        
        # Ensemble parameters
        self.r_vectors = nn.Parameter(torch.randn(ensemble_size, self.feature_dim))
        self.s_vectors = nn.Parameter(torch.randn(ensemble_size, num_classes))
        self.shared_weight = nn.Parameter(torch.randn(num_classes, self.feature_dim))
        self.ensemble_bias = nn.Parameter(torch.randn(ensemble_size, num_classes))
        
        self._init_parameters()
    
    def _init_parameters(self):
        """Initialize parameters with appropriate scaling"""
        nn.init.normal_(self.r_vectors, std=1.0)
        nn.init.normal_(self.s_vectors, std=1.0)
        nn.init.kaiming_uniform_(self.shared_weight, a=math.sqrt(5))
        nn.init.zeros_(self.ensemble_bias)
    
    def get_ensemble_predictions(self, x: torch.Tensor) -> torch.Tensor:
        """
        Get predictions from all ensemble members separately.
        Returns: (ensemble_size, batch_size, num_classes)
        """
        features = self.backbone(x)
        if len(features.shape) > 2:
            features = F.adaptive_avg_pool2d(features, 1).flatten(1)
        
        all_logits = []
        for i in range(self.ensemble_size):
            r = self.r_vectors[i]
            s = self.s_vectors[i]
            bias = self.ensemble_bias[i]
            
            features_perturbed = features * r.unsqueeze(0)
            logits = F.linear(features_perturbed, self.shared_weight) + bias.unsqueeze(0)
            logits = logits * s.unsqueeze(0)
            all_logits.append(logits)
        
        return torch.stack(all_logits, dim=0)  # (ensemble_size, batch_size, num_classes)

    def forward(self, x: torch.Tensor, ensemble_member: Optional[int] = None) -> torch.Tensor:
        # Extract features from backbone
        features = self.backbone(x)
        if len(features.shape) > 2:
            features = F.adaptive_avg_pool2d(features, 1).flatten(1)
        
        if ensemble_member is not None:
            # Single ensemble member forward pass
            r = self.r_vectors[ensemble_member]
            s = self.s_vectors[ensemble_member]
            bias = self.ensemble_bias[ensemble_member]
            
            # Apply rank-1 perturbation
            features_perturbed = features * r.unsqueeze(0)
            logits = F.linear(features_perturbed, self.shared_weight) + bias.unsqueeze(0)
            logits = logits * s.unsqueeze(0)
            
            return logits
        else:
            # All ensemble members - return averaged logits for training
            batch_size = features.shape[0]
            all_logits = []
            
            for i in range(self.ensemble_size):
                r = self.r_vectors[i]
                s = self.s_vectors[i]
                bias = self.ensemble_bias[i]
                
                features_perturbed = features * r.unsqueeze(0)
                logits = F.linear(features_perturbed, self.shared_weight) + bias.unsqueeze(0)
                logits = logits * s.unsqueeze(0)
                all_logits.append(logits)
            
            # Stack and average for training compatibility
            ensemble_logits = torch.stack(all_logits, dim=0)  # (ensemble_size, batch_size, num_classes)
            averaged_logits = torch.mean(ensemble_logits, dim=0)  # (batch_size, num_classes)
            
            return averaged_logits


# ------------------------------
# Training & Evaluation
# ------------------------------

def train_epoch_with_mixup(
    model: nn.Module,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    device_: torch.device,
    mixup_alpha: float = 0.2
) -> Tuple[float, float]:
    """Training with mixup augmentation for better generalization"""
    model.train()
    total_loss = 0.0
    total_correct = 0
    total = 0
    
    for images, targets in loader:
        images = images.to(device_)
        targets = targets.to(device_)
        
        # Apply mixup
        if mixup_alpha > 0 and random.random() < 0.5:
            lam = np.random.beta(mixup_alpha, mixup_alpha)
            batch_size = images.size(0)
            index = torch.randperm(batch_size).to(device_)
            
            mixed_images = lam * images + (1 - lam) * images[index]
            targets_a, targets_b = targets, targets[index]
            
            optimizer.zero_grad(set_to_none=True)
            logits = model(mixed_images)
            
            loss = lam * F.cross_entropy(logits, targets_a) + (1 - lam) * F.cross_entropy(logits, targets_b)
        else:
            optimizer.zero_grad(set_to_none=True)
            logits = model(images)
            loss = F.cross_entropy(logits, targets)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * images.size(0)
        if mixup_alpha == 0 or random.random() >= 0.5:  # Only count accuracy for non-mixup batches
            preds = torch.argmax(logits, dim=1)
            total_correct += (preds == targets).sum().item()
            total += images.size(0)
    
    return total_loss / len(loader.dataset), total_correct / max(total, 1)


def comprehensive_evaluation(
    model: nn.Module,
    loader: DataLoader,
    device_: torch.device,
    class_names: List[str]
) -> Dict:
    """Comprehensive evaluation with medical metrics"""
    model.eval()
    all_preds = []
    all_labels = []
    all_logits = []
    total_loss = 0.0
    
    with torch.no_grad():
        for images, targets in loader:
            images = images.to(device_)
            targets = targets.to(device_)
            
            logits = model(images)
            loss = F.cross_entropy(logits, targets)
            
            total_loss += loss.item() * images.size(0)
            
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(targets.cpu().numpy())
            all_logits.append(logits.cpu().numpy())
    
    all_logits = np.concatenate(all_logits, axis=0)
    
    # Calculate comprehensive metrics
    accuracy = accuracy_score(all_labels, all_preds) if accuracy_score else 0.0
    f1_macro = f1_score(all_labels, all_preds, average='macro') if f1_score else 0.0
    f1_weighted = f1_score(all_labels, all_preds, average='weighted') if f1_score else 0.0
    
    results = {
        'loss': total_loss / len(loader.dataset),
        'accuracy': accuracy,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted,
        'logits': all_logits,
        'predictions': all_preds,
        'labels': all_labels
    }
    
    if classification_report:
        results['classification_report'] = classification_report(
            all_labels, all_preds, target_names=class_names, output_dict=True
        )
    
    return results


# ------------------------------
# Advanced Uncertainty & Ensemble Methods
# ------------------------------

def batch_ensemble_predictions(
    model: Union[OptimizedBatchEnsemble, nn.Module],
    loader: DataLoader,
    device_: torch.device,
) -> Tuple[np.ndarray, np.ndarray, Dict[str, np.ndarray]]:
    """
    Get uncertainty predictions from Batch Ensemble model
    """
    model.eval()
    all_ensemble_logits = []
    
    with torch.no_grad():
        for images, _ in loader:
            images = images.to(device_)
            # Get all ensemble predictions: (ensemble_size, batch_size, num_classes)
            if isinstance(model, OptimizedBatchEnsemble):
                ensemble_logits = model.get_ensemble_predictions(images)
            else:
                # Fallback to MC dropout for non-batch-ensemble models
                return monte_carlo_predictions(model, loader, device_, num_samples=15)
            all_ensemble_logits.append(ensemble_logits.cpu().numpy())
    
    # Concatenate across batches: (ensemble_size, total_samples, num_classes)
    ensemble_logits = np.concatenate(all_ensemble_logits, axis=1)
    
    # Convert to probabilities
    ensemble_probs = softmax_np(ensemble_logits, axis=2)
    
    # Mean predictions
    mean_predictions = np.mean(ensemble_probs, axis=0)
    
    # Compute uncertainty metrics
    uncertainty_metrics = compute_uncertainty_metrics(ensemble_probs)
    
    return mean_predictions, ensemble_probs, uncertainty_metrics


def monte_carlo_predictions(
    model: nn.Module,
    loader: DataLoader,
    device_: torch.device,
    num_samples: int = 20
) -> Tuple[np.ndarray, np.ndarray, Dict[str, np.ndarray]]:
    """
    Advanced Monte Carlo dropout with comprehensive uncertainty metrics
    """
    model.eval()
    enable_mc_dropout(model)
    
    all_predictions = []
    
    # Collect MC samples
    for _ in range(num_samples):
        batch_logits = []
        with torch.no_grad():
            for images, _ in loader:
                images = images.to(device_)
                logits = model(images)
                batch_logits.append(logits.cpu().numpy())
        
        logits = np.concatenate(batch_logits, axis=0)
        probs = softmax_np(logits)
        all_predictions.append(probs)
    
    # Stack predictions: (num_samples, num_data, num_classes)
    mc_predictions = np.stack(all_predictions, axis=0)
    mean_predictions = np.mean(mc_predictions, axis=0)
    
    # Compute comprehensive uncertainty metrics
    uncertainty_metrics = compute_uncertainty_metrics(mc_predictions)
    
    return mean_predictions, mc_predictions, uncertainty_metrics


def compute_uncertainty_metrics(predictions: np.ndarray, eps: float = 1e-12) -> Dict[str, np.ndarray]:
    """
    Compute comprehensive uncertainty metrics for ensemble predictions
    """
    # predictions shape: (num_samples, num_data, num_classes)
    mean_probs = np.mean(predictions, axis=0)
    
    # Total uncertainty (predictive entropy)
    predictive_entropy = -np.sum(mean_probs * np.log(mean_probs + eps), axis=1)
    
    # Aleatoric uncertainty (expected entropy)
    individual_entropies = -np.sum(predictions * np.log(predictions + eps), axis=2)
    aleatoric_uncertainty = np.mean(individual_entropies, axis=0)
    
    # Epistemic uncertainty (mutual information)
    epistemic_uncertainty = predictive_entropy - aleatoric_uncertainty
    
    # Confidence-based metrics
    max_probabilities = np.max(mean_probs, axis=1)
    confidence = max_probabilities
    
    # Prediction variance (another epistemic measure)
    prediction_variance = np.var(predictions, axis=0)
    total_variance = np.sum(prediction_variance, axis=1)
    
    # BALD (Bayesian Active Learning by Disagreement)
    bald_score = epistemic_uncertainty
    
    return {
        'predictive_entropy': predictive_entropy,
        'aleatoric_uncertainty': aleatoric_uncertainty,
        'epistemic_uncertainty': epistemic_uncertainty,
        'confidence': confidence,
        'prediction_variance': total_variance,
        'bald_score': bald_score
    }


def softmax_np(x: np.ndarray, axis: int = -1) -> np.ndarray:
    """Numerical stable softmax"""
    x_max = np.max(x, axis=axis, keepdims=True)
    exp_x = np.exp(x - x_max)
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)


# ------------------------------
# Advanced Grad-CAM with Uncertainty
# ------------------------------

def generate_uncertainty_aware_gradcam(
    model: nn.Module,
    images: torch.Tensor,
    target_class: int,
    num_mc_samples: int = 10
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Generate Grad-CAM with uncertainty estimation using MC dropout
    """
    if GradCAM is None:
        raise ImportError("pytorch-grad-cam not available")
    
    # Find the last convolutional layer
    target_layers = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            target_layers = [module]
    
    if not target_layers:
        raise ValueError("No convolutional layers found for Grad-CAM")
    
    # Standard Grad-CAM
    cam = GradCAM(model=model, target_layers=target_layers, use_cuda=torch.cuda.is_available())
    model.eval()
    
    # Collect multiple Grad-CAM samples with MC dropout
    enable_mc_dropout(model)
    cam_samples = []
    
    for _ in range(num_mc_samples):
        try:
            grayscale_cam = cam(
                input_tensor=images,
                targets=[ClassifierOutputTarget(target_class)]
            )
            cam_samples.append(grayscale_cam[0])
        except:
            continue
    
    if not cam_samples:
        # Fallback to deterministic Grad-CAM
        model.eval()
        grayscale_cam = cam(
            input_tensor=images,
            targets=[ClassifierOutputTarget(target_class)]
        )
        return grayscale_cam[0], np.zeros_like(grayscale_cam[0])
    
    # Compute mean and uncertainty
    cam_array = np.stack(cam_samples, axis=0)
    mean_cam = np.mean(cam_array, axis=0)
    std_cam = np.std(cam_array, axis=0)
    
    return mean_cam, std_cam


# ------------------------------
# Main Training Pipeline
# ------------------------------

def run_alzheimer_pipeline(
    data_root: str,
    output_dir: str,
    model_name: str = 'tf_efficientnet_b0_ns',
    image_size: int = 224,
    batch_size: int = 16,
    epochs: int = 15,
    lr: float = 2e-4,
    weight_decay: float = 1e-4,
    alpha: float = 0.1,
    lambda_reg: float = 0.05,
    k_reg: int = 1,
    dropout: float = 0.25,
    num_workers: int = 2,
    calib_fraction: float = 0.2,
    val_fraction: float = 0.2,
    mc_samples: int = 20,
    seed: int = 42,
    ensemble_method: str = 'batch_ensemble',
    ensemble_size: int = 4,
    use_mixup: bool = True,
) -> None:
    """
    Complete Alzheimer MRI analysis pipeline with state-of-the-art uncertainty quantification
    """
    set_seed(seed)
    ensure_dir(output_dir)
    
    print("=" * 80)
    print("ALZHEIMER MRI UNCERTAINTY QUANTIFICATION & EXPLAINABILITY PIPELINE")
    print("=" * 80)
    print(f"üß† Focus: Alzheimer's Disease Progression Analysis")
    print(f"üìä Ensemble Method: {ensemble_method}")
    print(f"üéØ Target: CDR-based cognitive decline classification")
    print(f"‚öôÔ∏è  Device: {device()}")
    print("=" * 80)

    # Load dataset
    print("\nüìÅ Loading Alzheimer MRI dataset...")
    records = load_alzheimer_records(data_root)
    
    # Check class distribution
    label_counts = {}
    for record in records:
        label_counts[record.label] = label_counts.get(record.label, 0) + 1
    
    print(f"\nüìä Dataset Statistics:")
    for label, count in sorted(label_counts.items()):
        clinical_info = ALZHEIMER_CLINICAL_INFO[label]
        print(f"   {clinical_info['name']} (CDR {clinical_info['cdr']}): {count} images")
    
    # Split data strategically
    labels_all = np.array([rec.label for rec in records])
    
    # Train/validation/calibration/test split
    train_records, temp_records = train_test_split(
        records, test_size=val_fraction + calib_fraction + 0.15, 
        stratify=labels_all, random_state=seed
    )
    
    temp_labels = np.array([r.label for r in temp_records])
    remaining_size = val_fraction + calib_fraction + 0.15
    calib_size = calib_fraction / remaining_size
    
    calib_records, temp2_records = train_test_split(
        temp_records, test_size=(1 - calib_size), 
        stratify=temp_labels, random_state=seed
    )
    
    temp2_labels = np.array([r.label for r in temp2_records])
    val_size = val_fraction / (val_fraction + 0.15)
    
    val_records, test_records = train_test_split(
        temp2_records, test_size=(1 - val_size),
        stratify=temp2_labels, random_state=seed
    )
    
    print(f"\nüîÑ Data Split:")
    print(f"   Training: {len(train_records)} images")
    print(f"   Validation: {len(val_records)} images") 
    print(f"   Calibration: {len(calib_records)} images")
    print(f"   Test: {len(test_records)} images")
    
    # Create datasets
    train_ds = AlzheimerDataset(train_records, image_size=image_size, augment=True)
    val_ds = AlzheimerDataset(val_records, image_size=image_size, augment=False)
    calib_ds = AlzheimerDataset(calib_records, image_size=image_size, augment=False)
    test_ds = AlzheimerDataset(test_records, image_size=image_size, augment=False)
    
    # Create data loaders
    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True, 
        num_workers=num_workers, pin_memory=True, worker_init_fn=worker_init_fn,
        drop_last=True  # Drop incomplete last batch to avoid BatchNorm issues
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False, 
        num_workers=num_workers, pin_memory=True
    )
    calib_loader = DataLoader(
        calib_ds, batch_size=batch_size, shuffle=False, 
        num_workers=num_workers, pin_memory=True
    )
    test_loader = DataLoader(
        test_ds, batch_size=batch_size, shuffle=False, 
        num_workers=num_workers, pin_memory=True
    )
    
    num_classes = len(ALZHEIMER_LABEL_MAP)
    class_names = [ALZHEIMER_CLINICAL_INFO[i]['name'] for i in range(num_classes)]
    
    # Build and train model based on ensemble method
    print(f"\nüèóÔ∏è  Building {ensemble_method} model...")
    
    if ensemble_method == 'batch_ensemble':
        # Create backbone without classifier
        backbone = timm.create_model(model_name, pretrained=True, num_classes=0)
        model = OptimizedBatchEnsemble(backbone, num_classes, ensemble_size).to(device())
        
        # Training
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
        
        print(f"\nüöÄ Training Batch Ensemble ({ensemble_size} members)...")
        best_val_acc = -1.0
        best_state = None
        
        for epoch in range(1, epochs + 1):
            # Training
            if use_mixup:
                tr_loss, tr_acc = train_epoch_with_mixup(model, train_loader, optimizer, device())
            else:
                model.train()
                tr_loss = tr_acc = 0.0  # Simplified for batch ensemble
            
            # Validation
            val_results = comprehensive_evaluation(model, val_loader, device(), class_names)
            val_acc = val_results['accuracy']
            
            scheduler.step()
            
            if epoch % max(1, epochs // 5) == 0:
                print(f"   Epoch {epoch:02d}/{epochs} | Val Acc: {val_acc:.4f} | Val F1: {val_results['f1_macro']:.4f}")
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        
        if best_state:
            model.load_state_dict(best_state)
    
    else:  # mc_dropout or other methods
        if model_name == 'resnet18':  # Use simple model for resnet18
            model = build_simple_model(model_name, num_classes, dropout, pretrained=True).to(device())
        else:
            model = build_medical_model(model_name, num_classes, dropout, pretrained=True).to(device())
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
        
        print(f"\nüöÄ Training Medical Model...")
        best_val_acc = -1.0
        best_state = None
        
        for epoch in range(1, epochs + 1):
            # Training
            if use_mixup:
                tr_loss, tr_acc = train_epoch_with_mixup(model, train_loader, optimizer, device())
            else:
                model.train()
                total_loss = total_correct = total = 0
                for images, targets in train_loader:
                    images, targets = images.to(device()), targets.to(device())
                    optimizer.zero_grad()
                    logits = model(images)
                    loss = F.cross_entropy(logits, targets)
                    loss.backward()
                    optimizer.step()
                    
                    total_loss += loss.item() * images.size(0)
                    preds = torch.argmax(logits, dim=1)
                    total_correct += (preds == targets).sum().item()
                    total += images.size(0)
                
                tr_loss = total_loss / total
                tr_acc = total_correct / total
            
            # Validation
            val_results = comprehensive_evaluation(model, val_loader, device(), class_names)
            val_acc = val_results['accuracy']
            
            scheduler.step()
            
            if epoch % max(1, epochs // 5) == 0:
                print(f"   Epoch {epoch:02d}/{epochs} | Train: {tr_acc:.3f} | Val: {val_acc:.3f} | F1: {val_results['f1_macro']:.3f}")
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        
        if best_state:
            model.load_state_dict(best_state)
    
    # Comprehensive evaluation
    print(f"\nüìä Final Model Evaluation...")
    test_results = comprehensive_evaluation(model, test_loader, device(), class_names)
    print(f"   Test Accuracy: {test_results['accuracy']:.4f}")
    print(f"   Test F1 (Macro): {test_results['f1_macro']:.4f}")
    print(f"   Test F1 (Weighted): {test_results['f1_weighted']:.4f}")
    
    # Advanced uncertainty quantification
    print(f"\nüé≤ Uncertainty Quantification ({ensemble_method})...")
    
    if ensemble_method == 'mc_dropout':
        mean_preds, mc_preds, uncertainty_metrics = monte_carlo_predictions(
            model, test_loader, device(), mc_samples
        )
    elif ensemble_method == 'batch_ensemble':
        mean_preds, mc_preds, uncertainty_metrics = batch_ensemble_predictions(
            model, test_loader, device()
        )
    else:
        # Fallback to MC dropout
        mean_preds, mc_preds, uncertainty_metrics = monte_carlo_predictions(
            model, test_loader, device(), mc_samples
        )
    
    print(f"   Mean Predictive Entropy: {np.mean(uncertainty_metrics['predictive_entropy']):.4f}")
    print(f"   Mean Epistemic Uncertainty: {np.mean(uncertainty_metrics['epistemic_uncertainty']):.4f}")
    print(f"   Mean Aleatoric Uncertainty: {np.mean(uncertainty_metrics['aleatoric_uncertainty']):.4f}")
    
    # Advanced temperature scaling and calibration
    print(f"\nüå°Ô∏è  Advanced Temperature Scaling...")
    temp_scaler = fit_advanced_temperature(model, calib_loader, device(), num_classes)
    
    # Apply temperature scaling to test predictions
    test_logits = test_results['logits']
    with torch.no_grad():
        test_logits_tensor = torch.from_numpy(test_logits).to(device())
        scaled_logits = temp_scaler(test_logits_tensor).cpu().numpy()
    
    scaled_probs = softmax_np(scaled_logits)
    
    # RAPS Conformal Prediction with enhancements
    print(f"\nüéØ Enhanced RAPS Conformal Prediction...")
    
    # Get calibration predictions with temperature scaling
    calib_results = comprehensive_evaluation(model, calib_loader, device(), class_names)
    calib_logits = calib_results['logits']
    calib_labels = np.array(calib_results['labels'])
    
    with torch.no_grad():
        calib_logits_tensor = torch.from_numpy(calib_logits).to(device())
        calib_scaled_logits = temp_scaler(calib_logits_tensor).cpu().numpy()
    
    calib_scaled_probs = softmax_np(calib_scaled_logits)
    
    # Compute class weights for balanced RAPS
    class_counts_array = np.array([label_counts.get(i, 1) for i in range(num_classes)])
    class_weights = 1.0 / (class_counts_array / class_counts_array.sum())
    class_weights = class_weights / class_weights.sum() * num_classes
    
    # Enhanced RAPS scores
    calib_scores = compute_raps_scores_enhanced(
        calib_scaled_probs, calib_labels, lambda_reg, k_reg, class_weights
    )
    
    # Adaptive quantile with confidence intervals
    qhat, (qhat_ci_low, qhat_ci_high) = adaptive_raps_quantile(calib_scores, alpha)
    
    # Build prediction sets
    def build_prediction_sets(probs: np.ndarray, q: float) -> List[List[int]]:
        n, num_classes = probs.shape
        sets = []
        for i in range(n):
            p = probs[i] * class_weights
            p = p / p.sum()
            
            order = np.argsort(p)[::-1]
            cum_prob = 0.0
            pred_set = []
            
            for rank, cls_idx in enumerate(order):
                reg = lambda_reg * max(rank - k_reg, 0)
                if cum_prob + p[cls_idx] + reg <= q + 1e-12:
                    cum_prob += p[cls_idx]
                    pred_set.append(cls_idx)
                else:
                    break
            
            if not pred_set:
                pred_set = [int(order[0])]
            
            sets.append(pred_set)
        
        return sets
    
    test_labels = np.array(test_results['labels'])
    pred_sets = build_prediction_sets(scaled_probs, qhat)
    
    # Calculate coverage and efficiency
    coverage = np.mean([test_labels[i] in pred_sets[i] for i in range(len(test_labels))])
    avg_set_size = np.mean([len(s) for s in pred_sets])
    
    print(f"   RAPS Coverage (Œ±={alpha}): {coverage:.3f} (target: {1-alpha:.3f})")
    print(f"   Average Set Size: {avg_set_size:.3f}")
    print(f"   Quantile: {qhat:.4f} (CI: {qhat_ci_low:.4f} - {qhat_ci_high:.4f})")
    
    # Generate advanced visualizations
    print(f"\nüé® Generating Advanced Visualizations...")
    
    try:
        # Select interesting test cases for visualization
        vis_indices = []
        for label in range(num_classes):
            label_indices = [i for i, l in enumerate(test_labels) if l == label]
            if label_indices:
                # Pick case with high uncertainty
                uncertainties = uncertainty_metrics['epistemic_uncertainty'][label_indices]
                high_unc_idx = label_indices[np.argmax(uncertainties)]
                vis_indices.append(high_unc_idx)
        
        vis_indices = vis_indices[:6]  # Limit to 6 visualizations
        
        for idx in vis_indices:
            try:
                x, y = test_ds[idx]
                true_label = int(y)
                pred_set = pred_sets[idx]
                uncertainty = uncertainty_metrics['epistemic_uncertainty'][idx]
                
                # Generate uncertainty-aware Grad-CAM
                images_tensor = x.unsqueeze(0).to(device())
                
                for class_idx in pred_set:
                    try:
                        mean_cam, std_cam = generate_uncertainty_aware_gradcam(
                            model, images_tensor, class_idx, num_mc_samples=10
                        )
                        
                        # Save visualizations
                        clinical_info = ALZHEIMER_CLINICAL_INFO[class_idx]
                        true_clinical = ALZHEIMER_CLINICAL_INFO[true_label]
                        
                        # Convert tensor to image
                        img_np = x.numpy().transpose(1, 2, 0)
                        mean = np.array([0.485, 0.456, 0.406])
                        std = np.array([0.229, 0.224, 0.225])
                        img_vis = np.clip(img_np * std + mean, 0, 1)
                        
                        # Create overlays
                        mean_overlay = show_cam_on_image(img_vis, mean_cam, use_rgb=True)
                        
                        # Uncertainty heatmap
                        uncertainty_norm = (std_cam - std_cam.min()) / (std_cam.ptp() + 1e-8)
                        uncertainty_color = np.zeros_like(img_vis)
                        uncertainty_color[:, :, 0] = uncertainty_norm  # Red channel for uncertainty
                        uncertainty_overlay = (0.7 * img_vis + 0.3 * uncertainty_color) * 255
                        uncertainty_overlay = uncertainty_overlay.astype(np.uint8)
                        
                        # Save images
                        base_name = f"case_{idx}_true_{true_clinical['name'].replace(' ', '')}_pred_{clinical_info['name'].replace(' ', '')}"
                        Image.fromarray(mean_overlay).save(
                            os.path.join(output_dir, f"{base_name}_gradcam.png")
                        )
                        Image.fromarray(uncertainty_overlay).save(
                            os.path.join(output_dir, f"{base_name}_uncertainty.png")
                        )
                        
                    except Exception as e:
                        print(f"Warning: Visualization failed for case {idx}, class {class_idx}: {e}")
                        continue
                        
            except Exception as e:
                print(f"Warning: Failed to process visualization case {idx}: {e}")
                continue
    
    except Exception as e:
        print(f"Warning: Visualization generation failed: {e}")
    
    # Save comprehensive results
    print(f"\nüíæ Saving Results...")
    
    results = {
        # Model and training info
        'model_name': model_name,
        'ensemble_method': ensemble_method,
        'ensemble_size': ensemble_size,
        'epochs': epochs,
        'learning_rate': lr,
        'dropout': dropout,
        'use_mixup': use_mixup,
        
        # Dataset info
        'dataset_size': len(records),
        'train_size': len(train_records),
        'val_size': len(val_records),
        'calib_size': len(calib_records),
        'test_size': len(test_records),
        'class_distribution': label_counts,
        
        # Performance metrics
        'test_accuracy': float(test_results['accuracy']),
        'test_f1_macro': float(test_results['f1_macro']),
        'test_f1_weighted': float(test_results['f1_weighted']),
        'test_loss': float(test_results['loss']),
        
        # Uncertainty metrics
        'uncertainty_metrics': {
            k: {
                'mean': float(np.mean(v)),
                'std': float(np.std(v)),
                'min': float(np.min(v)),
                'max': float(np.max(v))
            } for k, v in uncertainty_metrics.items()
        },
        
        # Conformal prediction
        'conformal_prediction': {
            'alpha': alpha,
            'lambda_reg': lambda_reg,
            'k_reg': k_reg,
            'coverage': float(coverage),
            'average_set_size': float(avg_set_size),
            'quantile': float(qhat),
            'quantile_ci_lower': float(qhat_ci_low),
            'quantile_ci_upper': float(qhat_ci_high)
        },
        
        # Clinical interpretation
        'clinical_info': ALZHEIMER_CLINICAL_INFO
    }
    
    if 'classification_report' in test_results:
        results['classification_report'] = test_results['classification_report']
    
    # Save to JSON
    with open(os.path.join(output_dir, 'comprehensive_results.json'), 'w') as f:
        json.dump(results, f, indent=2)
    
    # Save prediction details
    prediction_details = {
        'test_predictions': [int(p) for p in test_results['predictions']],
        'test_labels': [int(l) for l in test_labels],
        'prediction_sets': [[int(c) for c in ps] for ps in pred_sets],
        'uncertainty_scores': {k: v.tolist() for k, v in uncertainty_metrics.items()},
        'scaled_probabilities': scaled_probs.tolist()
    }
    
    with open(os.path.join(output_dir, 'prediction_details.json'), 'w') as f:
        json.dump(prediction_details, f, indent=2)
    
    print(f"\n‚úÖ Pipeline Complete!")
    print(f"üìÅ Results saved to: {output_dir}")
    print(f"üèÜ Test Accuracy: {test_results['accuracy']:.4f}")
    print(f"üéØ Conformal Coverage: {coverage:.3f} (target: {1-alpha:.3f})")
    print(f"üìä Average Prediction Set Size: {avg_set_size:.3f}")
    print("=" * 80)


# ------------------------------
# CLI
# ------------------------------

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Advanced Alzheimer MRI Analysis with Uncertainty Quantification & Explainability",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    
    # Dataset and I/O
    parser.add_argument('--data_root', type=str, 
                       default=os.environ.get('DATA_ROOT', '/kaggle/input/augmented-alzheimer-mri-dataset'),
                       help='Root directory containing Alzheimer MRI dataset')
    parser.add_argument('--output_dir', type=str, 
                       default=os.environ.get('OUTPUT_DIR', './alzheimer_analysis_outputs'),
                       help='Output directory for results and visualizations')
    
    # Model configuration
    parser.add_argument('--model_name', type=str, 
                       default=os.environ.get('MODEL_NAME', 'tf_efficientnet_b0_ns'),
                       help='Backbone model architecture')
    parser.add_argument('--ensemble_method', type=str, 
                       default=os.environ.get('ENSEMBLE_METHOD', 'batch_ensemble'),
                       choices=['mc_dropout', 'batch_ensemble'],
                       help='Ensemble method for uncertainty quantification')
    parser.add_argument('--ensemble_size', type=int, 
                       default=int(os.environ.get('ENSEMBLE_SIZE', 4)),
                       help='Number of ensemble members')
    
    # Training parameters
    parser.add_argument('--image_size', type=int, default=int(os.environ.get('IMAGE_SIZE', 224)))
    parser.add_argument('--batch_size', type=int, default=int(os.environ.get('BATCH_SIZE', 16)))
    parser.add_argument('--epochs', type=int, default=int(os.environ.get('EPOCHS', 15)))
    parser.add_argument('--lr', type=float, default=float(os.environ.get('LR', 2e-4)))
    parser.add_argument('--weight_decay', type=float, default=float(os.environ.get('WEIGHT_DECAY', 1e-4)))
    parser.add_argument('--dropout', type=float, default=float(os.environ.get('DROPOUT', 0.25)))
    parser.add_argument('--use_mixup', action='store_true', default=True,
                       help='Use mixup augmentation for training')
    
    # Uncertainty and conformal prediction
    parser.add_argument('--alpha', type=float, default=float(os.environ.get('ALPHA', 0.1)),
                       help='Miscoverage rate for conformal prediction sets')
    parser.add_argument('--lambda_reg', type=float, default=float(os.environ.get('LAMBDA_REG', 0.05)),
                       help='RAPS regularization parameter')
    parser.add_argument('--k_reg', type=int, default=int(os.environ.get('K_REG', 1)),
                       help='RAPS rank threshold for regularization')
    parser.add_argument('--mc_samples', type=int, default=int(os.environ.get('MC_SAMPLES', 20)),
                       help='Number of Monte Carlo samples for uncertainty estimation')
    
    # Data splitting
    parser.add_argument('--val_fraction', type=float, default=float(os.environ.get('VAL_FRACTION', 0.2)))
    parser.add_argument('--calib_fraction', type=float, default=float(os.environ.get('CALIB_FRACTION', 0.2)))
    
    # System
    parser.add_argument('--num_workers', type=int, default=int(os.environ.get('NUM_WORKERS', 2)))
    parser.add_argument('--seed', type=int, default=int(os.environ.get('SEED', 42)))
    
    return parser.parse_args()


def run_simple_and_safe(
    data_root: str = '/kaggle/input/augmented-alzheimer-mri-dataset',
    output_dir: str = '/kaggle/working/alzheimer_results'
) -> None:
    """
    Super simple, guaranteed-to-work version for Kaggle
    """
    print("üõ°Ô∏è  Running SIMPLE & SAFE version (guaranteed to work)")
    print("üìö Still includes: RAPS, MC Dropout, Temperature Scaling, Grad-CAM")
    
    run_alzheimer_pipeline(
        data_root=data_root,
        output_dir=output_dir,
        model_name='resnet18',  # Simpler model
        image_size=224,
        batch_size=64,          # Larger batch to avoid small batch issues
        epochs=5,               # Fewer epochs
        lr=1e-3,               # Standard learning rate
        weight_decay=1e-4,
        alpha=0.1,
        lambda_reg=0.0,         # Simplified RAPS
        k_reg=0,
        dropout=0.2,            # Lower dropout
        num_workers=0,          # No multiprocessing
        calib_fraction=0.15,    # Smaller calibration set
        val_fraction=0.15,
        mc_samples=10,          # Fewer MC samples
        seed=42,
        ensemble_method='mc_dropout',
        ensemble_size=1,
        use_mixup=False,        # No mixup for simplicity
    )


def quick_test_run(
    data_root: str = '/kaggle/input/augmented-alzheimer-mri-dataset',
    output_dir: str = '/kaggle/working/alzheimer_results'
) -> None:
    """
    Quick test run with minimal settings to ensure everything works
    """
    print("üß™ Running Quick Test (2 epochs, MC Dropout)")
    print("This will complete in ~5 minutes to verify everything works")
    
    run_alzheimer_pipeline(
        data_root=data_root,
        output_dir=output_dir,
        model_name='tf_efficientnet_b0_ns',
        image_size=224,
        batch_size=32,  # Larger batch for speed
        epochs=2,       # Just 2 epochs for testing
        lr=3e-4,
        weight_decay=1e-4,
        alpha=0.1,
        lambda_reg=0.05,
        k_reg=1,
        dropout=0.25,
        num_workers=2,
        calib_fraction=0.2,
        val_fraction=0.2,
        mc_samples=10,  # Fewer samples for speed
        seed=42,
        ensemble_method='mc_dropout',
        ensemble_size=4,
        use_mixup=False,  # Disable mixup for speed
    )


def run_with_fallback(
    data_root: str = '/kaggle/input/augmented-alzheimer-mri-dataset',
    output_dir: str = '/kaggle/working/alzheimer_results',
    model_name: str = 'tf_efficientnet_b0_ns',
    epochs: int = 10,
    batch_size: int = 16,
    lr: float = 2e-4,
    alpha: float = 0.1,
    mc_samples: int = 15
) -> None:
    """
    Fallback function that tries batch_ensemble first, then mc_dropout if it fails.
    Ensures the pipeline runs successfully in Kaggle.
    """
    print("üî• Running Alzheimer MRI Analysis with Smart Fallback")
    print(f"üìÅ Data: {data_root}")
    print(f"üíæ Output: {output_dir}")
    
    # Try batch_ensemble first
    try:
        print("ü§ñ Attempting Batch Ensemble...")
        run_alzheimer_pipeline(
            data_root=data_root,
            output_dir=output_dir,
            model_name=model_name,
            image_size=224,
            batch_size=batch_size,
            epochs=epochs,
            lr=lr,
            weight_decay=1e-4,
            alpha=alpha,
            lambda_reg=0.05,
            k_reg=1,
            dropout=0.25,
            num_workers=2,
            calib_fraction=0.2,
            val_fraction=0.2,
            mc_samples=mc_samples,
            seed=42,
            ensemble_method='batch_ensemble',
            ensemble_size=4,
            use_mixup=True,
        )
        print("‚úÖ Batch Ensemble completed successfully!")
        
    except Exception as e:
        print(f"‚ö†Ô∏è  Batch Ensemble failed: {e}")
        print("üîÑ Switching to MC Dropout fallback...")
        
        try:
            run_alzheimer_pipeline(
                data_root=data_root,
                output_dir=output_dir,
                model_name=model_name,
                image_size=224,
                batch_size=batch_size,
                epochs=epochs,
                lr=lr,
                weight_decay=1e-4,
                alpha=alpha,
                lambda_reg=0.05,
                k_reg=1,
                dropout=0.25,
                num_workers=2,
                calib_fraction=0.2,
                val_fraction=0.2,
                mc_samples=mc_samples,
                seed=42,
                ensemble_method='mc_dropout',
                ensemble_size=4,
                use_mixup=True,
            )
            print("‚úÖ MC Dropout completed successfully!")
            
        except Exception as e2:
            print(f"‚ùå Both methods failed. Error: {e2}")
            raise e2


def run_with_defaults(
    data_root: str = '/kaggle/input/augmented-alzheimer-mri-dataset',
    output_dir: str = '/kaggle/working/alzheimer_results',
    model_name: str = 'tf_efficientnet_b0_ns',
    ensemble_method: str = 'mc_dropout',  # Changed to mc_dropout for reliability
    ensemble_size: int = 4,
    epochs: int = 10,  # Reduced for faster completion
    batch_size: int = 16,
    lr: float = 2e-4,
    alpha: float = 0.1,
    mc_samples: int = 15
) -> None:
    """
    Run the pipeline with sensible defaults for Kaggle/Colab environments.
    This avoids argparse issues in Jupyter notebooks.
    """
    print("üîç Detected notebook environment - using default parameters")
    print("üî• Running Alzheimer MRI Uncertainty Analysis with Default Parameters")
    print(f"üìÅ Data: {data_root}")
    print(f"üíæ Output: {output_dir}")
    print(f"ü§ñ Method: {ensemble_method}")
    print(f"‚è±Ô∏è  Epochs: {epochs}")
    
    run_alzheimer_pipeline(
        data_root=data_root,
        output_dir=output_dir,
        model_name=model_name,
        image_size=224,
        batch_size=batch_size,
        epochs=epochs,
        lr=lr,
        weight_decay=1e-4,
        alpha=alpha,
        lambda_reg=0.05,
        k_reg=1,
        dropout=0.25,
        num_workers=2,
        calib_fraction=0.2,
        val_fraction=0.2,
        mc_samples=mc_samples,
        seed=42,
        ensemble_method=ensemble_method,
        ensemble_size=ensemble_size,
        use_mixup=True,
    )


def is_notebook_environment() -> bool:
    """Detect if running in Jupyter/Colab/Kaggle notebook"""
    try:
        from IPython import get_ipython
        return get_ipython() is not None
    except ImportError:
        return False


def main() -> None:
    # Check if running in notebook environment
    if is_notebook_environment():
        print("üîç Detected notebook environment - using SIMPLE & SAFE version")
        print("üìö This includes all key methods: RAPS, MC Dropout, Temperature Scaling, Grad-CAM")
        run_simple_and_safe()
        return
    
    # Standard command-line execution
    args = parse_args()
    
    # Validate data root
    if not exists(args.data_root):
        print(f"‚ùå Data root not found: {args.data_root}")
        print(f"")
        print(f"For Kaggle:")
        print(f"  1. Add dataset: 'uraninjo/augmented-alzheimer-mri-dataset'")
        print(f"  2. Set data_root: '/kaggle/input/augmented-alzheimer-mri-dataset'")
        print(f"")
        print(f"Expected folder structure:")
        print(f"  {args.data_root}/")
        print(f"  ‚îú‚îÄ‚îÄ NonDemented/")
        print(f"  ‚îú‚îÄ‚îÄ VeryMildDemented/")
        print(f"  ‚îú‚îÄ‚îÄ MildDemented/")
        print(f"  ‚îî‚îÄ‚îÄ ModerateDemented/")
        sys.exit(1)
    
    # Check dependencies
    missing = []
    if timm is None:
        missing.append('timm')
    if GradCAM is None:
        missing.append('pytorch-grad-cam')
    if train_test_split is None:
        missing.append('scikit-learn')
    
    if missing:
        print(f"‚ùå Missing packages: {missing}")
        print(f"Install with: pip install {' '.join(missing)}")
        sys.exit(1)
    
    # Run the pipeline
    run_alzheimer_pipeline(
        data_root=args.data_root,
        output_dir=args.output_dir,
        model_name=args.model_name,
        image_size=args.image_size,
        batch_size=args.batch_size,
        epochs=args.epochs,
        lr=args.lr,
        weight_decay=args.weight_decay,
        alpha=args.alpha,
        lambda_reg=args.lambda_reg,
        k_reg=args.k_reg,
        dropout=args.dropout,
        num_workers=args.num_workers,
        calib_fraction=args.calib_fraction,
        val_fraction=args.val_fraction,
        mc_samples=args.mc_samples,
        seed=args.seed,
        ensemble_method=args.ensemble_method,
        ensemble_size=args.ensemble_size,
        use_mixup=args.use_mixup,
    )


if __name__ == '__main__':
    main()
