In [None]:
import json
from pathlib import Path

import numpy as np
import SimpleITK as sitk
from PIL import Image
from tqdm.auto import tqdm

from ..inference import setup_sam2
from ..model import run_algorithm

DATASET = Path(
    "/rodata/mnradonc_dev/m299164/trackrad/datasets/duke-liver-v2/Segmentation"
)
CHECKPOINT = Path("./resources/sam2.1_hiera_small_trackrad_07_21.pt")
OUTPUT = Path(
    "/rodata/mnradonc_dev/m299164/trackrad/datasets/nnUNet/nnUNet_raw/Dataset400_DukeLiver"
)
DAVIS_PALETTE = b"\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0  \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00  \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80  @\xa0 @ \xa0@\xa0\xa0@  \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0"

predictor = setup_sam2(CHECKPOINT)

count = 0
for subject in tqdm(list(DATASET.iterdir())):
    for case in subject.iterdir():
        count += 1

        images = sorted((case / "images").glob("*.dicom"))
        masks = sorted((case / "masks").glob("*.dicom"))

        frames = np.stack(
            [
                sitk.GetArrayFromImage(sitk.ReadImage(str(img))).squeeze()
                for img in images
            ],
            axis=-1,
        )
        labels = np.stack(
            [
                sitk.GetArrayFromImage(sitk.ReadImage(str(mask))).squeeze()
                for mask in masks
            ],
            axis=-1,
        )

        # Find the first frame in labels that is not all zeroes
        first_nonzero_idx = next(
            (i for i in range(labels.shape[-1]) if np.any(labels[:, :, i] != 0)), None
        )
        if first_nonzero_idx is None:
            print("All label frames are zero.")
            continue

        # Skip some frames bc the liver doesn't appear in the first slice
        frames = frames[:, :, first_nonzero_idx:]
        labels = labels[:, :, first_nonzero_idx:]
        label_first = labels[:, :, first_nonzero_idx : first_nonzero_idx + 1]

        predicted_masks = run_algorithm(
            case_id=case.name,
            predictor=predictor,
            refiner=None,
            frames=frames,
            target=label_first,
            frame_rate=1.0,
            magnetic_field_strength=1.0,
            scanned_region="test",
            do_refinement=False,
        )

        assert frames.shape == predicted_masks.shape == labels.shape

        (OUTPUT / "imagesTr").mkdir(parents=True, exist_ok=True)
        (OUTPUT / "labelsTr").mkdir(parents=True, exist_ok=True)

        for i in range(frames.shape[-1]):
            fname = OUTPUT / "imagesTr" / f"{subject.name}-{case.name}-{i:04d}_0000.png"
            frame = frames[:, :, i]
            Image.fromarray(frame).save(fname)

        for i in range(predicted_masks.shape[-1]):
            fname = OUTPUT / "imagesTr" / f"{subject.name}-{case.name}-{i:04d}_0001.png"
            mask = predicted_masks[:, :, i].astype(np.uint8)
            png = Image.fromarray(mask)
            png.putpalette(DAVIS_PALETTE)
            png.save(fname)

        for i in range(labels.shape[-1]):
            fname = OUTPUT / "labelsTr" / f"{subject.name}-{case.name}-{i:04d}.png"
            label = labels[:, :, i].astype(np.uint8)
            png = Image.fromarray(label)
            png.putpalette(DAVIS_PALETTE)
            png.save(fname)

metadata = {
    "channel_names": {
        "0": "raw_image",
        "1": "suggested_mask",
    },
    "labels": {
        "background": 0,
        "lesion": 1,
    },
    "numTraining": count,
    "file_ending": ".png",
}
dataset_json = OUTPUT / "dataset.json"
dataset_json.write_text(json.dumps(metadata, indent=4))

nnUNet_raw is not defined and nnU-Net can only be used on data for which preprocessed files are already present on your system. nnU-Net cannot be used for experiment planning and preprocessing like this. If this is not intended, please read documentation/setting_up_paths.md for information on how to set this up properly.
nnUNet_preprocessed is not defined and nnU-Net can not be used for preprocessing or training. If this is not intended, please read documentation/setting_up_paths.md for information on how to set this up.
nnUNet_results is not defined and nnU-Net cannot be used for training or inference. If this is not intended behavior, please read documentation/setting_up_paths.md for information on how to set this up.


  0%|          | 0/95 [00:00<?, ?it/s]

frame loading (JPEG): 100%|██████████| 56/56 [00:01<00:00, 34.11it/s]

Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).
  pred_masks_gpu = fill_holes_in_mask_scores(
propagate in video: 100%|██████████| 56/56 [00:01<00:00, 47.02it/s]
frame loading (JPEG): 100%|██████████| 37/37 [00:00<00:00, 37.02it/s]
propagate in video: 100%|██████████| 37/37 [00:00<00:00, 44.59it/s]
frame loading (JPEG): 100%|██████████| 40/40 [00:01<00:00, 38.94it/s]
propagate in video: 100%|██████████| 40/40 [00:00<00:00, 47.50it/s]
frame loading (JPEG): 100%|██████████| 56/56 [00:01<00:00, 35.19it/s]
propagate in video: 100%|██████████| 56/56 [00:01<00:00, 47.53it/s]
frame loading (JPEG): 100%|██████████| 60/60 [00:01<00:00, 45.76it/s]
propagate in video: 100%|█████████

KeyboardInterrupt: 