In [None]:
# %%
# Instalar Hugging Face Transformers
!pip install transformers datasets pillow -q

print("✓ Transformers instalado")

In [None]:

# %%
from transformers import CvtConfig, CvtModel

print("Verificando modelos CvT en Hugging Face...\n")

cvt_models = [
    'microsoft/cvt-13',    # 20M params - RECOMENDADO
    'microsoft/cvt-21',    # 32M params
    'microsoft/cvt-w24'    # 277M params - MUY GRANDE
]

for model_name in cvt_models:
    try:
        config = CvtConfig.from_pretrained(model_name)
        print(f"✓ {model_name}")
        print(f"  Depth: {config.depth}")
        print(f"  Embed dims: {config.embed_dim}")
        print()
    except Exception as e:
        print(f"✗ {model_name} - Error\n")

print("💡 Recomendado: microsoft/cvt-13")

In [None]:
# %%
# Core imports
import os
import sys
import json
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
import gc
import time
from datetime import datetime

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.cuda.amp import autocast, GradScaler
from torch.optim.swa_utils import AveragedModel, SWALR

# Torchvision
from torchvision import transforms
from PIL import Image

# Hugging Face Transformers
from transformers import CvtModel, CvtConfig

# Sklearn
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    f1_score, precision_score, recall_score,
    confusion_matrix, classification_report,
    accuracy_score
)

print("✓ Imports loaded successfully (CvT_TL_2 with Advanced Techniques)")
print(f"  PyTorch version: {torch.__version__}")

In [None]:
# %%
# GPU verification
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"CUDA Version: {torch.version.cuda}")

    # Clear GPU cache
    torch.cuda.empty_cache()
    gc.collect()
    print("✓ GPU cache cleared")
else:
    print("⚠️ WARNING: No GPU available, training will be slow!")

In [None]:
# %%
# Configuration para CvT_TL_2 (512x512) - MEJORADO
CONFIG = {
    # Paths
    'base_path': '/home/merivadeneira',
    'ddsm_benign_path': '/home/merivadeneira/Masas/DDSM/Benignas/Resized_512',
    'ddsm_malign_path': '/home/merivadeneira/Masas/DDSM/Malignas/Resized_512',
    'inbreast_benign_path': '/home/merivadeneira/Masas/INbreast/Benignas/Resized_512',
    'inbreast_malign_path': '/home/merivadeneira/Masas/INbreast/Malignas/Resized_512',
    'output_dir': '/home/merivadeneira/Outputs/CvT',
    'metrics_dir': '/home/merivadeneira/Metrics/CvT',

    # Model (HUGGING FACE)
    'model_name': 'CvT_TL_2',
    'pretrained_model': 'microsoft/cvt-13',
    'pretrained': True,
    'input_size': 512,
    'in_channels': 3,
    'num_classes': 2,

    # Fine-tuning strategy
    'freeze_strategy': 'none',

    # Training
    'batch_size': 16,
    'num_epochs': 100,
    'num_folds': 5,
    'early_stopping_patience': 25,
    'min_delta': 1e-4,
    'gradient_accumulation_steps': 2,  # NUEVO: batch efectivo = 32

    # Optimizer
    'optimizer': 'AdamW',
    'lr_initial': None,
    'lr_min': 1e-7,
    'lr_max': 1e-3,
    'weight_decay': 0.01,
    'betas': (0.9, 0.999),

    # Warmup
    'warmup_epochs': 5,
    'warmup_start_factor': 0.1,

    # Scheduler - NUEVO: CosineAnnealingWarmRestarts
    'scheduler': 'CosineAnnealingWarmRestarts',
    'cosine_T0': 10,        # Reinicia cada 10 épocas
    'cosine_T_mult': 2,     # Duplica el periodo
    'cosine_eta_min': 1e-7,

    # Focal Loss - NUEVO
    'use_focal_loss': True,
    'focal_alpha': 0.25,    # Peso para clase positiva (maligno)
    'focal_gamma': 2.0,     # Factor de enfoque

    # Class Weights - NUEVO
    'use_class_weights': True,

    # Mixup - NUEVO
    'use_mixup': True,
    'mixup_alpha': 0.2,
    'mixup_prob': 0.5,      # 50% de batches con mixup

    # SWA - NUEVO
    'use_swa': True,
    'swa_start_epoch': 70,  # Empezar SWA después del 70% del entrenamiento
    'swa_lr': 1e-5,

    # TTA - NUEVO
    'use_tta': True,
    'tta_augmentations': 5,

    # Label smoothing & gradient clipping
    'label_smoothing': 0.1,
    'gradient_clip_norm': 1.0,

    # Data augmentation - MEJORADO (SIN SCALE)
    'horizontal_flip': 0.5,
    'vertical_flip': 0.3,
    'rotation_degrees': 30,        # 20 → 30
    'translate': 0.20,             # 0.15 → 0.20
    'shear': 10,
    'brightness': 0.3,             # 0.2 → 0.3
    'contrast': 0.3,               # 0.2 → 0.3
    'random_erasing_p': 0.3,       # 0.2 → 0.3

    # Normalization (ImageNet)
    'mean': [0.485, 0.456, 0.406],
    'std': [0.229, 0.224, 0.225],

    # Mixed Precision & Memory
    'use_amp': True,
    'num_workers': 4,
    'pin_memory': True,

    # Reproducibility
    'seed': 42
}

# Create directories
os.makedirs(CONFIG['output_dir'], exist_ok=True)
os.makedirs(CONFIG['metrics_dir'], exist_ok=True)

# Save configuration
config_path = os.path.join(CONFIG['metrics_dir'], f"{CONFIG['model_name']}_config.json")
with open(config_path, 'w') as f:
    json.dump(CONFIG, f, indent=4)

print(f"✓ Configuration saved to: {config_path}")
print(f"\n📊 Model: {CONFIG['model_name']}")
print(f"🏗️ Base model: {CONFIG['pretrained_model']} (Hugging Face)")
print(f"📦 Batch size: {CONFIG['batch_size']} (effective: {CONFIG['batch_size'] * CONFIG['gradient_accumulation_steps']})")
print(f"🔢 Input size: {CONFIG['input_size']}x{CONFIG['input_size']}")
print(f"🎯 Classes: {CONFIG['num_classes']} (Benign vs Malignant)")
print(f"🔄 Folds: {CONFIG['num_folds']}")
print(f"🔥 Warmup epochs: {CONFIG['warmup_epochs']}")
print(f"⚡ Mixed Precision: {CONFIG['use_amp']}")
print(f"🎯 Focal Loss: {CONFIG['use_focal_loss']} (alpha={CONFIG['focal_alpha']}, gamma={CONFIG['focal_gamma']})")
print(f"⚖️ Class Weights: {CONFIG['use_class_weights']}")
print(f"🎨 Mixup: {CONFIG['use_mixup']} (alpha={CONFIG['mixup_alpha']})")
print(f"📊 SWA: {CONFIG['use_swa']} (start epoch={CONFIG['swa_start_epoch']})")
print(f"🔍 TTA: {CONFIG['use_tta']} (augmentations={CONFIG['tta_augmentations']})")
print(f"🔄 Gradient Accumulation: {CONFIG['gradient_accumulation_steps']} steps")
print(f"📈 Scheduler: {CONFIG['scheduler']}")


In [None]:
# %%
# Set seed for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(CONFIG['seed'])
print(f"✓ Random seed set to {CONFIG['seed']}")

In [None]:
# %%
def extract_patient_id(filename, dataset='ddsm'):
    """
    Extract patient ID from filename to prevent data leakage.

    DDSM format: P_00041_LEFT_CC_1.png -> P_00041
    INbreast format: 20586908_6c613a14b80a8591_MG_R_CC_ANON_lesion1_ROI.png -> 20586908
    """
    if dataset == 'ddsm':
        parts = filename.split('_')
        if len(parts) >= 2:
            return f"{parts[0]}_{parts[1]}"
    elif dataset == 'inbreast':
        return filename.split('_')[0]

    return filename

# Test
print("Testing patient ID extraction:")
print(f"DDSM: P_00041_LEFT_CC_1.png -> {extract_patient_id('P_00041_LEFT_CC_1.png', 'ddsm')}")
print(f"INbreast: 20586908_6c613a14b80a8591_MG_R_CC_ANON_lesion1_ROI.png -> {extract_patient_id('20586908_6c613a14b80a8591_MG_R_CC_ANON_lesion1_ROI.png', 'inbreast')}")

In [None]:
# %%
# NUEVO: Focal Loss Implementation
class FocalLoss(nn.Module):
    """
    Focal Loss para manejar desbalance de clases y enfocarse en ejemplos difíciles.
    """
    def __init__(self, alpha=0.25, gamma=2.0, label_smoothing=0.1):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.label_smoothing = label_smoothing

    def forward(self, inputs, targets):
        # Aplicar label smoothing
        num_classes = inputs.size(-1)
        targets_one_hot = F.one_hot(targets, num_classes=num_classes).float()
        
        if self.label_smoothing > 0:
            targets_one_hot = targets_one_hot * (1 - self.label_smoothing) + \
                             self.label_smoothing / num_classes

        # Calcular probabilidades
        p = F.softmax(inputs, dim=-1)
        ce_loss = -targets_one_hot * torch.log(p + 1e-8)

        # Aplicar focal loss
        p_t = (targets_one_hot * p).sum(dim=-1)
        focal_weight = (1 - p_t) ** self.gamma

        # Aplicar alpha weight
        if self.alpha is not None:
            alpha_t = targets * self.alpha + (1 - targets) * (1 - self.alpha)
            focal_loss = alpha_t * focal_weight * ce_loss.sum(dim=-1)
        else:
            focal_loss = focal_weight * ce_loss.sum(dim=-1)

        return focal_loss.mean()

print("✓ Focal Loss implemented")

In [None]:
# %%
# NUEVO: Mixup Implementation
def mixup_data(x, y, alpha=0.2):
    """
    Mixup augmentation: mezcla pares de imágenes y sus labels.
    """
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """
    Calcula loss para mixup.
    """
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

print("✓ Mixup implemented")

In [None]:
# %%
class MammographyDataset(Dataset):
    """Custom dataset for mammography images (converts grayscale to RGB)."""

    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image as grayscale
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('L')

        # Convert to RGB by duplicating the channel
        image = image.convert('RGB')

        # Apply transforms
        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]

        return image, label

print("✓ Dataset class defined (with grayscale → RGB conversion)")

In [None]:
# %%
def get_transforms(train=True):
    """Get data augmentation transforms - MEJORADO (SIN SCALE)."""

    if train:
        transform = transforms.Compose([
            # NO RESIZE - mantiene 512x512 original
            transforms.RandomHorizontalFlip(p=CONFIG['horizontal_flip']),
            transforms.RandomVerticalFlip(p=CONFIG['vertical_flip']),
            transforms.RandomAffine(
                degrees=CONFIG['rotation_degrees'],
                translate=(CONFIG['translate'], CONFIG['translate']),
                shear=CONFIG['shear'],
                # SIN SCALE - removido
            ),
            transforms.ColorJitter(
                brightness=CONFIG['brightness'],
                contrast=CONFIG['contrast']
            ),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=CONFIG['mean'],
                std=CONFIG['std']
            ),
            transforms.RandomErasing(
                p=CONFIG['random_erasing_p'],
                scale=(0.02, 0.33),
                ratio=(0.3, 3.3)
            )
        ])
    else:
        transform = transforms.Compose([
            # NO RESIZE - mantiene 512x512 original
            transforms.ToTensor(),
            transforms.Normalize(
                mean=CONFIG['mean'],
                std=CONFIG['std']
            )
        ])

    return transform

print("✓ Data augmentation transforms defined (IMPROVED - NO SCALE)")

In [None]:
# %%
def load_dataset():
    """Load mammography dataset from all sources."""

    print("\n📂 Loading dataset...")

    image_paths = []
    labels = []
    patient_ids = []

    # DDSM Benign
    ddsm_benign_files = sorted(os.listdir(CONFIG['ddsm_benign_path']))
    for filename in ddsm_benign_files:
        if filename.endswith('.png'):
            image_paths.append(os.path.join(CONFIG['ddsm_benign_path'], filename))
            labels.append(0)
            patient_ids.append(extract_patient_id(filename, 'ddsm'))

    # DDSM Malignant
    ddsm_malign_files = sorted(os.listdir(CONFIG['ddsm_malign_path']))
    for filename in ddsm_malign_files:
        if filename.endswith('.png'):
            image_paths.append(os.path.join(CONFIG['ddsm_malign_path'], filename))
            labels.append(1)
            patient_ids.append(extract_patient_id(filename, 'ddsm'))

    # INbreast Benign
    inbreast_benign_files = sorted(os.listdir(CONFIG['inbreast_benign_path']))
    for filename in inbreast_benign_files:
        if filename.endswith('.png'):
            image_paths.append(os.path.join(CONFIG['inbreast_benign_path'], filename))
            labels.append(0)
            patient_ids.append(extract_patient_id(filename, 'inbreast'))

    # INbreast Malignant
    inbreast_malign_files = sorted(os.listdir(CONFIG['inbreast_malign_path']))
    for filename in inbreast_malign_files:
        if filename.endswith('.png'):
            image_paths.append(os.path.join(CONFIG['inbreast_malign_path'], filename))
            labels.append(1)
            patient_ids.append(extract_patient_id(filename, 'inbreast'))

    image_paths = np.array(image_paths)
    labels = np.array(labels)
    patient_ids = np.array(patient_ids)

    print(f"\n✓ Dataset loaded:")
    print(f"  Total images: {len(image_paths)}")
    print(f"  Benign (0): {np.sum(labels == 0)} ({np.sum(labels == 0)/len(labels)*100:.1f}%)")
    print(f"  Malignant (1): {np.sum(labels == 1)} ({np.sum(labels == 1)/len(labels)*100:.1f}%)")
    print(f"  Unique patients: {len(np.unique(patient_ids))}")

    return image_paths, labels, patient_ids

# Load dataset
image_paths, labels, patient_ids = load_dataset()

In [None]:
# %%
# NUEVO: Calculate class weights
def calculate_class_weights(labels):
    """Calculate class weights for imbalanced dataset."""
    class_counts = np.bincount(labels)
    class_weights = len(labels) / (len(class_counts) * class_counts)
    class_weights = torch.FloatTensor(class_weights).to(device)
    
    print(f"\n⚖️ Class weights calculated:")
    print(f"  Benign (0): {class_weights[0]:.4f}")
    print(f"  Malignant (1): {class_weights[1]:.4f}")
    
    return class_weights

if CONFIG['use_class_weights']:
    class_weights = calculate_class_weights(labels)
else:
    class_weights = None

In [None]:
# %%
class CvTClassifier(nn.Module):
    """
    CvT-based classifier for mammography using Hugging Face.
    Mejorado con Dropout más alto.
    """

    def __init__(self, model_name='microsoft/cvt-13', num_classes=2, pretrained=True):
        super(CvTClassifier, self).__init__()

        # Load pretrained CvT model from Hugging Face
        if pretrained:
            self.cvt = CvtModel.from_pretrained(model_name)
            print(f"✓ Loaded pretrained weights from {model_name}")
        else:
            config = CvtConfig.from_pretrained(model_name)
            self.cvt = CvtModel(config)
            print(f"✓ Initialized CvT from config (no pretrained weights)")

        # Get feature dimension
        config = self.cvt.config
        self.feature_dim = config.embed_dim[-1]

        # Classification head con Dropout aumentado
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(self.feature_dim, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(0.5),  # 0.3 → 0.5
            nn.Linear(512, num_classes)
        )

        # Initialize classifier
        self._init_classifier()

    def _init_classifier(self):
        """Initialize classifier weights."""
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # CvT forward
        outputs = self.cvt(x)
        features = outputs.last_hidden_state

        # Classification
        logits = self.classifier(features)

        return logits

print("✓ CvTClassifier defined (IMPROVED with higher dropout)")

In [None]:
# %%
class EarlyStopping:
    """Early stopping basado en F1-Score (MEJORADO)."""

    def __init__(self, patience=25, min_delta=1e-4, mode='max'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_epoch = 0

    def __call__(self, score, epoch):
        if self.best_score is None:
            self.best_score = score
            self.best_epoch = epoch
        elif self.mode == 'min':
            if score < self.best_score - self.min_delta:
                self.best_score = score
                self.best_epoch = epoch
                self.counter = 0
            else:
                self.counter += 1
        else:  # mode == 'max'
            if score > self.best_score + self.min_delta:
                self.best_score = score
                self.best_epoch = epoch
                self.counter = 0
            else:
                self.counter += 1

        if self.counter >= self.patience:
            self.early_stop = True

print("✓ EarlyStopping defined (monitoring F1-Score)")

In [None]:
# %%
def calculate_metrics(y_true, y_pred):
    """Calculate comprehensive metrics."""

    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='binary', zero_division=0)
    recall = recall_score(y_true, y_pred, average='binary', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='binary', zero_division=0)

    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    tn, fp, fn, tp = cm.ravel()

    # Specificity
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'specificity': specificity,
        'confusion_matrix': cm,
        'tn': tn,
        'fp': fp,
        'fn': fn,
        'tp': tp
    }

print("✓ calculate_metrics function defined")

In [None]:
# %%
def train_epoch(model, loader, criterion, optimizer, device, scaler, use_mixup=False, 
                accumulation_steps=1, swa_model=None, use_swa=False):
    """Train one epoch with Mixup and Gradient Accumulation."""

    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    optimizer.zero_grad()

    pbar = tqdm(enumerate(loader), total=len(loader), desc="Training")

    for batch_idx, (images, labels) in pbar:
        images = images.to(device)
        labels = labels.to(device)

        # Mixup
        if use_mixup and np.random.rand() < CONFIG['mixup_prob']:
            images, labels_a, labels_b, lam = mixup_data(images, labels, CONFIG['mixup_alpha'])

            with autocast(enabled=CONFIG['use_amp']):
                outputs = model(images)
                loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
        else:
            with autocast(enabled=CONFIG['use_amp']):
                outputs = model(images)
                loss = criterion(outputs, labels)

        # Normalize loss by accumulation steps
        loss = loss / accumulation_steps

        # Backward
        scaler.scale(loss).backward()

        # Gradient accumulation
        if (batch_idx + 1) % accumulation_steps == 0:
            # Gradient clipping
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['gradient_clip_norm'])

            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        # Statistics
        running_loss += loss.item() * accumulation_steps * images.size(0)

        if not use_mixup or np.random.rand() >= CONFIG['mixup_prob']:
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        # Update progress bar
        pbar.set_postfix({
            'loss': f"{loss.item() * accumulation_steps:.4f}",
            'acc': f"{100. * correct / total if total > 0 else 0:.2f}%"
        })

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = correct / total if total > 0 else 0

    return epoch_loss, epoch_acc

print("✓ train_epoch function defined (with Mixup and Gradient Accumulation)")

In [None]:
# %%
def validate_epoch(model, loader, criterion, device, use_tta=False):
    """Validate one epoch with optional TTA."""

    model.eval()
    running_loss = 0.0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        pbar = tqdm(loader, desc="Validation")

        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)

            if use_tta and CONFIG['use_tta']:
                # Test-Time Augmentation
                outputs_list = []
                
                # Original
                outputs = model(images)
                outputs_list.append(outputs)

                # Augmented versions
                for _ in range(CONFIG['tta_augmentations'] - 1):
                    # Random horizontal flip
                    if np.random.rand() > 0.5:
                        aug_images = torch.flip(images, dims=[3])
                    else:
                        aug_images = images
                    
                    outputs_aug = model(aug_images)
                    outputs_list.append(outputs_aug)

                # Average predictions
                outputs = torch.stack(outputs_list).mean(dim=0)
            else:
                outputs = model(images)

            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)

            _, predicted = outputs.max(1)
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(labels.cpu().numpy())

            # Update progress bar
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = accuracy_score(all_targets, all_preds)

    return epoch_loss, epoch_acc, all_preds, all_targets

print("✓ validate_epoch function defined (with TTA support)")

In [None]:
# %%
def train_with_cross_validation():
    """Train model with cross-validation and advanced techniques."""

    # Patient-based stratified k-fold
    unique_patients = np.unique(patient_ids)
    patient_labels = np.array([labels[patient_ids == p][0] for p in unique_patients])

    skf = StratifiedKFold(n_splits=CONFIG['num_folds'], shuffle=True, random_state=CONFIG['seed'])

    fold_metrics = []
    fold_histories = []
    fold_cms = []

    for fold, (train_patient_idx, val_patient_idx) in enumerate(skf.split(unique_patients, patient_labels)):
        print(f"\n{'='*80}")
        print(f"FOLD {fold + 1}/{CONFIG['num_folds']}")
        print(f"{'='*80}")

        # Get train/val patient IDs
        train_patients = unique_patients[train_patient_idx]
        val_patients = unique_patients[val_patient_idx]

        # Get train/val indices
        train_idx = np.isin(patient_ids, train_patients)
        val_idx = np.isin(patient_ids, val_patients)

        train_images = image_paths[train_idx]
        train_labels = labels[train_idx]
        val_images = image_paths[val_idx]
        val_labels = labels[val_idx]

        print(f"\n📊 Fold {fold + 1} split:")
        print(f"  Train: {len(train_images)} images from {len(train_patients)} patients")
        print(f"  Val: {len(val_images)} images from {len(val_patients)} patients")
        print(f"  Train - Benign: {np.sum(train_labels == 0)}, Malignant: {np.sum(train_labels == 1)}")
        print(f"  Val - Benign: {np.sum(val_labels == 0)}, Malignant: {np.sum(val_labels == 1)}")

        # Create datasets
        train_transform = get_transforms(train=True)
        val_transform = get_transforms(train=False)

        train_dataset = MammographyDataset(train_images, train_labels, train_transform)
        val_dataset = MammographyDataset(val_images, val_labels, val_transform)

        # Create dataloaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=CONFIG['batch_size'],
            shuffle=True,
            num_workers=CONFIG['num_workers'],
            pin_memory=CONFIG['pin_memory']
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=CONFIG['batch_size'],
            shuffle=False,
            num_workers=CONFIG['num_workers'],
            pin_memory=CONFIG['pin_memory']
        )

        # Initialize model
        model = CvTClassifier(
            model_name=CONFIG['pretrained_model'],
            num_classes=CONFIG['num_classes'],
            pretrained=CONFIG['pretrained']
        ).to(device)

        # Loss function
        if CONFIG['use_focal_loss']:
            criterion = FocalLoss(
                alpha=CONFIG['focal_alpha'],
                gamma=CONFIG['focal_gamma'],
                label_smoothing=CONFIG['label_smoothing']
            )
            print(f"\n🎯 Using Focal Loss (alpha={CONFIG['focal_alpha']}, gamma={CONFIG['focal_gamma']})")
        else:
            criterion = nn.CrossEntropyLoss(
                weight=class_weights if CONFIG['use_class_weights'] else None,
                label_smoothing=CONFIG['label_smoothing']
            )

        # Optimizer
        optimizer = AdamW(
            model.parameters(),
            lr=CONFIG['lr_max'],
            weight_decay=CONFIG['weight_decay'],
            betas=CONFIG['betas']
        )

        # Warmup scheduler
        warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer,
            start_factor=CONFIG['warmup_start_factor'],
            end_factor=1.0,
            total_iters=CONFIG['warmup_epochs']
        )

        # Main scheduler
        main_scheduler = CosineAnnealingWarmRestarts(
            optimizer,
            T_0=CONFIG['cosine_T0'],
            T_mult=CONFIG['cosine_T_mult'],
            eta_min=CONFIG['cosine_eta_min']
        )

        # SWA
        if CONFIG['use_swa']:
            swa_model = AveragedModel(model)
            swa_scheduler = SWALR(optimizer, swa_lr=CONFIG['swa_lr'])
            print(f"📊 SWA enabled (start epoch: {CONFIG['swa_start_epoch']})")
        else:
            swa_model = None
            swa_scheduler = None

        # Mixed precision scaler
        scaler = GradScaler(enabled=CONFIG['use_amp'])

        # Early stopping (based on F1-Score)
        early_stopping = EarlyStopping(
            patience=CONFIG['early_stopping_patience'],
            min_delta=CONFIG['min_delta'],
            mode='max'  # Maximize F1-Score
        )

        # Training history
        history = {
            'train_loss': [],
            'train_acc': [],
            'val_loss': [],
            'val_acc': [],
            'val_f1': [],
            'lr': []
        }

        best_f1_score = 0
        best_epoch = 0

        print(f"\n🚀 Starting training for fold {fold + 1}...")
        start_time = time.time()

        for epoch in range(CONFIG['num_epochs']):
            print(f"\n{'─'*80}")
            print(f"Epoch {epoch + 1}/{CONFIG['num_epochs']}")
            print(f"{'─'*80}")

            # Determine if using SWA this epoch
            use_swa_this_epoch = CONFIG['use_swa'] and epoch >= CONFIG['swa_start_epoch']

            # Train
            train_loss, train_acc = train_epoch(
                model, train_loader, criterion, optimizer, device, scaler,
                use_mixup=CONFIG['use_mixup'],
                accumulation_steps=CONFIG['gradient_accumulation_steps'],
                swa_model=swa_model if use_swa_this_epoch else None,
                use_swa=use_swa_this_epoch
            )

            # Validate
            val_loss, val_acc, val_preds, val_targets = validate_epoch(
                model, val_loader, criterion, device, use_tta=False
            )

            # Calculate F1-Score for early stopping
            val_f1 = f1_score(val_targets, val_preds, average='binary')

            # Update SWA
            if use_swa_this_epoch:
                swa_model.update_parameters(model)
                swa_scheduler.step()

            # Update scheduler
            current_lr = optimizer.param_groups[0]['lr']

            if epoch < CONFIG['warmup_epochs']:
                warmup_scheduler.step()
            else:
                if CONFIG['use_swa'] and epoch >= CONFIG['swa_start_epoch']:
                    # SWA scheduler already stepped
                    pass
                else:
                    main_scheduler.step()

            # Store history
            history['train_loss'].append(train_loss)
            history['train_acc'].append(train_acc)
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)
            history['val_f1'].append(val_f1)
            history['lr'].append(current_lr)

            # Print epoch summary
            phase = "Warmup" if epoch < CONFIG['warmup_epochs'] else ("SWA" if use_swa_this_epoch else "Main")
            print(f"\n📊 Epoch {epoch + 1} Summary ({phase}):")
            print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
            print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}% | Val F1: {val_f1:.4f}")
            print(f"  LR: {current_lr:.2e}")

            # Save best model (based on F1-Score)
            if val_f1 > best_f1_score:
                best_f1_score = val_f1
                best_epoch = epoch + 1

                model_path = os.path.join(
                    CONFIG['output_dir'],
                    f"{CONFIG['model_name']}_fold{fold}.pth"
                )
                
                # Save SWA model if available, otherwise regular model
                model_to_save = swa_model if use_swa_this_epoch else model
                
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': model_to_save.state_dict() if use_swa_this_epoch else model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_loss': val_loss,
                    'val_acc': val_acc,
                    'val_f1': val_f1,
                    'config': CONFIG
                }, model_path)

                print(f"  ✓ Model saved: {model_path} (F1: {val_f1:.4f})")

            # Early stopping (only after warmup)
            if epoch >= CONFIG['warmup_epochs']:
                early_stopping(val_f1, epoch + 1)
                if early_stopping.early_stop:
                    print(f"\n⏹️ Early stopping triggered at epoch {epoch + 1}")
                    print(f"  Best epoch: {best_epoch} with F1-Score: {best_f1_score:.4f}")
                    break

            # Clear cache
            torch.cuda.empty_cache()

        training_time = time.time() - start_time
        print(f"\n⏱️ Training time for fold {fold + 1}: {training_time/60:.2f} minutes")

        # Load best model for evaluation
        model_path = os.path.join(CONFIG['output_dir'], f"{CONFIG['model_name']}_fold{fold}.pth")
        checkpoint = torch.load(model_path)
        
        # Load to SWA model if used
        if CONFIG['use_swa']:
            # Update BN statistics
            torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
            swa_model.load_state_dict(checkpoint['model_state_dict'])
            eval_model = swa_model
        else:
            model.load_state_dict(checkpoint['model_state_dict'])
            eval_model = model

        # Final evaluation with TTA
        print(f"\n📊 Final evaluation on validation set (with TTA)...")
        _, _, final_preds, final_targets = validate_epoch(
            eval_model, val_loader, criterion, device, use_tta=True
        )

        # Calculate metrics
        metrics = calculate_metrics(final_targets, final_preds)

        print(f"\n✅ Fold {fold + 1} Results:")
        print(f"  Accuracy: {metrics['accuracy']*100:.2f}%")
        print(f"  Precision: {metrics['precision']*100:.2f}%")
        print(f"  Recall: {metrics['recall']*100:.2f}%")
        print(f"  F1-Score: {metrics['f1']*100:.2f}%")
        print(f"  Specificity: {metrics['specificity']*100:.2f}%")

        # Store results
        fold_metrics.append(metrics)
        fold_histories.append(history)
        fold_cms.append(metrics['confusion_matrix'])

        # Plot training history
        plot_training_history(history, fold)

        # Plot confusion matrix
        plot_confusion_matrix(metrics['confusion_matrix'], fold)

        # Clean up
        del model, optimizer, train_loader, val_loader
        if CONFIG['use_swa']:
            del swa_model
        torch.cuda.empty_cache()
        gc.collect()

    return fold_metrics, fold_histories, fold_cms

print("✓ train_with_cross_validation function defined")

In [None]:
# %%
def plot_training_history(history, fold):
    """Plot training history with warmup and SWA indicators."""

    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    # Loss
    axes[0, 0].plot(history['train_loss'], label='Train Loss', marker='o', markersize=3)
    axes[0, 0].plot(history['val_loss'], label='Val Loss', marker='s', markersize=3)
    axes[0, 0].axvline(x=CONFIG['warmup_epochs'], color='r', linestyle='--', alpha=0.5, label='End of Warmup')
    if CONFIG['use_swa']:
        axes[0, 0].axvline(x=CONFIG['swa_start_epoch'], color='g', linestyle='--', alpha=0.5, label='SWA Start')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title(f'Fold {fold + 1}: Loss (CvT_TL_2)')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Accuracy
    axes[0, 1].plot([acc*100 for acc in history['train_acc']], label='Train Acc', marker='o', markersize=3)
    axes[0, 1].plot([acc*100 for acc in history['val_acc']], label='Val Acc', marker='s', markersize=3)
    axes[0, 1].axvline(x=CONFIG['warmup_epochs'], color='r', linestyle='--', alpha=0.5, label='End of Warmup')
    if CONFIG['use_swa']:
        axes[0, 1].axvline(x=CONFIG['swa_start_epoch'], color='g', linestyle='--', alpha=0.5, label='SWA Start')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy (%)')
    axes[0, 1].set_title(f'Fold {fold + 1}: Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # F1-Score
    axes[1, 0].plot([f1*100 for f1 in history['val_f1']], label='Val F1-Score', marker='d', markersize=3, color='purple')
    axes[1, 0].axvline(x=CONFIG['warmup_epochs'], color='r', linestyle='--', alpha=0.5, label='End of Warmup')
    if CONFIG['use_swa']:
        axes[1, 0].axvline(x=CONFIG['swa_start_epoch'], color='g', linestyle='--', alpha=0.5, label='SWA Start')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('F1-Score (%)')
    axes[1, 0].set_title(f'Fold {fold + 1}: F1-Score')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

    # Learning Rate
    axes[1, 1].plot(history['lr'], marker='o', markersize=3)
    axes[1, 1].axvline(x=CONFIG['warmup_epochs'], color='r', linestyle='--', alpha=0.5, label='End of Warmup')
    if CONFIG['use_swa']:
        axes[1, 1].axvline(x=CONFIG['swa_start_epoch'], color='g', linestyle='--', alpha=0.5, label='SWA Start')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Learning Rate')
    axes[1, 1].set_title(f'Fold {fold + 1}: Learning Rate (Cosine Annealing)')
    axes[1, 1].set_yscale('log')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()

    plot_path = os.path.join(
        CONFIG['metrics_dir'],
        f"{CONFIG['model_name']}_fold{fold}_history.png"
    )
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.show()

    print(f"  History plot saved: {plot_path}")

def plot_confusion_matrix(cm, fold, normalize=False):
    """Plot confusion matrix."""

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm,
        annot=True,
        fmt='.2f' if normalize else 'd',
        cmap='Blues',
        xticklabels=['Benign', 'Malignant'],
        yticklabels=['Benign', 'Malignant'],
        cbar_kws={'label': 'Count'}
    )
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.title(f'Fold {fold + 1}: Confusion Matrix')

    plot_path = os.path.join(
        CONFIG['metrics_dir'],
        f"{CONFIG['model_name']}_fold{fold}_cm.png"
    )
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.show()

    print(f"  Confusion matrix saved: {plot_path}")

def save_metrics_summary(fold_metrics):
    """Save metrics summary to CSV."""

    metrics_dict = {
        'Fold': [],
        'Accuracy': [],
        'Precision': [],
        'Recall': [],
        'F1-Score': [],
        'Specificity': []
    }

    for fold, metrics in enumerate(fold_metrics):
        metrics_dict['Fold'].append(fold + 1)
        metrics_dict['Accuracy'].append(metrics['accuracy'])
        metrics_dict['Precision'].append(metrics['precision'])
        metrics_dict['Recall'].append(metrics['recall'])
        metrics_dict['F1-Score'].append(metrics['f1'])
        metrics_dict['Specificity'].append(metrics['specificity'])

    # Calculate mean and std
    for key in ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'Specificity']:
        values = metrics_dict[key]
        metrics_dict['Fold'].append('Mean ± Std')
        mean_val = np.mean(values)
        std_val = np.std(values)
        metrics_dict[key].append(f"{mean_val:.4f} ± {std_val:.4f}")
        break

    for key in ['Precision', 'Recall', 'F1-Score', 'Specificity']:
        values = [m for m in metrics_dict[key] if isinstance(m, float)]
        mean_val = np.mean(values)
        std_val = np.std(values)
        metrics_dict[key].append(f"{mean_val:.4f} ± {std_val:.4f}")

    df = pd.DataFrame(metrics_dict)

    csv_path = os.path.join(
        CONFIG['metrics_dir'],
        f"{CONFIG['model_name']}_metrics.csv"
    )
    df.to_csv(csv_path, index=False)

    print(f"\n✅ Metrics saved to: {csv_path}")
    print(f"\n{df.to_string(index=False)}")

    return df

def plot_average_confusion_matrix(fold_cms):
    """Plot average confusion matrix across all folds."""

    avg_cm = np.mean(fold_cms, axis=0)

    plt.figure(figsize=(8, 6))
    sns.heatmap(
        avg_cm,
        annot=True,
        fmt='.1f',
        cmap='Blues',
        xticklabels=['Benign', 'Malignant'],
        yticklabels=['Benign', 'Malignant'],
        cbar_kws={'label': 'Average Count'}
    )
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.title('Average Confusion Matrix (5-Fold CV - CvT_TL_2)')

    plot_path = os.path.join(
        CONFIG['metrics_dir'],
        f"{CONFIG['model_name']}_avg_cm.png"
    )
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.show()

    print(f"\n✅ Average confusion matrix saved: {plot_path}")

def plot_mean_confusion_matrix(fold_cms):
    """
    NUEVO: Plot mean confusion matrix with both absolute values and percentages.
    """
    
    mean_cm = np.mean(fold_cms, axis=0)
    
    # Normalize by rows (percentages)
    mean_cm_normalized = mean_cm / mean_cm.sum(axis=1, keepdims=True)
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Absolute values
    sns.heatmap(
        mean_cm,
        annot=True,
        fmt='.1f',
        cmap='Blues',
        xticklabels=['Benign', 'Malignant'],
        yticklabels=['Benign', 'Malignant'],
        cbar_kws={'label': 'Mean Count'},
        ax=axes[0]
    )
    axes[0].set_ylabel('True Label')
    axes[0].set_xlabel('Predicted Label')
    axes[0].set_title('Mean Confusion Matrix - Absolute Values')
    
    # Percentages
    sns.heatmap(
        mean_cm_normalized,
        annot=True,
        fmt='.2%',
        cmap='Blues',
        xticklabels=['Benign', 'Malignant'],
        yticklabels=['Benign', 'Malignant'],
        cbar_kws={'label': 'Percentage'},
        ax=axes[1],
        vmin=0,
        vmax=1
    )
    axes[1].set_ylabel('True Label')
    axes[1].set_xlabel('Predicted Label')
    axes[1].set_title('Mean Confusion Matrix - Normalized by Row')
    
    plt.tight_layout()
    
    plot_path = os.path.join(
        CONFIG['metrics_dir'],
        f"{CONFIG['model_name']}_mean_confusion_matrix.png"
    )
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\n✅ Mean confusion matrix saved: {plot_path}")
    
    # Print statistics
    print("\n📊 Mean Confusion Matrix Statistics:")
    print(f"  True Negatives (TN): {mean_cm[0, 0]:.1f}")
    print(f"  False Positives (FP): {mean_cm[0, 1]:.1f}")
    print(f"  False Negatives (FN): {mean_cm[1, 0]:.1f}")
    print(f"  True Positives (TP): {mean_cm[1, 1]:.1f}")
    print(f"\n  Sensitivity (Recall): {mean_cm_normalized[1, 1]:.2%}")
    print(f"  Specificity: {mean_cm_normalized[0, 0]:.2%}")

print("✓ Visualization functions defined")

In [None]:
# %%
# Start training
print("\n" + "="*80)
print("STARTING 5-FOLD CROSS VALIDATION TRAINING (CvT_TL_2 - ADVANCED)")
print("="*80)
print("\n🚀 New Features:")
print("  ✓ Focal Loss")
print("  ✓ Class Weights")
print("  ✓ Mixup Augmentation")
print("  ✓ Cosine Annealing Scheduler")
print("  ✓ Stochastic Weight Averaging (SWA)")
print("  ✓ Test-Time Augmentation (TTA)")
print("  ✓ Gradient Accumulation")
print("  ✓ Improved Augmentation (NO SCALE)")
print("  ✓ Higher Dropout (0.5)")
print("  ✓ F1-Score based Early Stopping")

start_time = time.time()

# Train
fold_metrics, fold_histories, fold_cms = train_with_cross_validation()

total_time = time.time() - start_time

print("\n" + "="*80)
print("TRAINING COMPLETED")
print("="*80)
print(f"\n⏱️ Total training time: {total_time/3600:.2f} hours")

# Save metrics summary
metrics_df = save_metrics_summary(fold_metrics)

# Plot average confusion matrix
plot_average_confusion_matrix(fold_cms)

# NUEVO: Plot mean confusion matrix
plot_mean_confusion_matrix(fold_cms)

print("\n" + "="*80)
print("ALL DONE!")
print("="*80)
print(f"\n📁 Results saved to:")
print(f"  Models: {CONFIG['output_dir']}")
print(f"  Metrics: {CONFIG['metrics_dir']}")