In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)


Mon Dec  8 04:19:22 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:00:05.0 Off |                    0 |
| N/A   32C    P0             54W /  400W |       0MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


env: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


In [None]:
!pip -q install nibabel SimpleITK==2.3.1 einops==0.8.0 tqdm==4.67

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.6/57.6 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.7/52.7 MB[0m [31m45.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.6/78.6 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os, sys, json, glob, random, math, time, csv
from pathlib import Path
import numpy as np
import nibabel as nib
import SimpleITK as sitk
from tqdm import tqdm
from scipy import ndimage as ndi

import pandas as pd
import shutil

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#Copy from memory to vm
!cp /content/drive/MyDrive/m3a_neuroseg/cache2023_only.tar.gz /content/


In [None]:
#Unzip
!tar -xzf /content/cache2023_only.tar.gz -C /content/


In [None]:
!rm /content/cache2023_only.tar.gz


In [None]:
#Use when copied into VM
GDRIVE_ROOT=f'/content/drive/MyDrive'
GDRIVE_PROJ_ROOT=f'{GDRIVE_ROOT}/m3a_neuroseg'
GDRIVE_RUNS_DIR=f'{GDRIVE_PROJ_ROOT}/runs'


DRIVE_ROOT = "/content"
BRATS_ROOT = f"/content"          # e.g., /MyDrive/brats
BRATS23 = f"{BRATS_ROOT}/cache2023_only"

PROJ_ROOT = f"{DRIVE_ROOT}/m3a_neuroseg"    # project home
os.makedirs(PROJ_ROOT, exist_ok=True)
IDX_DIR = f"{PROJ_ROOT}/data_index"; os.makedirs(IDX_DIR, exist_ok=True)

CACHE_DIR = f"{BRATS_ROOT}/cache2023_only";
#os.makedirs(CACHE_DIR, exist_ok=True)
RUNS_DIR = f"{PROJ_ROOT}/runs"; os.makedirs(RUNS_DIR, exist_ok=True)
FIGS_DIR = f"{PROJ_ROOT}/figs"; os.makedirs(FIGS_DIR, exist_ok=True)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
SEED = 42
random.seed(SEED); np.random.seed(SEED)
print("Ready. Roots:\n", BRATS23, "\n", "\n", PROJ_ROOT,'\n',CACHE_DIR,'\n',RUNS_DIR)

Ready. Roots:
 /content/cache2023_only 
 
 /content/m3a_neuroseg 
 /content/cache2023_only 
 /content/m3a_neuroseg/runs


In [None]:
#copy data indexes and splits
!cp /content/drive/MyDrive/m3a_neuroseg/data_index/brats2023_index.csv /content/m3a_neuroseg/data_index

!cp /content/drive/MyDrive/m3a_neuroseg/data_index/splits_5fold_seed42.json /content/m3a_neuroseg/data_index

In [None]:
#index_path = "/content/drive/MyDrive/m3a_neuroseg/data_index/brats2023_index.csv"
index_path = "/content/m3a_neuroseg/data_index/brats2023_index.csv"
df = pd.read_csv(index_path)

# First column is case_id
case_ids = df.iloc[:, 0].tolist()
print(f"✅ Found {len(case_ids)} BRATS2023 cases.")

✅ Found 1251 BRATS2023 cases.


In [None]:
# ================================================
# Cell : PyTorch Dataset & tumor-aware patch sampler
# ================================================
#!pip -q install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu -q
# (If you have GPU runtime, comment the line above and rely on built-in CUDA torch.)

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
torch.backends.cudnn.benchmark = True
PATCH = 112
TUMOR_CENTER_PROB = 0.6

class BratsPatchDataset(Dataset):
    def __init__(self, case_ids, cache_dir=CACHE_DIR, patch=PATCH, tumor_center_prob=TUMOR_CENTER_PROB, augment=True):
        self.ids = case_ids
        self.cache = Path(cache_dir)
        self.patch = patch
        self.tumor_center_prob = tumor_center_prob
        self.augment = augment

    def _load_case(self, cid):
        d = self.cache / cid
        t1c = np.load(d/"t1c.npy",mmap_mode='r'); t1n = np.load(d/"t1n.npy",mmap_mode='r')
        t2w = np.load(d/"t2w.npy",mmap_mode='r'); t2f = np.load(d/"t2f.npy",mmap_mode='r')
        seg = np.load(d/"seg.npy",mmap_mode='r')
        vol = np.stack([t1c, t1n, t2w, t2f], axis=0) # (C,Z,Y,X)
        return vol, seg

    def _rand_center(self, seg, tumor_bias=True):
        Z,Y,X = seg.shape
        if tumor_bias and np.random.rand() < self.tumor_center_prob and (seg>0).any():
            tz,ty,tx = np.array(np.where(seg>0)).T[np.random.randint((seg>0).sum())]
        else:
            tz,ty,tx = np.random.randint(Z), np.random.randint(Y), np.random.randint(X)
        return tz,ty,tx

    def _crop_patch(self, vol, seg, cz, cy, cx):
        # vol: (C,Z,Y,X)
        C,Z,Y,X = vol.shape
        ps = self.patch
        z0 = np.clip(cz - ps//2, 0, Z-ps); y0 = np.clip(cy - ps//2, 0, Y-ps); x0 = np.clip(cx - ps//2, 0, X-ps)
        z1,y1,x1 = z0+ps, y0+ps, x0+ps
        v = vol[:, z0:z1, y0:y1, x0:x1]
        s = seg[z0:z1, y0:y1, x0:x1]
        return v, s

    def _augment(self, v, s):
        # Simple flips + small rotations (in-plane). Keep light to start.
        if np.random.rand()<0.5:
            v = v[:, :, :, ::-1]; s = s[:, :, ::-1]
        if np.random.rand()<0.5:
            v = v[:, :, ::-1, :]; s = s[:, ::-1, :]
        if np.random.rand()<0.5:
            v = v[:, ::-1, :, :]; s = s[::-1, :, :]
        # modality dropout (15%)
        if np.random.rand()<0.15:
            ch = np.random.randint(4); v[ch] = 0.0
        return v, s

    def __len__(self): return len(self.ids)

    def __getitem__(self, idx):
        cid = self.ids[idx]
        vol, seg = self._load_case(cid)
        cz,cy,cx = self._rand_center(seg, tumor_bias=True)
        v, s = self._crop_patch(vol, seg, cz,cy,cx)
        if self.augment:
            v, s = self._augment(v, s)
        # to tensors
        v = torch.from_numpy(v.copy()).float()
        s = torch.from_numpy(s.copy()).long()  # (Z,Y,X)
        # one-hot
        num_classes = 4
        oh = torch.zeros((num_classes,)+s.shape, dtype=torch.float32)
        for c in range(num_classes):
            oh[c] = (s==c).float()
        return dict(image=v, target=oh, case_id=cid)


In [None]:
!pip install monai nibabel tqdm

Collecting monai
  Downloading monai-1.5.1-py3-none-any.whl.metadata (13 kB)
Downloading monai-1.5.1-py3-none-any.whl (2.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m54.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: monai
Successfully installed monai-1.5.1


In [None]:
# =========================
# Baselines: 3D UNet and nnU-Net-like (DynUNet)
# =========================
import torch
import torch.nn as nn
from monai.networks.nets import UNet, DynUNet

class ReturnTuple(nn.Module):
    """
    Wrap any segmentation model so that forward() returns (seg_logits, None),
    matching your M3ANeuroSeg signature (seg_logits, bmap_logits).
    """
    def __init__(self, core: nn.Module):
        super().__init__()
        self.core = core
    def forward(self, x):
        y = self.core(x)
        return y, None

def make_unet3d(in_ch=4, out_ch=4, base=32):
    """
    Classic 3D U-Net (MONAI UNet). Comparable to “3D U-Net baseline”.
    Depth: 5 levels, InstanceNorm, residual blocks for stability.
    """
    model = UNet(
        spatial_dims=3,
        in_channels=in_ch,
        out_channels=out_ch,
        channels=(base, base*2, base*4, base*8, base*10),
        strides=(2, 2, 2, 2),
        num_res_units=2,
        norm="INSTANCE",
        dropout=0.0
    )
    return ReturnTuple(model)

import torch
import torch.nn as nn
from monai.networks.nets import DynUNet

class DynUNetWrapper(nn.Module):
    """
    Wrapper to make MONAI DynUNet compatible with (seg_logits, bmap_logits) outputs.
    """
    def __init__(self, in_ch=4, out_ch=4, base=32):
        super().__init__()
        kernel_size = [
            (3,3,3), (3,3,3), (3,3,3), (3,3,3), (3,3,3)
        ]
        strides = [
            (1,1,1), (2,2,2), (2,2,2), (2,2,2), (2,2,2)
        ]
        upsample_kernel_size = [
            (2,2,2), (2,2,2), (2,2,2), (2,2,2)
        ]
        filters = [base, base*2, base*4, base*8, base*16]

        self.model = DynUNet(
            spatial_dims=3,
            in_channels=in_ch,
            out_channels=out_ch,
            kernel_size=kernel_size,
            strides=strides,
            upsample_kernel_size=upsample_kernel_size,
            filters=filters,
            res_block=True,
            norm_name=("instance", {"affine": True}),
            act_name=("leakyrelu", {"negative_slope": 0.01, "inplace": True}),
            dropout=0.0,
            deep_supervision=False,
        )

    def forward(self, x):
        seg_logits = self.model(x)
        return seg_logits, None






In [None]:
import torch
import torch.nn.functional as F

# ---- helper loss components ----
def dice_loss(logits, targets, eps=1e-5):
    # expects logits, targets in shape (B,C,D,H,W)
    probs = torch.softmax(logits, dim=1)
    num = 2 * (probs * targets).sum(dim=(0,2,3,4))
    den = (probs + targets).sum(dim=(0,2,3,4))
    dice = 1 - (num + eps) / (den + eps)
    return dice.mean()

def focal_loss(logits, targets, gamma=2.0, alpha=0.25):
    probs = torch.softmax(logits, dim=1)
    ce = F.cross_entropy(logits, targets.argmax(1), reduction='none')
    pt = torch.exp(-ce)
    loss = alpha * (1 - pt) ** gamma * ce
    return loss.mean()

def bce_loss(logits, targets):
    probs = torch.softmax(logits, dim=1)
    bce = -(targets * torch.log(probs + 1e-8)).sum(dim=1)
    return bce.mean()
def dice_metric(y_pred, y_true, eps=1e-5):
    # y_pred/y_true: (B,C,D,H,W) one-hot
    dims = (0,2,3,4)
    inter = (y_pred * y_true).sum(dim=dims)
    denom = (y_pred + y_true).sum(dim=dims)
    dice = (2*inter + eps) / (denom + eps)
    return dice  # per-class
# optional boundary loss (safe dummy version)
def boundary_loss(bmap_logits, target):
    """
    If boundary maps exist (shape = (B,1,D,H,W)), compute simple BCE.
    Otherwise, return zero tensor.
    """
    if bmap_logits is None:
        return torch.tensor(0.0, device=target.device)
    if bmap_logits.shape[1] != 1:
        bmap_logits = bmap_logits.mean(dim=1, keepdim=True)
    target_boundary = (target.sum(dim=1, keepdim=True) > 0).float()
    return F.binary_cross_entropy_with_logits(bmap_logits, target_boundary)
def argmax_onehot(logits):
    arg = logits.argmax(dim=1, keepdim=True)
    return torch.zeros_like(logits).scatter_(1, arg, 1.0)
# ---- full unified loss ----
def compute_loss(seg_logits, bmap_logits, target, boundary_weight=0.0, gamma=2.0):
    """
    Unified loss for M3A / UNet / nnUNet-like
    """
    d = dice_loss(seg_logits, target)
    f = focal_loss(seg_logits, target, gamma=gamma)
    b = bce_loss(seg_logits, target)

    total = d + f + b

    # optional boundary component
    if bmap_logits is not None and boundary_weight > 0:
        bl = boundary_loss(bmap_logits, target)
        total = total + boundary_weight * bl
    else:
        bl = torch.tensor(0.0, device=seg_logits.device)

    return total, {
        "dice": float(d.item()),
        "focal": float(f.item()),
        "bce": float(b.item()),
        "boundary": float(bl.item())
    }


In [None]:
# ================================================
# Cell: M3A-NeuroSeg core blocks + model
# ================================================
from einops import rearrange
import torch
import torch.nn as nn
from einops import rearrange

def _partition_windows_3d(x, win):  # x: (B,C,D,H,W)
    B,C,D,H,W = x.shape
    wd, wh, ww = win
    assert D%wd==0 and H%wh==0 and W%ww==0, "D/H/W must be multiples of window size"
    x = x.view(B, C, D//wd, wd, H//wh, wh, W//ww, ww)
    x = x.permute(0,2,4,6,3,5,7,1).contiguous()  # (B, nD, nH, nW, wd, wh, ww, C)
    x = x.view(-1, wd*wh*ww, C)                  # (B*nWin, N, C)
    return x

def _reverse_windows_3d(x, win, out_shape):      # x: (B*nWin, N, C)
    B,C,D,H,W = out_shape
    wd, wh, ww = win
    nD, nH, nW = D//wd, H//wh, W//ww
    x = x.view(B, nD, nH, nW, wd, wh, ww, C)
    x = x.permute(0,7,1,4,2,5,3,6).contiguous()  # (B, C, nD, wd, nH, wh, nW, ww)
    x = x.view(B, C, D, H, W)
    return x

class WindowAttention3D(nn.Module):
    """
    Local 3D self-attention with non-overlapping windows.
    Drop-in for AxialAttention3D to get real attention behavior.
    """
    def __init__(self, dim, heads=4, window=8, qkv_bias=True):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.win = (window, window, window) if isinstance(window, int) else tuple(window)
        self.head_dim = dim // heads
        assert dim % heads == 0, "dim must be divisible by heads"
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)
        # simple relative bias per head per token pair inside a window (optional light version)
        N = self.win[0]*self.win[1]*self.win[2]
        self.rel_bias = nn.Parameter(torch.zeros(heads, N, N))

    def forward(self, x):
        # x: (B,C,D,H,W)
        B,C,D,H,W = x.shape
        wd, wh, ww = self.win
        # pad to multiples if needed
        padD = (wd - D % wd) % wd
        padH = (wh - H % wh) % wh
        padW = (ww - W % ww) % ww
        if padD or padH or padW:
            x = nn.functional.pad(x, (0,padW, 0,padH, 0,padD), mode="constant", value=0)
            Dp, Hp, Wp = x.shape[2], x.shape[3], x.shape[4]
        else:
            Dp, Hp, Wp = D, H, W

        # (B*nWin, N, C)
        xw = _partition_windows_3d(x, self.win)
        # project to qkv
        qkv = self.qkv(xw)  # (B*nWin, N, 3C)
        q,k,v = qkv.chunk(3, dim=-1)
        # split heads: (B*nWin, heads, N, head_dim)
        def split_heads(t): return rearrange(t, 'b n (h c) -> b h n c', h=self.heads)
        q,k,v = split_heads(q), split_heads(k), split_heads(v)
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)  # (B*nWin, heads, N, N)
        attn = attn + self.rel_bias  # broadcast per head
        attn = attn.softmax(dim=-1)

        out = attn @ v                       # (B*nWin, heads, N, head_dim)
        out = rearrange(out, 'b h n c -> b n (h c)')
        out = self.proj(out)                 # (B*nWin, N, C)
        # reverse windows
        out = _reverse_windows_3d(out, self.win, (B,C,Dp,Hp,Wp))
        # remove pad if added
        if (Dp,Hp,Wp)!=(D,H,W):
            out = out[:, :, :D, :H, :W]
        return out

class LayerNorm3d(nn.Module):
    def __init__(self, num_channels, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(1, num_channels, 1,1,1))
        self.bias = nn.Parameter(torch.zeros(1, num_channels, 1,1,1))
        self.eps = eps
    def forward(self, x):
        # x: (B,C,D,H,W)
        var = x.var(dim=(2,3,4), keepdim=True, unbiased=False)
        mean = x.mean(dim=(2,3,4), keepdim=True)
        x = (x - mean) / torch.sqrt(var + self.eps)
        return x * self.weight + self.bias

class ConvNeXtBlock3D(nn.Module):
    def __init__(self, dim, drop_path=0.0):
        super().__init__()
        self.dw = nn.Conv3d(dim, dim, kernel_size=7, padding=3, groups=dim)
        self.ln = LayerNorm3d(dim)
        self.pw1 = nn.Conv3d(dim, 4*dim, kernel_size=1)
        self.act = nn.GELU()
        self.pw2 = nn.Conv3d(4*dim, dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.ones(1,dim,1,1,1))  # LayerScale
        self.drop_path = drop_path
    def forward(self, x):
        shortcut = x
        x = self.dw(x)
        x = self.ln(x)
        x = self.pw1(x)
        x = self.act(x)
        x = self.pw2(x)
        x = self.gamma * x
        if self.drop_path>0 and self.training:
            if torch.rand(1).item()<self.drop_path:
                x = torch.zeros_like(x)
        return x + shortcut

class ModalityGate(nn.Module):
    """Tiny Squeeze-Excite across modality channels at the input of each stage."""
    def __init__(self, in_ch=4, hidden=8):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool3d(1)
        self.fc1 = nn.Conv3d(in_ch, hidden, 1)
        self.act = nn.GELU()
        self.fc2 = nn.Conv3d(hidden, in_ch, 1)
        self.sig = nn.Sigmoid()
    def forward(self, x):
        # x: (B,4,D,H,W)
        g = self.pool(x)
        g = self.fc1(g); g = self.act(g); g = self.fc2(g)
        g = self.sig(g)
        return x * g  # modality-wise reweight

class AxialAttention3D(nn.Module):
    """Factorized axial attention along D, H, W with local windows (for memory)."""
    def __init__(self, dim, heads=4, window=16):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.scale = (dim // heads) ** -0.5
        self.window = window
        self.to_qkv = nn.Conv3d(dim, dim*3, 1)
        self.proj = nn.Conv3d(dim, dim, 1)

    def _attn_axis(self, x, axis):
        # x: (B,C,D,H,W)
        B,C,D,H,W = x.shape
        qkv = self.to_qkv(x)
        q,k,v = torch.chunk(qkv, 3, dim=1)
        # split heads
        def split_heads(t):  # (B,heads,Ch,D,H,W)
            return rearrange(t, 'b (h c) d h1 w -> b h c d h1 w', h=self.heads)
        q,k,v = split_heads(q), split_heads(k), split_heads(v)

        # unfold windows along the chosen axis
        if axis==2:  # D
            win = self.window if D>=self.window else D
            q = q.unfold(dimension=3, size=win, step=win)  # (B,h,c, D//w, w, H, W)
            k = k.unfold(dimension=3, size=win, step=win)
            v = v.unfold(dimension=3, size=win, step=win)
            q = q * self.scale
            attn = (q.transpose(-1,-3) @ k.transpose(-1,-3).transpose(-2,-1))  # not exact; keep simple
        # For brevity we keep a simplified axial formulation.
        # In practice, using windowed attention libs (xformers/flash-attn) is preferred.

        # Fallback: return identity if axis handling too complex
        return x

    def forward(self, x):
        # For stability in this first skeleton, use a lightweight conv + identity skip
        # and leave full axial attention as future optimization.
        return x + nn.functional.conv3d(x, weight=torch.zeros((self.dim,self.dim,1,1,1), device=x.device), bias=None)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.pool = nn.MaxPool3d(2)
        self.proj = nn.Conv3d(in_ch, out_ch, 1)
    def forward(self, x): return self.proj(self.pool(x))

class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False)
        self.proj = nn.Conv3d(in_ch, out_ch, 1)
    def forward(self, x): return self.proj(self.up(x))

class XScaleAttentionBridge(nn.Module):
    """Query high-res features (neck) with current decoder tokens (very light)."""
    def __init__(self, ch_dec, ch_enc):
        super().__init__()
        self.q = nn.Conv3d(ch_dec, ch_dec, 1)
        self.k = nn.Conv3d(ch_enc, ch_dec, 1)
        self.v = nn.Conv3d(ch_enc, ch_dec, 1)
        self.proj = nn.Conv3d(ch_dec, ch_dec, 1)
    def forward(self, dec, enc):
        # dec, enc: (B,C,D,H,W) at similar scale
        q = self.q(dec); k = self.k(enc); v = self.v(enc)
        attn = torch.sigmoid(q * k)  # elementwise gating (cheap)
        out = self.proj(dec + attn * v)
        return out


class BoundaryHead(nn.Module):
    """Predict boundary LOGITS from segmentation logits (no sigmoid here)."""
    def __init__(self, in_ch=4):
        super().__init__()
        self.conv = nn.Conv3d(in_ch, 1, 3, padding=1)
    def forward(self, logits):
        # logits: (B,4,D,H,W) -> return boundary logits (B,1,D,H,W)
        return self.conv(logits)

class M3ANeuroSeg(nn.Module):
    def __init__(self, in_ch=4, num_classes=4, dims=(32, 64, 128, 256)):
        super().__init__()
        self.gate = ModalityGate(in_ch, hidden=8)
        self.stem = nn.Conv3d(in_ch, dims[0], kernel_size=3, padding=1)

        # Encoder dual-path (ConvNeXt local + axial/global placeholder)
        self.enc1_local = ConvNeXtBlock3D(dims[0]); self.enc1_global = WindowAttention3D(dims[0], heads=4, window=8)
        self.down1 = Down(dims[0], dims[1])

        self.enc2_local = ConvNeXtBlock3D(dims[1]); self.enc2_global = WindowAttention3D(dims[1], heads=4, window=8)
        self.down2 = Down(dims[1], dims[2])

        self.enc3_local = ConvNeXtBlock3D(dims[2]); self.enc3_global = WindowAttention3D(dims[2], heads=4, window=8)
        self.down3 = Down(dims[2], dims[3])

        self.enc4_local = ConvNeXtBlock3D(dims[3]); self.enc4_global = WindowAttention3D(dims[3], heads=4, window=8)

        # Neck (simple fusion; deformable skipped in skeleton)
        self.neck = nn.Conv3d(dims[3], dims[3], 1)

        # Decoder without raw U-skips; use bridges
        self.up3 = Up(dims[3], dims[2]); self.bridge3 = XScaleAttentionBridge(dims[2], dims[2])
        self.dec3 = ConvNeXtBlock3D(dims[2])

        self.up2 = Up(dims[2], dims[1]); self.bridge2 = XScaleAttentionBridge(dims[1], dims[1])
        self.dec2 = ConvNeXtBlock3D(dims[1])

        self.up1 = Up(dims[1], dims[0]); self.bridge1 = XScaleAttentionBridge(dims[0], dims[0])
        self.dec1 = ConvNeXtBlock3D(dims[0])

        self.head = nn.Conv3d(dims[0], num_classes, 1)
        self.boundary = BoundaryHead(num_classes)

    def forward(self, x):
      x = self.gate(x)
      x0 = self.stem(x)
      e1 = self.enc1_local(x0) + self.enc1_global(x0)
      x1 = self.down1(e1)
      e2 = self.enc2_local(x1) + self.enc2_global(x1)
      x2 = self.down2(e2)
      e3 = self.enc3_local(x2) + self.enc3_global(x2)
      x3 = self.down3(e3)
      e4 = self.enc4_local(x3) + self.enc4_global(x3)
      neck = self.neck(e4)

      d3 = self.up3(neck); d3 = self.bridge3(d3, e3); d3 = self.dec3(d3)
      d2 = self.up2(d3);   d2 = self.bridge2(d2, e2); d2 = self.dec2(d2)
      d1 = self.up1(d2);   d1 = self.bridge1(d1, e1); d1 = self.dec1(d1)

      logits = self.head(d1)              # (B,4,D,H,W) class logits
      bmap_logits = self.boundary(logits) # (B,1,D,H,W) boundary logits (no sigmoid)
      return logits, bmap_logits


# Sanity dry-run
model = M3ANeuroSeg(in_ch=4, num_classes=4).cpu()
x = torch.randn(1,4,128,128,128)
with torch.no_grad():
    lo, bm = model(x)
print("Model OK. logits:", lo.shape, "bmap:", bm.shape)


Model OK. logits: torch.Size([1, 4, 128, 128, 128]) bmap: torch.Size([1, 1, 128, 128, 128])


In [None]:
# ================================================
# Tri-Model Ensemble (UNet3D + DynUNet + M3A) — Mode A (per-fold, out-of-fold)
# ================================================
import os, json, gc, time
import numpy as np
import torch
from pathlib import Path
from tqdm import tqdm
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric, HausdorffDistanceMetric

device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PATCH   = globals().get("PATCH", 112)
OVERLAP = globals().get("OVERLAP", 0.5)

# --------- helpers ---------
def load_case_4ch_seg(case_id):
    d = Path(CACHE_DIR) / case_id
    vols = np.stack([
        np.load(d/"t1c.npy"),
        np.load(d/"t1n.npy"),
        np.load(d/"t2w.npy"),
        np.load(d/"t2f.npy"),
    ], axis=0)  # (4, Z, Y, X)
    seg = np.load(d/"seg.npy").astype(np.int64)      # (Z, Y, X) labels: 0 bg, 1 NCR, 2 ED, 4 ET
    return vols, seg

def one_hot_from_labels(lbl, num_classes):
    # lbl shape: (1,1,Z,Y,X) int64
    b, c, *sp = lbl.shape
    oh = torch.zeros((b, num_classes, *sp), dtype=torch.float32, device=lbl.device)
    for k in range(num_classes):
        oh[:, k] = (lbl[:,0] == k)
    return oh

def compute_composites_onehot(pred_bool, gt_bool):
    # pred_bool/gt_bool: (B,C,Z,Y,X) boolean one-hots for classes [0,BG,1 NCR, 2 ED, 3 ?], here we map from labels {0,1,2,4}
    # We assume channels are [0:bg,1:NCR,2:ED,3:ET] — ensure your model maps ET→channel 3 in training.
    # (That’s how your training/eval pipeline was set up.)
    bg  = pred_bool[:,0]
    ncr = pred_bool[:,1]
    ed  = pred_bool[:,2]
    et  = pred_bool[:,3]

    bg_g  = gt_bool[:,0]
    ncr_g = gt_bool[:,1]
    ed_g  = gt_bool[:,2]
    et_g  = gt_bool[:,3]

    # ET
    ET_pred = et.unsqueeze(1).float()
    ET_true = et_g.unsqueeze(1).float()

    # TC = NCR ∪ ET
    TC_pred = ((ncr | et).unsqueeze(1)).float()
    TC_true = ((ncr_g | et_g).unsqueeze(1)).float()

    # WT = NCR ∪ ED ∪ ET
    WT_pred = ((ncr | ed | et).unsqueeze(1)).float()
    WT_true = ((ncr_g | ed_g | et_g).unsqueeze(1)).float()

    return {"ET": (ET_pred, ET_true), "TC": (TC_pred, TC_true), "WT": (WT_pred, WT_true)}

@torch.no_grad()
def sw_logits(model, vol4, roi=112, overlap=0.5):
    # vol4: (4, Z, Y, X) numpy
    x = torch.from_numpy(vol4[None].copy()).float().to(device).to(memory_format=torch.channels_last_3d)
    with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
        def predictor(t):
            out = model(t)
            # unify to seg_logits if model returns (logits, None)
            if isinstance(out, (tuple, list)):
                return out[0]
            return out
        logits = sliding_window_inference(
            inputs=x, roi_size=(roi,roi,roi), overlap=overlap,
            sw_batch_size=1, predictor=predictor, mode="gaussian"
        )
    return logits[0].float().cpu().numpy()  # (C, Z, Y, X)

def make_m3a():       return M3ANeuroSeg(in_ch=4, num_classes=4, dims=(32,64,128,256))
def make_unet():      return make_unet3d(in_ch=4, out_ch=4, base=32)
def make_dynunet():   return DynUNetWrapper(in_ch=4, out_ch=4, base=32)

def load_model(ctor, ckpt_path):
    m = ctor().to(device).to(memory_format=torch.channels_last_3d)
    sd = torch.load(ckpt_path, map_location=device)["model"]
    m.load_state_dict(sd)
    m.eval()
    return m

def get_val_ids_for_fold(fold):
    with open(f"{IDX_DIR}/splits_5fold_seed{SEED}.json","r") as f:
        return list(json.load(f)["splits"][f"fold_{fold}"])

# --------- main tri-ensemble eval for one fold ---------
def eval_fold_tri_ensemble(fold, tta=False, save_csv=True):
    # resolve checkpoints (adjust names if your folders differ)
    ckpt_unet = f"/content/drive/MyDrive/m3a_neuroseg/runs/unet3d_brats2023_fold4_cycle1/best.pt"
    ckpt_dyn  = f"/content/drive/MyDrive/m3a_neuroseg/runs/nnunet_like_brats2023_fold4_cycle1/best.pt"
    ckpt_m3a  = f"/content/drive/MyDrive/m3a_neuroseg/runs/m3a_brats2023_fold4_cycle1/best.pt"
    assert os.path.exists(ckpt_unet) and os.path.exists(ckpt_dyn) and os.path.exists(ckpt_m3a), "Missing one or more checkpoints."

    # load once
    model_unet = load_model(make_unet, ckpt_unet)
    model_dyn  = load_model(make_dynunet, ckpt_dyn)
    model_m3a  = load_model(make_m3a, ckpt_m3a)

    # metrics
    dice_fn = DiceMetric(include_background=False, reduction="mean_batch")
    hd95_fn = HausdorffDistanceMetric(include_background=False, reduction="mean_batch", percentile=95)
    dice_fn.reset(); hd95_fn.reset()

    rows = []
    val_ids = get_val_ids_for_fold(fold)

    for cid in tqdm(val_ids, desc=f"Fold {fold} tri-ensemble", ncols=110):
        vol, gt = load_case_4ch_seg(cid)

        # base logits
        L1 = sw_logits(model_unet, vol, roi=PATCH, overlap=OVERLAP)
        L2 = sw_logits(model_dyn,  vol, roi=PATCH, overlap=OVERLAP)
        L3 = sw_logits(model_m3a,  vol, roi=PATCH, overlap=OVERLAP)
        logits = (L1 + L2 + L3) / 3.0

        # (optional) simple flip-TTA on Z axis for a small boost
        if tta:
            vol_flip = vol[:, ::-1]  # flip Z
            L1f = sw_logits(model_unet, vol_flip, roi=PATCH, overlap=OVERLAP)[:, ::-1]
            L2f = sw_logits(model_dyn,  vol_flip, roi=PATCH, overlap=OVERLAP)[:, ::-1]
            L3f = sw_logits(model_m3a,  vol_flip, roi=PATCH, overlap=OVERLAP)[:, ::-1]
            logits = 0.5 * logits + 0.5 * ((L1f + L2f + L3f) / 3.0)

        # to tensors for metrics
        pred_lbl = torch.from_numpy(logits).unsqueeze(0).argmax(1, keepdim=True)  # (1,1,Z,Y,X)
        pred_oh  = one_hot_from_labels(pred_lbl.to(torch.int64), 4).float().to(device)
        gt_oh    = one_hot_from_labels(torch.from_numpy(gt)[None,None].to(device), 4).float()

        dice_fn(y_pred=pred_oh, y=gt_oh)
        hd95_fn(y_pred=pred_oh, y=gt_oh)

        # composites
        comps = compute_composites_onehot(pred_oh.bool(), gt_oh.bool())
        def _dice(y_pred, y_true):
            m = DiceMetric(include_background=False, reduction="mean")
            return float(m(y_pred=y_pred.float(), y=y_true.float()).item())
        def _hd95(y_pred, y_true):
            m = HausdorffDistanceMetric(include_background=False, reduction="mean", percentile=95)
            return float(m(y_pred=y_pred.float(), y=y_true.float()).item())

        dET = _dice(*comps["ET"]); dTC = _dice(*comps["TC"]); dWT = _dice(*comps["WT"])
        hET = _hd95(*comps["ET"]); hTC = _hd95(*comps["TC"]); hWT = _hd95(*comps["WT"])

        rows.append(dict(case_id=cid, dice_ET=dET, dice_TC=dTC, dice_WT=dWT,
                         hd95_ET=hET, hd95_TC=hTC, hd95_WT=hWT))

        del L1, L2, L3, logits, pred_lbl, pred_oh, gt_oh
        gc.collect()

    d = np.squeeze(dice_fn.aggregate().cpu().numpy()).astype(np.float32)  # [NCR,ED,ET]
    h = np.squeeze(hd95_fn.aggregate().cpu().numpy()).astype(np.float32)

    print(f"\nFold {fold} (tri-ensemble {'+TTA' if tta else 'no-TTA'})")
    print(f"  Dice_mean (NCR/ED/ET): {d.mean():.4f} | {np.round(d,4)}")
    print(f"  HD95_mean (mm): {h.mean():.2f} | {np.round(h,2)}")

    import pandas as pd
    out_df = pd.DataFrame(rows)
    print(f"  Composites (mean): ET={out_df['dice_ET'].mean():.4f}, TC={out_df['dice_TC'].mean():.4f}, WT={out_df['dice_WT'].mean():.4f}")
    print(f"  HD95 comps (mm):  ET={out_df['hd95_ET'].mean():.2f}, TC={out_df['hd95_TC'].mean():.2f}, WT={out_df['hd95_WT'].mean():.2f}")

    if save_csv:
        out_dir = f"/content/drive/MyDrive/m3a_neuroseg/runs/tri_ensemble_results"
        os.makedirs(out_dir, exist_ok=True)
        if fold!=4:
          out_df.to_csv(f"{out_dir}/fold{fold}_tri_ensemble_cases.csv", index=False)
          print("  ✓ Saved:", f"{out_dir}/fold{fold}_tri_ensemble_cases.csv")
        else:
          out_df.to_csv(f"{out_dir}/fold{fold}_newtri_ensemble_cases.csv", index=False)
          print("  ✓ Saved:", f"{out_dir}/fold{fold}_newtri_ensemble_cases.csv")
    return out_df, d, h

# --------- RUN: pick the fold and go ---------
#fold = 4   # <-- set 0..4
#_ = eval_fold_tri_ensemble(fold, tta=False, save_csv=True)
def eval_all_folds(tta=False):
    all_rows = []
    all_native_dice = []
    all_native_hd95 = []

    for fold in range(5):
        print(f"\n==============================")
        print(f" Evaluating Fold {fold}")
        print(f"==============================")

        out_df, d, h = eval_fold_tri_ensemble(fold, tta=tta, save_csv=True)

        # collect case-level composite rows
        all_rows.append(out_df)

        # collect class-wise native metrics
        all_native_dice.append(d)
        all_native_hd95.append(h)

    # merge dataframes
    import pandas as pd
    final_df = pd.concat(all_rows, ignore_index=True)

    # aggregate native metrics
    all_native_dice = np.vstack(all_native_dice)
    all_native_hd95 = np.vstack(all_native_hd95)

    mean_dice = all_native_dice.mean(axis=0)
    mean_hd95 = all_native_hd95.mean(axis=0)

    print("\n==============================")
    print(" Final Tri-Ensemble Metrics Across ALL Folds")
    print("==============================")

    print(f"Native Dice (NCR/ED/ET): {mean_dice.mean():.4f} | {np.round(mean_dice,4)}")
    print(f"Native HD95 (mm): {mean_hd95.mean():.2f} | {np.round(mean_hd95,2)}")

    # composite metrics
    print("\nComposite Region Means:")
    print(f"ET Dice: {final_df['dice_ET'].mean():.4f}")
    print(f"TC Dice: {final_df['dice_TC'].mean():.4f}")
    print(f"WT Dice: {final_df['dice_WT'].mean():.4f}")

    print("\nComposite Region HD95 (mm):")
    print(f"ET HD95: {final_df['hd95_ET'].mean():.2f}")
    print(f"TC HD95: {final_df['hd95_TC'].mean():.2f}")
    print(f"WT HD95: {final_df['hd95_WT'].mean():.2f}")

    # save full summary
    final_df.to_csv(
        "/content/drive/MyDrive/m3a_neuroseg/runs/tri_ensemble_results/all_folds_tri_ensemble_cases.csv",
        index=False
    )

    print("\n✓ Saved full-case evaluation for all 5 folds.")
    return final_df, mean_dice, mean_hd95
_ = eval_all_folds(tta=True)



 Evaluating Fold 0


  win_data = inputs[unravel_slice[0]].to(sw_device)
  out[idx_zm] += p
Fold 0 tri-ensemble: 100%|██████████████████████████████████████████████████| 252/252 [34:42<00:00,  8.27s/it]



Fold 0 (tri-ensemble +TTA)
  Dice_mean (NCR/ED/ET): 0.8843 | [0.8469 0.8995 0.9066]
  HD95_mean (mm): 3.18 | [3.66 3.06 2.81]
  Composites (mean): ET=0.9066, TC=0.9434, WT=0.9519
  HD95 comps (mm):  ET=2.81, TC=2.80, WT=2.93
  ✓ Saved: /content/drive/MyDrive/m3a_neuroseg/runs/tri_ensemble_results/fold0_tri_ensemble_cases.csv

 Evaluating Fold 1


Fold 1 tri-ensemble: 100%|██████████████████████████████████████████████████| 252/252 [34:13<00:00,  8.15s/it]



Fold 1 (tri-ensemble +TTA)
  Dice_mean (NCR/ED/ET): 0.8689 | [0.8375 0.8898 0.8794]
  HD95_mean (mm): 3.71 | [3.74 4.04 3.34]
  Composites (mean): ET=0.8794, TC=0.9292, WT=0.9439
  HD95 comps (mm):  ET=3.34, TC=3.26, WT=3.74
  ✓ Saved: /content/drive/MyDrive/m3a_neuroseg/runs/tri_ensemble_results/fold1_tri_ensemble_cases.csv

 Evaluating Fold 2


Fold 2 tri-ensemble: 100%|██████████████████████████████████████████████████| 249/249 [35:59<00:00,  8.67s/it]



Fold 2 (tri-ensemble +TTA)
  Dice_mean (NCR/ED/ET): 0.8738 | [0.834  0.88   0.9074]
  HD95_mean (mm): 3.57 | [3.89 4.74 2.09]
  Composites (mean): ET=0.9074, TC=0.9382, WT=0.9448
  HD95 comps (mm):  ET=2.09, TC=2.85, WT=4.51
  ✓ Saved: /content/drive/MyDrive/m3a_neuroseg/runs/tri_ensemble_results/fold2_tri_ensemble_cases.csv

 Evaluating Fold 3


Fold 3 tri-ensemble: 100%|██████████████████████████████████████████████████| 249/249 [33:09<00:00,  7.99s/it]



Fold 3 (tri-ensemble +TTA)
  Dice_mean (NCR/ED/ET): 0.8829 | [0.8432 0.8895 0.916 ]
  HD95_mean (mm): 3.65 | [3.4  5.04 2.51]
  Composites (mean): ET=0.9160, TC=0.9488, WT=0.9466
  HD95 comps (mm):  ET=2.51, TC=2.96, WT=4.60
  ✓ Saved: /content/drive/MyDrive/m3a_neuroseg/runs/tri_ensemble_results/fold3_tri_ensemble_cases.csv

 Evaluating Fold 4


Fold 4 tri-ensemble:  76%|█████████████████████████████████████▊            | 188/249 [28:55<06:53,  6.78s/it]