In [None]:
# %%
# 1) Imports y setup
import os, json, math, random, warnings
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

from torchvision import transforms
import torchvision.transforms.functional as TF

from transformers import ViTModel
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix

warnings.filterwarnings("ignore")

# Reproducibilidad
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed); torch.backends.cudnn.deterministic = False
set_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)
if device.type == 'cuda':
    print("GPU:", torch.cuda.get_device_name(0))
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'


In [None]:
# %%
# 2) Paths y CONFIG
BASE_DIR   = Path("/home/merivadeneira")
MASAS_DIR  = BASE_DIR / "Masas"
OUTPUT_DIR = BASE_DIR / "Outputs" / "ViT"
METRICS_DIR= BASE_DIR / "Metrics" / "ViT"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
METRICS_DIR.mkdir(parents=True, exist_ok=True)

CONFIG = {
    'model_name': 'ViT_TL_1',
    'pretrained_model': 'google/vit-base-patch16-224',
    'input_size': 224,
    'num_classes': 2,
    # entrenamiento
    'k_folds': 5,
    'batch_size': 32,
    'max_epochs': 60,
    'early_stopping_patience': 15,      # monitor F1
    'freeze_epochs': 5,                 # solo cabeza
    'unfreeze_after_epoch': 5,          # alias
    # optim
    'base_lr': 1e-4,                    # cabeza
    'backbone_lr': 1e-5,                # al descongelar
    'weight_decay': 0.05,
    'grad_clip_norm': 1.0,
    # augments
    'rotation_degrees': 7,
    'use_flip': False,                  # True si lateridad no importa
    # eval
    'use_tta': False,                   # se puede activar si quieres
    # normalización (ImageNet)
    'mean': [0.485, 0.456, 0.406],
    'std' : [0.229, 0.224, 0.225],
}
print(json.dumps(CONFIG, indent=2))


In [None]:
# %%
# 3) Utilidades: carga de rutas y splits por paciente + dominio
def extract_patient_id(filename, database):
    if database == 'DDSM':
        parts = filename.split('_')
        if len(parts) >= 2:
            return f"DDSM_{parts[1]}"
    elif database == 'INbreast':
        parts = filename.split('_')
        if len(parts) >= 1:
            return f"INbreast_{parts[0]}"
    return filename

def load_image_paths_with_patient_ids():
    rows = []
    for db in ['DDSM','INbreast']:
        for label_name, label_val in [('Benignas',0), ('Malignas',1)]:
            p = MASAS_DIR / db / label_name / "Resized_512"
            if not p.exists(): continue
            for img_file in p.glob("*.png"):
                pid = extract_patient_id(img_file.name, db)
                rows.append({'image_path': str(img_file),
                             'patient_id': pid,
                             'label': label_val,
                             'database': db})
    df = pd.DataFrame(rows)
    print(f"Total imágenes: {len(df)} | Pacientes únicos: {df.patient_id.nunique()}")
    print(df.groupby(['database','label']).size().unstack(fill_value=0))
    return df

def create_patient_level_splits_stratified(data_df, k_folds=5, random_state=42):
    # label por paciente + dominio mayoritario por paciente
    grp = data_df.groupby('patient_id').agg(
        label=('label', lambda x: int(x.mode()[0] if len(x.mode()) else x.iloc[0])),
        database=('database', lambda x: x.mode()[0])
    ).reset_index()
    grp['strata'] = grp['label'].astype(str) + "_" + grp['database']
    skf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=random_state)

    splits = []
    for tr_idx, va_idx in skf.split(grp['patient_id'], grp['strata']):
        tr_pat = set(grp.iloc[tr_idx]['patient_id'])
        va_pat = set(grp.iloc[va_idx]['patient_id'])
        train_idx = data_df.index[data_df['patient_id'].isin(tr_pat)].tolist()
        val_idx   = data_df.index[data_df['patient_id'].isin(va_pat)].tolist()
        splits.append((train_idx, val_idx))
        print(f"Fold -> train imgs: {len(train_idx)} | val imgs: {len(val_idx)}")
    return splits


In [None]:
# %%
# 4) Dataset y transforms
class MammographyDataset(Dataset):
    def __init__(self, df, indices, transform):
        self.df = df.iloc[indices].reset_index(drop=True)
        self.t  = transform
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        p = self.df.loc[i, 'image_path']
        y = int(self.df.loc[i, 'label'])
        img = Image.open(p).convert('L')
        # Grayscale -> 3 canales
        img = img.convert('RGB')
        x = self.t(img)
        return x, y

def make_transforms(cfg, train=True):
    if train:
        aug = [
            transforms.Grayscale(num_output_channels=3),
            transforms.RandomResizedCrop(
                size=cfg['input_size'],
                scale=(0.90, 1.0),
                ratio=(0.95, 1.05)
            ),
            transforms.RandomRotation(cfg['rotation_degrees']),
        ]
        if cfg['use_flip']:
            aug.append(transforms.RandomHorizontalFlip(p=0.5))
        aug += [
            transforms.ToTensor(),
            transforms.Normalize(mean=cfg['mean'], std=cfg['std'])
        ]
        return transforms.Compose(aug)
    else:
        return transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize((cfg['input_size'], cfg['input_size'])),
            transforms.ToTensor(),
            transforms.Normalize(mean=cfg['mean'], std=cfg['std'])
        ])


In [None]:
# %%
# 5) Modelo: ViT + pooling híbrido (CLS + mean)
class ViTForMammography(nn.Module):
    def __init__(self, pretrained_model_name, num_classes=2, freeze_backbone=True):
        super().__init__()
        self.vit = ViTModel.from_pretrained(pretrained_model_name)
        if freeze_backbone:
            for p in self.vit.parameters():
                p.requires_grad = False
        h = self.vit.config.hidden_size  # 768 en vit-base
        self.proj = nn.Linear(h*2, h)    # fusion CLS+MEAN
        self.classifier = nn.Sequential(
            nn.LayerNorm(h),
            nn.Dropout(0.2),
            nn.Linear(h, int(h*0.75)),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(int(h*0.75), num_classes)
        )

    def unfreeze_backbone(self):
        for p in self.vit.parameters():
            p.requires_grad = True

    def forward(self, x):
        out = self.vit(pixel_values=x)
        cls = out.last_hidden_state[:, 0]              # [B, h]
        mean = out.last_hidden_state[:, 1:].mean(1)    # [B, h]
        feat = torch.cat([cls, mean], dim=1)
        feat = self.proj(feat)
        logits = self.classifier(feat)
        return logits


In [None]:
# %%
# 6) Métricas, pérdida con pesos, TTA opcional, validación y utilidades
def compute_metrics(y_true, y_prob, thr=0.5):
    y_pred = (y_prob >= thr).astype(int)
    acc  = (y_pred == y_true).mean()
    f1   = f1_score(y_true, y_pred, zero_division=0)
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec  = recall_score(y_true, y_pred, zero_division=0)
    cm   = confusion_matrix(y_true, y_pred, labels=[0,1])
    tn, fp, fn, tp = cm.ravel()
    spec = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    return dict(accuracy=acc, f1=f1, precision=prec, recall=rec,
                specificity=spec, cm=cm, tn=tn, fp=fp, fn=fn, tp=tp)

@torch.no_grad()
def infer_logits(model, loader, use_tta=False):
    model.eval()
    all_logits, all_labels = [], []
    for x, y in loader:
        x = x.to(device); y = y.numpy()
        if not use_tta:
            with autocast(enabled=(device.type=='cuda')):
                logits = model(x)
        else:
            # TTA: original, hflip, vflip (promedio de logits)
            logits_list = []
            for variant in ['orig','h','v']:
                x_i = x
                if variant == 'h': x_i = torch.flip(x, dims=[-1])
                if variant == 'v': x_i = torch.flip(x, dims=[-2])
                with autocast(enabled=(device.type=='cuda')):
                    logits_list.append(model(x_i))
            logits = torch.stack(logits_list, dim=0).mean(0)
        all_logits.append(logits.cpu())
        all_labels.append(y)
    all_logits = torch.cat(all_logits, dim=0)
    all_labels = np.concatenate(all_labels, axis=0)
    probs = all_logits.softmax(1)[:,1].numpy()
    return probs, all_labels

def find_best_threshold(y_true, y_prob):
    # barrido fino de 0.05 a 0.95
    thrs = np.linspace(0.05, 0.95, 37)
    f1s = [f1_score(y_true, (y_prob>=t).astype(int), zero_division=0) for t in thrs]
    i = int(np.argmax(f1s))
    return float(thrs[i]), float(f1s[i])

def make_class_weights(df_indices, data_df):
    y = data_df.iloc[df_indices]['label'].values
    cls_vals, counts = np.unique(y, return_counts=True)
    weights = {c: 1.0/max(cnt,1) for c, cnt in zip(cls_vals, counts)}
    w = np.array([weights[val] for val in y], dtype=np.float32)
    # normalizar
    w = w / w.sum() * len(w)
    return torch.as_tensor(w, dtype=torch.float32)


In [None]:
# %%
# 7) Entrenamiento por fold (AMP + clipping + early stop por F1)
class EarlyStopping:
    def __init__(self, patience=15, mode='max'):
        self.patience = patience
        self.mode = mode
        self.best = -np.inf if mode=='max' else np.inf
        self.count = 0
        self.stop = False
    def __call__(self, value):
        improved = (value > self.best) if self.mode=='max' else (value < self.best)
        if improved:
            self.best = value; self.count = 0
        else:
            self.count += 1
            if self.count >= self.patience: self.stop = True
        return self.stop

def train_one_epoch(model, loader, criterion, optimizer, scaler, grad_clip):
    model.train()
    total_loss, total_correct, total = 0.0, 0, 0
    for x, y in tqdm(loader, leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad(set_to_none=True)
        with autocast(enabled=(device.type=='cuda')):
            logits = model(x)
            loss = criterion(logits, y)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        if grad_clip is not None:
            nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()*x.size(0)
        total += y.size(0)
        total_correct += (logits.argmax(1)==y).sum().item()
    return total_loss/total, total_correct/total


In [None]:
# %%
# 8) Bucle de entrenamiento de un fold (incluye mejor modelo por F1)
def train_fold(fold_idx, train_idx, val_idx, data_df, cfg):
    print(f"\n========== FOLD {fold_idx+1} ==========")

    # datasets y dataloaders
    t_train = make_transforms(cfg, train=True)
    t_val   = make_transforms(cfg, train=False)
    ds_tr = MammographyDataset(data_df, train_idx, t_train)
    ds_va = MammographyDataset(data_df, val_idx, t_val)

    # sampler ponderado por clase
    weights = make_class_weights(train_idx, data_df)
    sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

    dl_tr = DataLoader(ds_tr, batch_size=cfg['batch_size'], sampler=sampler,
                       num_workers=4, pin_memory=True, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=cfg['batch_size'], shuffle=False,
                       num_workers=4, pin_memory=True, drop_last=False)

    # modelo
    model = ViTForMammography(pretrained_model_name=cfg['pretrained_model'],
                              num_classes=cfg['num_classes'] if 'num_classes' in cfg else 2,
                              freeze_backbone=True).to(device)

    # optim: primero solo cabeza
    head_params = list(model.proj.parameters()) + list(model.classifier.parameters())
    optimizer = AdamW(head_params, lr=cfg['base_lr'], weight_decay=cfg['weight_decay'])
    scaler = GradScaler(enabled=(device.type=='cuda'))

    # pérdida con pesos (globales del fold)
    y_tr = data_df.iloc[train_idx]['label'].values
    cls_counts = np.bincount(y_tr, minlength=2).astype(float)
    inv = 1.0 / np.maximum(cls_counts, 1.0)
    class_weights_ce = torch.tensor(inv / inv.sum() * 2.0, dtype=torch.float32, device=device)
    criterion = nn.CrossEntropyLoss(weight=class_weights_ce)

    history = {'train_loss':[], 'train_acc':[], 'val_f1':[], 'val_acc':[]}
    es = EarlyStopping(patience=cfg['early_stopping_patience'], mode='max')
    best_state, best_f1 = None, -1.0
    best_epoch = 0

    for epoch in range(cfg['max_epochs']):
        print(f"\nEpoch {epoch+1}/{cfg['max_epochs']}")
        # unfreeze a partir de freeze_epochs
        if epoch == cfg['unfreeze_after_epoch']:
            print(">> Descongelando backbone para fine-tuning")
            model.unfreeze_backbone()
            # nuevo optimizer: backbone LR menor
            optimizer = AdamW([
                {'params': model.vit.parameters(), 'lr': cfg['backbone_lr']},
                {'params': model.proj.parameters(), 'lr': cfg['base_lr']},
                {'params': model.classifier.parameters(), 'lr': cfg['base_lr']},
            ], weight_decay=cfg['weight_decay'])

        tr_loss, tr_acc = train_one_epoch(model, dl_tr, criterion, optimizer, scaler, cfg['grad_clip_norm'])
        history['train_loss'].append(tr_loss); history['train_acc'].append(tr_acc)

        # validación (umbral 0.5 para early stop)
        probs, y_true = infer_logits(model, dl_va, use_tta=cfg['use_tta'])
        mets = compute_metrics(y_true, probs, thr=0.5)
        history['val_f1'].append(mets['f1']); history['val_acc'].append(mets['accuracy'])

        print(f"Train: loss {tr_loss:.4f} | acc {tr_acc:.4f}")
        print(f"Val  : F1 {mets['f1']:.4f} | acc {mets['accuracy']:.4f} | prec {mets['precision']:.4f} | rec {mets['recall']:.4f} | spec {mets['specificity']:.4f}")

        if mets['f1'] > best_f1:
            best_f1 = mets['f1']; best_epoch = epoch+1
            best_state = {k:v.cpu().clone() for k,v in model.state_dict().items()}
            print("✓ Nuevo mejor modelo (por F1)")

        if es(mets['f1']):
            print(f"✓ Early stopping (paciencia agotada). Último mejor F1 en epoch {best_epoch}")
            break

    # cargar mejor estado
    if best_state is not None:
        model.load_state_dict(best_state)

    # métricas finales con umbral óptimo (y también con 0.5 si quieres guardar)
    probs_val, y_val = infer_logits(model, dl_va, use_tta=cfg['use_tta'])
    thr_opt, f1_opt = find_best_threshold(y_val, probs_val)
    mets_val = compute_metrics(y_val, probs_val, thr=thr_opt); mets_val['threshold']=thr_opt

    # métricas en train (mismo umbral óptimo del val por consistencia)
    probs_tr, y_tr_full = infer_logits(model, dl_tr, use_tta=cfg['use_tta'])
    mets_tr = compute_metrics(y_tr_full, probs_tr, thr=thr_opt); mets_tr['threshold']=thr_opt

    # guardar modelo del fold
    model_path = OUTPUT_DIR / f"{CONFIG['model_name']}_fold{fold_idx}.pth"
    torch.save({'fold': fold_idx, 'state_dict': model.state_dict(), 'config': cfg,
                'threshold': thr_opt, 'best_epoch': best_epoch}, model_path)
    print("Modelo guardado:", model_path)

    return model, history, mets_tr, mets_val


In [None]:
# %%
# 9) Orquestador K-Fold + escritura del CSV ViT_TL_1_metrics
def main():
    print("\n=== INICIO ENTRENAMIENTO: ViT_TL_1 ===")
    data_df = load_image_paths_with_patient_ids()
    splits = create_patient_level_splits_stratified(data_df, k_folds=CONFIG['k_folds'])

    # DataFrame de métricas por fold/split
    rows = []
    all_val_cms = []

    for f, (tr_idx, va_idx) in enumerate(splits):
        torch.cuda.empty_cache()
        model, hist, mets_tr, mets_va = train_fold(f, tr_idx, va_idx, data_df, CONFIG)

        def row_from_metrics(split, m, epochs):
            return {
                'fold': f+1, 'split': split,
                'loss': np.nan,                  # opcional si quieres registrar otra cosa
                'accuracy': m['accuracy'],
                'precision': m['precision'],
                'recall': m['recall'],
                'specificity': m['specificity'],
                'f1': m['f1'],
                'threshold_used': m.get('threshold', 0.5),
                'tn': m['tn'], 'fp': m['fp'], 'fn': m['fn'], 'tp': m['tp'],
                'samples': (m['tn']+m['fp']+m['fn']+m['tp']),
                'epochs_trained': epochs,
                'backbone_unfreeze_epoch': CONFIG['unfreeze_after_epoch'],
                'tta': CONFIG['use_tta']
            }

        rows.append(row_from_metrics('train', mets_tr, epochs=len(hist['train_loss'])))
        rows.append(row_from_metrics('val',   mets_va, epochs=len(hist['train_loss'])))
        all_val_cms.append(mets_va['cm'])

    df = pd.DataFrame(rows)
    csv_path = METRICS_DIR / f"{CONFIG['model_name']}_metrics.csv"
    df.to_csv(csv_path, index=False)
    print("CSV de métricas por fold guardado en:", csv_path)

    # promedios por split
    summary = df.groupby('split')[['accuracy','precision','recall','specificity','f1']].agg(['mean','std'])
    summary_path = METRICS_DIR / f"{CONFIG['model_name']}_metrics_summary.csv"
    summary.to_csv(summary_path)
    print("Resumen (mean ± std) guardado en:", summary_path)

    # añadir filas MEAN al CSV principal (al final)
    mean_rows = []
    for split in ['train','val']:
        sub = df[df['split']==split]
        mean_rows.append({
            'fold':'MEAN','split':split,
            'loss':np.nan,
            'accuracy':sub['accuracy'].mean(),
            'precision':sub['precision'].mean(),
            'recall':sub['recall'].mean(),
            'specificity':sub['specificity'].mean(),
            'f1':sub['f1'].mean(),
            'threshold_used':sub['threshold_used'].mean(),
            'tn':sub['tn'].mean(),'fp':sub['fp'].mean(),
            'fn':sub['fn'].mean(),'tp':sub['tp'].mean(),
            'samples':sub['samples'].mean(),
            'epochs_trained':sub['epochs_trained'].mean(),
            'backbone_unfreeze_epoch':CONFIG['unfreeze_after_epoch'],
            'tta':CONFIG['use_tta']
        })
    df_mean_appended = pd.concat([df, pd.DataFrame(mean_rows)], ignore_index=True)
    df_mean_appended.to_csv(csv_path, index=False)
    print("CSV actualizado con filas MEAN.")

    # matriz de confusión promedio (validación)
    mean_cm = np.mean(np.stack(all_val_cms, axis=0), axis=0)
    plt.figure(figsize=(7,6))
    sns.heatmap(mean_cm, annot=True, fmt='.1f', cmap='Blues',
                xticklabels=['Benigna','Maligna'], yticklabels=['Benigna','Maligna'])
    plt.title(f'Average Confusion Matrix (VAL) — {CONFIG["model_name"]}')
    out_cm = METRICS_DIR / f"{CONFIG['model_name']}_mean_confusion_matrix.png"
    plt.tight_layout(); plt.savefig(out_cm, dpi=300); plt.close()
    print("Matriz de confusión promedio guardada en:", out_cm)

    # guardar config
    with open(OUTPUT_DIR / f"{CONFIG['model_name']}_config.json",'w') as f:
        json.dump(CONFIG, f, indent=2)

    print("\n=== ENTRENAMIENTO COMPLETADO ===")

if __name__ == "__main__":
    main()
