In [4]:
todo completo final

SyntaxError: invalid syntax (1933874046.py, line 1)

In [11]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches
import scipy.ndimage
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, roc_curve, auc
from datetime import datetime
import random
import pandas as pd
from PIL import Image
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

# --------------------------------------------------------------------------------
# GLOBAL CONFIGURATION
# --------------------------------------------------------------------------------
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
RESULTS_ROOT = os.path.join("results", timestamp)
os.makedirs(RESULTS_ROOT, exist_ok=True)

MODEL_DIR    = os.path.join(RESULTS_ROOT, "models")
OVERLAYS_DIR = os.path.join(RESULTS_ROOT, "overlays")
METRICS_DIR  = os.path.join(RESULTS_ROOT, "metrics_plots")
for d in (MODEL_DIR, OVERLAYS_DIR, METRICS_DIR):
    os.makedirs(d, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Using device: {device}")

# --------------------------------------------------------------------------------
# COLORS & CLASSES
# --------------------------------------------------------------------------------
my_colors = {
    0: (1.0,0.65,0.0,1.0),
    1: (1.0,0.0,0.0,1.0),
    2: (0.0,1.0,0.0,1.0),
}
class_names = {
    0: "Background",
    1: "Front of Invasion",
    2: "Stroma",
}

cmap_vals = np.zeros((256,4))
for k,v in my_colors.items():
    cmap_vals[k] = v
unet_cmap = mcolors.ListedColormap(cmap_vals)

viz_colors = {0:my_colors[0],1:my_colors[1],2:my_colors[2],4:(0,0,0,0)}
custom_viz_cmap = mcolors.ListedColormap([viz_colors[i] for i in sorted(viz_colors)])

def prepare_mask_for_viz(m):
    m2 = m.copy()
    m2[m2==-1]=4
    return m2

# --------------------------------------------------------------------------------
# U-Net + SE Blocks
# --------------------------------------------------------------------------------
class SEBlock(nn.Module):
    def __init__(self,in_ch,reduction=16):
        super().__init__()
        self.fc1=nn.Linear(in_ch,in_ch//reduction)
        self.fc2=nn.Linear(in_ch//reduction,in_ch)
    def forward(self,x):
        b,c,_,_=x.size()
        y=F.adaptive_avg_pool2d(x,1).view(b,c)
        y=F.relu(self.fc1(y),inplace=True)
        y=torch.sigmoid(self.fc2(y)).view(b,c,1,1)
        return x*y

class DoubleConv(nn.Module):
    def __init__(self,in_ch,out_ch,kernel=3,dilation=1):
        super().__init__()
        pad=((kernel-1)*dilation)//2
        self.block=nn.Sequential(
            nn.Conv2d(in_ch,out_ch,kernel,padding=pad,dilation=dilation),
            nn.BatchNorm2d(out_ch),nn.ReLU(inplace=True),
            nn.Conv2d(out_ch,out_ch,kernel,padding=pad,dilation=dilation),
            nn.BatchNorm2d(out_ch),nn.ReLU(inplace=True)
        )
        self.se=SEBlock(out_ch)
    def forward(self,x):
        x=self.block(x)
        return self.se(x)

class UNet(nn.Module):
    def __init__(self,in_ch=2,base=32,classes=3):
        super().__init__()
        c1,c2,c3,c4,c5=base,base*2,base*4,base*8,base*16
        self.d1=DoubleConv(in_ch,c1); self.p1=nn.MaxPool2d(2)
        self.d2=DoubleConv(c1,c2); self.p2=nn.MaxPool2d(2)
        self.d3=DoubleConv(c2,c3,dilation=2); self.p3=nn.MaxPool2d(2)
        self.d4=DoubleConv(c3,c4,dilation=4); self.p4=nn.MaxPool2d(2)
        self.b=DoubleConv(c4,c5,dilation=4)
        self.u4=nn.ConvTranspose2d(c5,c4,2,2); self.c4=DoubleConv(c4*2,c4,dilation=2)
        self.u3=nn.ConvTranspose2d(c4,c3,2,2); self.c3=DoubleConv(c3*2,c3)
        self.u2=nn.ConvTranspose2d(c3,c2,2,2); self.c2=DoubleConv(c2*2,c2)
        self.u1=nn.ConvTranspose2d(c2,c1,2,2); self.c1=DoubleConv(c1*2,c1)
        self.final=nn.Conv2d(c1,classes,1)
    def forward(self,x):
        c1=self.d1(x); p1=self.p1(c1)
        c2=self.d2(p1); p2=self.p2(c2)
        c3=self.d3(p2); p3=self.p3(c3)
        c4=self.d4(p3); p4=self.p4(c4)
        bn=self.b(p4)
        u4=self.u4(bn); m4=torch.cat([u4,c4],1); c5=self.c4(m4)
        u3=self.u3(c5); m3=torch.cat([u3,c3],1); c6=self.c3(m3)
        u2=self.u2(c6); m2=torch.cat([u2,c2],1); c7=self.c2(m2)
        u1=self.u1(c7); m1=torch.cat([u1,c1],1); c8=self.c1(m1)
        return self.final(c8)

# --------------------------------------------------------------------------------
# DATASET
# --------------------------------------------------------------------------------
def get_valid_pairs(d):
    files=os.listdir(d)
    pairs=[]
    for f in files:
        if not (f.startswith("patch_") and f.endswith(".png")): continue
        num=f.split("_")[1].split(".")[0]
        rgb,hema,eos,msk=[os.path.join(d,p) for p in (
            f,f.replace("patch_","hematoxilin_"),
            f.replace("patch_","eosin_"),f.replace("patch_","mask_")
        )]
        if not all(os.path.exists(p) for p in (rgb,hema,eos,msk)): continue
        m=cv2.imread(msk,0)
        m=np.where(np.isin(m,[0,1,2]),m,-1)
        if {1,2}&set(np.unique(m)): pairs.append((rgb,hema,eos,msk))
    return pairs





class PatchDataset(Dataset):
    def __init__(self, pairs, augment=False):
        self.pairs = pairs
        self.augment = augment

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        rgb_f, hema_f, eosin_f, mask_f = self.pairs[idx]
        h = cv2.imread(hema_f, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.0
        e = cv2.imread(eosin_f, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.0
        m = cv2.imread(mask_f, cv2.IMREAD_GRAYSCALE)
        if h.shape != m.shape:
            m = cv2.resize(m, (h.shape[1], h.shape[0]), interpolation=cv2.INTER_NEAREST)
        m = np.where(np.isin(m, [0, 1, 2]), m, -1)

        ch = np.stack([h, e], axis=0)

        if self.augment:
            if random.random() > 0.5:
                ch, m = ch[:, :, ::-1], m[:, ::-1]
            if random.random() > 0.5:
                ch, m = ch[:, ::-1, :], m[::-1, :]
            k = random.randint(0, 3)
            if k:
                ch = np.rot90(ch, k, axes=(1, 2))
                m = np.rot90(m, k)
            delta = np.float32(random.uniform(-0.1, 0.1))
            ch = np.clip(ch + delta, 0, 1)
            noise = np.random.normal(0, 0.02, size=ch.shape).astype(np.float32)
            ch = np.clip(ch + noise, 0, 1)

        # **Aquí nos aseguramos de que no haya strides negativas**
        ch = np.ascontiguousarray(ch, dtype=np.float32)
        m  = np.ascontiguousarray(m)

        return torch.from_numpy(ch), torch.from_numpy(m).long()

# --------------------------------------------------------------------------------
# TRAINING
# --------------------------------------------------------------------------------
def train_model(model,train_ds,val_ds,epochs=5,lr=1e-4,bs=1,accum=4):
    opt=torch.optim.Adam(model.parameters(),lr=lr)
    crit=nn.CrossEntropyLoss(ignore_index=-1)
    scaler=GradScaler()
    tr_dl=DataLoader(train_ds,batch_size=bs,shuffle=True,pin_memory=True)
    val_dl=DataLoader(val_ds,batch_size=bs,shuffle=False,pin_memory=True)
    for ep in range(1,epochs+1):
        model.train(); opt.zero_grad(set_to_none=True)
        tloss=0
        for i,(x,y) in enumerate(tqdm(tr_dl,desc=f"Train {ep}")):
            x,y=x.to(device),y.to(device)
            with autocast():
                out=model(x)
                if out.shape[2:]!=y.shape[1:]:
                    y=F.interpolate(y.unsqueeze(1).float(),size=out.shape[2:],mode="nearest")\
                       .squeeze(1).long()
                loss=crit(out,y)/accum
            scaler.scale(loss).backward(); tloss+=loss.item()*accum
            if (i+1)%accum==0 or (i+1)==len(tr_dl):
                scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True)
        print(f"[Ep{ep}] Train Loss: {tloss/len(tr_dl):.4f}")
        model.eval(); vloss=0
        with torch.no_grad():
            for x,y in tqdm(val_dl,desc=f" Val {ep}"):
                x,y=x.to(device),y.to(device)
                with autocast():
                    out=model(x)
                    if out.shape[2:]!=y.shape[1:]:
                        y=F.interpolate(y.unsqueeze(1).float(),size=out.shape[2:],mode="nearest")\
                           .squeeze(1).long()
                    vloss+=crit(out,y).item()
        print(f"[Ep{ep}] Val Loss:   {vloss/len(val_dl):.4f}")
    return model

# --------------------------------------------------------------------------------
# PREDICTION & OVERLAYS
# --------------------------------------------------------------------------------
apply_ma=lambda p,ks=3:scipy.ndimage.generic_filter(p,lambda v:np.bincount(v.astype(int)).argmax(),size=ks)

def save_overlay(rgb, pred, idx, truth, sub):
    sd = os.path.join(OVERLAYS_DIR, sub)
    os.makedirs(sd, exist_ok=True)

    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    ax1, ax2, ax3 = axes

    ax1.imshow(rgb)
    ax1.set_title("Original")
    ax1.axis("off")

    ax2.imshow(pred, cmap=unet_cmap, norm=mcolors.NoNorm())
    ax2.set_title("Predicted")
    ax2.axis("off")

    m2 = prepare_mask_for_viz(truth)
    ax3.imshow(m2, cmap=custom_viz_cmap, norm=mcolors.NoNorm())
    ax3.set_title("GroundTruth")
    ax3.axis("off")

    # Preparamos la leyenda
    handles = [mpatches.Patch(color=my_colors[c], label=class_names[c])
               for c in sorted(class_names)]
    labels  = [class_names[c] for c in sorted(class_names)]

    # Añadimos la leyenda justo a la derecha de los ejes, centrada verticalmente
    fig.legend(
        handles, labels,
        loc="center right",
        bbox_to_anchor=(1.05, 0.8), 
        ncol=1,
        frameon=False
    )

    # Ajustamos márgenes para dejar un poquito de hueco a la derecha
    fig.subplots_adjust(
        left=0.05,
        right=0.95,   # ahora los ejes ocupan hasta el 95% del ancho
        top=0.95,
        bottom=0.05,
        wspace=0.2
    )

    fig.savefig(
        os.path.join(sd, f"overlay_{idx}.svg"),
        format="svg", dpi=150, bbox_inches="tight"
    )
    plt.close(fig)


def predict_and_collect(model,ds,pairs,ma_kernel=5):
    trs=[];prs=[];probs=[]
    for idx,((x,y),(rgb,_,_,_)) in enumerate(zip(ds,pairs)):
        img=cv2.cvtColor(cv2.imread(rgb),cv2.COLOR_BGR2RGB)
        xb=x.unsqueeze(0).to(device)
        with torch.no_grad():
            logits=model(xb)
            P=F.softmax(logits,dim=1).cpu().numpy()[0]
            pr=np.argmax(P,axis=0).astype(np.uint8)
        sm=apply_ma(pr,ks=ma_kernel)
        save_overlay(img,sm,idx,y.numpy(),f"ma_{ma_kernel}")
        trs.append(y.numpy().flatten()); prs.append(sm.flatten())
        h,w=P.shape[1:]
        probs.append(P.transpose(1,2,0).reshape(-1,P.shape[0]))
    return np.concatenate(trs),np.concatenate(prs),np.vstack(probs)

# --------------------------------------------------------------------------------
# METRICS & PLOTS
# --------------------------------------------------------------------------------

def save_conf(cm, name):
    df = pd.DataFrame(cm, index=class_names.values(), columns=class_names.values())
    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(df, cmap="viridis", interpolation="nearest")
    ax.set_title(name)
    cbar = fig.colorbar(im, ax=ax)
    ticks = np.arange(len(df))
    ax.set_xticks(ticks)
    ax.set_xticklabels(df.columns, rotation=0)
    ax.set_yticks(ticks)
    ax.set_yticklabels(df.index)
    for i, j in np.ndindex(df.shape):
        ax.text(j, i, int(df.iat[i, j]), ha="center", va="center", color="white")
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")

    # Ajuste de márgenes para que no se corten las etiquetas
    fig.tight_layout()
    fig.subplots_adjust(left=0.2)

    fig.savefig(
        os.path.join(METRICS_DIR, f"{name.replace(' ', '_')}.svg"),
        format="svg",
        bbox_inches="tight"
    )
    plt.close(fig)


def save_norm_conf(cm, name):
    cmn = cm.astype(float) / cm.sum(axis=1, keepdims=True)
    df = pd.DataFrame(cmn, index=class_names.values(), columns=class_names.values())
    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(df, cmap="Blues", vmin=0, vmax=1)
    ax.set_title(name)
    cbar = fig.colorbar(im, ax=ax)
    ticks = np.arange(len(df))
    ax.set_xticks(ticks)
    ax.set_xticklabels(df.columns, rotation=0)
    ax.set_yticks(ticks)
    ax.set_yticklabels(df.index)
    for i, j in np.ndindex(df.shape):
        ax.text(j, i, f"{df.iat[i, j]:.2f}", ha="center", va="center")
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")

    # Ajuste de márgenes para que no se corten las etiquetas
    fig.tight_layout()
    fig.subplots_adjust(left=0.2)

    fig.savefig(
        os.path.join(METRICS_DIR, f"{name.replace(' ', '_')}.svg"),
        format="svg",
        bbox_inches="tight"
    )
    plt.close(fig)


def save_iou_dice(cm):
    tp=np.diag(cm); fp=cm.sum(axis=0)-tp; fn=cm.sum(axis=1)-tp
    iou=tp/(tp+fp+fn); dice=2*tp/(2*tp+fp+fn)
    classes=list(class_names.values()); x=np.arange(len(classes)); w=0.35
    fig,ax=plt.subplots(figsize=(7,4))
    ax.bar(x-w/2,iou,w,label="IoU"); ax.bar(x+w/2,dice,w,label="Dice")
    ax.set_xticks(x); ax.set_xticklabels(classes); ax.set_ylim(0,1)
    ax.set_ylabel("Score"); ax.set_title("IoU vs Dice per Class"); ax.legend()
    fig.savefig(os.path.join(METRICS_DIR,"IoU_vs_Dice.svg"),format="svg")
    plt.close(fig)

import numpy as np
import os
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from matplotlib.lines import Line2D

def save_patch_iou(tr, pr, pairs):
    """
    Calcula el IoU por parche para las clases 1 y 2,
    filtrando denominadores cero para evitar nan/warnings,
    y dibuja un boxplot solo con los valores válidos,
    incluyendo una leyenda para el marcador de la media.
    """
    ious = []  # aquí almacenaremos vectores [iou0, iou1, iou2]

    for t_patch, p_patch in zip(
        np.split(tr, len(pairs)),
        np.split(pr, len(pairs))
    ):
        cm = confusion_matrix(t_patch, p_patch, labels=[0,1,2])
        tp = np.diag(cm)
        fp = cm.sum(axis=0) - tp
        fn = cm.sum(axis=1) - tp
        denom = tp + fp + fn

        iou = np.zeros_like(tp, dtype=float)
        valid = denom > 0
        iou[valid] = tp[valid] / denom[valid]
        ious.append(iou)

    ious = np.vstack(ious)



    # dibujamos el boxplot
    fig, ax = plt.subplots(figsize=(6,4))
    bp = ax.boxplot(
        [ious[:,1], ious[:,2]],
        labels=["Front Of Invasion","Stroma"],
        showmeans=True,
        meanprops={"marker":"D", "markeredgecolor":"black", "markerfacecolor":"white"}
    )
    ax.set_ylabel("IoU")
    ax.set_title("Per-Patch IoU")

    # añadimos leyenda para el marcador de la media
    mean_handle = Line2D(
        [0], [0],
        marker='D',
        color='none',
        markeredgecolor='black',
        markerfacecolor='white',
        markersize=8,
        linestyle='None',
        label='Mean'
    )
    ax.legend(handles=[mean_handle], loc='upper right')

    # guardamos la figura
    fig.savefig(
        os.path.join(METRICS_DIR,"Per_Patch_IoU.svg"),
        format="svg"
    )
    plt.close(fig)



import numpy as np
import matplotlib.pyplot as plt
import os

def save_metrics_per_class(cm):
    """
    Plot per-class Precision, Recall, F1-Score y Accuracy como gráficos de barras agrupadas,
    mostrando primero Accuracy, luego Precision, Recall y F1-Score.
    """
    # Descomposición de la matriz de confusión
    tp = np.diag(cm)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp
    total_samples = cm.sum()

    # Cálculo de Precision / Recall / F1 para cada clase
    with np.errstate(divide='ignore', invalid='ignore'):
        precision_pc = tp / (tp + fp)
        recall_pc    = tp / (tp + fn)
        f1_pc        = 2 * precision_pc * recall_pc / (precision_pc + recall_pc)
        # Accuracy por clase: (TP + TN) / Total
        tn = total_samples - (tp + fp + fn)
        accuracy_pc  = (tp + tn) / total_samples
    save_metrics_per_class(cm)

    # Sustituir NaNs por ceros
    precision_pc = np.nan_to_num(precision_pc)
    recall_pc    = np.nan_to_num(recall_pc)
    f1_pc        = np.nan_to_num(f1_pc)
    accuracy_pc  = np.nan_to_num(accuracy_pc)

    classes = list(class_names.values())
    x = np.arange(len(classes))
    width = 0.2

    fig, ax = plt.subplots(figsize=(10, 5))
    # Orden de barras: Accuracy, Precision, Recall, F1 Score
    ax.bar(x - 1.5*width, accuracy_pc,  width, label='Accuracy')
    ax.bar(x - 0.5*width, precision_pc, width, label='Precision')
    ax.bar(x + 0.5*width, recall_pc,    width, label='Recall')
    ax.bar(x + 1.5*width, f1_pc,        width, label='F1 Score')

    ax.set_xticks(x)
    ax.set_xticklabels(classes, rotation=0, ha='right')
    ax.set_ylim(0, 1)
    ax.set_ylabel('Score')
    ax.set_title('Per-Class Segmentation Metrics')

    # Leyenda a la derecha
    legend = ax.legend(
        loc='upper left',
        bbox_to_anchor=(1.02, 1),
        borderaxespad=0,
        frameon=True
    )
    legend.get_frame().set_edgecolor('black')

    # Anotar cada barra con su valor
    for i in range(len(classes)):
        ax.text(x[i] - 1.5*width, accuracy_pc[i]  + 0.02, f"{accuracy_pc[i]:.2f}",  ha='center')
        ax.text(x[i] - 0.5*width, precision_pc[i] + 0.02, f"{precision_pc[i]:.2f}", ha='center')
        ax.text(x[i] + 0.5*width, recall_pc[i]    + 0.02, f"{recall_pc[i]:.2f}",    ha='center')
        ax.text(x[i] + 1.5*width, f1_pc[i]        + 0.02, f"{f1_pc[i]:.2f}",        ha='center')

    plt.tight_layout()
    fig.savefig(
        os.path.join(METRICS_DIR, "Metrics_Per_Class.svg"),
        format="svg", dpi=150, bbox_inches="tight"
    )
    plt.close(fig)






def save_overall_metrics_from_cm(cm):
    """
    Compute global Accuracy, macro‐Precision, macro‐Recall and macro‐F1 
    from the confusion matrix, and save as a bar chart.
    """
    tp = np.diag(cm)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp
    total = cm.sum()

    # Accuracy
    accuracy = tp.sum() / total

    # per‐class Precision / Recall / F1
    with np.errstate(divide='ignore', invalid='ignore'):
        precision_pc = tp / (tp + fp)
        recall_pc    = tp / (tp + fn)
        f1_pc        = 2 * precision_pc * recall_pc / (precision_pc + recall_pc)
    precision_pc = np.nan_to_num(precision_pc)
    recall_pc    = np.nan_to_num(recall_pc)
    f1_pc        = np.nan_to_num(f1_pc)

    # macro‐averages
    precision = precision_pc.mean()
    recall    = recall_pc.mean()
    f1_score  = f1_pc.mean()

    metrics = {
        'Accuracy':   accuracy,
        'Precision':  precision,
        'Recall':     recall,
        'F1-Score':   f1_score
    }
    names, values = zip(*metrics.items())

    fig, ax = plt.subplots(figsize=(6,4))
    ax.bar(names, values)
    ax.set_ylim(0,1)
    ax.set_ylabel('Score')
    ax.set_title('Overall Segmentation Metrics')
    for i, v in enumerate(values):
        ax.text(i, v + 0.02, f"{v:.2f}", ha='center')
    fig.savefig(
        os.path.join(METRICS_DIR, "Overall_Metrics.svg"),
        format="svg", dpi=150, bbox_inches="tight"
    )
    plt.close(fig)

def save_area_scatter(tr,pr,pairs):
    true_cnt=[]; pred_cnt=[]; idx=0
    for _,_,_,msk in pairs:
        gt=cv2.imread(msk,0).flatten()
        true_cnt.append((gt==1).sum())
        pred_cnt.append((pr[idx:idx+gt.size]==1).sum())
        idx+=gt.size
    fig,ax=plt.subplots(figsize=(5,5))
    ax.scatter(true_cnt,pred_cnt,alpha=0.7)
    m=max(max(true_cnt),max(pred_cnt))
    ax.plot([0,m],[0,m],'k--')
    ax.set_xlabel("True FOI Pixels"); ax.set_ylabel("Pred FOI Pixels"); ax.set_title("Area Scatter")
    fig.savefig(os.path.join(METRICS_DIR,"Area_Scatter.svg"),format="svg")
    plt.close(fig)

def save_roc(tr, prob, max_samples=200_000, seed=42):
    """
    Guarda las curvas ROC/AUC por clase, muestreando como máximo `max_samples` píxeles.
    Para n > max_samples utiliza random.sample sobre range(n) sin cargar todo en RAM.
    """
    n = tr.shape[0]
    if n > max_samples:
        random.seed(seed)
        # sample sin reemplazo directamente de range(n)
        idx = random.sample(range(n), k=max_samples)
        tr_sub   = tr[idx]
        prob_sub = prob[idx]
    else:
        tr_sub, prob_sub = tr, prob

    fig, ax = plt.subplots(figsize=(6,5))
    for cls in range(prob_sub.shape[1]):
        yt = (tr_sub == cls).astype(int)
        fpr, tpr, _ = roc_curve(yt, prob_sub[:, cls])
        roc_auc = auc(fpr, tpr)
        ax.plot(fpr, tpr, label=f"{class_names[cls]} (AUC={roc_auc:.2f})")

    ax.plot([0,1], [0,1], 'k--')
    ax.set_xlabel("FPR")
    ax.set_ylabel("TPR")
    ax.set_title("ROC Curves")
    ax.legend(loc="lower right")

    fig.savefig(
        os.path.join(METRICS_DIR, "ROC_Curves.svg"),
        format="svg", dpi=150
    )
    plt.close(fig)

# --------------------------------------------------------------------------------
# MAIN
# --------------------------------------------------------------------------------
if __name__=="__main__":
    patch_dir = r"C:\Temp\patches_2048\HC22B0006126_HEmascaras_tumor_oversampled_balanced"
    pairs = get_valid_pairs(patch_dir)
    print(f"[INFO] Found {len(pairs)} patches.")

    # split
    strat=[]
    for *_,m in pairs:
        gt=cv2.imread(m,0); gt=np.where(np.isin(gt,[0,1,2]),gt,-1)
        u,c=np.unique(gt,return_counts=True)
        valid=[(u_,c_) for u_,c_ in zip(u,c) if u_ in [0,1,2]]
        strat.append(max(valid,key=lambda x:x[1])[0] if valid else 0)
    train_p, test_p = train_test_split(pairs, test_size=0.2, stratify=strat, random_state=42)
    ds_tr = PatchDataset(train_p, augment=True)
    ds_te = PatchDataset(test_p, augment=False)

    # ----- OPTION A: Train from scratch -----
    # Uncomment this block to train a new model

    model = UNet(in_ch=2, base=64).to(device)
    model = train_model(model, ds_tr, ds_te, epochs=1, lr=1e-4, bs=1, accum=4)
    model_path = os.path.join(MODEL_DIR, f"unet_{timestamp}.pth")
    torch.save(model.state_dict(), model_path)
    print(f"[INFO] Saved trained model to {model_path}")
    trained = model

    # ----- OPTION B: Load existing model -----
    # Uncomment this block to load a previously trained model
    
    #"""
    #model = UNet(in_ch=2, base=64).to(device)
    #model_path = r"\\imgserver.cnio.es\IMAGES\CONFOCAL\IA\CMU\tfm_colab\patrones_invasion\notebooks\results\20250527_130513\models\unet_20250527_130513.pth"
    #state = torch.load(model_path, map_location=device)
    #model.load_state_dict(state)
    #model.eval()
    #print(f"[INFO] Loaded model from {model_path}")
    #trained = model
    #"""

    # ---- end train/load choice ----

    # Predict & collect
    all_t, all_p, all_prob = predict_and_collect(trained, ds_te, test_p, ma_kernel=5)

    # Confusion matrix & metrics
    cm = confusion_matrix(all_t, all_p, labels=[0,1,2])
    save_conf(cm,"Confusion Matrix")
    save_norm_conf(cm,"Normalized Confusion Matrix")
    save_iou_dice(cm)
    save_patch_iou(all_t, all_p, test_p)
    save_area_scatter(all_t, all_p, test_p)
    save_roc(all_t, all_prob)
    save_overall_metrics_from_cm(cm)
    save_metrics_per_class(cm)

    print(f"[DONE] Results are in {RESULTS_ROOT}")

[INFO] Using device: cuda
[INFO] Found 2003 patches.


  scaler=GradScaler()
  with autocast():
Train 1: 100%|█████████████████████████████████████████████████████████████████████| 1602/1602 [23:09<00:00,  1.15it/s]


[Ep1] Train Loss: 0.7383


  with autocast():
 Val 1: 100%|████████████████████████████████████████████████████████████████████████| 401/401 [01:58<00:00,  3.38it/s]


[Ep1] Val Loss:   0.6412
[INFO] Saved trained model to results\20250529_171815\models\unet_20250529_171815.pth


  bp = ax.boxplot(


[DONE] Results are in results\20250529_171815
