# CADS Differentiable Inference Analysis (Configurable)

This notebook provides a synchronized view of real CT data and CADS predictions.
It demonstrates the **Differentiable Inference** pipeline used in training.

In [8]:
%matplotlib inline
import os
import sys
import torch
import torch.nn.functional as F
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
import gc
import importlib

# Add CADS to path
sys.path.append('CADS')
import cads.utils.inference
importlib.reload(cads.utils.inference)

from cads.utils.libs import setup_nnunet_env, get_model_weights_dir
from cads.utils.inference import nnUNetv2Predictor
from cads.dataset_utils.bodyparts_labelmaps import labelmap_part_bodyregions

# --- CONFIGURATION ---
PATCH_SIZE = 128  # Change this to 96, 160, 192, etc.
# PATCH_SIZE = 64 
TASK_ID = 559     # Body Regions (Saros)
SUBJ_ID = "1ABB169"
DATA_ROOT = "/gpfs/accounts/jjparkcv_root/jjparkcv98/minsukc/MRI2CT/SynthRAD_combined/3.0x3.0x3.0mm"
OS_WEIGHTS_PATH = "/gpfs/accounts/jjparkcv_root/jjparkcv98/minsukc/MRI2CT/cads_weights"
# ---------------------

os.environ["CADS_WEIGHTS_PATH"] = OS_WEIGHTS_PATH
setup_nnunet_env()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Device: {DEVICE}")
print(f"Active Patch Size: {PATCH_SIZE}³")

Device: cuda
Active Patch Size: 128³


In [9]:
print("Initializing Predictor...")
predictor_obj = nnUNetv2Predictor(model_folder=get_model_weights_dir(), task_id=TASK_ID, device=DEVICE)

Initializing Predictor...
Inference using model nnUNetTrainerNoMirroring__nnUNetResEncUNetLPlans__3d_fullres


In [10]:
def get_centered_padded_patch(data, target_size=128):
    sh = data.shape
    res = np.full((target_size, target_size, target_size), np.min(data))
    h, w, d = min(sh[0], target_size), min(sh[1], target_size), min(sh[2], target_size)
    s_h0, s_w0, s_d0 = sh[0]//2 - h//2, sh[1]//2 - w//2, sh[2]//2 - d//2
    source_crop = data[s_h0:s_h0+h, s_w0:s_w0+w, s_d0:s_d0+d]
    t_h0, t_w0, t_d0 = target_size//2 - h//2, target_size//2 - w//2, target_size//2 - d//2
    res[t_h0:t_h0+h, t_w0:t_w0+w, t_d0:t_d0+d] = source_crop
    return res

ct_p = os.path.join(DATA_ROOT, "train", SUBJ_ID, "ct.nii.gz")
seg_p = os.path.join(DATA_ROOT, "train", SUBJ_ID, "cads_ct_seg.nii.gz")

print(f"Loading data for {SUBJ_ID}...")
ct_full = nib.load(ct_p).get_fdata()
gt_full = nib.load(seg_p).get_fdata()

print(f"Full Volume Shape (GT): {gt_full.shape}")
ct_3 = get_centered_padded_patch(ct_full, PATCH_SIZE)
gt_3 = get_centered_padded_patch(gt_full, PATCH_SIZE)

t_ct = torch.from_numpy(ct_3).float().unsqueeze(0).unsqueeze(0).to(DEVICE)
t_gt = torch.from_numpy(gt_3).float().unsqueeze(0).unsqueeze(0).to(DEVICE)
print(f"Patch Shape (Input):   {t_ct.shape}")

if any(s > PATCH_SIZE for s in gt_full.shape):
    print(f"\n--- Context Alert ---")
    print(f"NOTE: Your volume is larger than {PATCH_SIZE} in some dimensions. This is a CROP.")
    print("Predictions at the edges may differ from GT due to InstanceNorm and missing spatial context.")

Loading data for 1ABB169...
Full Volume Shape (GT): (133, 130, 98)
Patch Shape (Input):   torch.Size([1, 1, 128, 128, 128])

--- Context Alert ---
NOTE: Your volume is larger than 128 in some dimensions. This is a CROP.
Predictions at the edges may differ from GT due to InstanceNorm and missing spatial context.


In [11]:
props = [{'spacing': [3.0, 3.0, 3.0]}]
with torch.no_grad():
    probs = predictor_obj.predict_differentiable(t_ct, props)
    hard_labels = torch.argmax(probs, dim=1)
print(f"Inference complete. Output probs shape: {probs.shape}")

Inference complete. Output probs shape: torch.Size([1, 11, 128, 128, 128])


In [12]:
def dice(p, t): return (2.*torch.sum(p*t)+1e-5)/(torch.sum(p)+torch.sum(t)+1e-5)

print(f"--- Regional Analysis (Task 559) ---")
print(f"{'Region':<25} | {'Hard Dice':<10} | {'Soft Dice':<10}")
print("-" * 55)

for idx, name in labelmap_part_bodyregions.items():
    if idx == 0: continue
    p_soft = probs[0, idx]
    t_curr = (t_gt[0, 0] == idx).float()
    h_curr = (hard_labels[0] == idx).float()
    s_dice = dice(p_soft, t_curr)
    h_dice = dice(h_curr, t_curr)
    if torch.sum(t_curr) > 0:
        print(f"{name:<25} | {h_dice.item():.4f}     | {s_dice.item():.4f}")
    else:
        print(f"{name:<25} | (Not in patch)")

--- Regional Analysis (Task 559) ---
Region                    | Hard Dice  | Soft Dice 
-------------------------------------------------------
subcutaneous_tissue       | 0.8392     | 0.8106
muscle                    | 0.6591     | 0.6255
abdominal_cavity          | 0.7234     | 0.6885
thoracic_cavity           | 0.6658     | 0.6465
bones                     | 0.8533     | 0.8232
glands                    | (Not in patch)
pericardium               | 0.0071     | 0.0185
breast_implant            | (Not in patch)
mediastinum               | 0.0011     | 0.0034
spinal_cord               | 0.0000     | 0.0000


In [13]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

img_v = t_ct.cpu().numpy().squeeze()
gt_all = t_gt[0, 0].cpu().numpy()
pred_all = hard_labels[0].cpu().numpy()

def viewer(sl, alpha, region_idx):
    fig, axes = plt.subplots(1, 5, figsize=(25, 5))
    
    sh = ct_full.shape
    d = min(sh[2], PATCH_SIZE)
    s_d0 = sh[2]//2 - d//2
    t_d0 = PATCH_SIZE//2 - d//2
    global_sl = np.clip(s_d0 + sl - t_d0, 0, sh[2]-1)
    
    axes[0].imshow(ct_full[:, :, global_sl].T, cmap='gray', vmin=-200, vmax=400, origin='lower')
    axes[0].set_title(f"Global Context (sl {global_sl})")
    
    axes[1].imshow(img_v[:, :, sl].T, cmap='gray', vmin=-200, vmax=400, origin='lower')
    axes[1].set_title(f"Patch CT ({PATCH_SIZE}³)")
    
    axes[2].imshow(img_v[:, :, sl].T, cmap='gray', vmin=-200, vmax=400, origin='lower', alpha=0.3)
    gt_m = np.ma.masked_where(gt_all[:, :, sl] != region_idx, gt_all[:, :, sl])
    axes[2].imshow(gt_m.T, cmap='autumn', origin='lower')
    axes[2].set_title(f"Global Context Hard Region {region_idx}")
    
    axes[3].imshow(img_v[:, :, sl].T, cmap='gray', vmin=-200, vmax=400, origin='lower', alpha=0.3)
    pred_m = np.ma.masked_where(pred_all[:, :, sl] != region_idx, pred_all[:, :, sl])
    axes[3].imshow(pred_m.T, cmap='winter', origin='lower')
    axes[3].set_title(f"Patch Hard Region {region_idx}")
    
    axes[4].imshow(img_v[:, :, sl].T, cmap='gray', vmin=-200, vmax=400, origin='lower', alpha=0.3)
    p_slice = probs[0, region_idx, :, :, sl].cpu().numpy()
    im = axes[4].imshow(p_slice.T, cmap='jet', vmin=0, vmax=1, origin='lower', alpha=alpha)
    axes[4].set_title(f"Patch Prob Region {region_idx}")
    
    divider = make_axes_locatable(axes[4])
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    
    for ax in axes: ax.axis('off')
    plt.show(); plt.close(fig)

dropdown_options = [(name, idx) for idx, name in labelmap_part_bodyregions.items()]
region_select = widgets.Dropdown(options=dropdown_options, value=5, description='Region:')

widgets.interact(viewer, sl=(0, PATCH_SIZE - 1), alpha=(0.0, 1.0, 0.05), region_idx=region_select)

interactive(children=(IntSlider(value=63, description='sl', max=127), FloatSlider(value=0.5, description='alph…

<function __main__.viewer(sl, alpha, region_idx)>