<a href="https://colab.research.google.com/github/lczamprogno/ct2us/blob/main/CT2US.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CT2US

This tool is intended to automate the generation of simulated ultrasound image and label pairs from ct volumes (.nii/.nii.gz).

---

### Purpose
Intended to be capable of supplementing datasets for ultrasound image labeling.

## Expandability
Image generation process is very dependant on tissue attenuation, so specialized US renderers would be necessary/ideal to expand this tool to work on other body parts. For this purpose, much of the following code has hence been designed with modularity as a core goal, so that new methods can be added/replaced, as for example the segmentation quality or speed could have a significant impact on overall results. 

---

## Current use:
- ![example](../assets/Full%20Demo.gif)
  

## Further goals:
- code for two alternate optimized segmentation pipelines is still being developed
  - one focusing on avoiding internal totalsegmentator steps being saved to memory [ ]
  - another further optimizes by properly using gpu and cpu acceleration. [ ]

- Throughout the process of acquiring the necessary information for the rendering - and even for visualization - could be a useful resource. With that in mind, enabling saving these intermediary results is a goal, albeit one that has yet to be fully achieved. The download tab is intended to serve the purpose of adjusting saved information. [ ]

- Improved version of the totalsegmentator nnunet is still WIP. Once that is taken care of, pluging this in the pipeline with the stacked assemble should yield a significant speed up. [ ]

- Currently there is some support for cpu, but this script is mainly designed with gpu acceleration in mind. For cpu optimization, the label_map dict needs to be used to gather all the label names which are contained. This list then can be used by just plugging it in the totalsegmentator function as a parameter, and the built in ROI should yield faster results. [ ]

- There seemed to be no significant improvement through using the ROI feature in the gpu version. Here, it is important to note that allocation of gpu memory and copying data over has a significant memory cost, which is actually one of the main improvements of the new totalsegmentator pipeline implemented. [ ]

---


This needs to be run once and then the session needs to be restarted

In [None]:
%pip install colab
%pip install totalsegmentator numba cupy-cuda12x torchvision xmltodict torchio cucim "bokeh>=3.1.0" di gradio pathlib trimesh[easy]

#numpy-stl

# Classes and methods are here

## Run this block

IMPORTANT: Acquire a totalsegmentator key (https://backend.totalsegmentator.com/license-academic/) and set google colab secret as shown:

![a](../assets/secret.png)

In [6]:
import os
import sys

from pathlib import PosixPath as pthlib
from zipfile import ZipFile
import random

import matplotlib.pyplot as plt
from math import pi
import tqdm

from itertools import islice
import glob
import shutil

import numpy as np
from numpy import uint8


from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
from nibabel import nifti1

import totalsegmentator.python_api as ts
from totalsegmentator.config import setup_nnunet, setup_totalseg, set_config_key, get_weights_dir

this_folder = pthlib("../CT2US").resolve()

sys.path.append(this_folder)
ts_cfg_path = pthlib.joinpath(this_folder, ".totalsegmentator")
ts_cfg_path.mkdir(exist_ok=True, parents=True)
os.environ["TOTALSEG_HOME_DIR"] = str(ts_cfg_path)

try: 
    from google.colab import userdata
    license = userdata.get('license_key')
except ImportError as e:
    print(e)

from totalsegmentator.libs import download_model_with_license_and_unpack, download_url_and_unpack
from totalsegmentator import resampling as rs
from totalsegmentator.map_to_binary import commercial_models

setup_nnunet()
setup_totalseg()

ts.set_license_number(license)

from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
from nnunetv2.inference.sliding_window_prediction import compute_gaussian

from cucim.skimage.transform import resize
from batchgenerators.utilities.file_and_folder_operations import join

from numba import jit, njit, cuda

try:
    import cupy as cp
    import cupyx.scipy.ndimage as cusci
except ImportError:
    print("Error loading cupy and cusci, GPU not available?")

import scipy.ndimage
import scipy

from torchvision import transforms
import torch
from torch import device

from torch.utils.data import DataLoader

import gradio as gr
import trimesh as tri
from PIL import Image

device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
torch.set_default_device(device)
print(device)

No module named 'google'
cuda:0


US slice simulation code(from https://github.com/danivelikova/lotus/blob/main/models/us_rendering_model.py)

In [7]:
# 2 - lung; 3 - fat; 4 - vessel; 6 - kidney; 8 - muscle; 9 - background; 11 - liver; 12 - soft tissue; 13 - bone;
# Default Parameters from: https://github.com/Blito/burgercpp/blob/master/examples/ircad11/liver.scene , labels 8, 9 and 12 approximated from other labels

                     # indexes:           2       3     4     6     8      9     11    12    13
acoustic_imped_def_dict = torch.tensor([0.0004, 1.38, 1.61,  1.62, 1.62,  0.3,  1.65, 1.63, 7.8], requires_grad=True).to(device=device)    # Z in MRayl
attenuation_def_dict =    torch.tensor([1.64,   0.63, 0.18,  1.0,  1.09, 0.54,  0.7,  0.54, 5.0], requires_grad=True).to(device=device)    # alpha in dB cm^-1 at 1 MHz
mu_0_def_dict =           torch.tensor([0.78,   0.5,  0.001, 0.45,  0.45,  0.3,  0.4, 0.45, 0.78], requires_grad=True).to(device=device) # mu_0 - scattering_mu   mean brightness
mu_1_def_dict =           torch.tensor([0.56,   0.5,  0.0,   0.6,  0.64,  0.2,  0.8,  0.64, 0.56], requires_grad=True).to(device=device) # mu_1 - scattering density, Nr of scatterers/voxel
sigma_0_def_dict =        torch.tensor([0.1,    0.0,  0.01,  0.3,  0.1,   0.0,  0.14, 0.1,  0.1], requires_grad=True).to(device=device) # sigma_0 - scattering_sigma - brightness std


alpha_coeff_boundary_map = 0.1
beta_coeff_scattering = 10  #100 approximates it closer
TGC = 8
CLAMP_VALS = True


def gaussian_kernel(size: int, mean: float, std: float):
    d1 = torch.distributions.Normal(mean, std)
    d2 = torch.distributions.Normal(mean, std*3)
    vals_x = d1.log_prob(torch.arange(-size, size+1, dtype=torch.float32)).exp()
    vals_y = d2.log_prob(torch.arange(-size, size+1, dtype=torch.float32)).exp()

    gauss_kernel = torch.einsum('i,j->ij', vals_x, vals_y)

    return gauss_kernel / torch.sum(gauss_kernel).reshape(1, 1)

g_kernel = gaussian_kernel(3, 0., 0.5)
g_kernel = torch.tensor(g_kernel[None, None, :, :], dtype=torch.float32).to(device=device)


class UltrasoundRendering(torch.nn.Module):
    def __init__(self, params, default_param=False):
        super(UltrasoundRendering, self).__init__()
        self.params = params

        if default_param:
            self.acoustic_impedance_dict = acoustic_imped_def_dict.detach().clone()
            self.attenuation_dict = attenuation_def_dict.detach().clone()
            self.mu_0_dict = mu_0_def_dict.detach().clone()
            self.mu_1_dict = mu_1_def_dict.detach().clone()
            self.sigma_0_dict = sigma_0_def_dict.detach().clone()

        else:
            self.acoustic_impedance_dict = torch.nn.Parameter(acoustic_imped_def_dict)
            self.attenuation_dict = torch.nn.Parameter(attenuation_def_dict)

            self.mu_0_dict = torch.nn.Parameter(mu_0_def_dict)
            self.mu_1_dict = torch.nn.Parameter(mu_1_def_dict)
            self.sigma_0_dict = torch.nn.Parameter(sigma_0_def_dict)

        self.labels = ["lung", "fat", "vessel", "kidney", "muscle", "background", "liver", "soft tissue", "bone"]

        self.attenuation_medium_map, self.acoustic_imped_map, self.sigma_0_map, self.mu_1_map, self.mu_0_map  = ([] for i in range(5))


    def map_dict_to_array(self, dictionary, arr):
        mapping_keys = torch.tensor([2, 3, 4, 6, 8, 9, 11, 12, 13], dtype=torch.long).to(device=device)
        keys = torch.unique(arr).to(device=device)

        index = torch.where(mapping_keys[None, :] == keys[:, None])[1]
        values = torch.gather(dictionary, dim=0, index=index)
        values = values.to(device=device)
        # values.register_hook(lambda grad: print(grad))    # check the gradient during training

        mapping = torch.zeros(keys.max().item() + 1).to(device=device)
        mapping[keys] = values
        return mapping[arr]


    def plot_fig(self, fig, fig_name, grayscale):
        save_dir='results_test/'
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        plt.clf()

        if torch.is_tensor(fig):
            fig = fig.cpu().detach().numpy()

        if grayscale:
            plt.imshow(fig, cmap='gray', vmin=0, vmax=1, interpolation='none', norm=None)
        else:
            plt.imshow(fig, interpolation='none', norm=None)
        plt.axis('off')
        plt.savefig(save_dir + fig_name + '.png', bbox_inches='tight',transparent=True, pad_inches=0)


    def clamp_map_ranges(self):
        self.attenuation_medium_map = torch.clamp(self.attenuation_medium_map, 0, 10)
        self.acoustic_imped_map = torch.clamp(self.acoustic_imped_map, 0, 10)
        self.sigma_0_map = torch.clamp(self.sigma_0_map, 0, 1)
        self.mu_1_map = torch.clamp(self.mu_1_map, 0, 1)
        self.mu_0_map = torch.clamp(self.mu_0_map, 0, 1)


    def rendering(self, H, W, z_vals=None, refl_map=None, boundary_map=None):

        dists = torch.abs(z_vals[..., :-1, None] - z_vals[..., 1:, None])     # dists.shape=(W, H-1, 1)
        dists = dists.squeeze(-1)                                             # dists.shape=(W, H-1)
        dists = torch.cat([dists, dists[:, -1, None]], dim=-1)                # dists.shape=(W, H)

        attenuation = torch.exp(-self.attenuation_medium_map * dists)
        attenuation_total = torch.cumprod(attenuation, dim=1, dtype=torch.float32, out=None)

        gain_coeffs = np.linspace(1, TGC, attenuation_total.shape[1])
        gain_coeffs = np.tile(gain_coeffs, (attenuation_total.shape[0], 1))
        gain_coeffs = torch.tensor(gain_coeffs).to(device=device)
        attenuation_total = attenuation_total * gain_coeffs     # apply TGC

        reflection_total = torch.cumprod(1. - refl_map * boundary_map, dim=1, dtype=torch.float32, out=None)
        reflection_total = reflection_total.squeeze(-1)
        reflection_total_plot = torch.log(reflection_total + torch.finfo(torch.float32).eps)

        texture_noise = torch.randn(H, W, dtype=torch.float32).to(device=device)
        scattering_probability = torch.randn(H, W, dtype=torch.float32).to(device=device)

        scattering_zero = torch.zeros(H, W, dtype=torch.float32).to(device=device)

        z = self.mu_1_map - scattering_probability
        sigmoid_map = torch.sigmoid(beta_coeff_scattering * z)

        # approximating  Eq. (4) to be differentiable:
        # where(scattering_probability <= mu_1_map,
        #                     texture_noise * sigma_0_map + mu_0_map,
        #                     scattering_zero)
        scatterers_map =  (sigmoid_map) * (texture_noise * self.sigma_0_map + self.mu_0_map) + (1 -sigmoid_map) * scattering_zero   # Eq. (6)

        psf_scatter_conv = torch.nn.functional.conv2d(input=scatterers_map[None, None, :, :], weight=g_kernel, stride=1, padding="same")
        psf_scatter_conv = psf_scatter_conv.squeeze()

        b = attenuation_total * psf_scatter_conv    # Eq. (3)

        border_convolution = torch.nn.functional.conv2d(input=boundary_map[None, None, :, :], weight=g_kernel, stride=1, padding="same")
        border_convolution = border_convolution.squeeze()

        r = attenuation_total * reflection_total * refl_map * border_convolution # Eq. (2)

        intensity_map = b + r   # Eq. (1)
        intensity_map = intensity_map.squeeze()
        intensity_map = torch.clamp(intensity_map, 0, 1)

        return intensity_map, attenuation_total, reflection_total_plot, scatterers_map, scattering_probability, border_convolution, texture_noise, b, r


    def render_rays(self, W, H):
        N_rays = W
        t_vals = torch.linspace(0., 1., H).to(device=device)   # 0-1 linearly spaced, shape H
        z_vals = t_vals.unsqueeze(0).expand(N_rays , -1) * 4

        return z_vals

    # warp the linear US image to approximate US image from curvilinear US probe
    def warp_img(self, inputImage):
        resultWidth = 360
        resultHeight = 220
        centerX = resultWidth / 2
        centerY = -120.0
        maxAngle =  60.0 / 2 / 180 * pi #rad
        minAngle = -maxAngle
        minRadius = 140.0
        maxRadius = 340.0

        h, w = inputImage.squeeze().shape

        import torch.nn.functional as F

        # Create x and y grids
        x = torch.arange(resultWidth).float() - centerX
        y = torch.arange(resultHeight).float() - centerY
        xx, yy = torch.meshgrid(x, y)

        # Calculate angle and radius
        angle = torch.atan2(xx, yy)
        radius = torch.sqrt(xx ** 2 + yy ** 2)

        # Create masks for angle and radius
        angle_mask = (angle > minAngle) & (angle < maxAngle)
        radius_mask = (radius > minRadius) & (radius < maxRadius)

        # Calculate original column and row
        origCol = (angle - minAngle) / (maxAngle - minAngle) * w
        origRow = (radius - minRadius) / (maxRadius - minRadius) * h

        # Reshape input image to be a batch of 1 image
        inputImage = inputImage.float().unsqueeze(0).unsqueeze(0)

        # Scale original column and row to be in the range [-1, 1]
        origCol = origCol / (w - 1) * 2 - 1
        origRow = origRow / (h - 1) * 2 - 1

        # Transpose input image to have channels first
        inputImage = inputImage.permute(0, 1, 3, 2)

        # Use grid_sample to interpolate
        grid = torch.stack([origCol, origRow], dim=-1).unsqueeze(0).to(device)
        resultImage = F.grid_sample(inputImage, grid, mode='bilinear', align_corners=True)

        # Apply masks and set values outside of mask to 0
        resultImage[~(angle_mask.unsqueeze(0).unsqueeze(0) & radius_mask.unsqueeze(0).unsqueeze(0))] = 0.0
        resultImage_resized = transforms.Resize((256,256))(resultImage).float().squeeze()

        return resultImage_resized


    def forward(self, ct_slice):
        if self.params["debug"]: self.plot_fig(ct_slice, "ct_slice", False)

        #init tissue maps
        #generate 2D acousttic_imped map
        self.acoustic_imped_map = self.map_dict_to_array(self.acoustic_impedance_dict, ct_slice)#.astype('int64'))

        #generate 2D attenuation map
        self.attenuation_medium_map = self.map_dict_to_array(self.attenuation_dict, ct_slice)

        if self.params["debug"]:
            self.plot_fig(self.acoustic_imped_map, "acoustic_imped_map", False)
            self.plot_fig(self.attenuation_medium_map, "attenuation_medium_map", False)

        self.mu_0_map = self.map_dict_to_array(self.mu_0_dict, ct_slice)

        self.mu_1_map = self.map_dict_to_array(self.mu_1_dict, ct_slice)

        self.sigma_0_map = self.map_dict_to_array(self.sigma_0_dict, ct_slice)

        self.acoustic_imped_map = torch.rot90(self.acoustic_imped_map, 1, [0, 1])
        diff_arr = torch.diff(self.acoustic_imped_map, dim=0)

        diff_arr = torch.cat((torch.zeros(diff_arr.shape[1], dtype=torch.float32).unsqueeze(0).to(device=device), diff_arr))

        boundary_map =  -torch.exp(-(diff_arr**2)/alpha_coeff_boundary_map) + 1

        boundary_map = torch.rot90(boundary_map, 3, [0, 1])

        if self.params["debug"]:
           self.plot_fig(diff_arr, "diff_arr", False)
           self.plot_fig(boundary_map, "boundary_map", True)

        shifted_arr = torch.roll(self.acoustic_imped_map, -1, dims=0)
        shifted_arr[-1:] = 0

        sum_arr = self.acoustic_imped_map + shifted_arr
        sum_arr[sum_arr == 0] = 1
        div = diff_arr / sum_arr

        refl_map = div ** 2
        refl_map = torch.sigmoid(refl_map)      # 1 / (1 + (-refl_map).exp())
        refl_map = torch.rot90(refl_map, 3, [0, 1])

        if self.params["debug"]: self.plot_fig(refl_map, "refl_map", True)

        z_vals = self.render_rays(ct_slice.shape[0], ct_slice.shape[1])

        if CLAMP_VALS:
            self.clamp_map_ranges()

        ret_list = self.rendering(ct_slice.shape[0], ct_slice.shape[1], z_vals=z_vals, refl_map=refl_map, boundary_map=boundary_map)

        intensity_map  = ret_list[0]

        if self.params["debug"]:
            self.plot_fig(intensity_map, "intensity_map", True)

            result_list = ["intensity_map", "attenuation_total", "reflection_total",
                            "scatters_map", "scattering_probability", "border_convolution",
                            "texture_noise", "b", "r"]

            for k in range(len(ret_list)):
                result_np = ret_list[k]
                if torch.is_tensor(result_np):
                    result_np = result_np.detach().cpu().numpy()

                if k==2:
                    self.plot_fig(result_np, result_list[k], False)
                else:
                    self.plot_fig(result_np, result_list[k], True)
                # print(result_list[k], ", ", result_np.shape)

        intensity_map_masked = self.warp_img(intensity_map)
        intensity_map_masked = torch.rot90(intensity_map_masked, 3)

        if self.params["debug"]:  self.plot_fig(intensity_map_masked, "intensity_map_masked", True)

        return intensity_map_masked


  return func(*args, **kwargs)


Segmentation, Composition and US slicing code

In [8]:
total_lmap = {"0": 0, "1": 12, "2": 6, "3": 6, "4": 8, "5": 11, "6": 8, "7": 12, "8": 12, "9": 12, "10": 2, "11": 2, "12": 2, "13": 2, "14": 2, "15": 8, "16": 0, "17": 0, "18": 8, "19": 12, "20": 8, "21": 0, "22": 0, "23": 6, "24": 6, "25": 13, "26": 13, "27": 13, "28": 13, "29": 13, "30": 13, "31": 13, "32": 13, "33": 13, "34": 13, "35": 13, "36": 13, "37": 13, "38": 13, "39": 13, "40": 13, "41": 13, "42": 13, "43": 13, "44": 13, "45": 13, "46": 13, "47": 13, "48": 13, "49": 13, "50": 13, "51": 8, "52": 4, "53": 4, "54": 8, "55": 4, "56": 4, "57": 4, "58": 4, "59": 4, "60": 4, "61": 4, "62": 4, "63": 4, "64": 4, "65": 4, "66": 4, "67": 4, "68": 4, "69": 13, "70": 13, "71": 13, "72": 13, "73": 13, "74": 13, "75": 0, "76": 0, "77": 0, "78": 0, "79": 0, "80": 0, "81": 0, "82": 0, "83": 0, "84": 0, "85": 0, "86": 0, "87": 0, "88": 0, "89": 0, "90": 0, "91": 0, "92": 13, "93": 13, "94": 13, "95": 13, "96": 13, "97": 13, "98": 13, "99": 13, "100": 13, "101": 13, "102": 13, "103": 13, "104": 13, "105": 13, "106": 13, "107": 13, "108": 13, "109": 13, "110": 13, "111": 13, "112": 13, "113": 13, "114": 13, "115": 13, "116": 13, "117": 0}

name2label = {
            "total": {
                "2": ["lung"],
                "4": ["aorta", "artery", "atrial", "iliac", "vein","vena"],
                "6": ["kidney"],
                "8": ["bowel", "colon", "esophagus", "gallbladder", "heart", "stomach", "trunk", "autochlon", "iliopsoas", "gluteus"],
                "11": ["liver"],
                "12": ["adrenal_gland", "duodenum", "pancreas", "spleen"],
                "13": ["clavicula", "humerus", "rib_", "vertebrae_", "sacrum", "scapula", "sternum", "femur", "hip", "fibula", "tibia", "radius", "ulna", "carpal", "tarsal", "patella"]
            },
            "body":{
                "bg": 9,
                "skin": 12,
                "fat": 3,
                "muscle": 8
            }
        }


palettedata = [ 0,0,0, 0,0,0, 220,30,30, 170,80,0, 0,170,0, 0,0,0, 0,175,20, 0,0,0, 0,170,190, 0,0,0, 0,0,0, 0,120,230, 115,65,200, 255,0,150] 

pointpalette = [torch.tensor([[0,0,0, 255],
                            [0,0,0, 255],
                            [220,30,30, 255],
                            [170,80,0, 31], 
                            [0,170,0, 255], 
                            [0,0,0, 255],
                            [0,175,20, 255], 
                            [0,0,0, 255],
                            [0,170,190, 255], 
                            [0,0,0, 255],
                            [0,0,0, 255], 
                            [0,120,230, 255], 
                            [115,65,200, 31], 
                            [255,0,150, 255]]),
                [0, 0, 100000, 100000, 100000, 0, 100000, 0, 100000, 0, 0, 100000, 400000, 100000]]

# from torch.nn import OptimizedModule
def dict_2_map(d: dict[list[uint8], uint8]) -> list[list[uint8]]:
    map = [[] for _ in range(15)]

    for k, v in d.items():
        int_k = uint8(k)
        map[v].append(int_k)

    return map

def batched(self, iterable, n):
    it = iter(iterable)
    while batch := tuple(islice(it, n)):
        yield batch

# Save time by initializing predictors once, instead of for each task
def initialize_predictors(device,
                        folds: list = (0,)) -> dict:
    """
    Initialize nnUNetPredictor instances for each segmentation task.

    Args:
        device (str): Device to run predictions on (device, 'cpu', 'mps').
        use_folds (tuple): Fold indices to use for prediction.

    Returns:
        dict: Dictionary mapping task names to their respective nnUNetPredictor instances.
    """
    # Define tasks
    tasks = [("total",
            [291, 292, 293, 294, 295],
            ["Dataset291_TotalSegmentator_part1_organs_1559subj",
            "Dataset292_TotalSegmentator_part2_vertebrae_1532subj",
            "Dataset293_TotalSegmentator_part3_cardiac_1559subj",
            "Dataset294_TotalSegmentator_part4_muscles_1559subj",
            "Dataset295_TotalSegmentator_part5_ribs_1559subj"],
            ["/v2.0.0-weights/Dataset291_TotalSegmentator_part1_organs_1559subj.zip",
            "/v2.0.0-weights/Dataset292_TotalSegmentator_part2_vertebrae_1532subj.zip",
            "/v2.0.0-weights/Dataset293_TotalSegmentator_part3_cardiac_1559subj.zip",
            "/v2.0.0-weights/Dataset294_TotalSegmentator_part4_muscles_1559subj.zip",
            "/v2.0.0-weights/Dataset295_TotalSegmentator_part5_ribs_1559subj.zip"],
            "nnUNetTrainerNoMirroring",
            False),
            ("tissue_types",
            [481],
            ["Dataset481_tissue_1559subj"],
            [],
            "nnUNetTrainer",
            True),
            ("body",
            [299],
            ["Dataset299_body_1559subj"],
            ["/v2.0.0-weights/Dataset299_body_1559subj.zip"],
            "nnUNetTrainer",
            False)]

    commercial_models_inv = {v: k for k, v in commercial_models.items()}
    base_url = "https://github.com/wasserth/TotalSegmentator/releases/download"

    # Get weights directory
    weights_dir = get_weights_dir()
    os.makedirs(weights_dir, exist_ok=True)

    predictors = {}
    for task_name, task_ids, paths, urls, trainer,with_license in tasks:
        print(f"INIT: {task_name} predictor")
        if with_license:
            for i in range(len(task_ids)):
                cfg_dataset = weights_dir / paths[i] / (trainer + '__nnUNetPlans__3d_fullres') / 'dataset.json'
                if paths[i] not in os.listdir(weights_dir):
                    download_model_with_license_and_unpack(commercial_models_inv[task_ids[i]], weights_dir)

                    # # directly remaps label assignments, saving time
                    # with open(cfg_dataset, mode='r') as f:
                    #     data = json.load(f)

                    # data['labels']['subcutaneous_fat'] = name2label['body']['fat']
                    # data['labels']['torso_fat'] = name2label['body']['fat']
                    # data['labels']['skeletal_muscle'] = name2label['body']['muscle']

                    # with open(cfg_dataset, mode='w') as f:
                    #     json.dump(data, f, indent=4)
                    

                # Initialize the predictor
                predictor = nnUNetPredictor(
                    tile_step_size=0.5,
                    use_gaussian=True,
                    use_mirroring=False,
                    perform_everything_on_device=(device.type == device),
                    device=device,
                    verbose=True,
                    allow_tqdm=True
                )
                # Initialize from the trained model folder
                predictor.initialize_from_trained_model_folder(
                    str(weights_dir / paths[i] / (trainer + "__nnUNetPlans__3d_fullres")),
                    use_folds=folds,
                    checkpoint_name='checkpoint_final.pth'
                )

                predictors[task_ids[i]] = predictor

        else:
            for i in range(len(urls)):
                cfg_dataset = weights_dir / paths[i] / (trainer + '__nnUNetPlans__3d_fullres') / 'dataset.json'
                if paths[i] not in os.listdir(weights_dir):
                    download_url_and_unpack(base_url + urls[i], weights_dir)

                    # # directly remaps label assignments, saving time
                    # with open(cfg_dataset, mode='r') as f:
                    #     data = json.load(f)

                    # if task_name == 'total':
                    #     for name, value in data['labels'].items():
                    #         data['labels'][name] = total_lmap[str(value)]
                    # else:
                    #     # TODO we could merge body trunc with extremeties and get free skin generation for the whole body
                    #     pass

                    # with open(cfg_dataset, mode='w') as f:
                    #     json.dump(data, f, indent=4)


                # Initialize the predictor
                predictor = nnUNetPredictor(
                    tile_step_size=0.5,
                    use_gaussian=True,
                    # use_mirroring=(task_name!='total'),
                    use_mirroring=False,
                    perform_everything_on_device=(device.type == device),
                    device=device,
                    verbose=True,
                    allow_tqdm=True
                )
                # Initialize from the trained model folder
                predictor.initialize_from_trained_model_folder(
                    str(weights_dir / paths[i] / (trainer + "__nnUNetPlans__3d_fullres")),
                    use_folds=folds,
                    checkpoint_name='checkpoint_final.pth'
                )
                predictors[task_ids[i]] = predictor

    return predictors

def bin_erosion(kernel:torch.Tensor, padded:torch.Tensor, ret:torch.Tensor):
    # Assumes stacked 3d and no normalization needed
    i, hdx, idx, jdx = cuda.grid(4)

    # Run kernel
    window = padded[i,
                    hdx-int((kernel.shape[0]-1) / 2):hdx+int((kernel.shape[0]-1) / 2),
                    idx-int((kernel.shape[0]-1) / 2):idx+int((kernel.shape[0]-1) / 2),
                    jdx-int((kernel.shape[0]-1) / 2):jdx+int((kernel.shape[0]-1) / 2)]
    # TODO: does this also get JITed?
    match = torch.all(kernel == window)
    ret[i, hdx, idx, jdx] = 1 if match else 0

def bin_dilation(kernel:torch.Tensor, padded:torch.Tensor, ret:torch.Tensor):
    # Assumes stacked 3d and no normalization needed
    i, hdx, idx, jdx = cuda.grid(4)

    # Run kernel
    window = padded[i,
                    hdx-int((kernel.shape[0]-1) / 2):hdx+int((kernel.shape[0]-1) / 2),
                    idx-int((kernel.shape[0]-1) / 2):idx+int((kernel.shape[0]-1) / 2),
                    jdx-int((kernel.shape[0]-1) / 2):jdx+int((kernel.shape[0]-1) / 2)]
    # TODO: does this also get JITed?
    match = torch.any(kernel == window)
    ret[i, hdx, idx, jdx] = 1 if match else 0

class CT2US(torch.nn.Module):
    def seg_predictor(self, imgs, properties, task, resamp_thr):
        return self.predictors[task].predict_from_list_of_npy_arrays(imgs,
                                                    None,
                                                    properties,
                                                    None, 2, save_probabilities=False,
                                                    num_processes_segmentation_export=resamp_thr)
    
    def seg_new(self, imgs, properties, task, resamp_thr):
        return self.predictors[task].predict_from_data_iterator(
                                        self.iterator(self.predictors[task], imgs, properties),
                                        save_probabilities=False,
                                        num_processes_segmentation_export=resamp_thr
                                    )

        # Does not work for some reason, totalsegmentator returns zeros instead of labels
    def seg_old(self, imgs, properties, task, resamp_thr):
        ret = []

        # TODO Get list from "total" labels and find matches in totalsegmentator
        roi = [l for _, l in self.name2label["total"].items()]
        roi = np.concatenate(roi).tolist()
        if task == "total":
            for img in imgs:
                ret.append(np.asarray(ts.totalsegmentator(
                                    input=img,
                                    task=task,
                                    nr_thr_resamp=resamp_thr
                                    # roi_subset=roi
                                ).dataobj, dtype=np.uint8))

        else:
            for img in imgs:
                ret.append(np.asarray(ts.totalsegmentator(
                                    input=img,
                                    task=task,
                                    nr_thr_resamp=resamp_thr
                                ).dataobj, dtype=np.uint8))


        return ret

    def __init__(self, method: str = 'old'):
        super(CT2US, self).__init__()
        methods = {'old', 'new', 'predictor'}
        if not method in methods:
            raise KeyError(f"Method not supported, choose from {methods}")
        else:
            self.method = method


        self.device = device
        if device.type == device and torch.cuda.is_available():
            self.m = cp
            self.dil_t = cuda.jit(bin_dilation)
            self.er_t = cuda.jit(bin_erosion)
            self.ops = cusci

        else:
            self.m = np
            self.dil_t = njit(bin_dilation)
            self.er_t = njit(bin_erosion)
            self.ops = scipy.ndimage

        if not method == 'old':
            predictors = initialize_predictors(device=device, folds=[0])
            self.predictors = predictors
            self.predictor_keys = predictors.keys()

        segmentator = {
            # 'new': self.predict_tensor_iter,
            'new': self.seg_new,
            'predictor': self.seg_predictor,
            'old': self.seg_old
        }
        self.segmentator = segmentator[method]

        us = {
            'new': self.to_us_sim_old,
            'predictor': self.to_us_sim_old,
            'old': self.to_us_sim_old
        }
        self.us = us[method]

        composer = {
            'new': self.stacked_assemble,
            'predictor': self.stacked_assemble,
            'old': self.assemble
        }
        self.composer = composer[method]

        hparams = {
            'debug' : False,
            'device' : device
        }

        self.ultrasound_rendering = UltrasoundRendering(hparams, default_param=True)

        self.total_lmap = total_lmap
        self.name2label = name2label
        self.tmap = dict_2_map(self.total_lmap)

    def bin_dilation(self, imgs:torch.Tensor, kernel_size:int=3 ,iterations:int=1):
        kernel = torch.ones((kernel_size, kernel_size, kernel_size), dtype=torch.uint8)
        if imgs.is_cuda:
            d_imgs = cuda.as_cuda_array(imgs.detach())
            kernel = cuda.as_cuda_array(kernel.detach())
            threadsperblock = (1, kernel_size, kernel_size, kernel_size)
            blocks = (imgs.shape[0],
                        np.ceil(imgs.shape[1] / threadsperblock[1]),
                        np.ceil(imgs.shape[2] / threadsperblock[2]),
                        np.ceil(imgs.shape[3] / threadsperblock[3]))
            for _ in iterations:
                ret = cuda.as_cuda_array(torch.zeros(imgs.shape, device=imgs.device).detach())
                padded = cuda.as_cuda_array(
                            imgs.to_padded_tensor(
                                padding=0,
                                output_size=(imgs.shape[0], imgs.shape[1] + 2, imgs.shape[2] + 2,imgs.shape[3] + 2)
                            ).detach())
                self._dil_cuda[blocks, threadsperblock](kernel, padded, ret)
                d_imgs = ret
        else:
            d_imgs = imgs.detach().numpy()
            kernel = kernel.detach()
            for _ in iterations:
                ret = torch.zeros(imgs.shape, device=imgs.device).detach()
                padded = imgs.to_padded_tensor(
                            padding=0,
                            output_size=(imgs.shape[0], imgs.shape[1] + 2, imgs.shape[2] + 2,imgs.shape[3] + 2)
                        ).detach()

                self._dil_cpu(kernel, padded, ret)
                d_imgs = ret

        return d_imgs

    def bin_erosion(self, imgs:torch.Tensor, kernel_size:int=3 ,iterations:int=1):
        kernel = torch.ones((kernel_size, kernel_size, kernel_size), dtype=torch.uint8)
        if imgs.is_cuda:
            d_imgs = cuda.as_cuda_array(imgs.detach())
            kernel = cuda.as_cuda_array(kernel.detach())
            threadsperblock = (1, kernel_size, kernel_size, kernel_size)
            blocks = (imgs.shape[0],
                        np.ceil(imgs.shape[1] / threadsperblock[1]),
                        np.ceil(imgs.shape[2] / threadsperblock[2]),
                        np.ceil(imgs.shape[3] / threadsperblock[3]))
            for _ in iterations:
                ret = cuda.as_cuda_array(torch.zeros(imgs.shape, device=imgs.device).detach())
                padded = cuda.as_cuda_array(
                            imgs.to_padded_tensor(
                                padding=0,
                                output_size=(imgs.shape[0], imgs.shape[1] + 2, imgs.shape[2] + 2,imgs.shape[3] + 2)
                            ).detach())
                self._er_cuda[blocks, threadsperblock](kernel, padded, ret)
                d_imgs = ret
        else:
            d_imgs = imgs.detach().numpy()
            kernel = kernel.detach()
            for _ in iterations:
                ret = torch.zeros(imgs.shape, device=imgs.device).detach()
                padded = imgs.to_padded_tensor(
                            padding=0,
                            output_size=(imgs.shape[0], imgs.shape[1] + 2, imgs.shape[2] + 2,imgs.shape[3] + 2)
                        ).detach()

                self._er_cpu(kernel, padded, ret)
                d_imgs = ret

        return d_imgs

    # Adapted from nnUNetPredictor
    def iterator(self,
                predictor: nnUNetPredictor,
                imgs: list[np.ndarray],
                properties: list[dict]):

        # MAYBE: look at data_iterators.preprocess_fromnpy_save_to_queue for vstack use for ROI foreground masking

        # pp = predictor.get_data_iterator_from_raw_npy_data(
        #     imgs,
        #     properties
        # )

        # preprocessor = predictor.configuration_manager.preprocessor_class(verbose=predictor.verbose)

        # properties = {key: [i[key] for i in properties] for key in properties[0]}

        # data, seg = preprocessor.run_case_npy(
        #                 np.stack(imgs),
        #                 None,
        #                 properties,
        #                 predictor.plans_manager,
        #                 predictor.configuration_manager,
        #                 predictor.dataset_json
        #             )

        # pass

        preprocessor = predictor.configuration_manager.preprocessor_class(verbose=predictor.verbose)
        for a, p in zip(imgs, properties):
            data, seg = preprocessor.run_case_npy(a,
                                                  None,
                                                  p,
                                                  predictor.plans_manager,
                                                  predictor.configuration_manager,
                                                  predictor.dataset_json)
            yield {'data': torch.from_numpy(data).contiguous().pin_memory(), 'data_properties': p, 'ofile': None}

    def convert_logits_to_segmentation(self, prediction, properties, predictor):
        spacing_transposed = [properties['spacing'][i] for i in predictor.plans_manager.transpose_forward]
        current_spacing = predictor.configuration_manager.spacing if \
            len(predictor.configuration_manager.spacing) == \
            len(properties['shape_after_cropping_and_before_resampling']) else \
            [spacing_transposed[0], *predictor.configuration_manager.spacing]
        predicted_logits = predictor.configuration_manager.resampling_fn_probabilities(predicted_logits,
                                                properties['shape_after_cropping_and_before_resampling'],
                                                current_spacing,
                                                [properties['spacing'][i] for i in predictor.plans_manager.transpose_forward])
        # return value of resampling_fn_probabilities can be ndarray or Tensor but that does not matter because
        # apply_inference_nonlin will convert to torch
        predicted_probabilities = predictor.label_manager.apply_inference_nonlin(predicted_logits)
        del predicted_logits
        segmentation = predictor.label_manager.convert_probabilities_to_segmentation(predicted_probabilities)

        # put segmentation in bbox (revert cropping)
        segmentation_reverted_cropping = np.zeros(properties['shape_before_cropping'],
                                                dtype=np.uint8 if len(predictor.label_manager.foreground_labels) < 255 else np.uint16)
        slicer = tuple([slice(*i) for i in properties['bbox_used_for_cropping']])
        segmentation_reverted_cropping[slicer] = segmentation
        del segmentation

        pass

    # Adapted from nnUNetPredictor
    def predict_tensor_iter(self,
                        data_iterator) -> list[torch.tensor]:

        r = []
        for preprocessed in data_iterator:
            asm = []
            data = preprocessed['data']

            properties = preprocessed['data_properties']

            for predictor in self.predictors.values():
                old_threads = torch.get_num_threads()
                # HYPERPARAMETER: number of threads to use for prediction
                default_num_processes = 4
                torch.set_num_threads(default_num_processes if default_num_processes < old_threads else old_threads)
                prediction = None

                for params in predictor.list_of_parameters:

                    # messing with state dict names...
                    # if not isinstance(predictor.network, OptimizedModule):
                    #     predictor.network.load_state_dict(params)
                    # else:
                    #     predictor.network._orig_mod.load_state_dict(params)

                    if prediction is None:
                        prediction = predictor.predict_sliding_window_return_logits(data)
                    else:
                        prediction += predictor.predict_sliding_window_return_logits(data)

                if len(predictor.list_of_parameters) > 1:
                    prediction /= len(self.list_of_parameters)

                prediction = self.convert_logits_to_segmentation(prediction, properties, predictor)

            print(f'\nDone with image of shape {data.shape}:')

            # clear lru cache
            compute_gaussian.cache_clear()
            # clear device cache
            if device.type == device:
                torch.cuda.empty_cache()

            r.append()

        return [i.get()[0] for i in r]

    def assemble(self,
                task:str,
                segs:list[np.ndarray],
                bases:list[np.ndarray],
                prev:list[np.ndarray]) -> list[np.ndarray]:

        print(f"ASSEMBLY STARTED: {task}")
        # Process total segmentation

        if task == 'total':
            for j in range(len(segs)):
                for i in range(len(self.tmap)):
                    if len(self.tmap[i]) > 0:  # if there are any keys for this value
                        a = self.m.where(self.m.isin(self.m.asarray(segs[j], dtype=self.m.uint8), self.m.array(self.tmap[i])), self.m.uint8(i), self.m.uint8(0))
                        prev[j] += a

        if task == 'tissue_types':
            for j in range(len(segs)):
                t = self.m.asarray(segs[j])

                prev[j][t == 1] = self.m.uint8(self.name2label["body"]["fat"])
                prev[j][t == 2] = self.m.uint8(self.name2label["body"]["fat"])

        if task == 'body':
            for j in range(len(segs)):
                t = self.m.asarray(segs[j])

                body = self.ops.binary_dilation(t == 1, iterations=1).astype(self.m.uint8)
                body_inner = self.ops.binary_erosion(t, iterations=3, brute_force=True).astype(self.m.uint8)
                
                skin = body - body_inner
                
                # Segment by density
                # Roughly the skin density range. Made large to make segmentation not have holes
                # (0 to 250 would have many small holes in skin)
                density_mask = (bases[j] > -200) & (bases[j] < 250)
                skin[~density_mask] = 0

                # Fill holes
                # skin = binary_closing(skin, iterations=1)  # no real difference
                # skin = binary_dilation(skin, iterations=1)  # not good

                mask, _ = self.ops.label(skin)
                counts = self.m.bincount(mask.flatten())  # number of pixels in each blob

                # If only one blob (only background) abort because nothing to remove
                if len(counts) > 1:
                    remove = self.m.where((counts <= 10) | (counts > 30), True, False)
                    remove_idx = self.m.nonzero(remove)[0]
                    mask[self.m.isin(self.m.array(mask), remove_idx)] = 0
                    mask[mask > 0] = 1

                # Removing blobs
                # End of snippet from totalsegmentator

                dilation_kernel = self.m.ones(shape=(2, 2, 2))

                skin = self.m.where(self.ops.binary_dilation(skin == 1, structure=dilation_kernel), self.m.uint8(1), self.m.uint8(0))

                prev[j][skin == 1] = self.m.uint8(self.name2label["body"]["skin"])

                tmp = prev[j].copy()
                prev[j][tmp == 0] = self.m.uint8(self.name2label["body"]["bg"])

        print("ASSEMBLY COMPLETED")

        del segs, bases

        return prev

    def stacked_assemble(self, task:str,
                segs:list[np.ndarray],
                stacked_bases:list[np.ndarray],
                prev: list[np.ndarray]) -> list[np.ndarray]:

        print("ASSEMBLY STARTED")

        # Process total segmentation
        labels = prev

        if task == "total":
            stacked_totals = torch.stack([torch.as_tensor(segs[i], device=self.device) for i in range(stacked_bases.shape[0])], axis=0)
            labels[prev != 0] = stacked_totals

        elif task == "tissue_types":
            stacked_tissues = torch.stack([torch.as_tensor(segs[i], device=self.device) for i in range(stacked_bases.shape[0])], axis=0)
            labels[stacked_tissues != 0] = stacked_tissues

        elif task == "body":
            stacked_outers = torch.stack([torch.as_tensor(segs[i], device=self.device, dtype=torch.uint8) for i in range(stacked_bases.shape[0])], axis=0)

            # Adapted code snippet from totalsegmentator
            body = self.bin_dilation(stacked_outers == 1, kernel_size=3, iterations=1).astype(torch.uint8)
            body_inner = self.bin_erosion(stacked_outers == 1, kernel_size=3, iterations=3).astype(torch.uint8)
            skin = body - body_inner

            # Segment by density
            # Roughly the skin density range. Made large to make segmentation not have holes
            # (0 to 250 would have many small holes in skin)
            density_mask = (stacked_bases > -200) & (stacked_bases < 250)
            skin[~density_mask] = 0

            # Fill holes
            # skin = binary_closing(skin, iterations=1)  # no real difference
            # skin = binary_dilation(skin, iterations=1)  # not good

            if torch.cuda.is_available():
                mask, _ = cusci.label(skin)
            else:
                mask, _ = scipy.ndimage.label(skin)

            counts = torch.bincount(mask.flatten())  # number of pixels in each blob

            # If only one blob (only background) abort because nothing to remove
            if len(counts) > 1:
                remove = torch.where((counts <= 10) | (counts > 30), True, False)
                remove_idx = torch.nonzero(remove)[0]
                mask[torch.isin(mask, remove_idx)] = 0
                mask[mask > 0] = 1

            # Removing blobs
            # End of snippet from totalsegmentator
            mask = torch.where(self.bin_dilation(mask == 1, kernel_size=3, iterations = 2), np.uint8(1), np.uint8(0))

            labels[mask == 1] = np.uint8(self.name2label["body"]["skin"])

            tmp = labels.copy()
            labels[tmp == 0] = np.uint8(self.name2label["body"]["bg"])

        print("ASSEMBLY COMPLETED")
        del segs, bases

        return labels

    def to_us_sim_new(self, segs:list[np.ndarray], properties:dict, dest_us: list[str], step_size:int) -> list[list[np.ndarray]]:
        print("US SIMULATION STARTED")

        hparams = {
            'debug' : False,
            'device' : device
        }

        us_r = UltrasoundRendering(params=hparams, default_param=True).to(hparams['device'])

        transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Resize([380, 380], transforms.InterpolationMode.NEAREST),
                    transforms.CenterCrop((256)),
                ])

        results = []
        for i in range(len(segs)):
            us_images = []
            labelmap = segs[i].get()
            dest = pthlib(dest_us[i]).joinpath("slice_")
            os.makedirs(dest.parent, exist_ok=True)

            for slice_idx in range(0, labelmap.shape[2], step_size):
                slice_data = labelmap[:, :, slice_idx].astype('int64')
                labelmap_slice = transform(slice_data).squeeze()

                us_image = us_r(labelmap_slice)
                us_images.append(us_image.cpu().numpy())

                us_image_pil = transforms.ToPILImage()(us_image.cpu().squeeze())
                us_image_pil.save(f"{dest}_{slice_idx}.png")

            results.append(us_images)

        print("US SIMULATION COMPLETED")

        return results

    def to_us_sim_old(self, segs:list[np.ndarray], properties:dict, dest_us: list[str],  step_size:int) -> list[tuple[np.ndarray, np.ndarray, np.ndarray]]:
        print("US SIMULATION STARTED")

        hparams = {
            'debug' : False,
            'device' : self.device
        }

        l_dict = {
            2:'lung',
            3:'fat',
            4:'vessel',
            6:'kidney',
            8:'muscle',
            11:'liver',
            12:'soft tissue',
            13:'bone'
        }
        us_r = UltrasoundRendering(params=hparams, default_param=True).to(device)

        transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Resize([380, 380], transforms.InterpolationMode.NEAREST),
                    transforms.CenterCrop((256)),
                ])

        us_imgs = []
        pcd_base = []

        for i in tqdm.tqdm(range(len(segs)), desc="Rendering"):
            # p = properties[i]["spacing"]
            warped = []
            us_slices = []
            if self.m == cp:
                labelmap = segs[i].get()
            else:
                labelmap = segs[i]

            dest = pthlib(dest_us[i]).joinpath("slice_")
            os.makedirs(dest.parent, exist_ok=True)

            for slice_idx in tqdm.tqdm(range(0, labelmap.shape[2], step_size), desc="US slice rendering"):
                slice_data = labelmap[:, :, slice_idx].astype('int64')
                labelmap_slice = transform(slice_data).squeeze()

                us_slice = us_r(labelmap_slice)
                if self.m == np:
                    us_slices.append(us_slice.cpu().numpy())
                else:
                    us_slices.append(us_slice)    

                us_image_pil = transforms.ToPILImage()(us_slice.cpu().squeeze())
                us_image_pil.save(f"{dest}_{slice_idx}.png")

                # Warp slice to match US and create individual label maps
                
                if self.m == np:
                    temp = us_r.warp_img(labelmap_slice.cuda(self.device)).cpu().numpy()
                else:
                    temp = self.m.asarray(us_r.warp_img(labelmap_slice.cuda(self.device)))

                a = self.m.fliplr(self.m.asarray(temp.copy()).transpose(1, 0))
 
                warped.append([(
                                    self.ops.binary_fill_holes(
                                        # scipy.ndimage.binary_erosion(
                                            self.ops.binary_dilation(
                                                self.ops.binary_closing(a==tag, iterations=3), iterations=2)),
                                                # scipy.ndimage.binary_closing(a==tag, iterations=3), iterations=2), iterations=1)),
                                    # scipy.ndimage.binary_erosion(scipy.ndimage.binary_dilation(a==tag, iterations=2), iterations=2),
                                    l_dict[tag]
                                ) for tag in [2, 3, 4, 6, 8, 11, 12, 13]])
 
                del temp, a

            torch.set_default_device(device)
            temp = torch.as_tensor(labelmap)

            list_pos = []
            list_color = []
            for i in tqdm.tqdm(range(len(pointpalette[1])), desc="Pointcloud rendering"):
                if pointpalette[1][i] != 0:
                    base = torch.zeros_like(temp)
                    base[temp == i] = 1
                    tmp = base.nonzero()

                    idx_list = list(range(tmp.shape[0]))
                    select_idx = self.m.random.choice(idx_list, size=pointpalette[1][i])
                    list_pos.append(tmp[select_idx])
                    
                    tmp = torch.zeros((pointpalette[1][i], 4), dtype=torch.uint8)
                    tmp[:,...] = torch.as_tensor(pointpalette[0][i])
                    list_color.append(tmp)

            pcd_base.append([list_pos, list_color, temp.shape])

            us_imgs.append(us_slices)
            # warped_labels.append(warped)
            

        print("US SIMULATION COMPLETED")

        return us_imgs, warped, pcd_base


    def forward(self,
                    imgs: list[nifti1.Nifti1Image|np.ndarray],
                    properties:list[dict],
                    dest_label: list[str],
                    dest_us: list[str],
                    step_size: int,
                    save_labels: bool,
                ) -> list[list[np.ndarray]]:

        if not self.method == 'old':
            bases = torch.stack([
                        torch.as_tensor(
                            img,
                            dtype=torch.float32,
                            device=self.device
                        ).squeeze() for img in imgs]
                    ).cuda(self.device)

            f_labels = torch.stack([
                            torch.zeros(
                                bases[i].shape,
                                dtype=torch.uint8,
                                device=self.device
                            ) for i in range(bases.shape[0])],
                            axis=0
                        ).cuda(self.device)
        else:
            bases = [self.m.array(img.dataobj, dtype=self.m.float32) for img in imgs]
            f_labels = [self.m.zeros(bases[idx].shape, dtype=self.m.uint8) for idx in range(len(imgs))]

        if self.method == "old":
            tasks = ["total", "tissue_types", "body"]
        else:
            tasks = list(self.predictor_keys)

        print("SEGMENTATING:")

        tmp = []
        for idx in tqdm.tqdm(range(len(tasks)), desc="Segmentating batch"):

            tmp.append(self.segmentator(imgs, properties, tasks[idx], 4))
            print("SEG DONE!")

        if self.method == 'old':
            for idx in tqdm.tqdm(range(len(tasks)), desc="Composing into suitable intermediate"):
                f_labels = self.composer(tasks[idx], tmp[idx], bases, f_labels)

        us_imgs, warped_labels, pcdb = self.us(f_labels.copy(), properties, dest_us, step_size)

        if save_labels:
            for idx in tqdm.tqdm(range(len(f_labels)), desc="Saving labels"):
                if self.m == cp:
                    SimpleITKIO().write_seg(f_labels[idx].get().transpose(2, 1, 0), dest_label[idx], properties[idx])
                else:
                    SimpleITKIO().write_seg(f_labels[idx].transpose(2, 1, 0), dest_label[idx], properties[idx])
                    
                print(f"SAVED TO '{dest_label[idx]}'")

        t = []
        for f in tqdm.tqdm(f_labels, desc="Generating warped labels for annotated view"):
            temp = []
            if self.m == cp:
                labels = f.get()
            else:
                labels = f
                
            for arr in list(np.flip(labels.copy().transpose(2, 1, 0), 2))[::step_size]:
                img = Image.fromarray(arr)
                img.putpalette(palettedata *16)
                temp.append(img)
            t.append(temp)

        return [str(pthlib(d).name) for d in dest_label], us_imgs, warped_labels, t, pcdb

Dataset loader

In [9]:
def collate_tensor(data):
    imgs, properties, dest_labels, dest_us = zip(*data)

    return torch.from_numpy(imgs), properties, dest_labels, dest_us

def collate_list(data):
    imgs, properties, dest_labels, dest_us = zip(*data)
    return imgs, properties, dest_labels, dest_us

# import pandas as pd
class CTDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir:str, method:str="old", annotations_file:str=None, resample:float=1.5):
        self.device = device
        self.method = method
        if device.type == device:
            self.m = cp

        else:
            self.m = np

        self.collate_fn = collate_list

        if not isinstance(img_dir, pthlib):
            img_dir = pthlib(img_dir)

        self.resample = resample

        self.img_dir = img_dir
        self.annotations_file = annotations_file

        l = {"predictor" : self.load_bases, "new" : self.load_bases, "old": self.load_bases}
        self.load = l[method]

        # r = {"predictor" : self.resampler_old, "new" : self.resampler_new, "old": None}
        r = {"predictor" : None, "new" : None, "old": None}
        self.resampler = r[method]

        if self.annotations_file is not None:
            # TODO IMPLEMENT LOADING FROM CSV
            # self.img_paths = pd.read_csv(annotations_file)
            pass
        else:
            self.img_paths = glob.glob(f"{str(self.img_dir)}/*.nii.gz")
            self.img_paths = [(pth,
                                pth.replace(".nii.gz", "_label.nii.gz").replace("/imgs/", "/labels/"),
                                pth.replace(".nii.gz", "_us").replace("/imgs/", "/us/")
                                ) for pth in self.img_paths]

    def __len__(self):
        return len(self.img_paths)

    def load_bases(self, pths: str) -> tuple[np.ndarray, dict]:
        img, prop = SimpleITKIO().read_images(image_fnames=[join(pths)])
        return (img, prop)

    def resampler_new(self, img) -> np.ndarray:
        spacing = np.array(img[1]["spacing"])

        new_spacing = np.array(self.resample)

        if self.m.array_equal(spacing, new_spacing):
            resampled_img = (img)
        else:
            zoom = spacing / new_spacing

        data = self.m.array(img[0], self.m.float32)

        new_shape = (np.array(data.shape) * zoom).round().astype(self.m.int32)

        if self.device.type == device:
            resampled_img = (self.m.array(resize(data, output_shape=new_shape, order=3, mode="edge", anti_aliasing=False)), img[1], img[2], img[3])
        else:
            resampled_img = (self.m.array(rs.resample_img(data, zoom)), img[1], img[2], img[3])

        return resampled_img

    def resampler_old(self, img) -> np.ndarray:
        spacing = np.array(img[0].header.get_zooms())
        new_spacing = np.array(self.resample)

        if self.m.array_equal(spacing, new_spacing):
            resampled_img = (img)
        else:
            zoom = spacing / new_spacing

        data = self.m.array(img[0].get_fdata(), self.m.float32)

        new_shape = (np.array(data.shape) * zoom).round().astype(self.m.int32)

        if self.device.type == device:
            resampled_img = (self.m.array(resize(data, output_shape=new_shape, order=3, mode="edge", anti_aliasing=False)), {"spacing": spacing}, img[1], img[2])
        else:
            resampled_img = (self.m.array(rs.resample_img(data, zoom)), {"spacing": spacing}, img[1], img[2])

        return resampled_img

    # TODO CLOSEST CANONICAL through transforms?
    def __getitem__(self, idx): # -> list[tuple[np.ndarray, dict, str, str]]:
        todo = self.img_paths[idx]
        img, prop = self.load(todo[0])
        dest_label = todo[1]
        dest_us = todo[2]

        # nnunet predictors require shape (1, x, y, z)
        if self.method == "predictor":
            ret = (img[0][None,...], prop, dest_label, dest_us)
        elif self.method == "new":
            ret = (img, prop, dest_label, dest_us)
        elif self.method == "old":
            affine = np.eye(4)
            affine[:3, :3] = np.array(prop["sitk_stuff"]["direction"]).reshape(3, 3) * prop["sitk_stuff"]["spacing"]
            affine[:3, 3] = prop["sitk_stuff"]["origin"]

            ret = (nifti1.Nifti1Image(img[0].transpose(2, 1, 0), affine=affine), prop, dest_label, dest_us)

        return ret

# Acquire samples for demo

In [None]:
todo_dir = pthlib.joinpath(this_folder, "sample")
todo_dir.mkdir(exist_ok=True)

if os.listdir(todo_dir) == []:
    print("Downloading sample data")
    !wget -O /CT2US/sample/sample.zip "https://www.dropbox.com/scl/fi/y44t2wu7eyg0t3fpknoxv/img.zip?rlkey=5uza6964xrrtffzfc7w977m3z&st=q6u2hzkh&dl=1"
    !unzip '/CT2US/sample/sample.zip' -d '/CT2US/sample'
    !rm '/CT2US/sample/sample.zip'


# Run this for the actual UI

There is still some work to be done here, especially a method choice after the new pipeline components are done and a preview of the slices (former by applying affine to labelmap and erasing background label, as to create a volume that can be directly displayed by gradio i.e. .obj file and latter via a simple slice selection and image 2d image preview). Generating a volume to allow for easy identification of slice location still wip.

In [None]:
img_dir = this_folder / "imgs"
label_dir = this_folder / "labels"
us_dir = this_folder / "us"
gen_dir = this_folder / "gen"

os.makedirs(img_dir, exist_ok=True)
os.makedirs(label_dir, exist_ok=True)
os.makedirs(us_dir, exist_ok=True)
os.makedirs(gen_dir, exist_ok=True)

os.environ["GRADIO_ALLOWED_PATHS"]=str(this_folder)

axis_sample_size = 20000
axis_points_rgba = [255,255,255,255]

def add_axis_pcd(points, colors, shape, y) -> tri.PointCloud:
    pcd_points = points.copy()
    pcd_colors = colors.copy()

    axis = torch.zeros(shape, device=device)
    axis[:,:,y] = 1
    tmp = axis.nonzero()
    
    idx_list = list(range(tmp.shape[0]))
    select_idx = np.random.choice(idx_list, size=axis_sample_size)
    pcd_points.append(tmp[select_idx])

    tmp = torch.zeros((axis_sample_size, 4), dtype=torch.uint8, device=device)
    tmp[:,...] = torch.as_tensor(axis_points_rgba)
    pcd_colors.append(tmp)

    point_pos = torch.concat(pcd_points)
    colors = torch.concat(pcd_colors)

    point_pos = point_pos.float()

    point_displacement = torch.rand(point_pos.shape).to(device)
    point_pos += point_displacement

    shape = torch.FloatTensor(np.array(shape)).to(device)
    point_pos /= shape
    point_pos = point_pos - .5

    pcd = tri.PointCloud(point_pos.cpu(), colors.cpu()).apply_transform(
                    np.dot(
                        tri.transformations.rotation_matrix(np.pi, [1, 0, 0]),
                        tri.transformations.rotation_matrix(np.pi/2, [0, -1, 0])
                    )
            )

    return pcd


with gr.Blocks() as ct_2_us: 
    with gr.Row():
        files = gr.State({})
        us_list = gr.State({})
        warped_list = gr.State({})
        label_list = gr.State({})
        pcdb_list = gr.State({})

        with gr.Column(scale=1):
            ct_imgs = gr.Files(file_types=['.nii', '.nii.gz'], type='filepath', label="Select CT images", interactive=True, file_count='multiple')
            step_size = gr.Slider(label="Slicing step interval", minimum=1, maximum=20, value=2, step=1, interactive=True)       

            with gr.Row():
                btn = gr.Button("Generate")
                reset = gr.Button("Reset")

            @gr.on([reset.click, ct_2_us.load], inputs=None, outputs=[files, us_list, warped_list, label_list, pcdb_list])
            def reset_all():
                for f in label_dir.glob('*.nii.gz'):
                    os.remove(f)
                for f in img_dir.glob('*.nii.gz'):
                    os.remove(f)
                for f in gen_dir.glob('*.glb'):
                    os.remove(f)
                for f in glob.glob(f"{us_dir}/*"):
                    shutil.rmtree(f, ignore_errors=True)
                try:
                    os.remove(f"{this_folder}/results.zip")
                except:
                    print("No need to delete results")
                return {}, {}, {}, {}, {}

            # gr.Examples([item for item in reduce(lambda result, x: result + [subset + [x] for subset in result], examples, [[]]) if len(item)>0], ct_imgs)
            # gr.Examples(
            #                 examples=[[str(path)] for path in sorted(pthlib(this_folder / 'sample').glob('**/*.nii'))]
            #                         + [[str(path)] for path in sorted(pthlib(this_folder / 'sample').glob('**/*.nii.gz'))], 
            #                 inputs=[ct_imgs],
            #                 cache_examples=False,
            #                 label='Sample CT volumes',
            #                 examples_per_page=10
            #             )

            with gr.Row():
                with gr.Column():
                    sample_in = gr.Dropdown(
                                                choices=[i+1 for i in range(len(glob.glob(str(this_folder / 'sample' / '*.nii.gz'))))], 
                                                label='Amount of samples to randomly select',
                                                info='Used for demo with no input'    
                                            )
                    seg_method = gr.Radio(choices=["predictor", "new", "old"], value="old", label="Segmentation method", interactive=True)
                    us_method = gr.Radio(choices=["lotus"], value="lotus", label="US rendering method", interactive=True)
        
        with gr.Column(scale=2):

            with gr.Tab(label='Preview'):
                note = gr.Markdown(label="status", value="Generate US images first through the input tab")

                img_idx = gr.State(0)
                slice_idx = gr.State(0)

                @gr.render(inputs=[files, us_list, warped_list, label_list, pcdb_list, step_size], triggers=[us_list.change])
                def dynamic(fl, us, warped, ll, pcdb, step):            
                    with gr.Column():
                        if len(us) > 0:
                            dropdown = gr.Dropdown(choices=[(f[0], n) for n, f in fl.items()], label='Select image to preview', value=0)
                            slider = gr.Slider(minimum=0, maximum=len(warped[0]) - 1, step=step, label='Slice selection', value=0)
                            
                            iden = lambda x: x

                            slider.release(fn=iden, inputs=[slider], outputs=[slice_idx])
                            dropdown.select(fn=iden, inputs=[dropdown], outputs=[img_idx])

                            # States for results and dropdown selection
                            with gr.Column():
                                # "original"
                                with gr.Row():
                                    base = gr.Image(
                                                label='US slice',
                                                value=np.asarray(us[img_idx.value][slice_idx.value], dtype=np.float32),
                                                height=300
                                            )
                                    
                                    label_preview = gr.Image(
                                                        label='Label slice',
                                                        value=ll[0][0],
                                                        type='pil',
                                                        height=300
                                                    )
                                # BOTTOM ROW -> annotation and 3ddw
                                with gr.Row():
                                    comp = gr.AnnotatedImage(
                                                value=(np.asarray(us[img_idx.value][slice_idx.value], dtype=np.float32), warped[0][0]),
                                                height=300
                                            )
                                    
                                    volume_preview = gr.Model3D(clear_color=(0, 0, 0, 1), label="Label map view", value=str(gen_dir / "current_pcd.glb"), height=300)
                                    
                                def route(x, y):
                                    b = np.asarray(us[x][y], dtype=np.float32)
                                    w = warped[x][y]
                                    l = ll[x][y]
                                    add_axis_pcd(pcdb[x][0], pcdb[x][1], pcdb[x][2], y * step).export(str(gen_dir / "current_pcd.glb"))
                                    
                                    p = str(gen_dir / "current_pcd.glb")
                                    return (b, w), b, l, p, y if y <= len(us[x]) else 0, \
                                        gr.Slider(minimum=0, maximum=len(warped[x]) - 1, step=step, label='Slice selection', value=y if y <= len(us[x]) else 0)

                                    
                                gr.on(triggers=[img_idx.change, slice_idx.change],
                                        fn=route,
                                        inputs=[img_idx, slice_idx],
                                        outputs=[comp, base, label_preview, volume_preview, slice_idx, slider]
                                        # outputs=[comp, base, label_preview]
                                    )
                            
            with gr.Tab(label='Download'):
                download = gr.DownloadButton(label="", visible=False)
                # TODO：allow for picking of specific results, potentially also generate point clouds into files and allow user to download them?
                @gr.render(inputs=[files, us_list, warped_list, label_list, pcdb_list, step_size], triggers=[us_list.change])
                def dynamic(fl, us, warped, ll, pcdb, step):
                    descr = gr.Markdown(label="This can be used to adjust contents of results.zip")
                    configs = gr.CheckboxGroup(choices=["Save labels"], value=["Save labels"], label="Options", interactive=True)
                    filename_in = gr.Textbox(label="Filename for result zip", value="results")

                    r = []
                    r.append(label_dir.glob('*.nii.gz'))
                    r.append(us_dir.glob('**/*.png'))


                    rezip = gr.Button("Reassemble results.zip")

                    @gr.on(rezip.click, inputs=[configs, filename_in], outputs=[download, descr])
                    def rezip_files(save_configs, name, r=r):
                        with ZipFile(f"{this_folder}/results.zip", 'w') as zipObj:
                            for f in r[0]:
                                zipObj.write(f, os.path.relpath(f, str(this_folder)))
                                os.remove(f)
                            for f in r[1]:
                                zipObj.write(f, os.path.relpath(f, str(this_folder)))

                            for f in glob.glob(f"{us_dir}/*"):
                                shutil.rmtree(f, ignore_errors=True)

                        return f"{this_folder}/{name}.zip", "Results have been rezipped"

                def start(ct, step, method, method_us, fl_s, us_s, warped_s, ll_s, pcdb_s, nr_samples, progress=gr.Progress(track_tqdm=True)):
                    if ct == None:
                        ct = glob.glob(str(this_folder / 'sample' / '*.nii.gz'))
                        ct = [this_folder / 'sample' / f for f in ct]

                    ct = random.sample(ct, k=nr_samples)

                    for f in ct:
                        shutil.copyfile(f, img_dir / f.name)
                        shutil.rmtree(f, ignore_errors=True)

                    local_dataset = CTDataset(
                        img_dir=img_dir,
                        method=method,
                        resample=None
                    )
                    batch_size = 1

                    ct_dataloader = DataLoader(local_dataset, batch_size=batch_size, collate_fn=local_dataset.collate_fn)

                    ct2us = CT2US(method=method)

                    fl = []
                    us = []
                    warped = []
                    ll = []
                    pcdb = []

                    for data in progress.tqdm(ct_dataloader, desc="Processing batches"):
                        imgs, properties, dest_labels, dest_us = data
                        
                        # TODO: Last parameter saves labels. This is intentended to allow for more download options and storing of other intermediary results that might be of interest 
                        n, u, w, l, b = ct2us(imgs, properties, dest_labels, dest_us, step, True)
                        
                        fl.append(n)
                        us.append(*u)
                        warped.append(w)
                        ll.append(*l)
                        pcdb.append(*b)
                    

                    add_axis_pcd(pcdb[0][0], pcdb[0][1], pcdb[0][2], 0).export(str(gen_dir / "current_pcd.glb"))

                    fl_s.update(enumerate(fl))
                    us_s.update(enumerate(us))
                    warped_s.update(enumerate(warped))
                    ll_s.update(enumerate(ll))
                    pcdb_s.update(enumerate(pcdb))

                    return gr.DownloadButton(label="Download results as zip", visible=True, value=f"{this_folder}/results.zip"), \
                            fl_s, \
                            us_s, \
                            warped_s, \
                            ll_s, \
                            pcdb_s, \
                            gr.Markdown(value="Status", height=30)
                                
                btn.click(
                            fn=lambda x: gr.Markdown(label="Status", value="", height=80), 
                            inputs=btn, 
                            outputs=note
                        ).success(
                            fn=start, 
                            inputs=[ct_imgs, step_size, seg_method, us_method, files, us_list, warped_list, label_list, pcdb_list, sample_in], 
                            outputs=[download, files, us_list, warped_list, label_list, pcdb_list, note]
                        ).success(
                            fn=lambda x: gr.Markdown(label="", value="", height=0, visible=False), 
                            inputs=btn, 
                            outputs=note
                        )
                
            # def run(progress=gr.Progress(track_tqdm=True)):
                
            # btn.click(fn=run,
            #             inputs=None,
            #             outputs=[files, us_imgs, warped_labels, labels, note])

ct_2_us.launch(debug=True)