In [None]:
import gradio as gr
import numpy as np
import nibabel as nib
from totalsegmentator.python_api import totalsegmentator
import tempfile
import matplotlib.pyplot as plt
import os

volume_cache = {}

def normalize_to_uint8(arr: np.ndarray) -> np.ndarray:
    arr = np.nan_to_num(arr)
    arr_min = arr.min()
    arr_max = arr.max()
    if arr_max - arr_min == 0:
        return np.zeros_like(arr, dtype=np.uint8)
    norm = (arr - arr_min) / (arr_max - arr_min)
    return (norm * 255).astype(np.uint8)

def rotate_90_cc(arr: np.ndarray) -> np.ndarray:
    h, w = arr.shape
    new_arr = np.zeros((w, h), dtype=arr.dtype)
    for i in range(h):
        for j in range(w):
            new_arr[w - 1 - j, i] = arr[i, j]
    return new_arr

def rotate_90_c(arr: np.ndarray) -> np.ndarray:
    h, w = arr.shape
    new_arr = np.zeros((w,h), dtype=arr.dtype)
    for i in range(h):
        for j in range(w):
            new_arr[i, j] = arr[w - 1 - j, i]
    return new_arr

def resize_image(arr: np.ndarray, target_size=(1024, 1024)) -> np.ndarray:
    from PIL import Image
    pil = Image.fromarray(arr)
    resized = pil.resize(target_size, resample=Image.NEAREST if arr.ndim == 2 else Image.BILINEAR)
    return np.array(resized)

def overlay_mask_on_slice(slice_img: np.ndarray, mask: np.ndarray, color=(187, 63, 63)) -> np.ndarray:
    if slice_img.ndim == 2:
        background = np.stack([slice_img]*3, axis=-1)
    else:
        background = slice_img.copy()

    mask = mask.astype(bool)
    overlay = background.copy()
    for i in range(3):
        overlay[..., i] = np.where(
            mask,
            (1 - 0.5) * background[..., i] + 0.5 * color[i],
            background[..., i]
        )

    return overlay.astype(np.uint8)

def segment(file):

    input_path = file.name
    output_dir = tempfile.mkdtemp()
    seg_img = totalsegmentator(
        input=input_path,
        output=output_dir,
        task="total_mr",
        quiet=True,
        fast=True,
        ml=False,
        skip_saving=False,
        output_type="nifti"
    )
    seg_data = np.nan_to_num(seg_img.get_fdata().astype(np.uint8))
    print(seg_data.shape) # (256, 256, 22)
    labels_data = np.nan_to_num(nib.load(input_path).get_fdata())
    volume_cache["volume"] = np.squeeze(labels_data)
    volume_cache["mask"] = np.squeeze(seg_data)
    mid = volume_cache["volume"].shape[2] // 2
    slice_2d = normalize_to_uint8(volume_cache["volume"][:, :, mid])
    mask_2d = (volume_cache["mask"][:, :, mid] == 5).astype(np.float32)

    rotated_slice = resize_image(rotate_90_cc(slice_2d))
    rotated_mask = resize_image(rotate_90_cc(mask_2d))

    overlay = overlay_mask_on_slice(rotated_slice, rotated_mask)

    return {
        "background": overlay,
        "layers": None,
        "composite": overlay
    }, gr.update(visible=True), gr.update(visible=True, maximum=volume_cache["volume"].shape[2] - 1, value=mid)

def update_slice(index):
    vol = volume_cache.get("volume")
    mask_vol = volume_cache.get("mask")
    if vol is None or mask_vol is None:
        raise gr.Error("No volume loaded.")
    slice_2d = normalize_to_uint8(vol[:, :, index])
    mask_2d = (mask_vol[:, :, index] == 5).astype(np.uint8)
    rotated_slice = resize_image(rotate_90_cc(slice_2d))
    rotated_mask = resize_image(rotate_90_cc(mask_2d))
    overlay = overlay_mask_on_slice(rotated_slice, rotated_mask)

    return {
        "background": overlay,
        "layers": None,
        "composite": overlay
    }

def apply_edited_mask(brush_type, editor_value, index):
    if editor_value is None or "layers" not in editor_value or len(editor_value["layers"]) == 0:
        return "No edited mask provided."

    edited_layer = np.array(editor_value["layers"][0])

    if edited_layer.ndim >= 3:
        edited_layer = edited_layer[..., 0]

    edited_layer_resized = resize_image(edited_layer, (volume_cache["mask"].shape[0], volume_cache["mask"].shape[1]))
    edited_mask = rotate_90_c((edited_layer_resized > 0).astype(np.uint8))

    original_mask_slice = volume_cache["mask"][:, :, index]

    if brush_type == "Add":
        new_mask_slice = np.where(edited_mask == 1, 5, original_mask_slice)
    elif brush_type == "Remove":
        new_mask_slice = np.where(edited_mask == 1, 0, original_mask_slice)

    volume_cache["mask"][:, :, index] = new_mask_slice
    updated_slice = update_slice(index)

    return updated_slice, "Mask modified."


with gr.Blocks() as demo:
    file_input = gr.File(label="Upload NIfTI")
    brush_type = gr.Dropdown(["Add", "Remove"], label="Brush mode", visible=False, interactive=True)
    image_editor = gr.ImageEditor(label="Brush-Editable Liver Mask", type="numpy", height=600)
    slice_slider = gr.Slider(minimum=0, maximum=1, step=1, label="Slice Index", visible=False)
    # apply_button = gr.Button("Apply Edited Mask")
    status = gr.Textbox(label="Status", interactive=False)

    file_input.change(fn=segment, inputs=file_input, outputs=[image_editor, brush_type, slice_slider])
    slice_slider.change(fn=update_slice, inputs=slice_slider, outputs=image_editor)
    # apply_button.click(fn=apply_edited_mask, inputs=[image_editor, slice_slider], outputs=status)
    image_editor.apply(apply_edited_mask, inputs=[brush_type, image_editor, slice_slider], outputs=[image_editor,status])

demo.launch(show_error=True)


Running on local URL:  http://127.0.0.1:7935

To create a public link, set `share=True` in `launch()`.




No GPU detected. Running on CPU. This can be very slow. The '--fast' or the `--roi_subset` option can help to reduce runtime.


Python(982) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(983) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(984) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(985) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(986) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(987) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(988) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(989) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(990) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(991) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(992) MallocStackLogging: can't tu

(256, 256, 22)
