In [None]:
# ============================================================================
# 1. SETUP Y CONFIGURACIÓN
# ============================================================================

import os
import sys
import json
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Torchvision
from torchvision import transforms
from PIL import Image

# Transformers (Hugging Face)
from transformers import ViTModel, ViTConfig

# Sklearn
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    f1_score, precision_score, recall_score,
    confusion_matrix, classification_report
)

# Verificar GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

    # Optimizaciones de memoria
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    # Configurar asignación de memoria expandible
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

    print("\n✓ Optimizaciones de memoria aplicadas")

In [None]:
# ============================================================================
# 2. CONFIGURACIÓN DE PATHS Y PARÁMETROS
# ============================================================================

# Paths base
BASE_DIR = Path("/home/merivadeneira")
MASAS_DIR = BASE_DIR / "Masas"
OUTPUT_DIR = BASE_DIR / "Outputs" / "ViT"
METRICS_DIR = BASE_DIR / "Metrics" / "ViT"

# Crear directorios
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
METRICS_DIR.mkdir(parents=True, exist_ok=True)

# Configuración del modelo
CONFIG = {
    # Datos
    'input_channels': 3,  # ViT pre-entrenado requiere 3 canales (RGB)
    'input_size': 224,    # Tamaño estándar para ViT pre-entrenado
    'num_classes': 2,

    # Transfer Learning
    'pretrained_model': 'google/vit-base-patch16-224',
    'freeze_backbone': True,  # Congelar capas pre-entrenadas inicialmente
    'unfreeze_after_epoch': 10,  # Descongelar después de 10 épocas

    # Training
    'batch_size': 32,
    'num_epochs': 100,
    'k_folds': 5,
    'early_stopping_patience': 25,

    # Optimizer (LR más bajo para transfer learning)
    'optimizer': 'AdamW',
    'base_lr': 1e-4,  # LR para clasificador
    'backbone_lr': 1e-5,  # LR para backbone (cuando se descongele)
    'weight_decay': 0.01,
    'lr_scheduler': 'ReduceLROnPlateau',
    'scheduler_patience': 5,
    'scheduler_factor': 0.5,

    # Augmentation
    'rotation_degrees': 15,
    'translate': 0.1,
    'gaussian_blur_kernel': 5,  # Kernel size para Gaussian Blur
    'gaussian_blur_sigma': (0.1, 0.5),  # Rango de sigma (bajo)

    # Normalization (ImageNet stats para transfer learning)
    'mean': [0.485, 0.456, 0.406],  # ImageNet mean
    'std': [0.229, 0.224, 0.225],   # ImageNet std

    # Model name
    'model_name': 'ViT_TL_Base'
}

print("\n" + "="*70)
print("CONFIGURACIÓN DEL ENTRENAMIENTO - TRANSFER LEARNING")
print("="*70)
for key, value in CONFIG.items():
    print(f"{key:30s}: {value}")
print("="*70 + "\n")


In [None]:
# ============================================================================
# 3. UTILIDADES PARA PREVENIR DATA LEAKAGE (IGUAL QUE ANTES)
# ============================================================================

def extract_patient_id(filename, database):
    """
    Extrae el patient ID del nombre del archivo para evitar data leakage
    """
    if database == 'DDSM':
        parts = filename.split('_')
        if len(parts) >= 2:
            return f"DDSM_{parts[1]}"

    elif database == 'INbreast':
        parts = filename.split('_')
        if len(parts) >= 1:
            return f"INbreast_{parts[0]}"

    return filename

def load_image_paths_with_patient_ids():
    """
    Carga todas las rutas de imágenes con sus patient IDs y labels
    """
    data_list = []

    # Procesar DDSM
    for label_name, label_value in [('Benignas', 0), ('Malignas', 1)]:
        ddsm_path = MASAS_DIR / "DDSM" / label_name / "Resized_512"

        if ddsm_path.exists():
            for img_file in ddsm_path.glob("*.png"):
                patient_id = extract_patient_id(img_file.name, 'DDSM')
                data_list.append({
                    'image_path': str(img_file),
                    'patient_id': patient_id,
                    'label': label_value,
                    'database': 'DDSM'
                })

    # Procesar INbreast
    for label_name, label_value in [('Benignas', 0), ('Malignas', 1)]:
        inbreast_path = MASAS_DIR / "INbreast" / label_name / "Resized_512"

        if inbreast_path.exists():
            for img_file in inbreast_path.glob("*.png"):
                patient_id = extract_patient_id(img_file.name, 'INbreast')
                data_list.append({
                    'image_path': str(img_file),
                    'patient_id': patient_id,
                    'label': label_value,
                    'database': 'INbreast'
                })

    data_df = pd.DataFrame(data_list)

    print(f"\nTotal de imágenes cargadas: {len(data_df)}")
    print(f"  - DDSM: {len(data_df[data_df['database']=='DDSM'])}")
    print(f"  - INbreast: {len(data_df[data_df['database']=='INbreast'])}")
    print(f"\nDistribución de clases:")
    print(f"  - Benignas (0): {len(data_df[data_df['label']==0])}")
    print(f"  - Malignas (1): {len(data_df[data_df['label']==1])}")
    print(f"\nTotal de pacientes únicos: {data_df['patient_id'].nunique()}")

    return data_df

def create_patient_level_splits(data_df, k_folds=10, random_state=42):
    """
    Crea splits de K-Fold a nivel de paciente (no de imagen)
    """
    patient_labels = data_df.groupby('patient_id')['label'].agg(
        lambda x: x.mode()[0] if len(x.mode()) > 0 else x.iloc[0]
    ).reset_index()

    patient_labels.columns = ['patient_id', 'label']

    skf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=random_state)

    fold_splits = []
    for fold_idx, (train_patient_idx, val_patient_idx) in enumerate(
        skf.split(patient_labels['patient_id'], patient_labels['label'])
    ):
        train_patients = patient_labels.iloc[train_patient_idx]['patient_id'].values
        val_patients = patient_labels.iloc[val_patient_idx]['patient_id'].values

        train_indices = data_df[data_df['patient_id'].isin(train_patients)].index.tolist()
        val_indices = data_df[data_df['patient_id'].isin(val_patients)].index.tolist()

        fold_splits.append((train_indices, val_indices))

        print(f"\nFold {fold_idx + 1}:")
        print(f"  Train: {len(train_indices)} images from {len(train_patients)} patients")
        print(f"  Val:   {len(val_indices)} images from {len(val_patients)} patients")

    return fold_splits


In [None]:
# ============================================================================
# 4. DATASET Y TRANSFORMACIONES
# ============================================================================

class MammographyDataset(Dataset):
    """Dataset personalizado para mamografías"""

    def __init__(self, data_df, indices, transform=None):
        self.data = data_df.iloc[indices].reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Cargar imagen en grayscale
        img_path = self.data.iloc[idx]['image_path']
        image = Image.open(img_path).convert('L')

        # Convertir a RGB (3 canales) para ViT pre-entrenado
        image = image.convert('RGB')

        # Aplicar transformaciones
        if self.transform:
            image = self.transform(image)

        label = self.data.iloc[idx]['label']

        return image, label

# Transformaciones
train_transform = transforms.Compose([
    transforms.Resize((CONFIG['input_size'], CONFIG['input_size'])),

    # Rotación
    transforms.RandomRotation(degrees=CONFIG['rotation_degrees']),

    # Traslación (sin scale ni shear)
    transforms.RandomAffine(
        degrees=0,
        translate=(CONFIG['translate'], CONFIG['translate']),
        scale=None,
        shear=None
    ),

    transforms.ToTensor(),

    # Gaussian Blur (aplicado después de ToTensor)
    transforms.GaussianBlur(
        kernel_size=CONFIG['gaussian_blur_kernel'],
        sigma=CONFIG['gaussian_blur_sigma']
    ),

    transforms.Normalize(mean=CONFIG['mean'], std=CONFIG['std'])
])

val_transform = transforms.Compose([
    transforms.Resize((CONFIG['input_size'], CONFIG['input_size'])),
    transforms.ToTensor(),
    transforms.Normalize(mean=CONFIG['mean'], std=CONFIG['std'])
])


In [None]:
# ============================================================================
# 5. MODELO CON TRANSFER LEARNING
# ============================================================================

class ViTForMammography(nn.Module):
    """
    ViT con Transfer Learning para clasificación de mamografías
    """

    def __init__(
        self,
        pretrained_model_name='google/vit-base-patch16-224',
        num_classes=2,
        freeze_backbone=True
    ):
        super().__init__()

        # Cargar modelo pre-entrenado
        print(f"Cargando modelo pre-entrenado: {pretrained_model_name}")
        self.vit = ViTModel.from_pretrained(pretrained_model_name)

        # Congelar backbone si se especifica
        if freeze_backbone:
            self.freeze_backbone()
            print("✓ Backbone congelado")

        # Clasificador personalizado
        hidden_size = self.vit.config.hidden_size  # 768 para ViT-Base

        self.classifier = nn.Sequential(
            nn.LayerNorm(hidden_size),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, num_classes)
        )

        print(f"✓ Modelo cargado con {self.count_parameters():,} parámetros")
        print(f"  - Backbone: {self.count_parameters(self.vit):,} parámetros")
        print(f"  - Classifier: {self.count_parameters(self.classifier):,} parámetros")

    def freeze_backbone(self):
        """Congela el backbone pre-entrenado"""
        for param in self.vit.parameters():
            param.requires_grad = False

    def unfreeze_backbone(self):
        """Descongela el backbone para fine-tuning"""
        for param in self.vit.parameters():
            param.requires_grad = True
        print("✓ Backbone descongelado para fine-tuning")

    def count_parameters(self, model=None):
        """Cuenta parámetros entrenables"""
        if model is None:
            model = self
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    def forward(self, pixel_values):
        # Pasar por ViT backbone
        outputs = self.vit(pixel_values=pixel_values)

        # Obtener CLS token (primer token)
        cls_output = outputs.last_hidden_state[:, 0]

        # Clasificación
        logits = self.classifier(cls_output)

        return logits


In [None]:
# ============================================================================
# 6. ENTRENAMIENTO Y EVALUACIÓN
# ============================================================================

def train_one_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc="Training")
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100.*correct/total:.2f}%'
        })

    epoch_loss = running_loss / total
    epoch_acc = correct / total

    return epoch_loss, epoch_acc


def validate(model, val_loader, criterion, device):
    """Validate model"""
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Validation"):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    epoch_loss = running_loss / len(all_labels)

    # Calculate metrics
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    accuracy = (all_preds == all_labels).mean()
    f1 = f1_score(all_labels, all_preds, average='binary')
    precision = precision_score(all_labels, all_preds, average='binary')
    recall = recall_score(all_labels, all_preds, average='binary')

    # Confusion matrix for specificity
    cm = confusion_matrix(all_labels, all_preds)
    tn, fp, fn, tp = cm.ravel()
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

    metrics = {
        'loss': epoch_loss,
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'specificity': specificity,
        'confusion_matrix': cm
    }

    return metrics


class EarlyStopping:
    """Early stopping"""

    def __init__(self, patience=25, min_delta=0, mode='min'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False

        if mode == 'min':
            self.monitor_op = lambda x, y: x < y - min_delta
        else:
            self.monitor_op = lambda x, y: x > y + min_delta

    def __call__(self, score):
        if self.best_score is None:
            self.best_score = score
        elif self.monitor_op(score, self.best_score):
            self.best_score = score
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

        return self.early_stop


def train_fold(
    fold_idx,
    train_loader,
    val_loader,
    model,
    criterion,
    optimizer,
    scheduler,
    device,
    num_epochs,
    early_stopping_patience,
    unfreeze_after_epoch
):
    """Train one fold with optional backbone unfreezing"""
    print(f"\n{'='*70}")
    print(f"FOLD {fold_idx + 1}")
    print(f"{'='*70}")

    early_stopping = EarlyStopping(patience=early_stopping_patience, mode='min')
    best_val_loss = float('inf')
    best_model_state = None
    backbone_unfrozen = False

    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'val_f1': [],
        'val_precision': [],
        'val_recall': [],
        'val_specificity': []
    }

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 70)

        # Descongelar backbone después de N épocas
        if epoch == unfreeze_after_epoch and not backbone_unfrozen:
            print(f"\n{'='*70}")
            print(f"DESCONGELANDO BACKBONE PARA FINE-TUNING")
            print(f"{'='*70}")
            model.unfreeze_backbone()

            # Ajustar optimizer con diferentes LRs
            optimizer = AdamW([
                {'params': model.vit.parameters(), 'lr': CONFIG['backbone_lr']},
                {'params': model.classifier.parameters(), 'lr': CONFIG['base_lr']}
            ], weight_decay=CONFIG['weight_decay'])

            backbone_unfrozen = True
            print(f"Backbone LR: {CONFIG['backbone_lr']:.2e}")
            print(f"Classifier LR: {CONFIG['base_lr']:.2e}\n")

        # Train
        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device
        )

        # Validate
        val_metrics = validate(model, val_loader, criterion, device)

        # Update scheduler
        scheduler.step(val_metrics['loss'])

        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_metrics['loss'])
        history['val_acc'].append(val_metrics['accuracy'])
        history['val_f1'].append(val_metrics['f1'])
        history['val_precision'].append(val_metrics['precision'])
        history['val_recall'].append(val_metrics['recall'])
        history['val_specificity'].append(val_metrics['specificity'])

        # Print metrics
        print(f"\nTrain Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_metrics['loss']:.4f} | Val Acc: {val_metrics['accuracy']:.4f}")
        print(f"Val F1: {val_metrics['f1']:.4f} | Val Precision: {val_metrics['precision']:.4f}")
        print(f"Val Recall: {val_metrics['recall']:.4f} | Val Specificity: {val_metrics['specificity']:.4f}")

        if backbone_unfrozen:
            print(f"Backbone LR: {optimizer.param_groups[0]['lr']:.2e} | Classifier LR: {optimizer.param_groups[1]['lr']:.2e}")
        else:
            print(f"Classifier LR: {optimizer.param_groups[0]['lr']:.2e}")

        # Save best model
        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            best_model_state = model.state_dict().copy()
            print(f"✓ Best model updated (Val Loss: {best_val_loss:.4f})")

        # Early stopping
        if early_stopping(val_metrics['loss']):
            print(f"\n✓ Early stopping triggered at epoch {epoch + 1}")
            break

    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    # Final validation with best model
    final_metrics = validate(model, val_loader, criterion, device)

    return model, history, final_metrics


def plot_training_history(history, fold_idx, save_path):
    """Plot training history"""
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle(f'Training History - Fold {fold_idx + 1}', fontsize=16, fontweight='bold')

    epochs = range(1, len(history['train_loss']) + 1)

    # Loss
    axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
    axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Accuracy
    axes[0, 1].plot(epochs, history['train_acc'], 'b-', label='Train Acc', linewidth=2)
    axes[0, 1].plot(epochs, history['val_acc'], 'r-', label='Val Acc', linewidth=2)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].set_title('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # F1 Score
    axes[0, 2].plot(epochs, history['val_f1'], 'g-', linewidth=2)
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].set_ylabel('F1 Score')
    axes[0, 2].set_title('Validation F1 Score')
    axes[0, 2].grid(True, alpha=0.3)

    # Precision
    axes[1, 0].plot(epochs, history['val_precision'], 'c-', linewidth=2)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Precision')
    axes[1, 0].set_title('Validation Precision')
    axes[1, 0].grid(True, alpha=0.3)

    # Recall
    axes[1, 1].plot(epochs, history['val_recall'], 'm-', linewidth=2)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Recall')
    axes[1, 1].set_title('Validation Recall')
    axes[1, 1].grid(True, alpha=0.3)

    # Specificity
    axes[1, 2].plot(epochs, history['val_specificity'], 'y-', linewidth=2)
    axes[1, 2].set_xlabel('Epoch')
    axes[1, 2].set_ylabel('Specificity')
    axes[1, 2].set_title('Validation Specificity')
    axes[1, 2].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()


def plot_confusion_matrix(cm, fold_idx, save_path):
    """Plot confusion matrix"""
    plt.figure(figsize=(8, 6))

    sns.heatmap(
        cm,
        annot=True,
        fmt='d',
        cmap='Blues',
        xticklabels=['Benigna', 'Maligna'],
        yticklabels=['Benigna', 'Maligna'],
        cbar_kws={'label': 'Count'}
    )

    plt.title(f'Confusion Matrix - Fold {fold_idx + 1}', fontsize=14, fontweight='bold')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')

    # Add percentage annotations
    total = cm.sum()
    for i in range(2):
        for j in range(2):
            percentage = cm[i, j] / total * 100
            plt.text(
                j + 0.5, i + 0.7,
                f'({percentage:.1f}%)',
                ha='center',
                va='center',
                fontsize=10,
                color='red' if i != j else 'green'
            )

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()


In [None]:
# ============================================================================
# 7. FUNCIÓN PRINCIPAL DE ENTRENAMIENTO
# ============================================================================

def main():
    """Main training function"""

    print("\n" + "="*70)
    print("INICIANDO ENTRENAMIENTO - VISION TRANSFORMER CON TRANSFER LEARNING")
    print("="*70)

    # 1. Cargar datos
    print("\n[1/5] Cargando datos...")
    data_df = load_image_paths_with_patient_ids()

    # 2. Crear splits por paciente
    print("\n[2/5] Creando splits K-Fold a nivel de paciente...")
    fold_splits = create_patient_level_splits(data_df, k_folds=CONFIG['k_folds'])

    # 3. Entrenamiento K-Fold
    print("\n[3/5] Iniciando entrenamiento K-Fold...")

    all_fold_metrics = []
    all_confusion_matrices = []

    for fold_idx, (train_indices, val_indices) in enumerate(fold_splits):

        # Limpiar memoria
        torch.cuda.empty_cache()

        # Crear datasets y dataloaders
        train_dataset = MammographyDataset(data_df, train_indices, train_transform)
        val_dataset = MammographyDataset(data_df, val_indices, val_transform)

        train_loader = DataLoader(
            train_dataset,
            batch_size=CONFIG['batch_size'],
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=CONFIG['batch_size'],
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )

        # Crear modelo
        model = ViTForMammography(
            pretrained_model_name=CONFIG['pretrained_model'],
            num_classes=CONFIG['num_classes'],
            freeze_backbone=CONFIG['freeze_backbone']
        ).to(device)

        # Optimizer (solo para el clasificador inicialmente)
        optimizer = AdamW(
            model.classifier.parameters(),
            lr=CONFIG['base_lr'],
            weight_decay=CONFIG['weight_decay']
        )

        scheduler = ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=CONFIG['scheduler_factor'],
            patience=CONFIG['scheduler_patience'],
            verbose=True
        )

        criterion = nn.CrossEntropyLoss()

        # Entrenar fold
        model, history, final_metrics = train_fold(
            fold_idx=fold_idx,
            train_loader=train_loader,
            val_loader=val_loader,
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            num_epochs=CONFIG['num_epochs'],
            early_stopping_patience=CONFIG['early_stopping_patience'],
            unfreeze_after_epoch=CONFIG['unfreeze_after_epoch']
        )

        # Guardar modelo
        model_path = OUTPUT_DIR / f"{CONFIG['model_name']}_fold{fold_idx}.pth"
        torch.save({
            'fold': fold_idx,
            'model_state_dict': model.state_dict(),
            'config': CONFIG,
            'final_metrics': final_metrics
        }, model_path)
        print(f"\n✓ Model saved: {model_path}")

        # Guardar gráficos
        history_plot_path = METRICS_DIR / f"{CONFIG['model_name']}_fold{fold_idx}_history.png"
        plot_training_history(history, fold_idx, history_plot_path)
        print(f"✓ History plot saved: {history_plot_path}")

        cm_plot_path = METRICS_DIR / f"{CONFIG['model_name']}_fold{fold_idx}_confusion_matrix.png"
        plot_confusion_matrix(final_metrics['confusion_matrix'], fold_idx, cm_plot_path)
        print(f"✓ Confusion matrix saved: {cm_plot_path}")

        # Guardar métricas
        all_fold_metrics.append(final_metrics)
        all_confusion_matrices.append(final_metrics['confusion_matrix'])

        # Limpiar memoria
        del model, optimizer, scheduler, train_loader, val_loader
        torch.cuda.empty_cache()

    # 4. Calcular métricas promedio
    print("\n[4/5] Calculando métricas promedio...")

    metrics_summary = {
        'f1': [m['f1'] for m in all_fold_metrics],
        'precision': [m['precision'] for m in all_fold_metrics],
        'recall': [m['recall'] for m in all_fold_metrics],
        'specificity': [m['specificity'] for m in all_fold_metrics],
        'accuracy': [m['accuracy'] for m in all_fold_metrics]
    }

    # Crear DataFrame con métricas
    metrics_df = pd.DataFrame({
        'Metric': ['F1-Score', 'Precision', 'Recall', 'Specificity', 'Accuracy'],
        'Mean': [
            np.mean(metrics_summary['f1']),
            np.mean(metrics_summary['precision']),
            np.mean(metrics_summary['recall']),
            np.mean(metrics_summary['specificity']),
            np.mean(metrics_summary['accuracy'])
        ],
        'Std': [
            np.std(metrics_summary['f1']),
            np.std(metrics_summary['precision']),
            np.std(metrics_summary['recall']),
            np.std(metrics_summary['specificity']),
            np.std(metrics_summary['accuracy'])
        ]
    })

    # Guardar métricas
    metrics_csv_path = METRICS_DIR / f"{CONFIG['model_name']}_metrics.csv"
    metrics_df.to_csv(metrics_csv_path, index=False)
    print(f"\n✓ Metrics saved: {metrics_csv_path}")

    # Imprimir métricas
    print("\n" + "="*70)
    print("MÉTRICAS PROMEDIO (10-FOLD CROSS VALIDATION)")
    print("="*70)
    print(metrics_df.to_string(index=False))
    print("="*70)

    # 5. Crear visualizaciones finales
    print("\n[5/5] Creando visualizaciones finales...")

    # Matriz de confusión promedio
    mean_cm = np.mean(all_confusion_matrices, axis=0).astype(int)

    plt.figure(figsize=(10, 8))
    sns.heatmap(
        mean_cm,
        annot=True,
        fmt='d',
        cmap='Blues',
        xticklabels=['Benigna', 'Maligna'],
        yticklabels=['Benigna', 'Maligna'],
        cbar_kws={'label': 'Average Count'}
    )
    plt.title(f'Average Confusion Matrix - {CONFIG["model_name"]}',
              fontsize=16, fontweight='bold')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')

    # Add metrics text
    metrics_text = f"""
    Mean Metrics:
    F1-Score: {np.mean(metrics_summary['f1']):.4f} ± {np.std(metrics_summary['f1']):.4f}
    Precision: {np.mean(metrics_summary['precision']):.4f} ± {np.std(metrics_summary['precision']):.4f}
    Recall: {np.mean(metrics_summary['recall']):.4f} ± {np.std(metrics_summary['recall']):.4f}
    Specificity: {np.mean(metrics_summary['specificity']):.4f} ± {np.std(metrics_summary['specificity']):.4f}
    """

    plt.text(
        2.5, 0.5, metrics_text,
        fontsize=10,
        verticalalignment='center',
        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    )

    mean_cm_path = METRICS_DIR / f"{CONFIG['model_name']}_mean_confusion_matrix.png"
    plt.tight_layout()
    plt.savefig(mean_cm_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✓ Mean confusion matrix saved: {mean_cm_path}")

    # Resumen visual completo
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle(f'10-Fold Cross Validation Summary - {CONFIG["model_name"]}',
                 fontsize=16, fontweight='bold')

    # Gráfico 1: F1-Score por fold
    folds = range(1, CONFIG['k_folds'] + 1)
    axes[0, 0].bar(folds, metrics_summary['f1'], color='skyblue', alpha=0.7)
    axes[0, 0].axhline(
        np.mean(metrics_summary['f1']),
        color='red',
        linestyle='--',
        label=f"Mean: {np.mean(metrics_summary['f1']):.4f}"
    )
    axes[0, 0].set_xlabel('Fold')
    axes[0, 0].set_ylabel('F1-Score')
    axes[0, 0].set_title('F1-Score per Fold')
    axes[0, 0].legend()
    axes[0, 0].grid(axis='y', alpha=0.3)

    # Gráfico 2: Todas las métricas por fold
    x = np.arange(CONFIG['k_folds'])
    width = 0.15

    axes[0, 1].bar(x - 2*width, metrics_summary['f1'], width, label='F1', alpha=0.8)
    axes[0, 1].bar(x - width, metrics_summary['precision'], width, label='Precision', alpha=0.8)
    axes[0, 1].bar(x, metrics_summary['recall'], width, label='Recall', alpha=0.8)
    axes[0, 1].bar(x + width, metrics_summary['specificity'], width, label='Specificity', alpha=0.8)
    axes[0, 1].bar(x + 2*width, metrics_summary['accuracy'], width, label='Accuracy', alpha=0.8)

    axes[0, 1].set_xlabel('Fold')
    axes[0, 1].set_ylabel('Score')
    axes[0, 1].set_title('All Metrics per Fold')
    axes[0, 1].set_xticks(x)
    axes[0, 1].set_xticklabels([f'{i+1}' for i in range(CONFIG['k_folds'])])
    axes[0, 1].legend()
    axes[0, 1].grid(axis='y', alpha=0.3)

    # Gráfico 3: Box plot de métricas
    metrics_data = [
        metrics_summary['f1'],
        metrics_summary['precision'],
        metrics_summary['recall'],
        metrics_summary['specificity'],
        metrics_summary['accuracy']
    ]

    bp = axes[1, 0].boxplot(
        metrics_data,
        labels=['F1', 'Precision', 'Recall', 'Specificity', 'Accuracy'],
        patch_artist=True
    )

    for patch in bp['boxes']:
        patch.set_facecolor('lightblue')

    axes[1, 0].set_ylabel('Score')
    axes[1, 0].set_title('Metrics Distribution')
    axes[1, 0].grid(axis='y', alpha=0.3)

    # Gráfico 4: Tabla de resumen
    axes[1, 1].axis('tight')
    axes[1, 1].axis('off')

    summary_text = f"""
    MODEL: {CONFIG['model_name']}

    TRANSFER LEARNING CONFIGURATION:
    • Base Model: {CONFIG['pretrained_model']}
    • Input Size: {CONFIG['input_size']}×{CONFIG['input_size']}
    • Initial Backbone: Frozen
    • Unfreeze After: Epoch {CONFIG['unfreeze_after_epoch']}

    TRAINING:
    • Folds: {CONFIG['k_folds']}
    • Max Epochs: {CONFIG['num_epochs']}
    • Batch Size: {CONFIG['batch_size']}
    • Base LR: {CONFIG['base_lr']:.2e}
    • Backbone LR: {CONFIG['backbone_lr']:.2e}
    • Weight Decay: {CONFIG['weight_decay']}
    • Early Stop Patience: {CONFIG['early_stopping_patience']}

    RESULTS (Mean ± Std):
    • F1-Score:     {np.mean(metrics_summary['f1']):.4f} ± {np.std(metrics_summary['f1']):.4f}
    • Precision:    {np.mean(metrics_summary['precision']):.4f} ± {np.std(metrics_summary['precision']):.4f}
    • Recall:       {np.mean(metrics_summary['recall']):.4f} ± {np.std(metrics_summary['recall']):.4f}
    • Specificity:  {np.mean(metrics_summary['specificity']):.4f} ± {np.std(metrics_summary['specificity']):.4f}
    • Accuracy:     {np.mean(metrics_summary['accuracy']):.4f} ± {np.std(metrics_summary['accuracy']):.4f}

    ADVANTAGES OF TRANSFER LEARNING:
    ✓ Pre-trained on ImageNet-21k
    ✓ Faster convergence
    ✓ Better generalization
    ✓ Requires less data
    """

    axes[1, 1].text(
        0.1, 0.5, summary_text,
        transform=axes[1, 1].transAxes,
        fontsize=9,
        verticalalignment='center',
        fontfamily='monospace',
        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3)
    )

    summary_plot_path = METRICS_DIR / f"{CONFIG['model_name']}_summary.png"
    plt.tight_layout()
    plt.savefig(summary_plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✓ Summary plot saved: {summary_plot_path}")

    # Guardar configuración completa
    config_path = OUTPUT_DIR / f"{CONFIG['model_name']}_config.json"
    with open(config_path, 'w') as f:
        json.dump(CONFIG, f, indent=4)
    print(f"✓ Configuration saved: {config_path}")

    print("\n" + "="*70)
    print("✓ ENTRENAMIENTO COMPLETADO - TRANSFER LEARNING")
    print("="*70)
    print(f"\nModelos guardados en: {OUTPUT_DIR}")
    print(f"Métricas guardadas en: {METRICS_DIR}")
    print("\nArchivos generados:")
    print(f"  • {CONFIG['k_folds']} modelos (.pth)")
    print(f"  • {CONFIG['k_folds']} gráficos de historial")
    print(f"  • {CONFIG['k_folds']} matrices de confusión")
    print(f"  • 1 tabla de métricas (CSV)")
    print(f"  • 1 matriz de confusión promedio")
    print(f"  • 1 resumen visual completo")
    print(f"  • 1 configuración (JSON)")
    print("="*70)

    # Comparación con modelo desde cero (si existe)
    vit_base_metrics_path = METRICS_DIR / "ViT_0_Base_metrics.csv"
    if vit_base_metrics_path.exists():
        print("\n" + "="*70)
        print("COMPARACIÓN: ViT desde cero vs Transfer Learning")
        print("="*70)

        vit_base_df = pd.read_csv(vit_base_metrics_path)

        comparison_df = pd.DataFrame({
            'Metric': metrics_df['Metric'],
            'ViT_0_Base (Mean)': vit_base_df['Mean'],
            'ViT_TL_Base (Mean)': metrics_df['Mean'],
            'Improvement': metrics_df['Mean'] - vit_base_df['Mean']
        })

        print(comparison_df.to_string(index=False))
        print("="*70)

        # Guardar comparación
        comparison_path = METRICS_DIR / "comparison_ViT_vs_ViT_TL.csv"
        comparison_df.to_csv(comparison_path, index=False)
        print(f"\n✓ Comparison saved: {comparison_path}")

In [None]:
# ============================================================================
# 8. EJECUTAR ENTRENAMIENTO
# ============================================================================

if __name__ == "__main__":
    main()