In [None]:
# %%
!pip install -q transformers datasets pillow timm

print("✓ Dependencias instaladas")


In [None]:
# %%
import os, gc, json, time, warnings, math, random
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler

from torchvision import transforms
from PIL import Image

from transformers import CvtModel, CvtConfig
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score,
                             confusion_matrix, roc_curve, auc, precision_recall_curve)
warnings.filterwarnings("ignore")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)
if device.type == 'cuda':
    print("GPU:", torch.cuda.get_device_name(0))
    torch.cuda.empty_cache(); gc.collect()

# Configuración principal (ajusta rutas si hace falta)
CONFIG = {
    # Paths
    'ddsm_benign_path': '/home/merivadeneira/Masas/DDSM/Benignas/Resized_512',
    'ddsm_malign_path': '/home/merivadeneira/Masas/DDSM/Malignas/Resized_512',
    'inbreast_benign_path': '/home/merivadeneira/Masas/INbreast/Benignas/Resized_512',
    'inbreast_malign_path': '/home/merivadeneira/Masas/INbreast/Malignas/Resized_512',
    'output_dir': '/home/merivadeneira/Outputs/CvT',
    'metrics_dir': '/home/merivadeneira/Metrics/CvT',

    # Modelo
    'model_name': 'CvT_TL_1_HF',
    'pretrained_model': 'microsoft/cvt-13',
    'input_size': 512,
    'num_classes': 2,

    # Entrenamiento
    'batch_size': 16,            # sube a 32 si la GPU lo permite
    'num_epochs': 100,
    'num_folds': 5,
    'early_stopping_patience': 12,
    'min_delta': 1e-4,
    'use_amp': True,

    # LRs por grupos
    'lr_head': 3e-4,
    'lr_last': 1e-4,
    'lr_rest': 3e-5,

    # Warmup + unfreezing
    'warmup_epochs': 3,

    # Scheduler (cosine restarts)
    'eta_min': 1e-6,
    'T_0': 10,
    'T_mult': 1,

    # Optimizer
    'weight_decay': 0.01,
    'betas': (0.9, 0.999),

    # Label smoothing y focal toggle
    'label_smoothing': 0.1,
    'use_focal': False,
    'focal_gamma': 1.5,

    # Aumentaciones
    'hflip_p': 0.5,
    'vflip_p': 0.0,          # desactivado
    'rot_deg': 10,
    'brightness': 0.1,
    'contrast': 0.1,
    'random_erasing_p': 0.1,
    'mixup_p': 0.2,
    'cutmix_p': 0.2,

    # TTA
    'tta_do': True,

    # Reproducibilidad
    'seed': 42,
    'num_workers': 4,
    'pin_memory': True,
}

os.makedirs(CONFIG['output_dir'], exist_ok=True)
os.makedirs(CONFIG['metrics_dir'], exist_ok=True)
with open(os.path.join(CONFIG['metrics_dir'], f"{CONFIG['model_name']}_config.json"), "w") as f:
    json.dump(CONFIG, f, indent=2)

def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
set_seed(CONFIG['seed'])
print("✓ Config listo")


In [None]:
# %%
def extract_patient_id(filename, dataset='ddsm'):
    if dataset == 'ddsm':
        parts = filename.split('_')
        if len(parts) >= 2:
            return f"{parts[0]}_{parts[1]}"
    elif dataset == 'inbreast':
        return filename.split('_')[0]
    return filename

def load_dataset():
    image_paths, labels, patient_ids = [], [], []
    datasets = [
        (CONFIG['ddsm_benign_path'], 0, 'ddsm'),
        (CONFIG['ddsm_malign_path'], 1, 'ddsm'),
        (CONFIG['inbreast_benign_path'], 0, 'inbreast'),
        (CONFIG['inbreast_malign_path'], 1, 'inbreast'),
    ]
    for path, label, name in datasets:
        if not os.path.exists(path):
            print("WARNING path not found:", path); continue
        files = [f for f in os.listdir(path) if f.lower().endswith(('.png','.jpg','.jpeg'))]
        for fn in files:
            image_paths.append(os.path.join(path, fn))
            labels.append(label)
            patient_ids.append(extract_patient_id(fn, name))
    return image_paths, labels, patient_ids

class MammographyDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform
    def __len__(self): return len(self.paths)
    def __getitem__(self, i):
        img = Image.open(self.paths[i]).convert('L').convert('RGB')  # gris→RGB
        if self.transform: img = self.transform(img)
        return img, self.labels[i]

mean = [0.485, 0.456, 0.406]; std = [0.229, 0.224, 0.225]

def get_transforms(train=True):
    if train:
        return transforms.Compose([
            transforms.RandomHorizontalFlip(p=CONFIG['hflip_p']),
            transforms.RandomRotation(CONFIG['rot_deg'], fill=0),
            transforms.ColorJitter(brightness=CONFIG['brightness'], contrast=CONFIG['contrast']),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
            transforms.RandomErasing(p=CONFIG['random_erasing_p']),
        ])
    else:
        return transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

image_paths, labels, patient_ids = load_dataset()
print(f"Total imágenes: {len(image_paths)} | Benign: {labels.count(0)} | Malignant: {labels.count(1)} | Pacientes: {len(set(patient_ids))}")


In [None]:
# %%
class GeM(nn.Module):
    def __init__(self, p=3.0, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps
    def forward(self, x):
        # x: (B, C, H, W) o (B, L, C) -> convertimos a BCHW si hace falta
        if x.dim() == 3:  # (B, L, C) -> (B, C, L, 1)
            x = x.permute(0,2,1).unsqueeze(-1)
        x = torch.clamp(x, min=self.eps)
        x = x.pow(self.p).mean(dim=(-1,-2)).pow(1./self.p)  # (B,C)
        return x

class FocalLoss(nn.Module):
    def __init__(self, weight=None, gamma=1.5, reduction='mean'):
        super().__init__()
        self.weight = weight
        self.gamma = gamma
        self.reduction = reduction
    def forward(self, logits, targets):
        ce = F.cross_entropy(logits, targets, weight=self.weight, reduction='none')
        pt = torch.exp(-ce)
        loss = (1-pt)**self.gamma * ce
        return loss.mean() if self.reduction=='mean' else loss.sum()

class ModelEmaV2:
    def __init__(self, model, decay=0.999):
        self.ema = type(model)()
        self.ema.load_state_dict(model.state_dict())
        self.ema.to(next(model.parameters()).device)
        self.decay = decay
        for p in self.ema.parameters(): p.requires_grad_(False)
    @torch.no_grad()
    def update(self, model):
        msd = model.state_dict()
        for k, v in self.ema.state_dict().items():
            if v.dtype.is_floating_point:
                v.copy_(v * self.decay + msd[k] * (1. - self.decay))
            else:
                v.copy_(msd[k])

def rand_bbox(W, H, lam):
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat); cut_h = int(H * cut_rat)
    cx = np.random.randint(W); cy = np.random.randint(H)
    x1 = np.clip(cx - cut_w // 2, 0, W); y1 = np.clip(cy - cut_h // 2, 0, H)
    x2 = np.clip(cx + cut_w // 2, 0, W); y2 = np.clip(cy + cut_h // 2, 0, H)
    return x1, y1, x2, y2

def apply_mixup_cutmix(inputs, targets):
    B, C, H, W = inputs.shape
    # Decide si aplicar algo
    r = random.random()
    if r < CONFIG['mixup_p']:
        lam = np.random.beta(0.4, 0.4)
        perm = torch.randperm(B, device=inputs.device)
        mixed = lam * inputs + (1-lam) * inputs[perm]
        return mixed, targets, targets[perm], lam, 'mixup'
    elif r < CONFIG['mixup_p'] + CONFIG['cutmix_p']:
        lam = np.random.beta(1.0, 1.0)
        perm = torch.randperm(B, device=inputs.device)
        x1,y1,x2,y2 = rand_bbox(W,H,lam)
        mixed = inputs.clone()
        mixed[:, :, y1:y2, x1:x2] = inputs[perm, :, y1:y2, x1:x2]
        lam = 1 - ((x2-x1)*(y2-y1) / (W*H))
        return mixed, targets, targets[perm], lam, 'cutmix'
    else:
        return inputs, targets, targets, 1.0, None


In [None]:
# %%
class CvTForImageClassification(nn.Module):
    def __init__(self, model_name='microsoft/cvt-13', num_classes=2, image_size=512, pretrained=True):
        super().__init__()
        cfg = CvtConfig.from_pretrained(model_name)
        cfg.image_size = image_size
        self.cvt = CvtModel.from_pretrained(model_name, config=cfg, ignore_mismatched_sizes=True) if pretrained else CvtModel(cfg)

        hidden = self.cvt.config.embed_dim[-1]  # 384 en cvt-13
        self.pool = GeM()
        self.norm = nn.LayerNorm(hidden)
        self.mlp = nn.Sequential(
            nn.Linear(hidden, hidden//2),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(hidden//2, num_classes)
        )

    def forward(self, x):
        out = self.cvt(pixel_values=x, return_dict=True)
        feats = out.last_hidden_state  # (B, L, C)
        x = self.pool(feats)           # (B, C)
        x = self.norm(x)
        logits = self.mlp(x)
        return logits


In [None]:
# %%
def build_param_groups(model):
    head = []
    last = []
    rest = []
    for n, p in model.named_parameters():
        if not p.requires_grad: continue
        if 'mlp' in n or 'norm' in n and 'mlp' in n:  # head
            head.append(p)
        elif 'stages' in n and '.2.' in n:  # último stage
            last.append(p)
        else:
            rest.append(p)
    return [
        {'params': rest, 'lr': CONFIG['lr_rest']},
        {'params': last, 'lr': CONFIG['lr_last']},
        {'params': head, 'lr': CONFIG['lr_head']},
    ]

def compute_class_weights(y):
    y = np.array(y)
    n0 = (y==0).sum(); n1 = (y==1).sum()
    n = len(y); w0 = n/(2*n0) if n0>0 else 1.0; w1 = n/(2*n1) if n1>0 else 1.0
    return torch.tensor([w0, w1], dtype=torch.float32, device=device)

def metrics_from_preds(y_true, y_prob, thr=0.5):
    y_pred = (y_prob >= thr).astype(int)
    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    cm = confusion_matrix(y_true, y_pred)
    spec = (cm[0,0]/(cm[0,0]+cm[0,1])) if (cm.shape==(2,2) and (cm[0,0]+cm[0,1])>0) else 0.0
    return acc, prec, rec, f1, spec, cm

@torch.no_grad()
def predict_loader(model, loader, tta=False):
    model.eval()
    all_prob, all_t = [], []
    for x, t in loader:
        x = x.to(device); t = torch.as_tensor(t, device=device)
        if tta and CONFIG['tta_do']:
            # identidad
            logits = model(x)
            # hflip
            logits += model(torch.flip(x, dims=[3]))
            # rotate +5°
            rot1 = transforms.functional.rotate(x, 5, fill=0)
            logits += model(rot1)
            # rotate -5°
            rot2 = transforms.functional.rotate(x, -5, fill=0)
            logits += model(rot2)
            logits = logits / 4.0
        else:
            logits = model(x)
        prob1 = torch.softmax(logits, dim=1)[:,1].detach().cpu().numpy()
        all_prob.append(prob1); all_t.append(t.cpu().numpy())
    return np.concatenate(all_prob), np.concatenate(all_t)

def train_one_epoch(model, loader, optimizer, scaler, criterion, use_mix=True):
    model.train()
    running = 0.0; total=0
    for x, y in loader:
        x = x.to(device); y = torch.as_tensor(y, device=device)

        optimizer.zero_grad(set_to_none=True)

        if use_mix:
            x, y1, y2, lam, kind = apply_mixup_cutmix(x, y)
        else:
            y1, y2, lam, kind = y, y, 1.0, None

        with autocast(enabled=CONFIG['use_amp']):
            logits = model(x)
            if kind is None:
                loss = criterion(logits, y1)
            else:
                loss = lam*criterion(logits, y1) + (1-lam)*criterion(logits, y2)

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer); scaler.update()

        running += loss.item() * x.size(0); total += x.size(0)
    return running/total


In [None]:
# %%
class EarlyStopping:
    def __init__(self, patience=12, min_delta=1e-4):
        self.patience = patience; self.min_delta = min_delta
        self.best = None; self.count = 0; self.stop = False
    def step(self, val):
        if self.best is None or val < self.best - self.min_delta:
            self.best = val; self.count = 0
        else:
            self.count += 1
            if self.count >= self.patience: self.stop = True

def plot_cm(cm, title, path):
    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Benign','Malignant'], yticklabels=['Benign','Malignant'])
    plt.title(title); plt.ylabel('True'); plt.xlabel('Pred')
    plt.tight_layout(); plt.savefig(path, dpi=150); plt.close()

def train_cv():
    # Agrupar por paciente
    p2idx, p2lab = defaultdict(list), {}
    for i,(pid,lab) in enumerate(zip(patient_ids, labels)):
        p2idx[pid].append(i); p2lab[pid]=lab
    up = list(p2idx.keys())
    y_up = [p2lab[p] for p in up]

    skf = StratifiedKFold(n_splits=CONFIG['num_folds'], shuffle=True, random_state=CONFIG['seed'])

    fold_metrics_val, fold_metrics_train, cms = [], [], []

    # Para CSV historia por fold
    history_rows = []

    for fold,(tr_idx, va_idx) in enumerate(skf.split(up, y_up), start=1):
        tr_p = [up[i] for i in tr_idx]; va_p = [up[i] for i in va_idx]
        tr_ids = sum([p2idx[p] for p in tr_p], [])
        va_ids = sum([p2idx[p] for p in va_p], [])

        tr_paths = [image_paths[i] for i in tr_ids]
        va_paths = [image_paths[i] for i in va_ids]
        tr_y = [labels[i] for i in tr_ids]
        va_y = [labels[i] for i in va_ids]

        ds_tr = MammographyDataset(tr_paths, tr_y, get_transforms(True))
        ds_va = MammographyDataset(va_paths, va_y, get_transforms(False))
        ld_tr = DataLoader(ds_tr, batch_size=CONFIG['batch_size'], shuffle=True,
                           num_workers=CONFIG['num_workers'], pin_memory=CONFIG['pin_memory'])
        ld_va = DataLoader(ds_va, batch_size=CONFIG['batch_size'], shuffle=False,
                           num_workers=CONFIG['num_workers'], pin_memory=CONFIG['pin_memory'])

        model = CvTForImageClassification(CONFIG['pretrained_model'], CONFIG['num_classes'],
                                          CONFIG['input_size'], pretrained=True).to(device)

        # Congelar backbone en warmup
        for n,p in model.named_parameters():
            if 'mlp' in n: p.requires_grad_(True)
            else: p.requires_grad_(False)

        groups = build_param_groups(model)
        optimizer = AdamW(groups, weight_decay=CONFIG['weight_decay'], betas=CONFIG['betas'])
        scaler = GradScaler(enabled=CONFIG['use_amp'])
        # scheduler con cosine restarts
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=CONFIG['T_0'], T_mult=CONFIG['T_mult'], eta_min=CONFIG['eta_min'])

        # Loss con weights del fold
        cls_weights = compute_class_weights(tr_y)
        if CONFIG['use_focal']:
            criterion = FocalLoss(weight=cls_weights, gamma=CONFIG['focal_gamma'])
        else:
            criterion = nn.CrossEntropyLoss(weight=cls_weights, label_smoothing=CONFIG['label_smoothing'])

        ema = ModelEmaV2(model, decay=0.999)
        early = EarlyStopping(CONFIG['early_stopping_patience'], CONFIG['min_delta'])

        best_val = float('inf'); best_path = os.path.join(CONFIG['output_dir'], f"{CONFIG['model_name']}_fold{fold}.pth")

        for epoch in range(1, CONFIG['num_epochs']+1):
            # Warmup fases: primeras epochs solo head; luego unfreeze parcial
            if epoch == CONFIG['warmup_epochs']+1:
                for n,p in model.named_parameters():
                    if 'stages' in n and '.2.' in n: p.requires_grad_(True)   # último stage
                    elif 'mlp' in n: p.requires_grad_(True)
                    else: p.requires_grad_(True)  # también liberamos resto (si prefieres gradual, cambia esto)

                groups = build_param_groups(model)
                optimizer = AdamW(groups, weight_decay=CONFIG['weight_decay'], betas=CONFIG['betas'])
                scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                    optimizer, T_0=CONFIG['T_0'], T_mult=CONFIG['T_mult'], eta_min=CONFIG['eta_min'])

            tr_loss = train_one_epoch(model, ld_tr, optimizer, scaler, criterion, use_mix=True)
            ema.update(model)
            scheduler.step(epoch + fold/10.0)  # paso suave

            # Eval con EMA (más estable)
            probs_val, y_val = predict_loader(ema.ema, ld_va, tta=True)
            # Busco umbral que maximiza F1
            best_thr, best_f1, best_pack = 0.5, -1, None
            for thr in np.linspace(0.2, 0.8, 61):
                acc,prec,rec,f1,spec,cm = metrics_from_preds(y_val, probs_val, thr)
                if f1 > best_f1:
                    best_f1, best_thr, best_pack = f1, thr, (acc,prec,rec,f1,spec,cm)
            acc,prec,rec,f1,spec,cm = best_pack
            val_loss_proxy = 1 - f1  # proxy simple para early stopping

            history_rows.append({
                'fold': fold, 'epoch': epoch, 'train_loss': tr_loss,
                'val_f1': f1, 'val_acc': acc, 'val_prec': prec, 'val_rec': rec, 'val_spec': spec,
                'best_thr': best_thr
            })

            if val_loss_proxy < best_val:
                best_val = val_loss_proxy
                torch.save({'model': ema.ema.state_dict(),
                            'thr': best_thr,
                            #'config': CONFIG}, best_path)
                },
                 best_path
                )

            early.step(val_loss_proxy)
            print(f"[Fold {fold} | Epoch {epoch}] tr_loss={tr_loss:.4f} | val_f1={f1:.4f} @thr={best_thr:.2f}")
            if early.stop:
                print(f"Early stopping en fold {fold} (epoch {epoch})")
                break

            torch.cuda.empty_cache(); gc.collect()

        # Cargar mejor modelo (EMA)
        ckpt = torch.load(best_path, map_location=device, weights_only=False)
        model.load_state_dict(ckpt['model'])
        best_thr = ckpt['thr']

        # Métricas finales en VAL (con TTA)
        probs_val, y_val = predict_loader(model, ld_va, tta=True)
        vacc,vprec,vrec,vf1,vspec,vcm = metrics_from_preds(y_val, probs_val, best_thr)
        cms.append(vcm)
        plot_cm(vcm, f'Fold {fold} - Confusion Matrix (Val)', os.path.join(CONFIG['metrics_dir'], f"{CONFIG['model_name']}_fold{fold}_cm.png"))

        # Métricas en TRAIN (para registro)
        probs_tr, y_tr = predict_loader(model, ld_tr, tta=False)
        tacc,tprec,trec,tf1,tspec,tcm = metrics_from_preds(y_tr, probs_tr, best_thr)

        fold_metrics_val.append({'Fold': fold, 'Accuracy': vacc, 'Precision': vprec, 'Recall': vrec, 'F1-Score': vf1, 'Specificity': vspec, 'BestThr': best_thr})
        fold_metrics_train.append({'Fold': fold, 'Accuracy': tacc, 'Precision': tprec, 'Recall': trec, 'F1-Score': tf1, 'Specificity': tspec, 'BestThr': best_thr})

        # limpieza
        del model, ema, optimizer, ld_tr, ld_va, ds_tr, ds_va; torch.cuda.empty_cache(); gc.collect()

    # Guardar historia por época
    pd.DataFrame(history_rows).to_csv(os.path.join(CONFIG['metrics_dir'], f"{CONFIG['model_name']}_history_per_epoch.csv"), index=False)
    return fold_metrics_train, fold_metrics_val, cms

print("✓ Funciones de entrenamiento definidas")


In [None]:
# %%
def save_fold_metrics(train_list, val_list):
    df_tr = pd.DataFrame(train_list)
    df_va = pd.DataFrame(val_list)

    def append_mean_std(df):
        row = {'Fold':'Mean ± Std'}
        for col in ['Accuracy','Precision','Recall','F1-Score','Specificity','BestThr']:
            vals = df[col].astype(float).values
            row[col] = f"{vals.mean():.4f} ± {vals.std():.4f}" if col!='BestThr' else f"{vals.mean():.3f} ± {vals.std():.3f}"
        return pd.concat([df, pd.DataFrame([row])], axis=0)

    df_tr_out = append_mean_std(df_tr)
    df_va_out = append_mean_std(df_va)

    path_tr = os.path.join(CONFIG['metrics_dir'], f"{CONFIG['model_name']}_metrics_train.csv")
    path_va = os.path.join(CONFIG['metrics_dir'], f"{CONFIG['model_name']}_metrics_val.csv")
    df_tr_out.to_csv(path_tr, index=False); df_va_out.to_csv(path_va, index=False)

    print("✅ Guardadas métricas:")
    print("  Train  ->", path_tr)
    print("  Val    ->", path_va)
    return df_tr_out, df_va_out

def save_mean_confusion_matrix(cms):
    avg_cm = np.mean(np.stack(cms, axis=0), axis=0)
    plt.figure(figsize=(6,5))
    sns.heatmap(avg_cm, annot=True, fmt='.1f', cmap='Blues',
                xticklabels=['Benign','Malignant'], yticklabels=['Benign','Malignant'])
    plt.title('Mean Confusion Matrix (Validation, 5-Fold)')
    plt.ylabel('True'); plt.xlabel('Pred')
    plt.tight_layout()
    # Nombre estándar y el nombre que me pediste explícitamente
    p1 = os.path.join(CONFIG['metrics_dir'], f"{CONFIG['model_name']}_avg_cm.png")
    p2 = os.path.join(CONFIG['metrics_dir'], "CvT_TL_1_mean_confusion_matrix.png")
    plt.savefig(p1, dpi=150); plt.savefig(p2, dpi=150); plt.close()
    print("✅ Mean CM guardado en:")
    print(" ", p1)
    print(" ", p2)


In [None]:
# %%
print("="*80)
print(">>> ENTRENANDO CvT_TL_1 (512x512) con 5-Fold CV (Stratified por paciente)")
print("="*80)

start = time.time()
train_list, val_list, cms = train_cv()
df_tr, df_va = save_fold_metrics(train_list, val_list)
save_mean_confusion_matrix(cms)
print(f"\n⏱️ Tiempo total: {(time.time()-start)/3600:.2f} h")
