In [None]:
%run src/ConvNext_models.ipynb
from torchvision import transforms
from torch.utils.data import DataLoader, random_split



Output: torch.Size([2, 1, 512, 512])


In [None]:
import os
import nibabel as nib
import numpy as np
import torch
from torch.utils.data import Dataset

class BRATSDataset_2(Dataset):
    def __init__(self, base_path, img_transform=None, mask_transform=None, year=2023):
        self.base_path = base_path

        self.folders = [f for f in os.listdir(base_path)
                        if os.path.isdir(os.path.join(base_path, f))]
        self.transform = img_transform
        self.mask_transform = mask_transform
        self.year = year

    def __len__(self):
        return len(self.folders)

    def __getitem__(self, idx):
        folder_name = self.folders[idx]
        folder_path = os.path.join(self.base_path, folder_name)
        files = os.listdir(folder_path)

        t2f_file, seg_file = None, None

        for nifty_file in files:
            file_path = os.path.join(folder_path, nifty_file)

            try:

                if nifty_file.endswith('.nii') and 't2f' in nifty_file:
                    t2f_file = nib.load(file_path).get_fdata()
                elif nifty_file.endswith('.nii') and 'seg' in nifty_file:
                    seg_file = nib.load(file_path).get_fdata()
                elif nifty_file.endswith('.nii.gz') and 't2' in nifty_file:
                    t2f_file = nib.load(file_path).get_fdata()
                elif nifty_file.endswith('.nii.gz') and 'seg' in nifty_file:
                    seg_file = nib.load(file_path).get_fdata()
            except Exception as e:
                print(f"Error al cargar {file_path}: {e}")

                return self.__getitem__((idx + 1) % len(self))

        if t2f_file is None or seg_file is None:
            print(f"Advertencia: Archivo t2f o segmentación faltante en {folder_path}.")

            return self.__getitem__((idx + 1) % len(self))


        max_content_slice = np.argmax(np.sum(seg_file, axis=(0, 1)))


        t2f_image = t2f_file[:, :, max_content_slice]
        seg_image = (seg_file[:, :, max_content_slice] > 0).astype(np.uint8)


        t2f_image = torch.tensor(t2f_image, dtype=torch.float32).unsqueeze(0)
        seg_image = torch.tensor(seg_image, dtype=torch.float32).unsqueeze(0)


        if self.transform:
            t2f_image = self.transform(t2f_image)

        if self.mask_transform:
            seg_image = self.mask_transform(seg_image)


        case_id = folder_name


        return t2f_image, seg_image, case_id


In [None]:
base_path ="../datasets/BRATS2021/"


transform = transforms.Compose([
    transforms.Resize((512, 512)),



])

mask_transform = transforms.Compose([
    transforms.Resize((512, 512))
])



dataset = BRATSDataset_2(base_path, img_transform=transform, mask_transform = mask_transform)

train_size = int(0.7 * len(dataset))
val_size   = int(0.15 * len(dataset))
test_size  = len(dataset) - train_size - val_size


train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])


BATCH_SIZE = 6
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size   = BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size  = BATCH_SIZE, shuffle=False)
print(len(train_loader))

146


In [1]:

import os
import json
import numpy as np
import torch
from torch.utils.data import DataLoader
from sklearn.cluster import DBSCAN
from collections import defaultdict





def analyze_slice_with_dbscan(gt_2d,
                              pred_2d,
                              eps=3,
                              min_samples=20):
    gt = gt_2d.astype(bool)
    pred = pred_2d.astype(bool)


    fn_mask = np.logical_and(gt, np.logical_not(pred))

    fp_mask = np.logical_and(np.logical_not(gt), pred)

    fn_coords = np.argwhere(fn_mask)
    fp_coords = np.argwhere(fp_mask)

    results = {
        "fn_clusters": [],
        "fp_clusters": [],
        "n_fn_pixels": int(fn_coords.shape[0]),
        "n_fp_pixels": int(fp_coords.shape[0]),
    }




    if fn_coords.shape[0] > 0:
        db_fn = DBSCAN(eps=eps, min_samples=min_samples).fit(fn_coords)
        labels_fn = db_fn.labels_

        for lab in set(labels_fn):
            if lab == -1:
                continue
            mask_lab = (labels_fn == lab)
            coords_lab = fn_coords[mask_lab]

            size = int(coords_lab.shape[0])
            y_mean, x_mean = coords_lab.mean(axis=0)

            results["fn_clusters"].append({
                "cluster_id": int(lab),
                "size": size,
                "centroid_rowcol": (float(y_mean), float(x_mean)),
            })




    if fp_coords.shape[0] > 0:
        db_fp = DBSCAN(eps=eps, min_samples=min_samples).fit(fp_coords)
        labels_fp = db_fp.labels_

        for lab in set(labels_fp):
            if lab == -1:
                continue
            mask_lab = (labels_fp == lab)
            coords_lab = fp_coords[mask_lab]

            size = int(coords_lab.shape[0])
            y_mean, x_mean = coords_lab.mean(axis=0)

            results["fp_clusters"].append({
                "cluster_id": int(lab),
                "size": size,
                "centroid_rowcol": (float(y_mean), float(x_mean)),
            })

    return results






def dice_score_torch(pred, gt, eps=1e-7):
    if pred.dim() == 4:
        pred = pred[:, 0]
        gt   = gt[:, 0]

    pred = pred.float()
    gt   = gt.float()

    inter = (pred * gt).sum()
    denom = pred.sum() + gt.sum()
    return float((2 * inter + eps) / (denom + eps))






def run_failure_analysis(model,
                         test_loader,
                         device=None,
                         eps=3,
                         min_samples=20,
                         save_json_path=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)
    model.eval()

    all_slices_results = []

    with torch.no_grad():
        for imgs, gts, names in test_loader:
            imgs = imgs.to(device)
            gts  = gts.to(device)




            logits = model(imgs)


            if logits.shape[1] == 1:
                probs = torch.sigmoid(logits)
            else:

                probs = torch.softmax(logits, dim=1)[:, 1:2, ...]


            preds = (probs > 0.5).long()




            batch_size = imgs.size(0)
            for b in range(batch_size):
                gt_b   = gts[b, 0].cpu().numpy()
                pred_b = preds[b, 0].cpu().numpy()


                dbscan_res = analyze_slice_with_dbscan(
                    gt_2d=gt_b,
                    pred_2d=pred_b,
                    eps=eps,
                    min_samples=min_samples
                )


                dice_b = dice_score_torch(preds[b:b+1], gts[b:b+1])


                tumor_pixels = int(gt_b.sum())

                slice_result = {
                    "name": str(names[b]),
                    "dice": dice_b,
                    "n_fn_pixels": dbscan_res["n_fn_pixels"],
                    "n_fp_pixels": dbscan_res["n_fp_pixels"],
                    "fn_clusters": dbscan_res["fn_clusters"],
                    "fp_clusters": dbscan_res["fp_clusters"],
                    "tumor_pixels_gt": tumor_pixels,
                }
                all_slices_results.append(slice_result)




    if save_json_path is not None:
        os.makedirs(os.path.dirname(save_json_path), exist_ok=True)
        with open(save_json_path, "w") as f:
            json.dump(all_slices_results, f, indent=2)
        print(f"[INFO] Resultados guardados en: {save_json_path}")

    return all_slices_results






def count_big_fn_clusters(r, min_size=100):
    return sum(1 for c in r["fn_clusters"] if c["size"] >= min_size)


def stratify_by_tumor_size(all_results, big_fn_min_size=100):
    sizes = [r["tumor_pixels_gt"] for r in all_results if r["tumor_pixels_gt"] > 0]

    if len(sizes) == 0:
        print("[WARN] No hay slices con tumor (tumor_pixels_gt > 0).")
        return [], None, None

    sizes_arr = np.array(sizes)
    q1, q2 = np.percentile(sizes_arr, [33, 66])
    print(f"[INFO] Quanti


SyntaxError: unterminated f-string literal (detected at line 265) (ipython-input-2604204789.py, line 265)

# Failure analisis

In [None]:
import os
import json
import numpy as np
import torch
from torch.utils.data import DataLoader
from sklearn.cluster import DBSCAN





def analyze_slice_with_dbscan(gt_2d,
                              pred_2d,
                              eps=3,
                              min_samples=20):
    gt = gt_2d.astype(bool)
    pred = pred_2d.astype(bool)


    fn_mask = np.logical_and(gt, np.logical_not(pred))

    fp_mask = np.logical_and(np.logical_not(gt), pred)

    fn_coords = np.argwhere(fn_mask)
    fp_coords = np.argwhere(fp_mask)

    results = {
        "fn_clusters": [],
        "fp_clusters": [],
        "n_fn_pixels": int(fn_coords.shape[0]),
        "n_fp_pixels": int(fp_coords.shape[0]),
    }




    if fn_coords.shape[0] > 0:
        db_fn = DBSCAN(eps=eps, min_samples=min_samples).fit(fn_coords)
        labels_fn = db_fn.labels_

        for lab in set(labels_fn):
            if lab == -1:
                continue
            mask_lab = (labels_fn == lab)
            coords_lab = fn_coords[mask_lab]

            size = int(coords_lab.shape[0])
            y_mean, x_mean = coords_lab.mean(axis=0)

            results["fn_clusters"].append({
                "cluster_id": int(lab),
                "size": size,
                "centroid_rowcol": (float(y_mean), float(x_mean)),
            })




    if fp_coords.shape[0] > 0:
        db_fp = DBSCAN(eps=eps, min_samples=min_samples).fit(fp_coords)
        labels_fp = db_fp.labels_

        for lab in set(labels_fp):
            if lab == -1:
                continue
            mask_lab = (labels_fp == lab)
            coords_lab = fp_coords[mask_lab]

            size = int(coords_lab.shape[0])
            y_mean, x_mean = coords_lab.mean(axis=0)

            results["fp_clusters"].append({
                "cluster_id": int(lab),
                "size": size,
                "centroid_rowcol": (float(y_mean), float(x_mean)),
            })

    return results






def dice_score_torch(pred, gt, eps=1e-7):
    if pred.dim() == 4:
        pred = pred[:, 0]
        gt   = gt[:, 0]

    pred = pred.float()
    gt   = gt.float()

    inter = (pred * gt).sum()
    denom = pred.sum() + gt.sum()
    return float((2 * inter + eps) / (denom + eps))






def run_failure_analysis(model,
                         test_loader,
                         device=None,
                         eps=3,
                         min_samples=20,
                         save_json_path=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)
    model.eval()

    all_slices_results = []

    with torch.no_grad():
        for imgs, gts, names in test_loader:
            imgs = imgs.to(device)
            gts  = gts.to(device)




            logits = model(imgs)


            if logits.shape[1] == 1:
                probs = torch.sigmoid(logits)
            else:

                probs = torch.softmax(logits, dim=1)[:, 1:2, ...]


            preds = (probs > 0.5).long()




            batch_size = imgs.size(0)
            for b in range(batch_size):
                gt_b   = gts[b, 0].cpu().numpy()
                pred_b = preds[b, 0].cpu().numpy()


                dbscan_res = analyze_slice_with_dbscan(
                    gt_2d=gt_b,
                    pred_2d=pred_b,
                    eps=eps,
                    min_samples=min_samples
                )


                dice_b = dice_score_torch(preds[b:b+1], gts[b:b+1])

                slice_result = {
                    "name": str(names[b]),
                    "dice": dice_b,
                    "n_fn_pixels": dbscan_res["n_fn_pixels"],
                    "n_fp_pixels": dbscan_res["n_fp_pixels"],
                    "fn_clusters": dbscan_res["fn_clusters"],
                    "fp_clusters": dbscan_res["fp_clusters"],
                }
                all_slices_results.append(slice_result)




    if save_json_path is not None:
        os.makedirs(os.path.dirname(save_json_path), exist_ok=True)
        with open(save_json_path, "w") as f:
            json.dump(all_slices_results, f, indent=2)
        print(f"[INFO] Resultados guardados en: {save_json_path}")

    return all_slices_results






if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("[INFO] Using device:", device)






















    model = ConvNeXtSegmentationMF(
            num_classes   = 1,
            backbone_type = 'base')
    model.to(device)





    ckpt_path = "pretrained_model/ConvNeXtSegmentationMF_fold1_2021.pth"

    checkpoint = torch.load(ckpt_path, map_location=device)

    if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
        print("[INFO] Loaded checkpoint['model_state_dict'].")
    else:
        model.load_state_dict(checkpoint)
        print("[INFO] Loaded checkpoint as plain state_dict.")

    model.to(device)














    all_results = run_failure_analysis(
         model=model,
         test_loader=test_loader,
         device=device,
         eps=3,
         min_samples=20,
         save_json_path="./failure_analysis/dbscan_results.json"
     )


    print("\n=== Slices con clústeres FN grandes (>100 píxeles) ===")
    for r in all_results:
         big_fn = [c for c in r["fn_clusters"] if c["size"] > 100]
         if len(big_fn) > 0:
             print(f"Slice: {r['name']}, Dice={r['dice']:.4f}, "
                   f"Big FN clusters={len(big_fn)}")


[INFO] Using device: cuda
[INFO] Loaded checkpoint as plain state_dict.
[INFO] Resultados guardados en: ./failure_analysis/dbscan_results.json

=== Slices con clústeres FN grandes (>100 píxeles) ===
Slice: BraTS2021_01570, Dice=0.9571, Big FN clusters=3
Slice: BraTS2021_00578, Dice=0.9499, Big FN clusters=4
Slice: BraTS2021_00380, Dice=0.8843, Big FN clusters=2
Slice: BraTS2021_00241, Dice=0.9592, Big FN clusters=1
Slice: BraTS2021_00155, Dice=0.8206, Big FN clusters=4
Slice: BraTS2021_00605, Dice=0.9595, Big FN clusters=5
Slice: BraTS2021_00759, Dice=0.9311, Big FN clusters=3
Slice: BraTS2021_01164, Dice=0.9312, Big FN clusters=5
Slice: BraTS2021_00286, Dice=0.5821, Big FN clusters=3
Slice: BraTS2021_01415, Dice=0.7721, Big FN clusters=2
Slice: BraTS2021_00397, Dice=0.8802, Big FN clusters=3
Slice: BraTS2021_00113, Dice=0.8636, Big FN clusters=3
Slice: BraTS2021_00012, Dice=0.9446, Big FN clusters=4
Slice: BraTS2021_01172, Dice=0.9446, Big FN clusters=3
Slice: BraTS2021_01625, Dice=0.

In [None]:
import torch

def dice_coefficient(pred, target, eps=1e-6):
    if pred.dim() == 3:
        pred = pred.squeeze(0)
    if target.dim() == 3:
        target = target.squeeze(0)

    pred = pred.float()
    target = target.float()

    intersection = (pred * target).sum()
    union = pred.sum() + target.sum()
    dice = (2.0 * intersection + eps) / (union + eps)
    return dice.item()


def run_failure_analysis(model, test_loader, device, dice_threshold=0.70):
    model.eval()
    results = []

    debug_limit = 10
    debug_count = 0

    with torch.no_grad():
        for imgs, gts, case_ids in test_loader:
            imgs = imgs.to(device)
            gts  = gts.to(device)

            logits = model(imgs)

            preds  = (logits > 0.4).float()

            B = imgs.size(0)
            for i in range(B):
                cid = case_ids[i]
                pred_slice = preds[i, 0].cpu()
                gt_slice   = gts[i, 0].cpu()

                dice_wt = dice_coefficient(pred_slice, gt_slice)
                volume_voxels = gt_slice.sum().item()
                pred_voxels   = pred_slice.sum().item()
                intersec      = (pred_slice * gt_slice).sum().item()


                if debug_count < debug_limit:
                    print(
                        f"{cid} | Dice: {dice_wt:.4f} | "
                        f"GT voxels: {volume_voxels:.0f} | "
                        f"Pred voxels: {pred_voxels:.0f} | "
                        f"Intersec: {intersec:.0f}"
                    )
                    debug_count += 1

                results.append({
                    "case_id": cid,
                    "dice_wt": dice_wt,
                    "volume_voxels": volume_voxels,
                    "is_failure": dice_wt < dice_threshold,
                })

    return results


In [None]:
def summarize_by_volume(results):
    if len(results) == 0:
        return {}

    volumes = torch.tensor([r["volume_voxels"] for r in results], dtype=torch.float32)
    dices   = torch.tensor([r["dice_wt"]        for r in results], dtype=torch.float32)


    p33, p66 = volumes.quantile(torch.tensor([0.33, 0.66]))

    def volume_group(v):
        if v <= p33:
            return "Small"
        elif v <= p66:
            return "Medium"
        else:
            return "Large"


    for r in results:
        r["volume_group"] = volume_group(r["volume_voxels"])

    summary = {}
    for group in ["Small", "Medium", "Large"]:
        group_cases = [r for r in results if r["volume_group"] == group]
        if len(group_cases) == 0:
            continue

        dice_vals = torch.tensor([r["dice_wt"] for r in group_cases])
        failures  = sum(r["is_failure"] for r in group_cases)
        n_cases   = len(group_cases)

        summary[group] = {
            "n_cases": n_cases,
            "dice_mean": dice_vals.mean().item(),
            "dice_std": dice_vals.std(unbiased=False).item(),
            "dice_median": dice_vals.median().item(),
            "failure_rate": 100.0 * failures / n_cases
        }

    return summary


In [None]:
%run src/ConvNext_models.ipynb
from torchvision import transforms
from torch.utils.data import DataLoader, random_split

transform = transforms.Compose([
    transforms.Resize((512, 512)),



])

mask_transform = transforms.Compose([
    transforms.Resize((512, 512))
])

Output: torch.Size([2, 1, 512, 512])


In [None]:
from torch.utils.data import DataLoader


base_path    = "../datasets/BRATS2024"
dataset = BRATSDataset_2(base_path, img_transform=transform, mask_transform = mask_transform)


train_size = int(0.7 * len(dataset))
val_size   = int(0.15 * len(dataset))
test_size  = len(dataset) - train_size - val_size


train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])


BATCH_SIZE = 6
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size   = BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size  = BATCH_SIZE, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


num_classes = 1
model = ConvNeXtSegmentationMF(num_classes, backbone_type='base')
model.to(device)

ckpt_path = "pretrained_model/ConvNeXtSegmentationMF_fold1_2024.pth"
checkpoint = torch.load(ckpt_path, map_location=device)







model.eval()


results = run_failure_analysis(model, test_loader, device, dice_threshold=0.60)
summary = summarize_by_volume(results)

print("=== Resumen por volumen ===")
for group, stats in summary.items():
    print(group, stats)

print("\nEjemplo de primeras filas de results:")
for r in results[:5]:
    print(r)


BraTS-GLI-02152-103 | Dice: 0.1824 | GT voxels: 9237 | Pred voxels: 20920 | Intersec: 2750
BraTS-GLI-02092-102 | Dice: 0.1361 | GT voxels: 9990 | Pred voxels: 15212 | Intersec: 1715
BraTS-GLI-02194-105 | Dice: 0.1909 | GT voxels: 14477 | Pred voxels: 17194 | Intersec: 3023
BraTS-GLI-02105-104 | Dice: 0.1023 | GT voxels: 4560 | Pred voxels: 17436 | Intersec: 1126
BraTS-GLI-02111-104 | Dice: 0.1190 | GT voxels: 8973 | Pred voxels: 35477 | Intersec: 2644
BraTS-GLI-02248-101 | Dice: 0.0981 | GT voxels: 7885 | Pred voxels: 16863 | Intersec: 1214
BraTS-GLI-02066-104 | Dice: 0.3226 | GT voxels: 24458 | Pred voxels: 24515 | Intersec: 7899
BraTS-GLI-02063-105 | Dice: 0.2808 | GT voxels: 17722 | Pred voxels: 15898 | Intersec: 4720
BraTS-GLI-02520-104 | Dice: 0.1509 | GT voxels: 13524 | Pred voxels: 13023 | Intersec: 2003
BraTS-GLI-02093-102 | Dice: 0.0892 | GT voxels: 8371 | Pred voxels: 16112 | Intersec: 1092
=== Resumen por volumen ===
Small {'n_cases': 10, 'dice_mean': 0.12164004147052765, 'd

In [None]:
thresholds = [0.2, 0.3, 0.4, 0.5]

for thr in thresholds:

    for r in results:
        r["is_failure"] = (r["dice_wt"] < thr)

    summary = summarize_by_volume(results)
    print(f"\n=== Threshold Dice = {thr} ===")
    for group, stats in summary.items():
        print(group, stats)


=== Threshold Dice = 0.2 ===
Small {'n_cases': 10, 'dice_mean': 0.12164004147052765, 'dice_std': 0.04142368212342262, 'dice_median': 0.09808380156755447, 'failure_rate': 100.0}
Medium {'n_cases': 10, 'dice_mean': 0.12958410382270813, 'dice_std': 0.027641048654913902, 'dice_median': 0.13375408947467804, 'failure_rate': 100.0}
Large {'n_cases': 10, 'dice_mean': 0.20714735984802246, 'dice_std': 0.06369852274656296, 'dice_median': 0.19341248273849487, 'failure_rate': 60.0}

=== Threshold Dice = 0.3 ===
Small {'n_cases': 10, 'dice_mean': 0.12164004147052765, 'dice_std': 0.04142368212342262, 'dice_median': 0.09808380156755447, 'failure_rate': 100.0}
Medium {'n_cases': 10, 'dice_mean': 0.12958410382270813, 'dice_std': 0.027641048654913902, 'dice_median': 0.13375408947467804, 'failure_rate': 100.0}
Large {'n_cases': 10, 'dice_mean': 0.20714735984802246, 'dice_std': 0.06369852274656296, 'dice_median': 0.19341248273849487, 'failure_rate': 90.0}

=== Threshold Dice = 0.4 ===
Small {'n_cases': 10