In [1]:
import os
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from totalsegmentator.python_api import totalsegmentator
from totalsegmentator.nifti_ext_header import load_multilabel_nifti

In [2]:
ROOT = "/home/minsukc/MRI2CT"   # change if needed
DATA_DIR = os.path.join(ROOT, "data", "1ABA103_3x3x3_resampled")
# DATA_DIR = os.path.join(ROOT, "data", "1THB211_3x3x3_resampled")
CT_PATH = os.path.join(DATA_DIR, "ct_resampled.nii.gz")
MR_PATH = os.path.join(DATA_DIR, "mr_resampled.nii.gz")

# choose which to segment
ct_input_path = CT_PATH     
mr_input_path = MR_PATH

# --- Output setup ---
# output_dir = os.path.join(DATA_DIR, "totalseg_output")
# os.makedirs(output_dir, exist_ok=True)
# output_path = os.path.join(output_dir, "seg.nii.gz")
ct_output_path = os.path.join(DATA_DIR, "ct_seg.nii.gz")
mr_output_path = os.path.join(DATA_DIR, "mr_seg.nii.gz")

ct_task = "total"
mr_task = "total_mr"

In [None]:
print(f"Running TotalSegmentator on {ct_input_path} (task={ct_task}) ...")
ct_seg_img = totalsegmentator(
    input=ct_input_path,
    output=ct_output_path,
    task=ct_task,
    device="gpu",
    ml=True,
    fast=False,
)

print("✅ CT Segmentation completed!")
print("Output saved to:", ct_output_path)

print(f"Running TotalSegmentator on {mr_input_path} (task={mr_task}) ...")
mr_seg_img = totalsegmentator(
    input=mr_input_path,
    output=mr_output_path,
    task=mr_task,
    device="gpu",
    ml=True,
    fast=False,
)
print("✅ MR Segmentation completed!")
print("Output saved to:", mr_output_path)

Running TotalSegmentator on /home/minsukc/MRI2CT/data/1ABA103_3x3x3_resampled/ct_resampled.nii.gz (task=total) ...

If you use this tool please cite: https://pubs.rsna.org/doi/10.1148/ryai.230024

Resampling...
  Resampled in 4.29s
Predicting part 1 of 5 ...


100%|███████████████████████████████████████████████████████████████████████████████| 18/18 [00:02<00:00,  7.89it/s]


Predicting part 2 of 5 ...


100%|███████████████████████████████████████████████████████████████████████████████| 18/18 [00:01<00:00,  9.30it/s]


Predicting part 3 of 5 ...


100%|███████████████████████████████████████████████████████████████████████████████| 18/18 [00:01<00:00,  9.28it/s]


Predicting part 4 of 5 ...


100%|███████████████████████████████████████████████████████████████████████████████| 18/18 [00:01<00:00,  9.32it/s]


Predicting part 5 of 5 ...


100%|███████████████████████████████████████████████████████████████████████████████| 18/18 [00:01<00:00,  9.44it/s]


In [None]:
import numpy as np

ct_arr = ct_seg_img.get_fdata().astype(np.uint16)
print("Segmentation shape:", ct_arr.shape)
print("Unique labels:", np.unique(ct_arr)[1:]) 
mr_arr = mr_seg_img.get_fdata().astype(np.uint16)
print("Segmentation shape:", mr_arr.shape)
print("Unique labels:", np.unique(mr_arr)[1:]) 

In [None]:
from totalsegmentator.nifti_ext_header import load_multilabel_nifti

seg_nifti, label_dict = load_multilabel_nifti(ct_output_path)
print("Example label mapping:")
for k, v in list(label_dict.items())[:10]:
    print(f"  {k:>3} → {v}")
    
seg_nifti, label_dict = load_multilabel_nifti(mr_output_path)
print("Example label mapping:")
for k, v in list(label_dict.items())[:10]:
    print(f"  {k:>3} → {v}")

In [None]:
%matplotlib inline
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt

ct_data = nib.load(CT_PATH).get_fdata()
ct_seg = nib.load(ct_output_path).get_fdata()

mr_data = nib.load(MR_PATH).get_fdata()
mr_seg = nib.load(mr_output_path).get_fdata()

# ----------------------------
# Normalize CT and MRI for display
# ----------------------------
ct_min, ct_max = -450, 450
ct_disp = np.clip(ct_data, ct_min, ct_max)
ct_disp = (ct_disp - ct_disp.min()) / (ct_disp.max() - ct_disp.min())

mr_disp = (mr_data - mr_data.min()) / (mr_data.max() - mr_data.min())

# ----------------------------
# Choose a slice
# ----------------------------
# sliceidx = ct_data.shape[2] // 2  # middle slice
sliceidx = 30

plt.figure(figsize=(18, 6))

# === CT ROW ===
plt.subplot(2, 3, 1)
plt.imshow(np.rot90(ct_disp[:, :, sliceidx]), cmap="gray")
plt.title("CT")
plt.axis("off")

plt.subplot(2, 3, 2)
plt.imshow(np.rot90(ct_disp[:, :, sliceidx]), cmap="gray")
plt.imshow(np.rot90(ct_seg[:, :, sliceidx]), cmap="tab20", alpha=0.25)
plt.title("CT + Segmentation Overlay")
plt.axis("off")

plt.subplot(2, 3, 3)
plt.imshow(np.rot90(ct_seg[:, :, sliceidx]), cmap="tab20")
plt.title("CT Segmentation")
plt.axis("off")

# === MRI ROW ===
plt.subplot(2, 3, 4)
plt.imshow(np.rot90(mr_disp[:, :, sliceidx]), cmap="gray")
plt.title("MRI")
plt.axis("off")

plt.subplot(2, 3, 5)
plt.imshow(np.rot90(mr_disp[:, :, sliceidx]), cmap="gray")
plt.imshow(np.rot90(mr_seg[:, :, sliceidx]), cmap="tab20", alpha=0.25)
plt.title("MRI + Segmentation Overlay")
plt.axis("off")

plt.subplot(2, 3, 6)
plt.imshow(np.rot90(mr_seg[:, :, sliceidx]), cmap="tab20")
plt.title("MRI Segmentation")
plt.axis("off")

plt.tight_layout()
plt.show()