In [1]:
import os
import torch
import torchio as tio
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from anatomix.anatomix.model.network import Unet

from anatomix import anatomix
print(anatomix.__file__)

  pkg = __import__(module)  # top level module


/Users/ms/cs/ML/MRI2CT/anatomix/anatomix/__init__.py


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cpu


# 1. Load SynthRAD2025 Dataset

In [3]:
def create_synthrad_dataset(root_dir):
    subjects = []
    for anatomy in sorted(os.listdir(root_dir)):
        anatomy_path = os.path.join(root_dir, anatomy)
        if not os.path.isdir(anatomy_path):
            continue

        for patient in sorted(os.listdir(anatomy_path)):
            patient_path = os.path.join(anatomy_path, patient)
            if not os.path.isdir(patient_path):
                continue

            mr_path = os.path.join(patient_path, "mr.mha")
            ct_path = os.path.join(patient_path, "ct.mha")
            mask_path = os.path.join(patient_path, "mask.mha")

            if not (os.path.exists(mr_path) and os.path.exists(ct_path)):
                continue

            subject = tio.Subject(
                id=patient,
                anatomy=anatomy,
                mr=tio.ScalarImage(mr_path),
                ct=tio.ScalarImage(ct_path),
                mask=tio.LabelMap(mask_path),
            )
            subjects.append(subject)
    return tio.SubjectsDataset(subjects)

root = "/Users/ms/cs/ML/MRI2CT/SynthRAD2025/Task1"
dataset = create_synthrad_dataset(root)
print(f"Loaded {len(dataset)} subjects")

# Pick one subject to visualize
subject = dataset[0]
mri = subject["mr"].data[0].numpy()  # [H, W, D]
ct = subject["ct"].data[0].numpy()
print("MRI shape:", mri.shape)
print("CT shape:", ct.shape)

# Normalize intensity
mri = (mri - mri.min()) / (mri.max() - mri.min())

# Select one representative slice (e.g., middle axial)
slice_z = mri.shape[-1] // 2

Loaded 513 subjects
MRI shape: (465, 367, 91)
CT shape: (465, 367, 91)


# 2. Load Anatomix pretrained model

In [None]:
model = Unet(
    dimension=3,
    input_nc=1,
    output_nc=16,  # feature channels
    num_downs=4,
    ngf=16,
).to(device)

ckpt_path = "./anatomix/model-weights/anatomix.pth"
model.load_state_dict(torch.load(ckpt_path, map_location=device), strict=True)
model.eval()
print("✅ Loaded anatomix pretrained model")

Encoder skip connect id [8, 15, 22, 29]
Decoder skip connect id [37, 44, 51, 58]
✅ Loaded anatomix pretrained model


# 3. Extract 3D feature maps

In [None]:
with torch.no_grad():
    inp = torch.from_numpy(mri[None, None]).float().to(device)
    feats = model(inp)  # [1, C, H, W, D]
feats = feats.squeeze(0).cpu().numpy()  # [C, H, W, D]
print("Feature shape:", feats.shape)

# 4. Select one slice for 2D visualization

In [None]:
feat_slice = feats[:, :, :, slice_z]  # [C, H, W]
feat_slice = np.transpose(feat_slice, (1, 2, 0))  # [H, W, C]
H, W, C = feat_slice.shape
print(f"Feature slice: {H}x{W}, {C} channels")

# 5. PCA Feature Visualization

In [None]:
flat_feats = feat_slice.reshape(-1, C)
pca = PCA(n_components=3)
pca_feats = pca.fit_transform(flat_feats)
pca_feats = (pca_feats - pca_feats.min(0)) / (pca_feats.max(0) - pca_feats.min(0) + 1e-8)
pca_img = pca_feats.reshape(H, W, 3)

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(mri[:, :, slice_z], cmap='gray')
plt.title("MRI (axial slice)")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(pca_img)
plt.title("Anatomix PCA Feature Map")
plt.axis('off')
plt.tight_layout()
plt.show()

# 6. Patch-wise Similarity Map

In [None]:
import torch.nn.functional as F

# Convert to tensor
feat_t = torch.from_numpy(feat_slice).permute(2, 0, 1).unsqueeze(0)  # [1, C, H, W]
feat_t = F.normalize(feat_t, dim=1)

# Pick some query patches
grid_size = 4
ys = np.linspace(0, H - 1, grid_size, dtype=int)
xs = np.linspace(0, W - 1, grid_size, dtype=int)
query_points = [(y, x) for y in ys for x in xs]

fig, axes = plt.subplots(grid_size, grid_size + 1, figsize=(4 * (grid_size + 1), 4 * grid_size))

# Left column: MRI
for r in range(grid_size):
    ax_img = axes[r, 0]
    ax_img.imshow(mri[:, :, slice_z], cmap='gray')
    for (y, x) in query_points:
        ax_img.scatter(x, y, color='red', s=20, edgecolors='black', linewidth=0.5)
    ax_img.axis('off')
    ax_img.set_title("MRI")

# Right grid: similarity maps
ax_list = []
for r in range(grid_size):
    for c in range(1, grid_size + 1):
        ax_list.append(axes[r, c])

for ax, (qy, qx) in zip(ax_list, query_points):
    q_vec = feat_t[0, :, qy, qx].unsqueeze(0)  # [1, C]
    sim = F.cosine_similarity(q_vec[:, :, None, None], feat_t, dim=1).squeeze()
    sim = (sim - sim.min()) / (sim.max() - sim.min())
    ax.imshow(sim.cpu().numpy(), cmap='plasma', interpolation='nearest')
    ax.scatter(qx, qy, color='red', s=30, edgecolors='black', linewidth=0.5)
    ax.axis('off')
    ax.set_title(f"({qx},{qy})")

plt.suptitle("Anatomix Patch Similarity Maps", fontsize=16)
plt.tight_layout()
plt.show()