In [1]:
import os
from pathlib import Path
import json
import SimpleITK as sitk


In [2]:
# Setup nnunet env vars
os.environ["nnUNet_raw_data_base"] = "./nnunet_data/nnUNet_raw_data_base"
os.environ["nnUNet_preprocessed"] = "./nnunet_data/nnUNet_preprocessed"
os.environ["RESULTS_FOLDER"] = "./nnunet_data/nnUNet_trained_models"


In [3]:
TaskNumber = "Task773_Liver"

In [4]:
def flare21_to_liver(input_file, output_file):
    """Get just liver (label=1) from flare21 dataset"""
    img = sitk.ReadImage(input_file)
    img = img == 1
    sitk.WriteImage(img, output_file, useCompression=True, compressionLevel=9)

Setup dataset in format nnunet expects

In [5]:
dataset = []

# Get total segmentator dataset
ts_dir = Path("training_data/total_segmentator")
for ct in ts_dir.rglob("ct.nii.gz"):
    seg = ct.parent / "segmentations" / "liver.nii.gz"
    assert seg.exists()
    dataset.append(("ts", ct, seg))

# Get flare21 dataset
fl_dir = Path("training_data/flare21/TrainingImg")
for ct in fl_dir.rglob("train*.nii.gz"):
    seg = ct.parent.parent / "TrainingMask" / f"{ct.name[:9]}.nii.gz"
    assert seg.exists()
    dataset.append(("flare", ct, seg))

# setup links in nnunet format
base_dir = Path(os.environ["nnUNet_raw_data_base"]) / "nnUNet_raw_data" / TaskNumber
scans_dir = base_dir / "imagesTr"
labels_dir = base_dir / "labelsTr"
scans_dir.mkdir(parents=True, exist_ok=True)
labels_dir.mkdir(parents=True, exist_ok=True)

if len(list(scans_dir.glob("*"))) > 0:
    raise RuntimeError("nnunet data dir is not empty!")
if len(list(labels_dir.glob("*"))) > 0:
    raise RuntimeError("nnunet data dir is not empty!")

training_set = []
for i, (ds_name, ct, seg) in enumerate(sorted(dataset)):
    case_id = f"train_{i:04d}"
    ct_link = scans_dir / f"{case_id}_0000.nii.gz"
    seg_link = labels_dir / f"{case_id}.nii.gz"
    ct_link.symlink_to(ct.absolute())
    if ds_name == "flare":
        flare21_to_liver(seg, seg_link)
    else:
        seg_link.symlink_to(seg.absolute())
    training_set.append(
        {
            "image": str((scans_dir / f"{case_id}.nii.gz").relative_to(base_dir)),
            "label": str(seg_link.relative_to(base_dir)),
        }
    )

# Save training.json
with open("dataset.json") as f:
    training_json = json.load(f)

training_json["training"] = training_set
training_json["numTraining"] = len(training_set)

with open("dataset.json", "w") as f:
    json.dump(training_json, f, indent=2)

dataset_link = base_dir / "dataset.json"
dataset_link.symlink_to(Path("dataset.json").absolute())