In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from monai.transforms import LoadImage
from src.models.neurosegnet import NeuroSegNet
from src.xai.attention_rollout import AttentionRollout
import yaml
from pathlib import Path

# Load config and model
with open("configs/neurosegnet_config.yaml", "r") as f:
    config = yaml.safe_load(f)

model = NeuroSegNet(config)
checkpoint = torch.load("experiments/neurosegnet_v1/checkpoints/best_model.pth", map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


In [None]:
# Load validation image
image_path = sorted(Path("data/val_preprocessed/").glob("*.nii.gz"))[0]
loader = LoadImage(image_only=True)
img_np = loader(str(image_path))  # [H,W,D,C]
img_np = np.moveaxis(img_np, -1, 0)  # [C,H,W,D]
input_tensor = torch.tensor(img_np, dtype=torch.float32).unsqueeze(0).to(device)

# Run attention rollout
rollout = AttentionRollout(model, head_fusion='mean', discard_ratio=0.0)
with torch.no_grad():
    _ = model(input_tensor)

attn_map = rollout.compute_rollout_attention()[0]  # (tokens x tokens)
rollout.reset()

# Show summary attention rollout heatmap
plt.figure(figsize=(6,6))
plt.imshow(attn_map.numpy(), cmap="viridis")
plt.title("Attention Rollout Map")
plt.colorbar()
plt.show()
