In [None]:
# %%
# ============================================================================
# ViT_TL_2 - VERSIÓN MEJORADA CON OPTIMIZACIONES AVANZADAS
# ============================================================================

import os
import sys
import json
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

# Torchvision
from torchvision import transforms
from PIL import Image

# Transformers (Hugging Face)
from transformers import ViTModel, ViTConfig

# Sklearn
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    f1_score, precision_score, recall_score,
    confusion_matrix, classification_report
)

# Verificar GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

    # Optimizaciones de memoria
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    # Configurar asignación de memoria expandible
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

    print("\n✓ Optimizaciones de memoria aplicadas")


In [None]:
# %%
# ============================================================================
# 2. CONFIGURACIÓN DE PATHS Y PARÁMETROS MEJORADOS
# ============================================================================

# Paths base
BASE_DIR = Path("/home/merivadeneira")
MASAS_DIR = BASE_DIR / "Masas"
OUTPUT_DIR = BASE_DIR / "Outputs" / "ViT"
METRICS_DIR = BASE_DIR / "Metrics" / "ViT"

# Crear directorios
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
METRICS_DIR.mkdir(parents=True, exist_ok=True)

# Configuración del modelo MEJORADA
CONFIG = {
    # Datos
    'input_channels': 3,
    'input_size': 224,
    'num_classes': 2,

    # Transfer Learning
    'pretrained_model': 'google/vit-base-patch16-224',
    'freeze_backbone': True,
    'unfreeze_after_epoch': 5,  # CAMBIADO: De 10 a 5
    'gradual_unfreeze': True,  # NUEVO: Descongelar gradualmente
    'unfreeze_layers_per_epoch': 2,  # NUEVO

    # Training
    'batch_size': 16,  # CAMBIADO: De 32 a 16
    'gradient_accumulation_steps': 4,  # NUEVO: Batch efectivo = 64
    'num_epochs': 100,
    'k_folds': 5,
    'early_stopping_patience': 35,  # CAMBIADO: De 25 a 35
    'early_stopping_min_delta': 0.001,  # NUEVO

    # Optimizer
    'optimizer': 'AdamW',
    'base_lr': 5e-5,  # CAMBIADO: De 1e-4 a 5e-5
    'backbone_lr': 5e-6,  # CAMBIADO: De 1e-5 a 5e-6
    'weight_decay': 0.05,  # CAMBIADO: De 0.01 a 0.05
    'warmup_epochs': 5,  # NUEVO
    'min_lr': 1e-7,  # NUEVO

    # Scheduler
    'lr_scheduler': 'CosineAnnealingWarmRestarts',  # CAMBIADO
    'scheduler_T_0': 10,  # NUEVO
    'scheduler_T_mult': 2,  # NUEVO
    'scheduler_eta_min': 1e-7,  # NUEVO

    # Regularization
    'dropout_rate': 0.3,  # NUEVO
    'label_smoothing': 0.1,  # NUEVO
    'mixup_alpha': 0.2,  # NUEVO

    # Loss Function
    'use_focal_loss': True,  # NUEVO
    'focal_loss_alpha': 0.25,  # NUEVO
    'focal_loss_gamma': 2.0,  # NUEVO
    'use_class_weights': True,  # NUEVO

    # Augmentation (MÁS AGRESIVO)
    'rotation_degrees': 20,  # CAMBIADO: De 15 a 20
    'translate': 0.15,  # CAMBIADO: De 0.1 a 0.15
    'horizontal_flip_prob': 0.5,  # NUEVO
    'vertical_flip_prob': 0.3,  # NUEVO
    'gaussian_blur_prob': 0.3,  # NUEVO
    'gaussian_blur_kernel': 5,
    'gaussian_blur_sigma': (0.1, 1.0),  # CAMBIADO: De (0.1, 0.5) a (0.1, 1.0)

    # Test-Time Augmentation
    'use_tta': True,  # NUEVO
    'tta_transforms': 5,  # NUEVO

    # Normalization
    'mean': [0.485, 0.456, 0.406],
    'std': [0.229, 0.224, 0.225],

    # Model name
    'model_name': 'ViT_TL_2'
}

print("\n" + "="*70)
print("CONFIGURACIÓN MEJORADA - ViT_TL_2")
print("="*70)
for key, value in CONFIG.items():
    print(f"{key:35s}: {value}")
print("="*70 + "\n")

In [None]:
# %%
# ============================================================================
# 3. UTILIDADES PARA PREVENIR DATA LEAKAGE
# ============================================================================

def extract_patient_id(filename, database):
    """Extrae el patient ID del nombre del archivo"""
    if database == 'DDSM':
        parts = filename.split('_')
        if len(parts) >= 2:
            return f"DDSM_{parts[1]}"

    elif database == 'INbreast':
        parts = filename.split('_')
        if len(parts) >= 1:
            return f"INbreast_{parts[0]}"

    return filename

def load_image_paths_with_patient_ids():
    """Carga todas las rutas de imágenes con sus patient IDs y labels"""
    data_list = []

    # Procesar DDSM
    for label_name, label_value in [('Benignas', 0), ('Malignas', 1)]:
        ddsm_path = MASAS_DIR / "DDSM" / label_name / "Resized_512"

        if ddsm_path.exists():
            for img_file in ddsm_path.glob("*.png"):
                patient_id = extract_patient_id(img_file.name, 'DDSM')
                data_list.append({
                    'image_path': str(img_file),
                    'patient_id': patient_id,
                    'label': label_value,
                    'database': 'DDSM'
                })

    # Procesar INbreast
    for label_name, label_value in [('Benignas', 0), ('Malignas', 1)]:
        inbreast_path = MASAS_DIR / "INbreast" / label_name / "Resized_512"

        if inbreast_path.exists():
            for img_file in inbreast_path.glob("*.png"):
                patient_id = extract_patient_id(img_file.name, 'INbreast')
                data_list.append({
                    'image_path': str(img_file),
                    'patient_id': patient_id,
                    'label': label_value,
                    'database': 'INbreast'
                })

    data_df = pd.DataFrame(data_list)

    print(f"\nTotal de imágenes cargadas: {len(data_df)}")
    print(f"  - DDSM: {len(data_df[data_df['database']=='DDSM'])}")
    print(f"  - INbreast: {len(data_df[data_df['database']=='INbreast'])}")
    print(f"\nDistribución de clases:")
    print(f"  - Benignas (0): {len(data_df[data_df['label']==0])}")
    print(f"  - Malignas (1): {len(data_df[data_df['label']==1])}")
    print(f"\nTotal de pacientes únicos: {data_df['patient_id'].nunique()}")

    return data_df

def create_patient_level_splits(data_df, k_folds=5, random_state=42):
    """Crea splits de K-Fold a nivel de paciente"""
    patient_labels = data_df.groupby('patient_id')['label'].agg(
        lambda x: x.mode()[0] if len(x.mode()) > 0 else x.iloc[0]
    ).reset_index()

    patient_labels.columns = ['patient_id', 'label']

    skf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=random_state)

    fold_splits = []
    for fold_idx, (train_patient_idx, val_patient_idx) in enumerate(
        skf.split(patient_labels['patient_id'], patient_labels['label'])
    ):
        train_patients = patient_labels.iloc[train_patient_idx]['patient_id'].values
        val_patients = patient_labels.iloc[val_patient_idx]['patient_id'].values

        train_indices = data_df[data_df['patient_id'].isin(train_patients)].index.tolist()
        val_indices = data_df[data_df['patient_id'].isin(val_patients)].index.tolist()

        fold_splits.append((train_indices, val_indices))

        print(f"\nFold {fold_idx + 1}:")
        print(f"  Train: {len(train_indices)} images from {len(train_patients)} patients")
        print(f"  Val:   {len(val_indices)} images from {len(val_patients)} patients")

    return fold_splits

In [None]:
# %%
# ============================================================================
# 4. DATASET CON AUGMENTACIONES MEJORADAS
# ============================================================================

class MammographyDataset(Dataset):
    """Dataset con augmentaciones mejoradas"""

    def __init__(self, data_df, indices, transform=None):
        self.data = data_df.iloc[indices].reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.loc[idx, 'image_path']
        label = self.data.loc[idx, 'label']

        # Cargar imagen
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

def get_transforms(config, is_training=True):
    """Transformaciones mejoradas con más augmentations"""

    if is_training:
        transform_list = [
            transforms.Resize((config['input_size'], config['input_size'])),
            transforms.RandomAffine(
                degrees=config['rotation_degrees'],
                translate=(config['translate'], config['translate'])
            ),
            transforms.RandomHorizontalFlip(p=config['horizontal_flip_prob']),
            transforms.RandomVerticalFlip(p=config['vertical_flip_prob']),
        ]

        # Gaussian Blur con probabilidad
        if config['gaussian_blur_prob'] > 0:
            transform_list.append(
                transforms.RandomApply([
                    transforms.GaussianBlur(
                        kernel_size=config['gaussian_blur_kernel'],
                        sigma=config['gaussian_blur_sigma']
                    )
                ], p=config['gaussian_blur_prob'])
            )

        transform_list.extend([
            transforms.ToTensor(),
            transforms.Normalize(mean=config['mean'], std=config['std'])
        ])

    else:
        transform_list = [
            transforms.Resize((config['input_size'], config['input_size'])),
            transforms.ToTensor(),
            transforms.Normalize(mean=config['mean'], std=config['std'])
        ]

    return transforms.Compose(transform_list)

In [None]:

# %%
# ============================================================================
# 5. FOCAL LOSS Y LABEL SMOOTHING
# ============================================================================

class FocalLoss(nn.Module):
    """Focal Loss para manejar desbalance de clases"""

    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class LabelSmoothingCrossEntropy(nn.Module):
    """Cross Entropy con Label Smoothing"""

    def __init__(self, smoothing=0.1):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):
        n_classes = pred.size(-1)
        log_preds = F.log_softmax(pred, dim=-1)

        loss = -log_preds.sum(dim=-1).mean()
        nll = F.nll_loss(log_preds, target, reduction='mean')

        return self.smoothing * loss / n_classes + (1 - self.smoothing) * nll

In [None]:
# %%
# ============================================================================
# 6. MIXUP DATA AUGMENTATION
# ============================================================================

def mixup_data(x, y, alpha=0.2):
    """MixUp augmentation"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]

    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Loss para MixUp"""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [None]:
# %%
# ============================================================================
# 7. MODELO ViT CON DROPOUT
# ============================================================================

class ViTClassifier(nn.Module):
    """ViT con dropout para regularización"""

    def __init__(self, config):
        super(ViTClassifier, self).__init__()

        self.vit = ViTModel.from_pretrained(config['pretrained_model'])
        hidden_size = self.vit.config.hidden_size

        # Clasificador con dropout
        self.classifier = nn.Sequential(
            nn.Dropout(config['dropout_rate']),
            nn.Linear(hidden_size, config['num_classes'])
        )

        # Congelar backbone si se especifica
        if config['freeze_backbone']:
            for param in self.vit.parameters():
                param.requires_grad = False

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        pooled_output = outputs.last_hidden_state[:, 0]
        logits = self.classifier(pooled_output)
        return logits

    def unfreeze_backbone(self):
        """Descongelar todo el backbone"""
        for param in self.vit.parameters():
            param.requires_grad = True

    def gradual_unfreeze(self, num_layers=2):
        """Descongelar gradualmente capas del encoder"""
        encoder_layers = list(self.vit.encoder.layer)

        # Descongelar las últimas num_layers capas
        for layer in encoder_layers[-num_layers:]:
            for param in layer.parameters():
                param.requires_grad = True


In [None]:
# %%
# ============================================================================
# 8. WARMUP LEARNING RATE SCHEDULER
# ============================================================================

class WarmupScheduler:
    """Warmup scheduler para aumentar LR gradualmente"""

    def __init__(self, optimizer, warmup_epochs, base_lr):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.base_lr = base_lr
        self.current_epoch = 0

    def step(self):
        if self.current_epoch < self.warmup_epochs:
            lr = self.base_lr * (self.current_epoch + 1) / self.warmup_epochs
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
        self.current_epoch += 1

In [None]:
# %%
# ============================================================================
# 9. TRAINING CON TODAS LAS MEJORAS
# ============================================================================

def train_epoch(model, train_loader, criterion, optimizer, device, config, epoch, warmup_scheduler=None):
    """Entrenamiento de una época con todas las mejoras"""
    model.train()

    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    # Warmup en las primeras épocas
    if warmup_scheduler and epoch < config['warmup_epochs']:
        warmup_scheduler.step()

    optimizer.zero_grad()

    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Training Epoch {epoch+1}")

    for batch_idx, (images, labels) in pbar:
        images, labels = images.to(device), labels.to(device)

        # MixUp augmentation
        if config['mixup_alpha'] > 0 and np.random.random() > 0.5:
            images, labels_a, labels_b, lam = mixup_data(images, labels, config['mixup_alpha'])
            outputs = model(images)
            loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
        else:
            outputs = model(images)
            loss = criterion(outputs, labels)

        # Gradient accumulation
        loss = loss / config['gradient_accumulation_steps']
        loss.backward()

        if (batch_idx + 1) % config['gradient_accumulation_steps'] == 0:
            optimizer.step()
            optimizer.zero_grad()

        # Métricas
        running_loss += loss.item() * config['gradient_accumulation_steps']
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

        pbar.set_postfix({
            'loss': running_loss / (batch_idx + 1),
            'acc': 100. * correct / total
        })

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    epoch_f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)

    return epoch_loss, epoch_acc, epoch_f1


def validate_epoch(model, val_loader, criterion, device, use_tta=False, tta_transforms=5):
    """Validación con opción de Test-Time Augmentation"""
    model.eval()

    running_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validating", leave=False):
            images, labels = images.to(device), labels.to(device)

            if use_tta:
                # Test-Time Augmentation
                tta_outputs = []
                for _ in range(tta_transforms):
                    # Aplicar transformaciones aleatorias
                    augmented = images.clone()
                    if torch.rand(1) > 0.5:
                        augmented = torch.flip(augmented, dims=[3])  # Horizontal flip
                    if torch.rand(1) > 0.7:
                        augmented = torch.flip(augmented, dims=[2])  # Vertical flip

                    outputs = model(augmented)
                    tta_outputs.append(F.softmax(outputs, dim=1))

                # Promediar predicciones
                outputs = torch.stack(tta_outputs).mean(dim=0)
                loss = criterion(torch.log(outputs + 1e-8), labels)
            else:
                outputs = model(images)
                loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    epoch_loss = running_loss / len(val_loader)

    # Calcular métricas
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)

    cm = confusion_matrix(all_labels, all_preds)
    tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0

    metrics = {
        'loss': epoch_loss,
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'specificity': specificity,
        'accuracy': accuracy * 100,
        'confusion_matrix': cm,
        'predictions': all_preds,
        'labels': all_labels
    }

    return metrics

In [None]:
# %%
# ============================================================================
# 10. MAIN TRAINING LOOP
# ============================================================================

def main():
    print("\n" + "="*70)
    print("INICIANDO ENTRENAMIENTO - ViT_TL_2")
    print("="*70 + "\n")

    # 1. Cargar datos
    print("[1/5] Cargando datos...")
    data_df = load_image_paths_with_patient_ids()
    fold_splits = create_patient_level_splits(data_df, k_folds=CONFIG['k_folds'])

    # 2. Calcular class weights
    if CONFIG['use_class_weights']:
        class_counts = data_df['label'].value_counts().sort_index()
        total_samples = len(data_df)
        class_weights = torch.FloatTensor([
            total_samples / (len(class_counts) * class_counts[i])
            for i in range(CONFIG['num_classes'])
        ]).to(device)
        print(f"\nClass weights: {class_weights.cpu().numpy()}")
    else:
        class_weights = None

    # Almacenar resultados
    all_fold_results = []
    all_confusion_matrices = []
    metrics_summary = defaultdict(list)
    epoch_metrics_all_folds = []  # NUEVO: Para guardar métricas por época

    # 3. K-Fold Cross Validation
    print(f"\n[2/5] Iniciando {CONFIG['k_folds']}-Fold Cross Validation...")

    for fold_idx, (train_indices, val_indices) in enumerate(fold_splits):
        print(f"\n{'='*70}")
        print(f"FOLD {fold_idx + 1}/{CONFIG['k_folds']}")
        print(f"{'='*70}")

        # Datasets
        train_transform = get_transforms(CONFIG, is_training=True)
        val_transform = get_transforms(CONFIG, is_training=False)

        train_dataset = MammographyDataset(data_df, train_indices, train_transform)
        val_dataset = MammographyDataset(data_df, val_indices, val_transform)

        train_loader = DataLoader(
            train_dataset,
            batch_size=CONFIG['batch_size'],
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=CONFIG['batch_size'],
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )

        # Modelo
        model = ViTClassifier(CONFIG).to(device)

        # Loss function
        if CONFIG['use_focal_loss']:
            criterion = FocalLoss(
                alpha=CONFIG['focal_loss_alpha'],
                gamma=CONFIG['focal_loss_gamma']
            )
        elif CONFIG['label_smoothing'] > 0:
            criterion = LabelSmoothingCrossEntropy(smoothing=CONFIG['label_smoothing'])
        else:
            criterion = nn.CrossEntropyLoss(weight=class_weights)

        # Optimizer con diferentes LR para backbone y clasificador
        optimizer = AdamW([
            {'params': model.classifier.parameters(), 'lr': CONFIG['base_lr']},
            {'params': model.vit.parameters(), 'lr': CONFIG['backbone_lr']}
        ], weight_decay=CONFIG['weight_decay'])

        # Scheduler
        scheduler = CosineAnnealingWarmRestarts(
            optimizer,
            T_0=CONFIG['scheduler_T_0'],
            T_mult=CONFIG['scheduler_T_mult'],
            eta_min=CONFIG['scheduler_eta_min']
        )

        # Warmup scheduler
        warmup_scheduler = WarmupScheduler(
            optimizer,
            CONFIG['warmup_epochs'],
            CONFIG['base_lr']
        )

        # Training loop
        best_val_f1 = 0
        patience_counter = 0
        history = {
            'train_loss': [], 'train_acc': [], 'train_f1': [],
            'val_loss': [], 'val_acc': [], 'val_f1': [],
            'val_precision': [], 'val_recall': [], 'val_specificity': []
        }

        fold_epoch_metrics = []  # NUEVO: Métricas de este fold por época

        for epoch in range(CONFIG['num_epochs']):
            # Descongelar backbone gradualmente
            if epoch == CONFIG['unfreeze_after_epoch']:
                if CONFIG['gradual_unfreeze']:
                    print(f"\n🔓 Descongelando {CONFIG['unfreeze_layers_per_epoch']} capas del backbone...")
                    model.gradual_unfreeze(CONFIG['unfreeze_layers_per_epoch'])
                else:
                    print(f"\n🔓 Descongelando todo el backbone...")
                    model.unfreeze_backbone()

            # Train
            train_loss, train_acc, train_f1 = train_epoch(
                model, train_loader, criterion, optimizer, device, CONFIG, epoch, warmup_scheduler
            )

            # Validate
            val_metrics = validate_epoch(
                model, val_loader, criterion, device,
                use_tta=CONFIG['use_tta'],
                tta_transforms=CONFIG['tta_transforms']
            )

            # Scheduler step
            if epoch >= CONFIG['warmup_epochs']:
                scheduler.step()

            # Guardar historia
            history['train_loss'].append(train_loss)
            history['train_acc'].append(train_acc)
            history['train_f1'].append(train_f1)
            history['val_loss'].append(val_metrics['loss'])
            history['val_acc'].append(val_metrics['accuracy'])
            history['val_f1'].append(val_metrics['f1'])
            history['val_precision'].append(val_metrics['precision'])
            history['val_recall'].append(val_metrics['recall'])
            history['val_specificity'].append(val_metrics['specificity'])

            # NUEVO: Guardar métricas por época
            fold_epoch_metrics.append({
                'fold': fold_idx + 1,
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'train_acc': train_acc,
                'train_f1': train_f1,
                'val_loss': val_metrics['loss'],
                'val_acc': val_metrics['accuracy'],
                'val_f1': val_metrics['f1'],
                'val_precision': val_metrics['precision'],
                'val_recall': val_metrics['recall'],
                'val_specificity': val_metrics['specificity']
            })

            # Print progress
            print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
            print(f"  Train - Loss: {train_loss:.4f} | Acc: {train_acc:.2f}% | F1: {train_f1:.4f}")
            print(f"  Val   - Loss: {val_metrics['loss']:.4f} | Acc: {val_metrics['accuracy']:.2f}% | F1: {val_metrics['f1']:.4f}")
            print(f"  Val   - Precision: {val_metrics['precision']:.4f} | Recall: {val_metrics['recall']:.4f} | Spec: {val_metrics['specificity']:.4f}")

            # Early stopping
            if val_metrics['f1'] > best_val_f1 + CONFIG['early_stopping_min_delta']:
                best_val_f1 = val_metrics['f1']
                patience_counter = 0

                # Guardar mejor modelo
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_val_f1': best_val_f1,
                    'config': CONFIG,
                    'history': history
                }, OUTPUT_DIR / f"{CONFIG['model_name']}_fold{fold_idx+1}_best.pth")

                print(f"  ✓ Mejor modelo guardado (F1: {best_val_f1:.4f})")
            else:
                patience_counter += 1
                print(f"  Patience: {patience_counter}/{CONFIG['early_stopping_patience']}")

            if patience_counter >= CONFIG['early_stopping_patience']:
                print(f"\n⚠ Early stopping triggered at epoch {epoch+1}")
                break

        # Guardar métricas de época de este fold
        epoch_metrics_all_folds.extend(fold_epoch_metrics)

        # Cargar mejor modelo
        checkpoint = torch.load(
          OUTPUT_DIR / f"{CONFIG['model_name']}_fold{fold_idx+1}_best.pth",
          map_location=device,
          weights_only=False
      )
        model.load_state_dict(checkpoint['model_state_dict'])

        # Validación final
        final_val_metrics = validate_epoch(
            model, val_loader, criterion, device,
            use_tta=CONFIG['use_tta'],
            tta_transforms=CONFIG['tta_transforms']
        )

        # Guardar resultados del fold
        fold_results = {
            'fold': fold_idx + 1,
            'best_epoch': checkpoint['epoch'] + 1,
            'f1': final_val_metrics['f1'],
            'precision': final_val_metrics['precision'],
            'recall': final_val_metrics['recall'],
            'specificity': final_val_metrics['specificity'],
            'accuracy': final_val_metrics['accuracy'],
            'confusion_matrix': final_val_metrics['confusion_matrix']
        }

        all_fold_results.append(fold_results)
        all_confusion_matrices.append(final_val_metrics['confusion_matrix'])

        for key in ['f1', 'precision', 'recall', 'specificity', 'accuracy']:
            metrics_summary[key].append(fold_results[key])

        print(f"\nFold {fold_idx + 1} completado:")
        print(f"  F1: {fold_results['f1']:.4f} | Precision: {fold_results['precision']:.4f} | Recall: {fold_results['recall']:.4f}")
        print(f"  Specificity: {fold_results['specificity']:.4f} | Accuracy: {fold_results['accuracy']:.2f}%")

        # Matriz de confusión individual
        plt.figure(figsize=(8, 6))
        cm_percent = final_val_metrics['confusion_matrix'].astype('float') / final_val_metrics['confusion_matrix'].sum(axis=1)[:, np.newaxis] * 100

        sns.heatmap(
            final_val_metrics['confusion_matrix'],
            annot=True,
            fmt='d',
            cmap='Blues',
            xticklabels=['Benigna', 'Maligna'],
            yticklabels=['Benigna', 'Maligna'],
            cbar_kws={'label': 'Count'}
        )

        # Añadir porcentajes
        for i in range(2):
            for j in range(2):
                plt.text(j + 0.5, i + 0.7, f'({cm_percent[i, j]:.1f}%)',
                        ha='center', va='center', color='red', fontsize=9)

        plt.title(f'Confusion Matrix - Fold {fold_idx + 1}', fontsize=14, fontweight='bold')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        plt.savefig(METRICS_DIR / f"{CONFIG['model_name']}_fold{fold_idx+1}_confusion_matrix.png", dpi=300)
        plt.close()

    # 4. Guardar métricas finales
    print("\n[4/5] Guardando métricas finales...")

    # ARCHIVO 1: ViT_TL_2_metrics.csv (métricas por fold con promedio)
    metrics_data = []
    for result in all_fold_results:
        metrics_data.append({
            'Fold': result['fold'],
            'F1': result['f1'],
            'Precision': result['precision'],
            'Recall': result['recall'],
            'Specificity': result['specificity'],
            'Accuracy': result['accuracy']
        })

    # Añadir fila de promedios
    metrics_data.append({
        'Fold': 'Mean ± Std',
        'F1': f"{np.mean(metrics_summary['f1']):.4f} ± {np.std(metrics_summary['f1']):.4f}",
        'Precision': f"{np.mean(metrics_summary['precision']):.4f} ± {np.std(metrics_summary['precision']):.4f}",
        'Recall': f"{np.mean(metrics_summary['recall']):.4f} ± {np.std(metrics_summary['recall']):.4f}",
        'Specificity': f"{np.mean(metrics_summary['specificity']):.4f} ± {np.std(metrics_summary['specificity']):.4f}",
        'Accuracy': f"{np.mean(metrics_summary['accuracy']):.2f} ± {np.std(metrics_summary['accuracy']):.2f}"
    })

    metrics_df = pd.DataFrame(metrics_data)
    metrics_csv_path = METRICS_DIR / f"{CONFIG['model_name']}_metrics.csv"
    metrics_df.to_csv(metrics_csv_path, index=False)
    print(f"✓ Métricas guardadas: {metrics_csv_path}")

    # ARCHIVO 2: ViT_TL_2_metrics_summary.csv (métricas por época para todos los folds)
    summary_df = pd.DataFrame(epoch_metrics_all_folds)
    summary_csv_path = METRICS_DIR / f"{CONFIG['model_name']}_metrics_summary.csv"
    summary_df.to_csv(summary_csv_path, index=False)
    print(f"✓ Resumen de métricas guardado: {summary_csv_path}")

    # ARCHIVO 3: ViT_TL_2_mean_confusion_matrix.png
    print("\n[5/5] Creando matriz de confusión promedio...")
    mean_cm = np.mean(all_confusion_matrices, axis=0).astype(int)

    plt.figure(figsize=(10, 8))
    sns.heatmap(
        mean_cm,
        annot=True,
        fmt='d',
        cmap='Blues',
        xticklabels=['Benigna', 'Maligna'],
        yticklabels=['Benigna', 'Maligna'],
        cbar_kws={'label': 'Average Count'}
    )
    plt.title(f'Average Confusion Matrix - {CONFIG["model_name"]}',
              fontsize=16, fontweight='bold')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')

    # Añadir métricas
    metrics_text = f"""Mean Metrics:
F1-Score: {np.mean(metrics_summary['f1']):.4f} ± {np.std(metrics_summary['f1']):.4f}
Precision: {np.mean(metrics_summary['precision']):.4f} ± {np.std(metrics_summary['precision']):.4f}
Recall: {np.mean(metrics_summary['recall']):.4f} ± {np.std(metrics_summary['recall']):.4f}
Specificity: {np.mean(metrics_summary['specificity']):.4f} ± {np.std(metrics_summary['specificity']):.4f}
Accuracy: {np.mean(metrics_summary['accuracy']):.2f} ± {np.std(metrics_summary['accuracy']):.2f}"""

    plt.text(
        2.5, 0.5, metrics_text,
        fontsize=10,
        verticalalignment='center',
        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    )

    mean_cm_path = METRICS_DIR / f"{CONFIG['model_name']}_mean_confusion_matrix.png"
    plt.tight_layout()
    plt.savefig(mean_cm_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✓ Matriz de confusión promedio guardada: {mean_cm_path}")

    # Guardar configuración
    config_path = OUTPUT_DIR / f"{CONFIG['model_name']}_config.json"
    with open(config_path, 'w') as f:
        json.dump(CONFIG, f, indent=4)
    print(f"✓ Configuración guardada: {config_path}")

    # Resumen final
    print("\n" + "="*70)
    print("✓ ENTRENAMIENTO COMPLETADO - ViT_TL_2")
    print("="*70)
    print(f"\nRESULTADOS FINALES ({CONFIG['k_folds']}-Fold Cross Validation):")
    print(f"  F1-Score:     {np.mean(metrics_summary['f1']):.4f} ± {np.std(metrics_summary['f1']):.4f}")
    print(f"  Precision:    {np.mean(metrics_summary['precision']):.4f} ± {np.std(metrics_summary['precision']):.4f}")
    print(f"  Recall:       {np.mean(metrics_summary['recall']):.4f} ± {np.std(metrics_summary['recall']):.4f}")
    print(f"  Specificity:  {np.mean(metrics_summary['specificity']):.4f} ± {np.std(metrics_summary['specificity']):.4f}")
    print(f"  Accuracy:     {np.mean(metrics_summary['accuracy']):.2f}% ± {np.std(metrics_summary['accuracy']):.2f}%")
    print("\nARCHIVOS GENERADOS:")
    print(f"  ✓ {CONFIG['model_name']}_metrics.csv")
    print(f"  ✓ {CONFIG['model_name']}_metrics_summary.csv")
    print(f"  ✓ {CONFIG['model_name']}_mean_confusion_matrix.png")
    print(f"  ✓ {CONFIG['k_folds']} modelos (.pth)")
    print(f"  ✓ {CONFIG['k_folds']} matrices de confusión individuales")
    print(f"  ✓ 1 archivo de configuración (.json)")
    print("="*70)


In [None]:
# %%
# ============================================================================
# 11. EJECUTAR
# ============================================================================

if __name__ == "__main__":
    main()