In [1]:
import pickle
from dataclasses import fields
from pathlib import Path

import numpy as np
import pyvista as pv

from cardiac_electrophysiology import mesh_processing as mp
from cardiac_electrophysiology.ulac import base
from cardiac_electrophysiology.ulac import interface as ulac

In [2]:
def plot_over_dict(plotter, mesh, input_value, transformation, color, size):
    if isinstance(input_value, dict):
        for value in input_value.values():
            plot_over_dict(plotter, mesh, value, transformation, color, size)
    else:
        indices, scalars = transformation(input_value)
        point_mesh = pv.PolyData(mesh.points[indices])
        if scalars is not None:
            point_mesh["scalars"] = scalars
            plotter.add_mesh(
                point_mesh,
                scalars="scalars",
                render_points_as_spheres=True,
                point_size=size,
            )
        else:
            plotter.add_mesh(
                point_mesh,
                color=color,
                render_points_as_spheres=True,
                point_size=size,
            )

In [3]:
patient_id = "01"
input_mesh_file = Path(f"../data/processed/patient_{patient_id}/mesh_with_fibers_tags.vtk")
segmentation_file_name = Path(f"../data/processed/patient_{patient_id}/segmentation.pkl")
anatomical_tags = {"MV": 0, "LAA": 1, "LIPV": 2, "LSPV": 3, "RIPV": 4, "RSPV": 5}

mesh = pv.read(input_mesh_file)
mesh = mp.convert_unstructured_to_polydata_mesh(mesh)

In [4]:
if segmentation_file_name.exists():
    with segmentation_file_name.open("rb") as f:
        segmentation_paths = pickle.load(f)
else:
    segmentation_paths = ulac.construct_segmentation(mesh, anatomical_tags)
    segmentation_file_name.parent.mkdir(parents=True, exist_ok=True)
    with segmentation_file_name.open("wb") as f:
        pickle.dump(segmentation_paths, f)

transformation = lambda input_value: (input_value, None)
plotter = pv.Plotter(window_size=[700, 500])
plotter.add_mesh(mesh, style="wireframe", color="grey", show_edges=True)

plot_over_dict(plotter, mesh, segmentation_paths, transformation, "blue", 10)
plotter.show()

Widget(value='<iframe src="http://localhost:35879/index.html?ui=P_0x7f24edeeba10_0&reconnect=auto" class="pyvi…

In [5]:
markers = ulac.get_markers(segmentation_paths)

plotter = pv.Plotter(window_size=[700, 500])
plotter.add_mesh(mesh, style="wireframe", color="grey")
transformation = lambda input_value: (input_value, None)
plot_over_dict(plotter, mesh, segmentation_paths, transformation, "blue", 5)
transformation = lambda input_value: (np.array((input_value.ind,)), None)
plot_over_dict(plotter, mesh, markers, transformation, "red", 10)
plotter.show()

Widget(value='<iframe src="http://localhost:35879/index.html?ui=P_0x7f24cefa20d0_1&reconnect=auto" class="pyvi…

In [6]:
parameterized_paths = ulac.parameterize_paths(mesh, segmentation_paths, markers)

plotter = pv.Plotter(window_size=[700, 500])
plotter.add_mesh(mesh, style="wireframe", color="grey")
transformation = lambda input_value: (input_value.inds, input_value.relative_lengths)
plot_over_dict(plotter, mesh, parameterized_paths, transformation, "blue", 5)
plotter.show()

Widget(value='<iframe src="http://localhost:35879/index.html?ui=P_0x7f24cc5bc190_2&reconnect=auto" class="pyvi…

In [7]:
uac_paths = ulac.construct_uac_paths(parameterized_paths, markers)

plotter = pv.Plotter(window_size=[700, 500])
plotter.add_mesh(mesh, style="wireframe", show_edges=True, color="grey")
transformation = lambda input_value: (input_value.inds, input_value.alpha)
plot_over_dict(plotter, mesh, uac_paths, transformation, "blue", 10)
plotter.show()

Widget(value='<iframe src="http://localhost:35879/index.html?ui=P_0x7f24cc5be210_3&reconnect=auto" class="pyvi…

In [8]:
patch_boundaries = ulac.get_patch_boundaries(uac_paths)
uac_patches = ulac.compute_patch_uacs(mesh, patch_boundaries, segmentation_paths)

In [9]:
def plot_values_over_dict(plotter, mesh, input_value):
    if isinstance(input_value, dict):
        for value in input_value.values():
            plot_values_over_dict(plotter, mesh, value)
    else:
        mapping, simplices = input_value
        mapping_z = np.zeros((mapping.shape[0], 1))
        mapping_3d = np.hstack((mapping, mapping_z))
        uac_mesh = pv.PolyData.from_regular_faces(mapping_3d, simplices)
        plotter.add_mesh(uac_mesh, style="wireframe", color="gray")


plotter = pv.Plotter(window_size=[900, 900])
plot_values_over_dict(plotter, mesh, uac_patches)
plotter.view_xy()
plotter.show()

Widget(value='<iframe src="http://localhost:35879/index.html?ui=P_0x7f24cc5bf750_4&reconnect=auto" class="pyvi…

In [13]:
uac_list = []
simplex_list = []
global num_points
num_points = 0

def collect_uac_data(input_value):
    global num_points
    if isinstance(input_value, dict):
        for value in input_value.values():
            collect_uac_data(value)
    else:
        mapping, simplices = input_value
        uac_list.append(mapping)
        simplex_list.append(simplices + num_points)
        num_points += mapping.shape[0]

collect_uac_data(uac_patches)

In [None]:
uacs = np.vstack(uac_list)
uacs = np.hstack((uacs, np.zeros((uacs.shape[0], 1))))
simplices = np.vstack(simplex_list)
unique_mask = np.u
uac_mesh = pv.PolyData.from_regular_faces(uacs, simplices)
smooth_mesh = uac_mesh.smooth(n_iter=100, relaxation_factor=0.1)
plotter = pv.Plotter(window_size=[700, 700])
plotter.add_mesh(uac_mesh, style="wireframe", color="gray")
plotter.view_xy()
plotter.show()

Widget(value='<iframe src="http://localhost:35879/index.html?ui=P_0x7f24ed57d6d0_8&reconnect=auto" class="pyvi…