In [None]:
import os
import torch
import gdown
import numpy as np
import SimpleITK as sitk
from skimage.measure import regionprops
from annotatepropwizard3d.mask_propagation import MaskPropagation

## Mask Propagation
* Initialize the MaskPropagation class.
  * Args:
    * sam_checkpoint (str): Path to the SAM checkpoint file. (you can download it from official website)
    * cutie_yaml (str): Path to the CUTIE configuration file. (you can get it from yaml directory)

  * Attributes:
    * model_size (int): Size of the input image.
    * encoder_adapter (bool): Whether to use the encoder adapter.
    * image_size (int): Size of the input image after resizing.
    * device (torch.device): Device to use for tensor operations.
    * sam_predictor (SammedPredictor): SAM predictor object.
    * vos_processor (InferenceCoreWithLogits): VOS processor object.
    * trace_num (int): Number of traces to use for mask propagation.

  * Methods:
    * predict(img, initial_mask, input_z): Predict the mask propagation.
    * predict_by_volume(img, initial_mask): Predict the mask propagation with initial 3D mask

In [None]:
model = MaskPropagation("../weights/sam_vit_b_01ec64.pth", "../yaml/eval_config.yaml")

## Download testing data
* Nii data from Amos
* Amos is a multi-organ segmentation dataset

In [None]:
os.makedirs("data", exist_ok=True)

output = "data/amos_0001_data.nii.gz"
if not os.path.exists(output):
    file_id = "1TtQyLI0X6nq90n0dw3zJFFZZzgrZj8br"
    url = f"https://drive.google.com/uc?id={file_id}"
    gdown.download(url, output)

output = "data/amos_0001_label.nii.gz"
if not os.path.exists(output):
    file_id = "1MgfbufE3802ZsNwQO8Xloce1ezXMszpM"
    url = f"https://drive.google.com/uc?id={file_id}"
    gdown.download(url, output)

## Load Nii data
* Load Nii data
* Choose class 1 as target mask
* Mask other class mask
* find central slice for each group masks
* add to initial 3d mask

In [None]:
sitk_img = sitk.ReadImage("data/amos_0001_data.nii.gz")
img = sitk.GetArrayFromImage(sitk_img)

sitk_mask = sitk.ReadImage("data/amos_0001_label.nii.gz")
mask = sitk.GetArrayFromImage(sitk_mask)

mask = mask.astype(int)
mask[mask != 1] = 0
mask[mask == 1] = 1

mask = mask.astype(np.uint8)
labeled_mask = np.zeros_like(mask)

for prop in regionprops(mask):

    init_center = np.array(prop.centroid, dtype=int)
    index_z = init_center[0]
    labeled_mask[index_z] = mask[index_z]

## Inference
* Predict the mask propagation for a whole volume.
  * Args:
    * img (numpy.ndarray): Input image.
    * labeled_mask (numpy.ndarray): Labeled mask containing some labeled slices.

  * Returns:
    * numpy.ndarray: Predicted mask for the whole volume.

In [None]:
with torch.inference_mode():
    result = model.predict_by_volume(img, labeled_mask)

## Calculate F1 score

In [None]:
f1 = 2 * (result * mask).sum() / (result.sum() + mask.sum())
print(f"f1 score is: {f1}")

## Visualization
* mask_propagation_visualize
  * Visualize result for these slices has predicted masks
    * Args:
        * img: [np.ndarray] original image
        * predict_masks: [dict<int, np.ndarray>] predicted masks dictionary
        * output_folder: output folder path

In [None]:
from annotatepropwizard3d.utils.visualize import mask_propagation_visualize

In [None]:
os.makedirs("output/mask_propagation", exist_ok=True)
result_dict = {i: result[i] for i in result.sum((1, 2)).nonzero()[0].flatten()}
mask_propagation_visualize(img, result_dict, "output/mask_propagation")