In [19]:
import gradio as gr
import numpy as np
import nibabel as nib
from totalsegmentator.python_api import totalsegmentator
import tempfile
import matplotlib.pyplot as plt
import os
import time
from image_utils import normalize_to_uint8, rotate_90_cc, rotate_90_c, resize_image, overlay_mask_on_slice

import io
import base64
from PIL import Image

volume_cache = {}

def array_to_base64_png(arr):
    if arr.ndim == 2:
        arr = np.stack([arr]*3, axis=-1)
    img = Image.fromarray(arr.astype(np.uint8))
    buf = io.BytesIO()
    img.save(buf, format='PNG')
    return base64.b64encode(buf.getvalue()).decode()

def segment(file, progress=gr.Progress()):
    progress(0, desc="Starting segmentation...")

    input_path = file.name
    output_dir = tempfile.mkdtemp()

    volume_cache["volume_nii"] = nib.load(input_path)

    progress(0.05, desc="Running TotalSegmentator...")
    seg_img = totalsegmentator(
        input=input_path,
        output=output_dir,
        task="total_mr",
        quiet=True,
        fast=True,
        ml=False,
        skip_saving=False,
        output_type="nifti"
    )

    time.sleep(1)
    progress(0.75, desc="Processing results...")
    seg_data = np.nan_to_num(seg_img.get_fdata().astype(np.uint8))
    print(seg_data.shape) # (256, 256, 22)
    labels_data = np.nan_to_num(nib.load(input_path).get_fdata())
    volume_cache["volume"] = np.squeeze(labels_data)
    volume_cache["mask"] = np.squeeze(seg_data)
    mid = volume_cache["volume"].shape[2] // 2
    slice_2d = normalize_to_uint8(volume_cache["volume"][:, :, mid])
    mask_2d = (volume_cache["mask"][:, :, mid] == 5).astype(np.float32)

    rotated_slice = resize_image(rotate_90_cc(slice_2d))
    rotated_mask = resize_image(rotate_90_cc(mask_2d))

    overlay = overlay_mask_on_slice(rotated_slice, rotated_mask)

    time.sleep(1)
    progress(1.0, desc="Done.")

    # return gr.update(value={
    #     "background": overlay,
    #     "layers": None,
    #     "composite": overlay
    # }, visible=True), gr.update(visible=True), gr.update(visible=True, maximum=volume_cache["volume"].shape[2] - 1, value=mid), "Segmentation complete.", gr.update(visible=True)

    return gr.update(value=overlay, visible=True), gr.update(visible=True), gr.update(visible=True, maximum=volume_cache["volume"].shape[2] - 1, value=mid), "Segmentation complete.", gr.update(visible=True)

def update_slice(index):
    vol = volume_cache.get("volume")
    mask_vol = volume_cache.get("mask")
    if vol is None or mask_vol is None:
        raise gr.Error("No volume loaded.")
    slice_2d = normalize_to_uint8(vol[:, :, index])
    mask_2d = (mask_vol[:, :, index] == 5).astype(np.uint8)
    rotated_slice = resize_image(rotate_90_cc(slice_2d))
    rotated_mask = resize_image(rotate_90_cc(mask_2d))
    overlay = overlay_mask_on_slice(rotated_slice, rotated_mask)

    # return {
    #     "background": overlay,
    #     "layers": None,
    #     "composite": overlay
    # }

    return overlay

def apply_edited_mask(brush_type, editor_value, index):
    print(editor_value)
    if editor_value is None or "layers" not in editor_value or len(editor_value["layers"]) == 0:
        return "No edited mask provided."

    edited_layer = np.array(editor_value["layers"][0])

    if edited_layer.ndim >= 3:
        edited_layer = edited_layer[..., 0]

    edited_layer_resized = resize_image(edited_layer, (volume_cache["mask"].shape[0], volume_cache["mask"].shape[1]))
    edited_mask = rotate_90_c((edited_layer_resized > 0).astype(np.uint8))

    original_mask_slice = volume_cache["mask"][:, :, index]

    if brush_type == "Remove":
        new_mask_slice = np.where(edited_mask == 1, 0, original_mask_slice)
    else:
        new_mask_slice = np.where(edited_mask == 1, 5, original_mask_slice)
        

    volume_cache["mask"][:, :, index] = new_mask_slice
    updated_slice = update_slice(index)

    return updated_slice, "Mask updated."

def draw(index, evt: gr.SelectData):
    x, y = evt.index
    radius = 15

    original_mask_slice = volume_cache["mask"][:, :, index]
    mask_edited = rotate_90_cc(resize_image(original_mask_slice))

    h, w = mask_edited.shape
    Y, X = np.ogrid[:h, :w]
    dist_from_center = (X - x)**2 + (Y - y)**2
    circular_mask = dist_from_center <= radius**2

    mask_edited[circular_mask] = 5
    mask_restored = rotate_90_c(resize_image(mask_edited, original_mask_slice.shape))

    volume_cache["mask"][:, :, index] = mask_restored
    updated_slice = update_slice(index)

    return updated_slice, f"Drawn ({x}, {y}) with radius {radius}."


def render_vol():
    if "mask" not in volume_cache:
        raise gr.Error("No mask available for rendering.")
    
    stl_path = render_vol_from_mask(volume_cache["mask"], label_value=5)
    return gr.update(value=stl_path, visible=True)


def render_vol_from_mask(mask_3d, label_value=5):
    from skimage import measure
    import trimesh
    
    binary_mask = (mask_3d == label_value).astype(np.uint8)

    if "volume_nii" not in volume_cache:
        raise gr.Error("Original NIfTI volume info not found.")
    spacing = volume_cache["volume_nii"].header.get_zooms()[:3]

    verts, faces, normals, _ = measure.marching_cubes(binary_mask, level=0.5, spacing=spacing)
    mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals)

    tmp_dir = tempfile.mkdtemp()
    stl_path = os.path.join(tmp_dir, "mask_mesh.stl")
    mesh.export(stl_path)

    return stl_path


with gr.Blocks() as demo:
    status = gr.Textbox(label="Status", interactive=False, value="Pending file input.")

    file_input = gr.File(label="Upload NIfTI")

    # slice viewer and mask editor
    brush_type = gr.Dropdown(["Add", "Remove"], label="Brush mode", visible=False, interactive=True)
    image_editor = gr.ImageEditor(label="Brush-Editable Liver Mask", type="numpy", height=800, visible=False)
    image_viewer = gr.Image(label="Image viewer", type="numpy", visible=False)
    slice_slider = gr.Slider(minimum=0, maximum=1, step=1, label="Slice Index", visible=False)
    
    # 3d model
    render_button = gr.Button(value="Show 3d", visible=False)
    volume = gr.Model3D(label="3d model", visible=False)

    file_input.change(
        fn=segment, 
        inputs=file_input, 
        outputs=[image_viewer, brush_type, slice_slider, status, render_button]
    )

    slice_slider.change(
        fn=update_slice, 
        inputs=slice_slider, 
        outputs=image_viewer
    )

    image_viewer.select(
        fn=draw,
        inputs=[slice_slider],
        outputs=[image_viewer, status]
    )

    # image_editor.apply(
    #     fn=apply_edited_mask, 
    #     inputs=[brush_type, image_editor, slice_slider], 
    #     outputs=[image_editor,status]
    # )
    render_button.click(fn=render_vol, outputs=[volume])

app = demo.launch(show_error=True, server_port=7693)



Running on local URL:  http://127.0.0.1:7693

To create a public link, set `share=True` in `launch()`.


No GPU detected. Running on CPU. This can be very slow. The '--fast' or the `--roi_subset` option can help to reduce runtime.
(256, 256, 22)
