# Baseline-UNet 3D – whole-brain visualisation

This notebook crops a 32³ patch from a chosen fMRI run,
runs the trained **UNet-3D** checkpoint and plots *noisy vs denoised* slices.

> **Tip** – set `nii_path` to any NIfTI you’ve uploaded to Colab / workspace.

In [None]:
# install deps (Colab) – skip if you’re on your local env
!pip -q install nibabel zarr torchmetrics matplotlib pillow


In [None]:
# clone the repo & add to PYTHONPATH (Colab)
!git clone --depth 1 https://github.com/nishxnt/fmri_project.git || true
import sys, pathlib
sys.path.append('/content/fmri_project')  # adjust if path differs

In [None]:
import torch, numpy as np, matplotlib.pyplot as plt, nibabel as nib
from fmri_project.models.unet3d import UNet3D
from fmri_project.kim_dataset.mask import compute_mask
from fmri_project.kim_dataset.normalize import zscore

### ---- user paths -------------------------------------------------
nii_path = "/content/sub-01_ses-1_task-motor_run-1_bold.nii.gz"  # change!
ckpt  = "/content/unet_baseline_best.pt"                          # change if elsewhere
#####################################################################

img    = nib.load(nii_path)
vol_np = img.get_fdata(dtype="float32")          # (X,Y,Z,T)
vol_np = np.moveaxis(vol_np, -1, 0)               # (T,X,Y,Z)
mask_np = compute_mask(img)
vol_np  = zscore(vol_np, mask_np)

# small demo crop 32×32×32 around the centre
xs = ys = zs = slice(40,72)
crop_np = vol_np[:, xs, ys, zs]                   # (T,32,32,32)

# (N,C,D,H,W) for UNet: (1,16,32,32,32)
patch = torch.from_numpy(crop_np).permute(1,0,2,3)  # (32, T,32,32)
patch = patch.unsqueeze(0).unsqueeze(1).float()

net = UNet3D(in_ch=16, out_ch=16, features=16).cpu()
net.load_state_dict(torch.load(ckpt, map_location="cpu"))
net.eval()
with torch.no_grad():
    den = net(patch).cpu().squeeze()              # (16,T,32,32)

# choose t=0, axial slice z=16
t, z = 0, 16
noisy = crop_np[t, :, :, z]
den2d = den[0, t, :, :, z].numpy()
clean = vol_np[t, xs, ys, z]                     # just baseline comparison

fig, axs = plt.subplots(1,3, figsize=(9,3))
for ax, img, title in zip(axs, [noisy, den2d, clean], ["Noisy","Denoised","Clean"]):
    ax.imshow(img, cmap="gray", vmin=-2, vmax=2)
    ax.set_title(title); ax.axis("off")
plt.tight_layout(); plt.show()
