In [12]:
import gradio as gr
import numpy as np
import nibabel as nib
from totalsegmentator.python_api import totalsegmentator
import tempfile

volume_cache = {}

def normalize_to_uint8(arr: np.ndarray) -> np.ndarray:
    arr = np.nan_to_num(arr)
    arr_min = arr.min()
    arr_max = arr.max()
    if arr_max - arr_min == 0:
        return np.zeros_like(arr, dtype=np.uint8)
    norm = (arr - arr_min) / (arr_max - arr_min)
    return (norm * 255).astype(np.uint8)

def create_dummy_mask(height, width):
    mask = np.zeros((height, width), dtype=np.float32)
    cx, cy = width // 2, height // 2
    r = min(height, width) // 4
    Y, X = np.ogrid[:height, :width]
    dist_sq = (X - cx)**2 + (Y - cy)**2
    mask[dist_sq <= r**2] = 1.0
    return mask

def rotate_90_cc(arr: np.ndarray) -> np.ndarray:
    h, w = arr.shape
    new_arr = np.zeros((w, h), dtype=arr.dtype)
    for i in range(h):
        for j in range(w):
            new_arr[w - 1 - j, i] = arr[i, j]
    return new_arr


def segment(file):
    input_path = file.name
    output_dir = tempfile.mkdtemp()
    seg_img = totalsegmentator(
        input=input_path,
        output=output_dir,
        task="total_mr",   # adjust in future
        quiet=True,
        fast=True,
        ml=False,
        skip_saving=False,
        output_type="nifti"
    )

    seg_data = np.nan_to_num(seg_img.get_fdata().astype(np.uint8))
    labels_data = np.nan_to_num(nib.load(input_path).get_fdata())

    volume_cache["volume"] = np.squeeze(labels_data)
    volume_cache["mask"] = np.squeeze(seg_data)

    mid = volume_cache["volume"].shape[2] // 2
    slice_2d = normalize_to_uint8(volume_cache["volume"][:, :, mid])
    mask_2d = (volume_cache["mask"][:, :, mid] == 109).astype(np.float32)  # liver = 109

    return (slice_2d, [(mask_2d, "liver")]), gr.update(visible=True, maximum=volume_cache["volume"].shape[2] - 1, value=mid)


def update_slice(index):
    vol = volume_cache.get("volume")
    mask_vol = volume_cache.get("mask")

    if vol is None:
        raise gr.Error("No volume loaded.")
    
    slice_2d = normalize_to_uint8(vol[:, :, index])
    mask_2d = (mask_vol[:, :, index] == 5).astype(np.float32)

    rotated_slice_2d = rotate_90_cc(slice_2d)
    rotated_mask_2d = rotate_90_cc(mask_2d)

    return (rotated_slice_2d, [(rotated_mask_2d, "liver")])


with gr.Blocks() as demo:
    file_input = gr.File(label="Upload NIfTI file")
    slice_viewer = gr.AnnotatedImage(label="MRI Slice + Mask", color_map={"liver": "#bb3f3f"})
    slice_slider = gr.Slider(minimum=0, maximum=1, step=1, label="Slice Index", visible=False)

    file_input.change(fn=segment, inputs=file_input, outputs=[slice_viewer, slice_slider])
    slice_slider.change(fn=update_slice, inputs=slice_slider, outputs=slice_viewer)

demo.launch(show_error=True)


Running on local URL:  http://127.0.0.1:7868

Thanks for being a Gradio user! If you have questions or feedback, please join our Discord server and chat with us: https://discord.gg/feTf9x3ZSB

To create a public link, set `share=True` in `launch()`.




No GPU detected. Running on CPU. This can be very slow. The '--fast' or the `--roi_subset` option can help to reduce runtime.


Python(79609) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(79610) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(79611) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(79612) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(79613) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(79614) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(79615) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(79616) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(79617) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(79618) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(79619) Malloc