In [1]:
import torch, numpy as np, random, os, sys, matplotlib
print("Torch:", torch.__version__)
print("NumPy ok (<2):", np.__version__)

Torch: 2.9.0
NumPy ok (<2): 2.3.4


In [2]:
from pathlib import Path
import os, re, csv

root = Path("../data/BraTS-2023").resolve()
cases = sorted([p for p in root.iterdir() if p.is_dir() or p.suffix==".nii.gz" or p.suffix==".gz"])
rows = []

def list_files(case_dir):
    if case_dir.is_dir():
        return [str(p) for p in case_dir.glob("*.nii*")]
    else:
        return []

for case in sorted([p for p in root.iterdir() if p.is_dir()]):
    files = list_files(case)
    pat = {
        "t1":    r"-t1n\.nii\.gz$|-t1\.nii\.gz$",
        "t1ce":  r"-t1c\.nii\.gz$|-t1ce\.nii\.gz$",
        "t2":    r"-t2w\.nii\.gz$|-t2\.nii\.gz$",
        "flair": r"-t2f\.nii\.gz$|-flair\.nii\.gz$",
        "mask":  r"-seg\.nii\.gz$"
    }
    def pick(rx_pat):
        rx = re.compile(rx_pat)
        for f in files:
            if rx.search(os.path.basename(f)): return f
        return ""
    rows.append({
        "id": case.name,
        "t1":    pick(pat["t1"]),
        "t1ce":  pick(pat["t1ce"]),
        "t2":    pick(pat["t2"]),
        "flair": pick(pat["flair"]),
        "mask":  pick(pat["mask"]),
    })

out = root / "manifest.csv"
with open(out, "w", newline="") as fp:
    w = csv.DictWriter(fp, fieldnames=["id","t1","t1ce","t2","flair","mask"])
    w.writeheader(); w.writerows(rows)

print(f"Wrote {len(rows)} rows -> {out}")

Wrote 1251 rows -> /Users/kylelukaszek/Classes/AI/Project/data/BraTS-2023/manifest.csv


In [3]:
import os, random, math, time
from pathlib import Path
import numpy as np
import pandas as pd
import nibabel as nib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm


device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print("Device:", device)


df = pd.read_csv("../data/BraTS-2023/manifest.csv")
cols = ["t1","t1ce","t2","flair","mask"]
exists = lambda p: isinstance(p,str) and len(p)>0 and os.path.exists(p)
df = df[df[cols].map(exists).all(axis=1)].reset_index(drop=True)
print("Rows:", len(df))


def load_nii(path):
    img = nib.load(path)
    data = img.get_fdata().astype(np.float32)
    return data, img.header.get_zooms()

def norm_img(x):
    x = np.nan_to_num(x)
    p1, p99 = np.percentile(x, 1), np.percentile(x, 99)
    x = np.clip(x, p1, p99)
    m, s = x.mean(), x.std()+1e-6
    return (x - m) / s


def resize_img(x, size=192, mode="bilinear", is_mask=False):
    if x.ndim == 3:  # (H,W,C) یا (C,H,W)
        if x.shape[0] in (3,4) and x.shape[0] < x.shape[2]:
            t = torch.from_numpy(x)[None]
        else:
            t = torch.from_numpy(x).permute(2,0,1)[None]  # (1,C,H,W)
    elif x.ndim == 2:
        t = torch.from_numpy(x)[None,None]               # (1,1,H,W)
    else:
        raise ValueError("Unexpected shape:", x.shape)

    if is_mask:
        t = F.interpolate(t.float(), size=(size,size), mode="nearest")
        t = t[0,0].to(torch.int64)
        return t.numpy()
    else:
        t = F.interpolate(t.float(), size=(size,size), mode="bilinear", align_corners=False)
        t = t[0]
        if t.shape[0] > 1:
            return t.numpy().transpose(1,2,0)
        else:
            return t[0].numpy()

Device: mps
Rows: 1251


In [4]:
IMG_SIZE = 192
MAX_SLICES_PER_CASE_TRAIN = 12
MAX_SLICES_PER_CASE_VAL   = 6
BATCH_SIZE = 2
VAL_SPLIT = 0.12  # ~ 12%

all_ids = df["id"].tolist()
random.seed(42); random.shuffle(all_ids)
val_n = max(1, int(len(all_ids)*VAL_SPLIT))
val_ids = set(all_ids[:val_n])
train_ids = set(all_ids[val_n:])

class BraTSSliceDataset(Dataset):
    def __init__(self, df, id_set, max_slices_per_case, include_empty, img_size=IMG_SIZE):
        self.records = []
        self.img_size = img_size
        for _, row in df.iterrows():
            cid = row["id"]
            if cid not in id_set: 
                continue
            vols = {}
            for k in ["t1","t1ce","t2","flair","mask"]:
                v, _ = load_nii(row[k])
                vols[k] = v.astype(np.float32)
            
            x = np.stack([norm_img(vols["t1"]), norm_img(vols["t1ce"]), norm_img(vols["t2"]), norm_img(vols["flair"])], axis=0)  # (4,H,W,Z)
            y = vols["mask"].astype(np.int64)  # (H,W,Z)

            H,W,Z = y.shape
            zs = list(range(Z))
            random.shuffle(zs)
            kept = 0
            for z in zs:
                mask_slice = y[:,:,z]
                if not include_empty and mask_slice.max()==0:
                    continue
                img_slice  = x[:,:,:,z]  # (4,H,W)
                # Resize
                img_slice_r = resize_img(img_slice.transpose(1,2,0), size=self.img_size, is_mask=False)  # -> (H,W,4)
                mask_slice_r= resize_img(mask_slice, size=self.img_size, is_mask=True)                    # -> (H,W)
                # CHW
                img_chw = np.ascontiguousarray(img_slice_r.transpose(2,0,1))  # (4,H,W)
                self.records.append((cid, z, img_chw, mask_slice_r))
                kept += 1
                if kept >= max_slices_per_case:
                    break

    def __len__(self): return len(self.records)

    def __getitem__(self, i):
        cid,z,x,y = self.records[i]
        return {
            "id": cid, "z": z,
            "image": torch.from_numpy(x).float(),
            "mask":  torch.from_numpy(y).long()
        }

train_ds = BraTSSliceDataset(df, train_ids, MAX_SLICES_PER_CASE_TRAIN, include_empty=False)
val_ds   = BraTSSliceDataset(df, val_ids,   MAX_SLICES_PER_CASE_VAL,   include_empty=True)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0, pin_memory=False)
val_dl   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=False)
print("Train samples:", len(train_ds), " Val samples:", len(val_ds))

Train samples: 13211  Val samples: 900


In [5]:
class DoubleConv(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(c_in, c_out, 3, padding=1), nn.BatchNorm2d(c_out), nn.ReLU(inplace=True),
            nn.Conv2d(c_out, c_out, 3, padding=1), nn.BatchNorm2d(c_out), nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

class UNet2D(nn.Module):
    def __init__(self, in_ch=4, n_classes=4, base=32):
        super().__init__()
        self.d1 = DoubleConv(in_ch, base)
        self.p1 = nn.MaxPool2d(2)
        self.d2 = DoubleConv(base, base*2); self.p2 = nn.MaxPool2d(2)
        self.d3 = DoubleConv(base*2, base*4); self.p3 = nn.MaxPool2d(2)
        self.b  = DoubleConv(base*4, base*8)
        self.u3 = nn.ConvTranspose2d(base*8, base*4, 2, stride=2)
        self.c3 = DoubleConv(base*8, base*4)
        self.u2 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.c2 = DoubleConv(base*4, base*2)
        self.u1 = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.c1 = DoubleConv(base*2, base)
        self.out= nn.Conv2d(base, n_classes, 1)

    def forward(self, x):
        x1 = self.d1(x)
        x2 = self.d2(self.p1(x1))
        x3 = self.d3(self.p2(x2))
        xb = self.b(self.p3(x3))
        x  = self.u3(xb); x = self.c3(torch.cat([x,x3], dim=1))
        x  = self.u2(x);  x = self.c2(torch.cat([x,x2], dim=1))
        x  = self.u1(x);  x = self.c1(torch.cat([x,x1], dim=1))
        return self.out(x)

model = UNet2D(in_ch=4, n_classes=4).to(device)
sum(p.numel() for p in model.parameters())/1e6

1.928804

In [6]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0, n_classes=4):
        super().__init__()
        self.smooth = smooth
        self.n_classes = n_classes
    def forward(self, logits, target):
        # logits: (B,C,H,W), target: (B,H,W)
        probs = torch.softmax(logits, dim=1)
        target_1h = torch.nn.functional.one_hot(target, num_classes=self.n_classes).permute(0,3,1,2).float()
        dims = (0,2,3)
        intersection = (probs * target_1h).sum(dims)
        denom = probs.sum(dims) + target_1h.sum(dims)
        dice = (2*intersection + self.smooth) / (denom + self.smooth)
        return 1 - dice.mean()

dice_loss = DiceLoss(n_classes=4)
ce_loss   = nn.CrossEntropyLoss()
def combo_loss(logits, target, w_dice=0.7, w_ce=0.3):
    return w_dice * dice_loss(logits, target) + w_ce * ce_loss(logits, target)

In [None]:
EPOCHS = 50
lr = 1e-3
opt = torch.optim.AdamW(model.parameters(), lr=lr)

def run_epoch(dl, train=True):
    model.train(train)
    total_loss, total_dice, n = 0.0, 0.0, 0
    pbar = tqdm(dl, leave=False)
    for batch in pbar:
        x = batch["image"].to(device)          # (B,4,H,W)
        y = batch["mask"].to(device)           # (B,H,W)
        if train:
            opt.zero_grad()
        logits = model(x)
        loss = combo_loss(logits, y)
        with torch.no_grad():
            
            probs = torch.softmax(logits, dim=1)
            pred  = probs.argmax(1)
            
            dices = []
            for cls in [1,2,3]:
                inter = ((pred==cls) & (y==cls)).sum().item()
                denom = (pred==cls).sum().item() + (y==cls).sum().item()
                d = (2*inter)/(denom+1e-6) if denom>0 else 1.0
                dices.append(d)
            mean_dice = float(np.mean(dices)) if len(dices)>0 else 0.0

        if train:
            loss.backward()
            opt.step()
        total_loss += float(loss.item()) * x.size(0)
        total_dice += mean_dice * x.size(0)
        n += x.size(0)
        pbar.set_description(f"{'Train' if train else 'Val'} loss={total_loss/n:.4f} dice={total_dice/n:.4f}")
    return total_loss/n, total_dice/n

best = {"val_dice": -1, "path": "./checkpoints/unet2d_braTS.pt"}
os.makedirs("./checkpoints", exist_ok=True)

for ep in range(1, EPOCHS+1):
    tr_loss, tr_dice = run_epoch(train_dl, train=True)
    va_loss, va_dice = run_epoch(val_dl,   train=False)
    print(f"[{ep:02d}] train loss={tr_loss:.4f} dice={tr_dice:.4f} | val loss={va_loss:.4f} dice={va_dice:.4f}")
    if va_dice > best["val_dice"]:
        best["val_dice"] = va_dice
        torch.save({"model":model.state_dict()}, best["path"])
        print("✓ Saved:", best["path"])
print("Best val dice:", best["val_dice"])

                                                                                  

[01] train loss=0.3039 dice=0.4895 | val loss=0.2113 dice=0.6568
✓ Saved: ./checkpoints/unet2d_braTS.pt


                                                                                  

[02] train loss=0.2663 dice=0.5373 | val loss=0.1912 dice=0.6938
✓ Saved: ./checkpoints/unet2d_braTS.pt


                                                                                  

[03] train loss=0.2131 dice=0.6324 | val loss=0.1797 dice=0.7154
✓ Saved: ./checkpoints/unet2d_braTS.pt


Train loss=0.1884 dice=0.6745:  90%|████████▉ | 5940/6606 [02:45<00:18, 35.56it/s]

In [None]:

model.eval()
samples = 3
fig, axes = plt.subplots(samples, 4, figsize=(10, 3*samples))
with torch.no_grad():
    idxs = np.random.choice(len(val_ds), size=samples, replace=False)
    for r,i in enumerate(idxs):
        item = val_ds[i]
        x = item["image"][None].to(device)
        y = item["mask"].cpu().numpy()
        logits = model(x); pred = torch.softmax(logits,1).argmax(1)[0].cpu().numpy()

        
        flair = item["image"][1].cpu().numpy() 
        flair = item["image"][3].cpu().numpy()  

        axes[r,0].imshow(flair, cmap="gray"); axes[r,0].set_title("FLAIR"); axes[r,0].axis("off")
        axes[r,1].imshow(y, cmap="nipy_spectral"); axes[r,1].set_title("GT mask"); axes[r,1].axis("off")
        axes[r,2].imshow(pred, cmap="nipy_spectral"); axes[r,2].set_title("Pred mask"); axes[r,2].axis("off")
        diff = (pred!=y).astype(np.uint8)
        axes[r,3].imshow(diff, cmap="hot"); axes[r,3].set_title("Error map"); axes[r,3].axis("off")
plt.tight_layout(); plt.show()

In [None]:

ckpt_path = "./checkpoints/unet2d_braTS.pt"
state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state["model"])
model.eval()
print("Loaded:", ckpt_path)

In [None]:

from pathlib import Path
import re, os
import numpy as np
import torch
import matplotlib.pyplot as plt


MU_BASE = Path("../data/MU-Glioma-Post")

def _pick_one(files, patterns):
    for p in files:
        name = p.name.lower()
        for pat in patterns:
            if re.search(pat, name):
                return p
    return None

def find_case_mu(base=MU_BASE):
    cases = []
 
    for tp in sorted(Path(base).rglob("Timepoint_*")):
        files = [p for p in tp.glob("*.nii*")]
        if not files:
            continue

     
        t1ce = _pick_one(files, [r"t1c", r"t1ce", r"post.*t1", r"t1.*post"])
        t1   = _pick_one(files, [r"t1n", r"(?<!c)(?<!ce)(^|[_-])t1([_-]|$)"])
        t2   = _pick_one(files, [r"t2w", r"(^|[_-])t2([_-]|$)"])
        flair= _pick_one(files, [r"t2f", r"flair"])
        mask = _pick_one(files, [r"mask", r"seg"])

        if all([t1, t1ce, t2, flair]):
            cases.append({
                "dir":   str(tp),
                "t1":    str(t1),
                "t1ce":  str(t1ce),
                "t2":    str(t2),
                "flair": str(flair),
                "mask":  str(mask) if mask else ""
            })
    return cases

mu_cases = find_case_mu()
print("MU cases found:", len(mu_cases))
for i,c in enumerate(mu_cases[:3]):
    print(f"[{i}] {c['dir']}")

In [None]:
# ---- visualize one MU case with your trained model ----
if len(mu_cases):
    samp = mu_cases[0]
    im_t1,_   = load_nii(samp["t1"])
    im_t1ce,_ = load_nii(samp["t1ce"])
    im_t2,_   = load_nii(samp["t2"])
    im_fl,_   = load_nii(samp["flair"])

    img = np.stack([norm_img(im_t1), norm_img(im_t1ce), norm_img(im_t2), norm_img(im_fl)], axis=0)  # (4,H,W,Z)
    z = img.shape[-1] // 2
    sl = img[:, :, :, z]                                          # (4,H,W)
    sl_r = resize_img(sl.transpose(1,2,0), size=IMG_SIZE, is_mask=False)  # -> (H,W,4)
    x = torch.from_numpy(sl_r.transpose(2,0,1))[None].float().to(device)  # -> (1,4,H,W)

    with torch.no_grad():
        pred = torch.softmax(model(x), 1).argmax(1)[0].cpu().numpy()

    plt.figure(figsize=(10,3))
    plt.subplot(1,3,1); plt.imshow(sl_r[:,:,3], cmap="gray"); plt.title("FLAIR"); plt.axis("off")
    plt.subplot(1,3,2); plt.imshow(pred, cmap="nipy_spectral"); plt.title("Pred"); plt.axis("off")
    if samp["mask"] and os.path.exists(samp["mask"]):
        gt,_ = load_nii(samp["mask"])
        gt_r = resize_img(gt[:,:,z], size=IMG_SIZE, is_mask=True)
        plt.subplot(1,3,3); plt.imshow(gt_r, cmap="nipy_spectral"); plt.title("GT"); plt.axis("off")
    plt.tight_layout(); plt.show()
else:
    print("No MU cases matched. Check path/patterns.")

In [None]:
# ========== Phase 2: Distillation dataset creation ==========
import torch, numpy as np, os
from tqdm import tqdm

OUT_PATH = "./distillation_data"
os.makedirs(OUT_PATH, exist_ok=True)

def collect_voxels(df, sample_count=200_000):
    vox, labels = [], []
    for idx,row in tqdm(df.iterrows(), total=len(df)):
        mask,_ = load_nii(row["mask"])
        if mask.max()==0: continue
        t1,_ = load_nii(row["t1"]); t1ce,_ = load_nii(row["t1ce"])
        t2,_ = load_nii(row["t2"]); fl,_ = load_nii(row["flair"])
        vol = np.stack([norm_img(t1), norm_img(t1ce), norm_img(t2), norm_img(fl)], axis=0)
       
        xs, ys, zs = np.random.randint(0,vol.shape[1],sample_count), \
                     np.random.randint(0,vol.shape[2],sample_count), \
                     np.random.randint(0,vol.shape[3],sample_count)
        for i in range(sample_count):
            vox.append([xs[i]/vol.shape[1], ys[i]/vol.shape[2], zs[i]/vol.shape[3]])
            labels.append(int(mask[xs[i],ys[i],zs[i]]))
    np.savez_compressed(os.path.join(OUT_PATH,"samples.npz"), xyz=np.array(vox), label=np.array(labels))
    print("Saved:", len(vox), "voxels")

collect_voxels(df.sample(20))  

In [None]:
# ========== Phase 2: Train implicit MLP ==========
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

data = np.load(os.path.join(OUT_PATH,"samples.npz"))
xyz, label = torch.tensor(data["xyz"],dtype=torch.float32), torch.tensor(data["label"],dtype=torch.long)

class VoxelSet(Dataset):
    def __len__(self): return len(label)
    def __getitem__(self,i): return xyz[i], label[i]

loader = DataLoader(VoxelSet(), batch_size=1024, shuffle=True)

class ImplicitMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(3,64), nn.ReLU(),
            nn.Linear(64,64), nn.ReLU(),
            nn.Linear(64,4)   
        )
    def forward(self,x): return self.net(x)

mlp = ImplicitMLP().to(device)
opt = torch.optim.Adam(mlp.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(10):
    total = 0
    for x,y in loader:
        x,y = x.to(device), y.to(device)
        opt.zero_grad()
        out = mlp(x)
        loss = loss_fn(out,y)
        loss.backward()
        opt.step()
        total += loss.item()
    print(f"Epoch {epoch+1:02d} | Loss {total/len(loader):.4f}")

torch.save(mlp.state_dict(),"./checkpoints/implicit_mlp.pth")
print("Saved implicit model.")