# 0. Setup

# 1. Helper functions

In [None]:
def cleanup_gpu():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

def minmax(arr, minclip=None, maxclip=None):
    if not (minclip is None and maxclip is None):
        arr = np.clip(arr, minclip, maxclip)
    arr = (arr - arr.min()) / (arr.max() - arr.min() + 1e-8)
    return arr

def find_image(root, subj_id, name):
    base = os.path.join(root, subj_id.replace("_", "/"))
    matches = glob.glob(os.path.join(base, f"{name}.mha")) \
             + glob.glob(os.path.join(base, f"{name}.nii*"))
    if not matches:
        raise FileNotFoundError(f"{name} file not found for {subj_id}")
    return matches[0]

def load_image_pair(root, subj_id):
    mr_path = find_image(root, subj_id, "mr")
    ct_path = find_image(root, subj_id, "ct")
    mask_path = find_image(root, subj_id, "mask")

    mr_img = tio.ScalarImage(mr_path)
    ct_img = tio.ScalarImage(ct_path)

    mri = mr_img.data[0].numpy()
    ct  = ct_img.data[0].numpy()

    mri = minmax(mri)
    ct  = minmax(ct, minclip=-450, maxclip=450)
    print("MRI shape:", mri.shape, "CT shape:", ct.shape)
    return mri, ct


# 2. Load dataset (one subject)

In [None]:
root = "/content/drive/MyDrive/Colab Notebooks/MRI2CT"
SUBJ_ID = "SynthRAD2023_Task1_pelvis_1PA001"

mri, ct = load_image_pair(root, SUBJ_ID)

# 3. Load Anatomix model and extract features

In [None]:
model = Unet(
    dimension=3,
    input_nc=1,
    output_nc=16,
    num_downs=4,
    ngf=16,
).to(device)

ckpt_path = "/content/anatomix/model-weights/anatomix.pth"
model.load_state_dict(torch.load(ckpt_path, map_location=device), strict=True)
print("âœ… Loaded anatomix pretrained model")

@torch.no_grad()
def extract_feats(volume_np, model, device):
    inp = torch.from_numpy(volume_np[None, None]).float().to(device)
    H, W, D = inp.shape[-3:]
    pad_H = (16 - H % 16) % 16
    pad_W = (16 - W % 16) % 16
    pad_D = (16 - D % 16) % 16
    inp_padded = F.pad(inp, (0, pad_D, 0, pad_W, 0, pad_H))
    feats = model(inp_padded)
    feats = feats[:, :, :H, :W, :D]
    return feats.squeeze(0).cpu().numpy()  # [C,H,W,D]

feats_mri = extract_feats(mri, model, device)
cleanup_gpu()
feats_ct  = extract_feats(ct,  model, device)
print(f"âœ… MRI feats: {feats_mri.shape}, CT feats: {feats_ct.shape}")

# 4. Prepare voxel-wise dataset

In [None]:
X = torch.from_numpy(feats_mri).permute(1,2,3,0).reshape(-1, 16)
Y = torch.from_numpy(feats_ct ).permute(1,2,3,0).reshape(-1, 16)
print(f"Total voxels: {len(X):,}")

max_vox = 500_000
if len(X) > max_vox:
    idx = torch.randperm(len(X))[:max_vox]
    X, Y = X[idx], Y[idx]

dataset = TensorDataset(X, Y)
loader = DataLoader(dataset, batch_size=4096, shuffle=True, num_workers=2)

# 5. Simple translator model (MLP or Conv)

In [None]:
class MLPTranslator(nn.Module):
    def __init__(self, in_dim=16, hidden=64, out_dim=16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_dim),
        )
    def forward(self, x):
        return self.net(x)

class Conv1x1Translator(nn.Module):
    def __init__(self, in_dim=16, out_dim=16):
        super().__init__()
        self.conv = nn.Conv3d(in_dim, out_dim, kernel_size=1)
    def forward(self, x):
        return self.conv(x)

use_conv = False
model_t = Conv1x1Translator(16, 16).to(device) if use_conv else MLPTranslator().to(device)

optimizer = torch.optim.Adam(model_t.parameters(), lr=1e-3)
criterion = nn.MSELoss()


# 6. Training loop

In [None]:
n_epochs = 10
model_t.train()
for epoch in range(n_epochs):
    total_loss = 0.0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        pred = model_t(xb)
        loss = criterion(pred, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * xb.size(0)
    avg_loss = total_loss / len(loader.dataset)
    print(f"Epoch {epoch+1:02d}/{n_epochs} - Loss: {avg_loss:.6f}")
print("âœ… Training complete!")

save_path = f"/content/drive/MyDrive/Colab Notebooks/MRI2CT/mri2ct_simple_model.pt"
torch.save(model_t.state_dict(), save_path)
print(f"ðŸ’¾ Saved to {save_path}")


# 7. Evaluate: reconstruct predicted CT features

In [None]:
model_t.eval()
with torch.no_grad():
    X_full = torch.from_numpy(feats_mri).permute(1,2,3,0).reshape(-1,16).to(device)
    pred_full = model_t(X_full).cpu().numpy()
pred_feats = pred_full.reshape(*feats_ct.shape)  # [C,H,W,D]

print("âœ… Predicted CT feature volume:", pred_feats.shape)