<a href="https://colab.research.google.com/github/lczamprogno/ct2us/blob/main/CT2US.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CT2US

This tool is intended to automate the generation of simulated ultrasound image and label pairs from ct volumes (.nii/.nii.gz).

CT to Ultrasound simulation with tissue label maps
- Developed a modular tool to supplement ultrasound segmentation datasets.
- Created a pipeline to process computerized tomography volumes, extract labels for different tissue types and simulate ultrasound slices.
- Improved performance, implementing CPU and GPU optimizations.
- Created an interface and visualizations, to allow a preview of results, through overlapped slice annotations and sampled point clouds.

---

### Purpose
Intended to be capable of supplementing datasets for ultrasound image labeling.

## Expandability
Image generation process is very dependant on tissue attenuation, so specialized US renderers would be necessary/ideal to expand this tool to work on other body parts. For this purpose, much of the following code has hence been designed with modularity as a core goal, so that new methods can be added/replaced, as for example the segmentation quality or speed could have a significant impact on overall results. 

---

## Current use:
- ![example](https://github.com/lczamprogno/ct2us/blob/main/assets/Full%20Demo.gif)
  

## Further goals:
- code for two alternate optimized segmentation pipelines is still being developed

- Improved version of the totalsegmentator nnunet is still WIP. Once that is taken care of, pluging this in the pipeline with the stacked assemble should yield a significant speed up. [ ]

---

This needs to be run once and then the session needs to be restarted

In [None]:
%pip install totalsegmentator numba cupy-cuda12x torchvision xmltodict torchio cucim "bokeh>=3.1.0" di gradio pathlib trimesh[easy]

# Import pipeline components and ultrasound rendering

In [1]:
try:
    # Import pipeline components
    from pipeline.dataset import CTDataset
    from pipeline.pipeline_config import CT2USPipelineFactory, PipelineConfig
except ImportError:
    try:
        !git clone https://github.com/lczamprogno/ct2us.git

        from pipeline.dataset import CTDataset
        from pipeline.pipeline_config import CT2USPipelineFactory, PipelineConfig
    
    except Exception as e:
        print(e)

# Classes and methods are gathered here

## Run this block

IMPORTANT: Acquire a totalsegmentator key (https://backend.totalsegmentator.com/license-academic/) and set google colab secret as shown:

![a](https://github.com/lczamprogno/ct2us/blob/main/assets/secret.png)

In [2]:
global license
try: 
    from google.colab import userdata
    license = userdata.get('license_key')
except ImportError as e:
    print("Not running in Google Colab, using default license key.")
    # If you are running this in a different environment, set your license key here
    license = ""

import os
os.environ["TS_LICENSE_KEY"] = license

Not running in Google Colab, using default license key.


In [3]:
import sys

from pathlib import PosixPath as pthlib
from zipfile import ZipFile
import random

from math import pi
import matplotlib.pyplot as plt
import tqdm

from itertools import islice
import glob
import shutil

# Use string for path to avoid PosixPath issues with sys.path
this_folder = str(pthlib("../CT2US").resolve())

import numpy as np

sys.path.append(this_folder)
ts_cfg_path = pthlib(this_folder).joinpath(".totalsegmentator")
ts_cfg_path.mkdir(exist_ok=True, parents=True)
os.environ["TOTALSEG_HOME_DIR"] = str(ts_cfg_path)

# First check CUDA availability and print status
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"CUDA device count: {torch.cuda.device_count()}")
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
else:
    print("CUDA is not available. Will use CPU instead.")

# Now try to import numba and CUDA components
from numba import jit, njit, cuda

# Try to import cupy if CUDA is available
if torch.cuda.is_available():
    try:
        import cupy as cp
        import cupyx.scipy.ndimage as cusci
        print(f"CuPy version: {cp.__version__}")
        print("CuPy and cusci loaded successfully")
    except ImportError as e:
        print(f"Error loading cupy or cusci even though CUDA is available: {e}")
        cp = None
        cusci = None
else:
    print("Not attempting to load CuPy since CUDA is not available")
    cp = None
    cusci = None

import scipy.ndimage
import scipy

from torchvision import transforms
from torch import device
from torch import uint8

from torch.utils.data import DataLoader

import gradio as gr

# Set default device based on CUDA availability
device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}")
torch.set_default_device(device)

PyTorch version: 2.5.1+cu124
CUDA available: True
CUDA version: 12.4
CUDA device count: 1
CUDA device name: NVIDIA GeForce RTX 4060 Laptop GPU
CuPy version: 13.3.0
CuPy and cusci loaded successfully
Using device: cuda:0


# Configure and run the CT2US pipeline

In [4]:
# Define base directories
img_dir = pthlib(this_folder).joinpath("imgs")
label_dir = pthlib(this_folder).joinpath("labels")
us_dir = pthlib(this_folder).joinpath("us")
gen_dir = pthlib(this_folder).joinpath("gen")

# Create directories if they don't exist
os.makedirs(str(img_dir), exist_ok=True)
os.makedirs(str(label_dir), exist_ok=True)
os.makedirs(str(us_dir), exist_ok=True)
os.makedirs(str(gen_dir), exist_ok=True)

# Define tissue types mapping for the UI
TISSUE_TYPES = {
    0: "Background",
    1: "Background",
    2: "Lung",
    3: "Fat",
    4: "Vessel",
    5: "Unused",
    6: "Kidney",
    7: "Unused",
    8: "Muscle",
    9: "Background",
    10: "Unused",
    11: "Liver",
    12: "Soft Tissue",
    13: "Bone"
}

In [5]:
# Define base factory for the pipeline
from pipeline.component_classes import UltrasoundRenderingMethod


global _factory

_factory = CT2USPipelineFactory()

class CACTUSS(UltrasoundRenderingMethod):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.description = "CACTUSS ultrasound rendering method"
        self.tissue_types = TISSUE_TYPES
        
    def render(self, data):
        # Implement the rendering logic here
        raise NotImplementedError("Placeholder for CACTUSS rendering logic")

    def name():
        return "CACTUSS"
        
_factory.register_rendering_method(CACTUSS)

GPU 0 (NVIDIA GeForce RTX 4060 Laptop GPU) - Total: 8.00GB, Used: 0.00GB, Free: 8.00GB
Using GPU 0 for processing (free memory: 8.00GB)
Pipeline configured to use device: cuda


In [6]:
def process_ct_images(segmentation_method, rendering_method, step_size=1, **kwargs):
    """Process CT images using the CT2US pipeline with configurable parameters.
    
    Args:
        ct_images: List of paths to CT images (.nii.gz files)
        step_size: Step size for slicing the volume
        segmentation_method: Segmentation method to use ("TotalSegmentator" or "TotalSegmentatorFast")
        rendering_method: Rendering method to use ("lotus")
        **kwargs: Additional configuration parameters to pass to components
        
    Returns:
        Tuple containing:
        - List of destination label names
        - List of ultrasound images
        - List of warped labels
        - List of viewable label images
        - List of point cloud data
        - Dictionary with timing information
    """
    # Initialize dataset
    local_dataset = CTDataset(
        img_dir=str(img_dir),
        resample=None
    )
    
    # Create data loader
    ct_dataloader = DataLoader(
        local_dataset, 
        batch_size=1, 
        collate_fn=local_dataset.collate_fn
    )
    
    # Organize kwargs into component-specific configs
    # Extract parameters for each component type
    segmentation_config = {}
    rendering_config = {}
    pointcloud_config = {}
    
    # Binary operations parameters (used by both segmentation and rendering)
    for param in ['binary_dilation_iterations', 'binary_erosion_iterations', 'density_min', 'density_max']:
        if param in kwargs:
            segmentation_config[param] = kwargs[param]
            rendering_config[param] = kwargs[param]
    
    # Rendering-specific parameters
    for param in ['resize_size', 'crop_size']:
        if param in kwargs:
            rendering_config[param] = kwargs[param]
    
    # Point cloud-specific parameters (if any)
    if 'pointcloud_settings' in kwargs:
        pointcloud_config.update(kwargs['pointcloud_settings'])
    
    # Use string paths for intermediate directory
    intermediate_dir = kwargs.get('intermediate_dir', './intermediates')
    if hasattr(intermediate_dir, 'startswith') and not intermediate_dir.startswith('/'):
        # Convert relative path to absolute if it's a string
        intermediate_dir = os.path.join(this_folder, intermediate_dir)
    
    # Create pipeline configuration
    pipeline_config = PipelineConfig(
        device='cuda' if torch.cuda.is_available() else 'cpu',
        save_intermediates=kwargs.get('save_intermediates', False),
        intermediate_dir=intermediate_dir,
        segmentation_config=segmentation_config,
        rendering_config=rendering_config,
        pointcloud_config=pointcloud_config
    )
    
    # Set segmentation method
    pipeline_config.set_segmentator(segmentation_method)
    
    # Set rendering method
    pipeline_config.set_renderer(rendering_method)
    
    pipeline = _factory.create_pipeline(pipeline_config)
    
    # Process each batch of data
    labels = []
    us_images = []
    warped_labels = []
    viewable_labels = []
    timing_info = {}
    
    print("Processing data...")
    for data in tqdm.tqdm(ct_dataloader, desc="Processing batch"):
        imgs, properties, dest_labels, dest_us = data
        
        # Process with pipeline
        label_imgs, batch_us, batch_warped, batch_viewable, batch_timing = pipeline(
            imgs, properties, dest_labels, dest_us, step_size, False
        )
        
        # Store processed data
        labels.extend(label_imgs)
        us_images.extend(batch_us)
        warped_labels.extend(batch_warped)
        viewable_labels.extend(batch_viewable)
        timing_info.update(batch_timing)
    
    return labels, us_images, warped_labels, viewable_labels, timing_info, pipeline.pcd_sampler

# Gradio UI for interactive use

In [None]:
# Set environment for Gradio
os.environ["GRADIO_ALLOWED_PATHS"] = this_folder

def update_license(x):
    global license
    license = x

with gr.Blocks() as ct_2_us: 
    # Create state objects for storing data
    files = gr.State({})
    us_list = gr.State({})
    warped_list = gr.State({})
    label_list = gr.State({})
    pcdb_list = gr.State({})

    pcd_method_obj = gr.State({})

    img_idx = gr.State(0)
    slice_idx = gr.State(0)

    with gr.Row():
        with gr.Column(scale=1):
            # Input configuration
            gr.Markdown("Input Configuration")
            ct_imgs = gr.Files(
                file_types=['.nii', '.nii.gz'], 
                type='filepath', 
                label="Select CT images", 
                interactive=True, 
                file_count='multiple'
            )
            
            step_size = gr.Slider(
                label="Slicing step interval", 
                minimum=1, 
                maximum=20, 
                value=1, 
                step=1, 
                interactive=True
            )   

            license_key = gr.Textbox(
                label="License Key",
                placeholder="Enter your totalsegmentator license",
                value=license,
                interactive=True
            )
            

            license_key.change(fn=update_license, inputs=[license_key], outputs=None)

            # Advanced configuration (accessed via kwargs)
            with gr.Accordion("Advanced Configuration", open=False):
                binary_dilation_iterations = gr.Slider(
                    label="Binary Dilation Iterations", 
                    minimum=0, 
                    maximum=5, 
                    value=2, 
                    step=1, 
                    interactive=True
                )
                
                binary_erosion_iterations = gr.Slider(
                    label="Binary Erosion Iterations", 
                    minimum=0, 
                    maximum=5, 
                    value=3, 
                    step=1, 
                    interactive=True
                )
                
                density_min = gr.Slider(
                    label="Density Minimum", 
                    minimum=-500, 
                    maximum=0, 
                    value=-200, 
                    step=10, 
                    interactive=True
                )
                
                density_max = gr.Slider(
                    label="Density Maximum", 
                    minimum=0, 
                    maximum=500, 
                    value=250, 
                    step=10, 
                    interactive=True
                )
                
                resize_size = gr.Slider(
                    label="Resize Size", 
                    minimum=256, 
                    maximum=512, 
                    value=380, 
                    step=8, 
                    interactive=True
                )
                
                crop_size = gr.Slider(
                    label="Crop Size", 
                    minimum=128, 
                    maximum=384, 
                    value=256, 
                    step=8, 
                    interactive=True
                )
                
                save_intermediates = gr.Checkbox(
                    label="Save Intermediate Results", 
                    value=False, 
                    interactive=True
                )

            with gr.Row():
                btn = gr.Button("Generate", variant="primary")
                reset = gr.Button("Reset")

            # Define reset handler
            @gr.on([reset.click], inputs=None, 
                  outputs=[files, us_list, warped_list, label_list, pcdb_list])
            def reset_all():
                # Cleanup files
                for f in glob.glob(str(label_dir / '*.nii.gz')):
                    os.remove(f)
                for f in glob.glob(str(img_dir / '*.nii.gz')):
                    os.remove(f)
                for f in glob.glob(str(gen_dir / '*.glb')):
                    os.remove(f)
                for f in glob.glob(f"{str(us_dir)}/*"):
                    shutil.rmtree(f, ignore_errors=True)
                try:
                    os.remove(f"{this_folder}/results.zip")
                except:
                    print("No need to delete results")
                return {}, {}, {}, {}, {}

            with gr.Row():
                with gr.Column():
                    # Sample selection for demo
                    sample_in = gr.Dropdown(
                        choices=[i+1 for i in range(len(glob.glob(f"{this_folder}/sample/*.nii.gz")))], 
                        label='Amount of samples to randomly select',
                        info='Used for demo with no input',
                        value=1
                    )
                    
                    # Choose pipeline method
                    available_methods = _factory.config.methods["segmentation"].keys()
                    available_us = _factory.config.methods["rendering"].keys()

                    seg_method = gr.Radio(
                        choices=available_methods, 
                        value="TotalSegmentator", 
                        label="Segmentation method", 
                        interactive=True
                    )
                    
                    us_method = gr.Radio(
                        choices=available_us, 
                        value="LOTUS", 
                        label="US rendering method", 
                        interactive=True
                    )
        
        with gr.Column(scale=2):
            with gr.Tab(label='Pointcloud Settings', visible=False) as pcd_tab:
                @gr.render(inputs=[step_size], triggers=[us_list.change])
                def pcd_control(step):      
                    # Point cloud settings - only show if pipeline is available
                    gr.Markdown("### Point Cloud Settings")
                    
                    # Get tissue types with voxels
                    label_counts = _pcd_method.get_label_counts()
                    if label_counts and 0 in label_counts:
                        available_labels = sorted(label_counts[0].keys())
                        
                        # Create sliders for each tissue type
                        with gr.Row():
                            with gr.Column():
                                pcd_sliders = []
                                for i, label in enumerate(available_labels):
                                    if label in TISSUE_TYPES and TISSUE_TYPES[label] != "Unused":
                                        # Get current point count
                                        if i < len(_pcd_method.points_per_label):
                                            current_value = _pcd_method.points_per_label[i]
                                        else:
                                            current_value = 0
                                            
                                        # Calculate max points
                                        max_points = min(label_counts[0].get(label, 0), 400000)
                                        
                                        # Skip labels with no points
                                        if max_points == 0:
                                            continue
                                        
                                        # Create slider
                                        slider = gr.Slider(
                                            label=f"{TISSUE_TYPES[label]} Points",
                                            minimum=0,
                                            maximum=max_points,
                                            value=current_value,
                                            step=1000,
                                            interactive=True
                                        )
                                        pcd_sliders.append((label, slider))
                                        
                                # Create update button
                                update_pcd = gr.Button("Resample Pointcloud")
                                
                                # Handle updates
                                def update_point_cloud(x, y, *slider_values):
                                    # Get current points per label
                                    new_counts = _pcd_method.points_per_label.copy()
                                    
                                    # Update with slider values
                                    for (label, _), value in zip(pcd_sliders, slider_values):
                                        if label < len(new_counts):
                                            new_counts[label] = int(value)
                                        
                                    # Update point cloud
                                    _pcd_method.update_points_per_label(new_counts)

                                    # Re-export
                                    pcdb_new = _pcd_method.sample(x)
                                    
                                    try:
                                        # Make sure to use a safe value for the slice index
                                        safe_slice = min(int(y * step), pcdb_new[2][2]-1) if len(pcdb_new) >= 3 and hasattr(pcdb_new[2], "__len__") and len(pcdb_new[2]) >= 3 else 0
                                        _pcd_method.add_axis_pcd(pcdb_new, safe_slice).export(str(gen_dir / "current_pcd.glb"))
                                    except Exception as e:
                                        print(f"Error creating axis point cloud: {e}")
                                        # Create a fallback point cloud if there's an error
                                        import trimesh as tri
                                        fallback_pcd = tri.PointCloud([[0, 0, 0]], colors=[[255, 255, 255, 255]])
                                        fallback_pcd.export(str(gen_dir / "current_pcd.glb"))
                                    
                                    # Return the path to the new point cloud
                                    return pcdb_new
                                
                                # Connect button to handler
                                if pcd_sliders:
                                    update_pcd.click(
                                        fn=update_point_cloud,
                                        inputs=[img_idx, slice_idx] + [s[1] for s in pcd_sliders],
                                        outputs=pcdb_list
                                    )
                                

            with gr.Tab(label='Preview'):
                note = gr.Markdown(value="Generate US images first through the input tab")

                # Dynamic UI based on available data
                @gr.render(inputs=[files, us_list, warped_list, label_list, step_size], 
                         triggers=[us_list.change])
                def dynamic(fl, us, warped, ll, step):       
                    with gr.Column():
                        if len(us) > 0:
                            # Image selection
                            dropdown = gr.Dropdown(
                                choices=[(f, n) for n, f in fl.items()], 
                                label='Select image to preview', 
                                value=0
                            )
                            
                            # Slice selection
                            slider = gr.Slider(
                                minimum=0, 
                                maximum=len(warped[0]) - 1, 
                                step=step, 
                                label='Slice selection', 
                                value=0
                            )
                            
                            # Identity function for state updates
                            iden = lambda x: x

                            slider.release(fn=iden, inputs=[slider], outputs=[slice_idx])
                            dropdown.select(fn=iden, inputs=[dropdown], outputs=[img_idx])

                            # Results display
                            with gr.Column():
                                # Top row: US and label images
                                with gr.Row():
                                    base = gr.Image(
                                        label='US slice',
                                        value=np.asarray(us[img_idx.value][slice_idx.value], dtype=np.float32) if len(us) > 0 and len(us[0]) > 0 else np.zeros((256, 256), dtype=np.float32),
                                        height=300
                                    )
                                    
                                    label_preview = gr.Image(
                                        label='Label slice',
                                        value=ll[img_idx.value][slice_idx.value] if len(ll) > 0 else None,
                                        type='pil',
                                        height=300
                                    )
                                
                                # Bottom row: Annotation and 3D view
                                with gr.Row():
                                    comp = gr.AnnotatedImage(
                                        value=(us[img_idx.value][slice_idx.value], warped[img_idx.value][slice_idx.value]),
                                        height=300
                                    )
                                    
                                    volume_preview = gr.Model3D(
                                        clear_color=(0, 0, 0, 1), 
                                        label="Label map view", 
                                        value=str(gen_dir / "current_pcd.glb") if os.path.exists(str(gen_dir / "current_pcd.glb")) else None, 
                                        height=300
                                    )
                                    
                                pcdb_list.change(fn=lambda : str(gen_dir / "current_pcd.glb"), outputs=volume_preview)
                                    
                                # Update function for image/slice selection
                                def route(x, y, pcdb):
                                        new_y = y if y < len(us[x]) else 0
                                    
                                        b = us[x][new_y]
                                        w = warped[x][new_y]
                                        l = ll[x][new_y]
                                        
                                        # adjust current slice highlighted
                                        pcdb = _pcd_method.sample(x)

                                        # Make sure to use a safe value for the slice index
                                        safe_slice = min(int(y * step), pcdb[2][2]-1) if len(pcdb) >= 3 and hasattr(pcdb[2], "__len__") and len(pcdb[2]) >= 3 else 0
                                        _pcd_method.add_axis_pcd(pcdb, safe_slice).export(str(gen_dir / "current_pcd.glb"))
                                        
                                        # Calculate new slider value and ensure it's in range
                                        new_slider = gr.Slider(
                                            minimum=0, 
                                            maximum=len(warped[x]) - 1, 
                                            step=step, 
                                            label='Slice selection', 
                                            value=new_y
                                        )
                                        
                                        # Return updated UI components 
                                        return (b, w), b, l, str(gen_dir / "current_pcd.glb"), new_y, new_slider, pcdb
                                    # except Exception as e:
                                    #     print(f"Error in route function: {e}")
                                    #     # Return default values on error
                                    #     empty_img = np.zeros((256, 256), dtype=np.float32)
                                    #     return (empty_img, []), empty_img, None, None, 0, slider
                                    
                                def route_y(x, y, pcdb):
                                    # try:
                                        b = us[x][y]
                                        w = warped[x][y]
                                        l = ll[x][y]
                                        
                                        # adjust current slice highlighted
                                        try:
                                            # Make sure to use a safe value for the slice index
                                            safe_slice = min(int(y * step), pcdb[2][2]-1) if len(pcdb) >= 3 and hasattr(pcdb[2], "__len__") and len(pcdb[2]) >= 3 else 0
                                            _pcd_method.add_axis_pcd(pcdb, safe_slice).export(str(gen_dir / "current_pcd.glb"))
                                        except Exception as e:
                                            print(f"Error creating axis point cloud: {e}")
                                            # Create a fallback point cloud if there's an error
                                            import trimesh as tri
                                            fallback_pcd = tri.PointCloud([[0, 0, 0]], colors=[[255, 255, 255, 255]])
                                            fallback_pcd.export(str(gen_dir / "current_pcd.glb"))
                                        
                                        # Return updated UI components 
                                        return (b, w), b, l, str(gen_dir / "current_pcd.glb")
                                    # except Exception as e:
                                    #     print(f"Error in route function: {e}")
                                    #     # Return default values on error
                                    #     empty_img = np.zeros((256, 256), dtype=np.float32)
                                    #     return (empty_img, []), empty_img, None, None
                                
                                # Connect route function to UI events
                                gr.on(
                                    triggers=[img_idx.change],
                                    fn=route,
                                    inputs=[img_idx, slice_idx, pcdb_list],
                                    outputs=[comp, base, label_preview, volume_preview, slice_idx, slider, pcdb_list]
                                )
                                gr.on(
                                    triggers=[slice_idx.change],
                                    fn=route_y,
                                    inputs=[img_idx, slice_idx, pcdb_list],
                                    outputs=[comp, base, label_preview, volume_preview]
                                )
                            
            with gr.Tab(label='Download'):
                download = gr.DownloadButton(label="", visible=False)
                
                # Dynamic UI for download options
                @gr.render(inputs=[files, us_list, warped_list, label_list, pcdb_list, step_size], 
                         triggers=[us_list.change])
                def dynamic(fl, us, warped, ll, pcdb, step):
                    descr = gr.Markdown(label="This can be used to adjust contents of results.zip")
                    configs = gr.CheckboxGroup(
                        choices=["Save labels", "Save US images"], 
                        value=["Save labels", "Save US images"], 
                        label="Options", 
                        interactive=True
                    )
                    filename_in = gr.Textbox(label="Filename for result zip", value="results")

                    r = []
                    r.append(glob.glob(str(label_dir / '*.nii.gz')))
                    r.append(glob.glob(str(us_dir / '**/*.png')))

                    rezip = gr.Button("Reassemble results.zip")

                    @gr.on(rezip.click, inputs=[configs, filename_in], outputs=[download, descr])
                    def rezip_files(save_configs, name, r=r):
                        with ZipFile(f"{this_folder}/results.zip", 'w') as zipObj:
                            # Save labels if selected
                            if "Save labels" in save_configs:
                                for f in r[0]:
                                    zipObj.write(f, os.path.relpath(f, this_folder))
                                    # Don't delete labels until we're done
                            
                            # Save US images if selected
                            if "Save US images" in save_configs:
                                for f in r[1]:
                                    zipObj.write(f, os.path.relpath(f, this_folder))
                            
                            # Clean up if requested
                            if "Clean up after export" in save_configs:
                                for f in r[0]:
                                    os.remove(f)
                                for f in glob.glob(f"{str(us_dir)}/*"):
                                    shutil.rmtree(f, ignore_errors=True)

                        return f"{this_folder}/{name}.zip", "Results have been rezipped"

                def start(ct, step, method, method_us, fl_s, us_s, warped_s, ll_s, pcdb_s, nr_samples,
                         binary_dilation_iters, binary_erosion_iters, dens_min, dens_max, resize, crop, save_int,
                         progress=gr.Progress(track_tqdm=True)):                    
                    # Sample if no input
                    if not ct:
                        ct = glob.glob(f"{this_folder}/sample/*.nii.gz")
                        ct = [f for f in ct]

                    # Random sample if needed
                    if len(ct) > nr_samples:
                        ct = random.sample(ct, k=nr_samples)

                    # Copy files to working directory
                    for f in ct:
                        shutil.copyfile(f, str(img_dir / os.path.basename(f)))
                    
                    # Organize parameters into component-specific configs
                    # Common parameters for different components
                    segmentation_config = {
                        'binary_dilation_iterations': binary_dilation_iters,
                        'binary_erosion_iterations': binary_erosion_iters,
                        'density_min': dens_min,
                        'density_max': dens_max
                    }
                    
                    rendering_config = {
                        'binary_dilation_iterations': binary_dilation_iters,
                        'binary_erosion_iterations': binary_erosion_iters,
                        'density_min': dens_min,
                        'density_max': dens_max,
                        'resize_size': resize,
                        'crop_size': crop
                    }
                    
                    # Process images using the pipeline with all configuration parameters
                    labels, us_images, warped_labels, viewable_labels, timing_info, sampler = process_ct_images(
                        ct_images=ct,
                        step_size=step,
                        segmentation_method=method,
                        rendering_method=method_us,
                        save_intermediates=save_int,
                        segmentation_config=segmentation_config,
                        rendering_config=rendering_config
                    )
                    
                    # Get point cloud sampler for later adjustment
                    global _pcd_method
                    _pcd_method = sampler
                        
                    # Update state
                    fl_s.update(enumerate(labels))
                    us_s.update(enumerate(us_images))
                    ll_s.update(enumerate(viewable_labels))
                    warped_s.update(enumerate(warped_labels))
                    
                    # Sample point clouds and save initial 3D view
                    pcdb_s = _pcd_method.sample(0)
                    
                    try:
                        # Create initial view with slice 0
                        _pcd_method.add_axis_pcd(pcdb_s, 0).export(str(gen_dir / "current_pcd.glb"))
                    except Exception as e:
                        print(f"Error creating initial point cloud: {e}")
                        # Create a fallback point cloud if there's an error
                        import trimesh as tri
                        fallback_pcd = tri.PointCloud([[0, 0, 0]], colors=[[255, 255, 255, 255]])
                        fallback_pcd.export(str(gen_dir / "current_pcd.glb"))

                    # Return downloadable zip and updated states
                    return (
                        gr.DownloadButton(label="Download results as zip", visible=True, value=f"{this_folder}/results.zip"), 
                        fl_s, 
                        us_s, 
                        warped_s, 
                        ll_s, 
                        pcdb_s,
                        gr.Markdown(value="Processing complete!", height=30)
                    )
                
                def finalize(x):
                    return gr.Markdown(label="", value="", height=0, visible=False), gr.Tab(label="Pointcloud Settings", visible=True)
                                
                # Connect generate button
                btn.click(
                    fn=reset_all,
                    inputs=None, 
                    outputs=[files, us_list, warped_list, label_list, pcdb_list]
                ).success(
                    fn=lambda x: gr.Markdown(label="Status", value="Processing...", height=80), 
                    inputs=btn, 
                    outputs=note
                ).success(
                    fn=start, 
                    inputs=[
                        ct_imgs, step_size, seg_method, us_method, files, us_list, warped_list, 
                        label_list, pcdb_list, sample_in, binary_dilation_iterations, 
                        binary_erosion_iterations, density_min, density_max, resize_size, 
                        crop_size, save_intermediates
                    ],
                    outputs=[download, files, us_list, warped_list, label_list, pcdb_list, note]
                ).success(
                    fn=finalize,
                    inputs=btn,
                    outputs=[note, pcd_tab]
                )                           

# Launch the Gradio app
ct_2_us.launch(debug=True)

* Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


No need to delete results
Pipeline configured to use device: cuda
Processing data...
SEGMENTATING:
CUDA detected: True
Cleared CUDA cache
Processing image with TotalSegmentator (on device: cuda:0) for task: total...

If you use this tool please cite: https://pubs.rsna.org/doi/10.1148/ryai.230024

Resampling...
  Resampled in 1.60s
Predicting part 1 of 5 ...


  checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name),


Predicting part 2 of 5 ...


  checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name),


Predicting part 3 of 5 ...
Predicting part 4 of 5 ...
Predicting part 5 of 5 ...
