In [None]:
import os
import time
from glob import glob

import nibabel as nib
import numpy as np
import torch
from torch import nn

from monai.data import CacheDataset, DataLoader
from monai.metrics import DiceMetric
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped,
    Orientationd, Spacingd, NormalizeIntensityd, ResizeWithPadOrCropd
)
from monai.utils import set_determinism

from preprocess.utils.RepeatChannel import RepeatChannelsd
from tomultichannel import ConvertToMultiChannel
from model_builder import UNet3D
# from U_Mamba_net import U_Mamba_net   # <- for later if you want

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_determinism(seed=0)

NUM_WORKERS = 1

# ---- CHANGE THESE ----
MODEL_PATH = "/home/luudh/luudh/MyFile/medical_image_lab/monai/going_modular/model/UNet3D.pth"
IMAGES_DIR = "/home/luudh/luudh/MyFile/medical_image_lab/monai/data/Task01_BrainTumour/imagesVal"
LABELS_DIR = "/home/luudh/luudh/MyFile/medical_image_lab/monai/data/Task01_BrainTumour/labelsVal"
RESULTS_CSV = "preliminary_dice_unet.csv"
# ----------------------




In [25]:
def build_eval_dataset():
    val_images = sorted(glob(os.path.join(IMAGES_DIR, "*.nii.gz")))
    val_labels = sorted(glob(os.path.join(LABELS_DIR, "*.nii.gz")))

    assert len(val_images) == len(val_labels), \
        f"Found {len(val_images)} images but {len(val_labels)} labels."

    val_data = [{"image": i, "label": l} for i, l in zip(val_images, val_labels)]

    val_transforms = Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        EnsureTyped(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=(240, 240, 144)),
        RepeatChannelsd(keys=["image"], target_channels=4),   # 4 MRI modalities
        ConvertToMultiChannel(keys="label"),                  # 3 channels (WT/TC/ET etc.)
        EnsureTyped(keys=["image", "label"]),
    ])

    val_ds = CacheDataset(
        data=val_data,
        transform=val_transforms,
        cache_rate=0.1,
        num_workers=NUM_WORKERS,
    )
    return val_ds


In [26]:
def load_unet3d():
    model = UNet3D(in_channels=4, out_channels=3).to(device)

    # IMPORTANT: explicitly set weights_only=False for PyTorch 2.6+
    checkpoint = torch.load(
        MODEL_PATH,
        map_location=device,
        weights_only=False,   # <- add this
    )

    # handle both types: full checkpoint dict or raw state_dict
    if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
        state_dict = checkpoint["model_state_dict"]
    else:
        state_dict = checkpoint

    # strip "module." if saved with DDP
    new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    model.load_state_dict(new_state_dict, strict=False)
    model.eval()
    return model



In [27]:
from monai.transforms import Activations, AsDiscrete

def evaluate_model(model, val_ds, results_csv=RESULTS_CSV):
    loader = DataLoader(val_ds, batch_size=1, num_workers=NUM_WORKERS)

    # Post-processing: sigmoid + threshold -> binary masks
    post_trans = Compose([
        Activations(sigmoid=True),
        AsDiscrete(threshold=0.5),
    ])

    dice_metric = DiceMetric(include_background=False, reduction="mean_batch")   # for overall
    dice_per_case_metric = DiceMetric(include_background=False, reduction="none")  # for per-case

    per_case_scores = []
    model.eval()

    with torch.no_grad():
        for i, batch in enumerate(loader):
            images = batch["image"].to(device)
            labels = batch["label"].to(device)  # already multi-channel (C=3)

            logits = model(images)
            preds = post_trans(logits)

            # accumulate for overall Dice
            dice_metric(y_pred=preds, y=labels)

            # per-case Dice (shape: [B, C])
            d_case = dice_per_case_metric(y_pred=preds, y=labels)
            d_case_np = d_case.cpu().numpy().reshape(-1, d_case.shape[-1])  # (1, C)
            per_case_scores.append(d_case_np)

            print(f"[{i+1}/{len(loader)}] Dice per class (case {i}):", d_case_np[0])

    mean_dice = dice_metric.aggregate().cpu().numpy()  # (C,)
    dice_metric.reset()

    print("\n=== PRELIMINARY UNet RESULTS (Dice, include_background=False) ===")
    print(f"Mean Dice per class: {mean_dice}")
    print(f"Mean Dice (all classes): {mean_dice.mean():.4f}")

    # Save per-case Dice scores to CSV
    if per_case_scores:
        all_scores = np.vstack(per_case_scores)  # shape (N_cases, C)
        header = "class1,class2,class3"
        np.savetxt(results_csv, all_scores, delimiter=",", header=header, comments="")
        print(f"Per-case Dice scores saved to: {results_csv}")


In [28]:
if __name__ == "__main__":
    torch.cuda.empty_cache()
    print("[INFO] Building evaluation dataset...")
    val_ds = build_eval_dataset()
    print(f"[INFO] Validation/Test set size: {len(val_ds)} cases")

    print("[INFO] Loading UNet3D model...")
    model = load_unet3d()

    print("[INFO] Evaluating model...")
    start = time.time()
    evaluate_model(model, val_ds)
    end = time.time()
    print(f"[INFO] Evaluation time: {end - start:.1f} seconds")


[INFO] Building evaluation dataset...


Loading dataset: 100%|██████████| 4/4 [00:04<00:00,  1.10s/it]


[INFO] Validation/Test set size: 48 cases
[INFO] Loading UNet3D model...
[INFO] Evaluating model...
[1/48] Dice per class (case 0): [0.75157905 0.4251383 ]
[2/48] Dice per class (case 1): [0.8002038  0.16205518]
[3/48] Dice per class (case 2): [0.82215244 0.16976556]
[4/48] Dice per class (case 3): [0.5495274  0.09812512]
[5/48] Dice per class (case 4): [0.66676384 0.08962226]
[6/48] Dice per class (case 5): [0.5249975  0.06016559]
[7/48] Dice per class (case 6): [0.59487194 0.01662111]
[8/48] Dice per class (case 7): [0.60108054 0.25122195]
[9/48] Dice per class (case 8): [0.54279864 0.05440509]
[10/48] Dice per class (case 9): [0.43135354 0.0962012 ]
[11/48] Dice per class (case 10): [0.85900724 0.05353356]
[12/48] Dice per class (case 11): [0.38332075 0.16988166]
[13/48] Dice per class (case 12): [0.57058936 0.04916548]
[14/48] Dice per class (case 13): [0.34758857 0.00047921]
[15/48] Dice per class (case 14): [0.6942916  0.06638841]
[16/48] Dice per class (case 15): [0.51774704 0.1

In [None]:
import argparse
import glob
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from typing import List, Tuple

import numpy as np
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from torch import nn

from monai.data import Dataset, DataLoader, decollate_batch
from monai.inferers import sliding_window_inference
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped, Orientationd,
    Spacingd, NormalizeIntensityd, ResizeWithPadOrCropd,
    AsDiscreted, Activationsd
)
from monai.metrics import DiceMetric
from monai.transforms import AsDiscrete, Activations

# local imports
from U_Mamba_net import U_Mamba_net
from model_builder import UNet3D
from preprocess.utils.RepeatChannel import RepeatChannelsd


# ----------------------
#  Helpers
# ----------------------

def build_model(name: str, in_channels: int, num_classes: int, device: torch.device) -> nn.Module:
    name = name.lower()
    if name in {"u_mamba", "mamba", "u-mamba"}:
        model = U_Mamba_net(in_channels=in_channels, num_classes=num_classes)
    elif name in {"unet", "unet3d"}:
        model = UNet3D(in_channels=in_channels, out_channels=num_classes)
    else:
        raise ValueError(f"Unknown model '{name}'. Choose from ['u_mamba','unet3d'].")
    return model.to(device)


def load_weights(model: nn.Module, ckpt_path: str) -> None:
    # robust to PyTorch 2.6 weights_only change
    try:
        ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    except TypeError:
        # older PyTorch versions without weights_only argument
        ckpt = torch.load(ckpt_path, map_location="cpu")

    state = ckpt.get("model_state_dict", ckpt)
    new_state = {k.replace("module.", ""): v for k, v in state.items()}
    missing, unexpected = model.load_state_dict(new_state, strict=False)
    if missing:
        print(f"[warn] Missing keys: {sorted(missing)[:10]} ...")
    if unexpected:
        print(f"[warn] Unexpected keys: {sorted(unexpected)[:10]} ...")


def make_preprocess(roi: Tuple[int, int, int], target_channels: int):
    """
    Preprocess for BRATS-like images and labels.
    Assumes labels are integer masks (0..C).
    """
    return Compose([
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        EnsureTyped(keys=["image", "label"], track_meta=True),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode="bilinear"),
        Spacingd(keys=["label"], pixdim=(1.0, 1.0, 1.0), mode="nearest"),
        NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
        ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=roi),
        RepeatChannelsd(keys=["image"], target_channels=target_channels),
        EnsureTyped(keys=["image", "label"], track_meta=True),
    ])


def parse_args():
    p = argparse.ArgumentParser(description="3D brain tumor preliminary evaluation (Dice)")
    p.add_argument("--model", default="u_mamba", help="u_mamba or unet3d")
    p.add_argument("--weights", required=True, help="Path to .pth checkpoint")
    p.add_argument("--images", required=True, nargs="+", help="Glob(s) or path(s) to image .nii/.nii.gz")
    p.add_argument("--labels", required=True, nargs="+", help="Glob(s) or path(s) to label .nii/.nii.gz")
    p.add_argument("--roi", type=int, nargs=3, default=[128, 128, 64], help="Sliding-window ROI size")
    p.add_argument("--sw-batch", type=int, default=1, help="Sliding window batch size")
    p.add_argument("--overlap", type=float, default=0.3, help="Sliding window overlap [0-1]")
    p.add_argument("--channels", type=int, default=4, help="Model input channels")
    p.add_argument("--num-classes", type=int, default=3, help="Output channels/classes")
    p.add_argument("--activation", choices=["sigmoid", "softmax"], default="softmax")
    p.add_argument("--multilabel", action="store_true", help="Treat outputs as independent classes (sigmoid)")
    p.add_argument("--threshold", type=float, default=0.5, help="Sigmoid threshold for multilabel")
    p.add_argument("--amp", action="store_true", help="Enable mixed precision")
    p.add_argument("--out-csv", default="prelim_results.csv", help="CSV file to store per-case Dice")
    return p.parse_args()


def resolve_paths(patterns: List[str]) -> List[str]:
    files = []
    for pat in patterns:
        files.extend(glob.glob(pat))
    files = sorted(set(files))
    if not files:
        raise FileNotFoundError(f"No files matched: {patterns}")
    return files


# ----------------------
#  Main eval
# ----------------------

def main():
    args = parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[INFO] Using device: {device}")

    # model
    model = build_model(args.model, in_channels=args.channels, num_classes=args.num_classes, device=device)
    load_weights(model, args.weights)
    model.eval()

    # data
    img_files = resolve_paths(args.images)
    lbl_files = resolve_paths(args.labels)
    if len(img_files) != len(lbl_files):
        raise RuntimeError(f"#images ({len(img_files)}) != #labels ({len(lbl_files)})")

    print(f"[INFO] Found {len(img_files)} image/label pairs for evaluation.")

    data = [
        {"image": i, "label": l, "case_id": os.path.basename(i)}
        for i, l in zip(img_files, lbl_files)
    ]

    pre_tf = make_preprocess(tuple(args.roi), args.channels)
    ds = Dataset(data=data, transform=pre_tf)
    loader = DataLoader(ds, batch_size=1, shuffle=False, num_workers=0)

    # metrics
    dice_metric = DiceMetric(include_background=False, reduction="mean_batch", get_not_nans=True)

    # post-processing for metrics
    if args.multilabel:
        post_pred = Compose([
            Activations(sigmoid=True),
            AsDiscrete(threshold=args.threshold),
        ])
        post_label = AsDiscrete(threshold=0.5)  # assumes labels already 0/1 per channel if multilabel GT
    else:
        # standard multi-class softmax + argmax + one-hot
        post_pred = Compose([
            Activations(softmax=True),
            AsDiscrete(argmax=True, to_onehot=args.num_classes),
        ])
        post_label = AsDiscrete(to_onehot=args.num_classes)

    case_ids = []
    per_case_dice = []  # [N, C]

    with torch.no_grad():
        for batch in loader:
            images = batch["image"].to(device)
            labels = batch["label"].to(device)
            case_id = batch["case_id"][0]
            case_ids.append(case_id)

            # sliding-window inference
            def _fwd(inp):
                out = model(inp)
                return out[0] if isinstance(out, (list, tuple)) else out

            if device.type == "cuda" and args.amp:
                with torch.amp.autocast("cuda"):
                    logits = sliding_window_inference(
                        inputs=images,
                        roi_size=tuple(args.roi),
                        sw_batch_size=args.sw_batch,
                        predictor=_fwd,
                        overlap=args.overlap,
                        mode="gaussian",
                    )
            else:
                logits = sliding_window_inference(
                    inputs=images,
                    roi_size=tuple(args.roi),
                    sw_batch_size=args.sw_batch,
                    predictor=_fwd,
                    overlap=args.overlap,
                    mode="gaussian",
                )

            # decollate & post-process
            logits_list = decollate_batch(logits)
            labels_list = decollate_batch(labels)

            preds = [post_pred(p) for p in logits_list]
            labs = [post_label(l) for l in labels_list]

            # update metric
            dice_metric(y_pred=preds, y=labs)
            # get per-case (for this batch size=1)
            per_case = dice_metric.aggregate(reduction="none").cpu().numpy()  # shape [B, C]
            dice_metric.reset()  # reset because we're tracking per-case ourselves

            per_case_dice.append(per_case[0])  # [C]

    per_case_dice = np.stack(per_case_dice, axis=0)  # [N, C]
    mean_per_class = per_case_dice.mean(axis=0)      # [C]
    mean_dice = per_case_dice.mean()

    # print summary
    print("\n=== Preliminary Dice Results ===")
    print(f"Num cases: {len(case_ids)}")
    print(f"Mean Dice (all classes, all cases): {mean_dice:.4f}")
    for c in range(args.num_classes):
        print(f"  Class {c}: mean Dice = {mean_per_class[c]:.4f}")

    # save CSV
    import csv
    print(f"\n[INFO] Saving per-case Dice to {args.out_csv}")
    with open(args.out_csv, "w", newline="") as f:
        writer = csv.writer(f)
        header = ["case_id"] + [f"class_{c}_dice" for c in range(args.num_classes)]
        writer.writerow(header)
        for cid, dice_vec in zip(case_ids, per_case_dice):
            writer.writerow([cid] + [float(x) for x in dice_vec])

    print("[INFO] Done.")




In [None]:
case_ids, per_case_dice, mean_per_class, mean_dice = run_prelim_eval(
    model_name="unet3d",
    weights_path="model/Medical_Image_UNet3D_fresh.pth",
    image_globs=["/home/luudh/.../imagesVal/*.nii.gz"],
    label_globs=["/home/luudh/.../labelsVal/*.nii.gz"],
    roi=(128, 128, 64),
    num_classes=3,
    activation="softmax",
    multilabel=False,
    use_amp=True,
)
