In [None]:
import os

import numpy as np
import SimpleITK as sitk
import torch
from skimage.measure import regionprops

from annotatepropwizard3d.xyroll_prediction import XYrollPrediction

## XYrollPrediction
Initialize the XYrollPrediction class.

* Args:
  * cutie_yaml (str): Path to the CUTIE configuration file. (you can get it from yaml directory)

In [None]:
model = XYrollPrediction("../yaml/eval_config.yaml")

## Download testing data
* Nii data from Amos
* Amos is a multi-organ segmentation dataset

In [None]:
os.makedirs("data", exist_ok=True)

output = "data/amos_0001_data.nii.gz"
if not os.path.exists(output):
    file_id = "1TtQyLI0X6nq90n0dw3zJFFZZzgrZj8br"
    url = f"https://drive.google.com/uc?id={file_id}"
    gdown.download(url, output)

output = "data/amos_0001_label.nii.gz"
if not os.path.exists(output):
    file_id = "1MgfbufE3802ZsNwQO8Xloce1ezXMszpM"
    url = f"https://drive.google.com/uc?id={file_id}"
    gdown.download(url, output)

## Load Nii data
* Load Nii data
* Choose class 1 as target mask
* Mask other class mask

In [None]:
sitk_img = sitk.ReadImage("data/amos_0001_data.nii.gz")
img = sitk.GetArrayFromImage(sitk_img)

sitk_mask = sitk.ReadImage("data/amos_0001_label.nii.gz")
mask = sitk.GetArrayFromImage(sitk_mask)

mask = mask.astype(int)
mask[mask != 2] = 0
mask[mask == 2] = 1

mask = mask.astype(np.uint8)

## Inference
* Grouping masks
* Inference
  * Predict the XYroll prediction for a given input image and initial mask.
  * Args:
    * img (np.ndarray): Input image as a NumPy array.
    * inital_mask (np.ndarray): Initial mask as a NumPy array.
    * input_z (int): The z-coordinate of the initial mask.

  * Returns:
    * np.ndarray: The predicted XYroll prediction as a NumPy array.
* Calculate F1 Matrix

In [None]:
for prop in regionprops(mask):
    init_center = np.array(prop.centroid, dtype=int)
    index_z = init_center[0]

    with torch.inference_mode():
        result = model.predict(img, mask[index_z], index_z)

    result_z = result.sum((1, 2)) > 0
    mask_z = mask.sum((1, 2)) > 0

    f1 = 2 * (result_z * mask_z).sum() / (result_z.sum() + mask_z.sum())
    print(f"f1 score is: {f1}")

## Visualization
* Visualize result for xy roll prediction, rotate the image along z axis
  * Args:
    * img_volume: (np.ndarray) original image volume
    * mask_volume: (np.ndarray) original mask volume
    * spacing_yxz: (tuple) spacing in yxz order
    * output_folder: (str) output folder path
    * device: (torch.device) device to use
    * rotate_degree: (int) degree to rotate the image

In [None]:
from annotatepropwizard3d.utils.visualize import visualize_zaxis_rotation

In [None]:
os.makedirs("output/xyroll", exist_ok=True)
visualize_zaxis_rotation(img, result, sitk_img.GetSpacing(), 'output/xyroll', 'cuda')

In [None]:
import imageio
import glob

filenames = glob.glob('output/xyroll/*.png')
images = []
for filename in filenames:
    images.append(imageio.imread(filename))
# for filename in filenames[::-1]:
#     images.append(imageio.imread(filename))
imageio.mimsave('movie.gif', images, loop=2, duration = 0.01)