In [1]:
import numpy as np
import pyvista as pv
import trimesh as tm
import vedo as vd

vd.settings.set_vtk_verbosity(0)

# Load Dataset

In [2]:
patient_id = "patient_01"
infile_name = f"../data/raw/{patient_id}_mesh_with_uacs_fibers_tags.vtk"
outfile_name = f"{patient_id}_mesh_with_fibers_tags.vtk"
decimation_fraction = 0.8
tag_threshold_portion = 0.5
legacy_anatomical_tags = {
    "Body": 11,
    "LAA": 13,
    "LIPV": 21,
    "LSPV": 23,
    "RIPV": 25,
    "RSPV": 27,
}

vtk_mesh = pv.read(infile_name)
pv_mesh = pv.PolyData.from_regular_faces(
    vtk_mesh.points, vtk_mesh.cells.reshape(-1, 4)[:, 1:]
)
pv_mesh.point_data["alpha"] = vtk_mesh.point_data["alpha"]
pv_mesh.point_data["beta"] = vtk_mesh.point_data["beta"]
pv_mesh.cell_data["fibers"] = vtk_mesh.cell_data["fibers"]
pv_mesh.cell_data["anatomical_tags"] = vtk_mesh.cell_data["anatomical_tags"]
body_region = np.where(pv_mesh.cell_data["anatomical_tags"] == 11)[0]
pv_mesh.cell_data["anatomical_tags"][body_region] = 0

# Coarsen Mesh for faster computations

In [3]:
pv_mesh_interpolated = pv_mesh.cell_data_to_point_data()
coarse_mesh_interpolated = pv_mesh_interpolated.decimate(
    decimation_fraction, volume_preservation=False, scalars=True, vectors=True
)
alpha_data = coarse_mesh_interpolated.point_data["alpha"]
beta_data = coarse_mesh_interpolated.point_data["beta"]
del coarse_mesh_interpolated.point_data["alpha"]
del coarse_mesh_interpolated.point_data["beta"]
coarse_mesh = coarse_mesh_interpolated.point_data_to_cell_data()
coarse_mesh.point_data["alpha"] = alpha_data
coarse_mesh.point_data["beta"] = beta_data

In [4]:
plotter = pv.Plotter(window_size=[900, 900])
plotter.add_mesh(
    coarse_mesh,
    show_edges=True,
    scalars="anatomical_tags",
    cmap="tab20",
    show_scalar_bar=False,
    edge_color="lightgray",
    edge_opacity=0.3,
)
plotter.show()

Widget(value='<iframe src="http://localhost:37977/index.html?ui=P_0x7fc9ae039a90_0&reconnect=auto" class="pyvi…

In [5]:
body_mesh = coarse_mesh.extract_values(values=0, scalars="anatomical_tags")
feature_meshes = coarse_mesh.extract_values(
    ranges=[1e-5, 100], scalars="anatomical_tags"
)

plotter = pv.Plotter(window_size=[900, 900])
plotter.add_mesh(
    coarse_mesh,
    style="wireframe",
    color="lightgray",
    opacity=0.3,
)
plotter.add_mesh(
    feature_meshes,
    show_edges=True,
    scalars="anatomical_tags",
    cmap="tab20",
    show_scalar_bar=False,
    edge_color="lightgray",
    edge_opacity=0.3,
)
plotter.show()

Widget(value='<iframe src="http://localhost:37977/index.html?ui=P_0x7fc999343110_1&reconnect=auto" class="pyvi…

In [7]:
separated_meshes = feature_meshes.split_bodies()

for label in separated_meshes.keys():
    mesh_portion = separated_meshes[label]
    maximum_tag_value = np.max(mesh_portion.cell_data["anatomical_tags"])

    threshold_value = maximum_tag_value * tag_threshold_portion
    feature_cell_inds = np.where(
        mesh_portion.cell_data["anatomical_tags"] >= threshold_value
    )[0]
    mesh_portion.cell_data["anatomical_tags"][feature_cell_inds] = maximum_tag_value
    body_cell_inds = np.setdiff1d(np.arange(mesh_portion.n_cells), feature_cell_inds)
    mesh_portion.cell_data["anatomical_tags"][body_cell_inds] = 0
    separated_meshes[label] = mesh_portion
    coarse_mesh.cell_data["anatomical_tags"][
        mesh_portion.cell_data["vtkOriginalCellIds"]
    ] = mesh_portion.cell_data["anatomical_tags"]

plotter = pv.Plotter(window_size=[900, 900])
plotter.add_mesh(
    coarse_mesh,
    show_edges=True,
    scalars="anatomical_tags",
    cmap="tab20",
    show_scalar_bar=False,
    edge_color="lightgray",
    edge_opacity=0.3,
)
plotter.show()

Widget(value='<iframe src="http://localhost:37977/index.html?ui=P_0x7fc999232e90_2&reconnect=auto" class="pyvi…

# Clean Up Anatomical Regions

In [8]:
def get_connected_main_segment(input_mesh, region_tag):
    pv_region_mesh = input_mesh.extract_values(region_tag, scalars="anatomical_tags")
    tm_region_mesh = tm.Trimesh(
        pv_region_mesh.points, pv_region_mesh.cells.reshape(-1, 4)[:, 1:], process=False
    )
    tm_region_mesh.fill_holes()
    connected_segments = tm_region_mesh.split(only_watertight=False)
    if not connected_segments:
        main_segment  = vtk_mesh
    else:
        num_segment_vertices = [len(seg.vertices) for seg in connected_segments]
        maximum_segment_index = np.argmax(num_segment_vertices)
        main_segment = connected_segments[maximum_segment_index]
    return main_segment


def convert_cell_to_point_data(input_mesh, main_segment):
    segment_cell_inds = main_segment.cell_data["vtkOriginalCellIds"]
    new_anatomical_tags = np.zeros(input_mesh.n_cells)
    new_anatomical_tags[segment_cell_inds] = 1
    input_mesh.cell_data["anatomical_tags"] = new_anatomical_tags
    input_mesh_interpolated = input_mesh.cell_data_to_point_data()
    return input_mesh_interpolated


def get_smoothened_region(input_mesh):
    vd_mesh = vd.Mesh([input_mesh.points, input_mesh.faces.reshape(-1, 4)[:, 1:]])
    vd_mesh.pointdata["anatomical_tags"] = input_mesh.point_data["anatomical_tags"]
    vd_mesh = vd_mesh.smooth_data(niter=50)
    input_mesh.point_data["anatomical_tags"] = vd_mesh.pointdata["anatomical_tags"]
    input_mesh = input_mesh.extract_values(
        ranges=[0.5, 1.1], scalars="anatomical_tags"
    )
    return input_mesh

In [9]:
new_anatomical_tags = np.zeros(coarse_mesh.n_points)

for i, region_tag in enumerate(legacy_anatomical_tags.values()):
    main_segment = get_connected_main_segment(coarse_mesh, region_tag)
    interpolated_coarse_mesh = convert_cell_to_point_data(coarse_mesh, main_segment)
    smoothened_coarse_mesh = get_smoothened_region(interpolated_coarse_mesh)
    new_anatomical_tags[interpolated_coarse_mesh.point_data["vtkOriginalPointIds"]] = i + 1

coarse_mesh.point_data["anatomical_tags"] = new_anatomical_tags

KeyError: 'vtkOriginalCellIds'

In [None]:
np.unique(coarse_mesh.point_data["anatomical_tags"])

In [None]:
plotter = pv.Plotter(window_size=[900, 900])
plotter.add_mesh(
    coarse_mesh,
    show_edges=True,
    scalars="anatomical_tags",
    cmap="tab20",
    show_scalar_bar=False,
    edge_color="lightgray",
    edge_opacity=0.3,
)
plotter.show()