In [None]:
import torch
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from model import UNet3D
from preprocess import load_decathlon_image, load_decathlon_mask


In [None]:
def visualize_slice(image, gt_mask, pred_mask, slice_idx=None):
    """Show overlay for a single slice"""
    if slice_idx is None:
        slice_idx = image.shape[2] // 2

    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    axs[0].imshow(image[:, :, slice_idx], cmap='gray')
    axs[0].set_title("MRI Slice")
    axs[1].imshow(gt_mask[:, :, slice_idx], cmap='Reds')
    axs[1].set_title("Ground Truth")
    axs[2].imshow(pred_mask[:, :, slice_idx], cmap='Blues')
    axs[2].set_title("Prediction")
    plt.show()


In [None]:
# Paths
model_path = "runs/experiment_1/best_model.pth"
image_path = "data/processed/val/images/BraTS_001.nii.gz"
mask_path  = "data/processed/val/masks/BraTS_001.nii.gz"


In [None]:
# Load data
image = load_decathlon_image(image_path)
mask = load_decathlon_mask(mask_path)


In [None]:
# Prepare model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet3D(in_channels=image.shape[0], out_channels=1).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()


In [None]:
# Inference
with torch.no_grad():
    input_tensor = torch.tensor(image[None], dtype=torch.float32).to(device)  # shape: (1, C, H, W, D)
    output = model(input_tensor)
    pred = torch.sigmoid(output).squeeze().cpu().numpy()
    pred_mask = (pred > 0.5).astype(np.uint8)


In [None]:
# Visualize
visualize_slice(image[0], mask, pred_mask)
