In [1]:
import pickle
from dataclasses import fields
from pathlib import Path

import basix
import dolfinx as dlx
import matplotlib.pyplot as plt
import numpy as np
import pyvista as pv
import scifem as dlx_helper
import scipy as sp
import seaborn as sns
import ufl
from mpi4py import MPI

from cardiac_electrophysiology.ulac import preprocessing

In [2]:
patient_id = "01"
infile_name = f"../data/processed/patient_{patient_id}/mesh_with_fibers_tags.vtk"
load_file = Path("parameterization.pkl")

In [None]:
vtk_mesh = pv.read(infile_name)
triangular_mesh = preprocessing.convert_unstructured_to_polydata_mesh(vtk_mesh)
vertices = triangular_mesh.points
simplices = preprocessing.convert_pv_cells_to_numpy_cells(triangular_mesh.faces)

triangle_element = ufl.Mesh(basix.ufl.element("Lagrange", "triangle", 1, shape=(2,)))
dolfinx_mesh = dlx.mesh.create_mesh(MPI.COMM_SELF, simplices, vertices, triangle_element)

with open(load_file, "rb") as f:
    uac_pv_boundaries, uac_laa_boundary, uac_mv_boundary = pickle.load(f)

In [None]:
dolfinx_index_permutation = [
    np.where(dolfinx_mesh.geometry.input_global_indices == i)[0][0]
    for i in range(vertices.shape[0])
]
dolfinx_index_permutation = np.array(dolfinx_index_permutation)

def get_dolfinx_indices_from_original_indices(original_indices):
    dolfinx_indices = dolfinx_index_permutation[original_indices]
    return dolfinx_indices

def get_original_indices_from_dolfinx_indices(dolfinx_indices):
    original_indices = dolfinx_mesh.geometry.input_global_indices[dolfinx_indices]
    return original_indices

In [None]:
function_space = dlx.fem.functionspace(dolfinx_mesh, ("Lagrange", 1, (dolfinx_mesh.topology.dim,)))
vertex_to_dof_map = dlx_helper.vertex_to_dofmap(function_space)
dof_to_vertex_map = dlx_helper.dof_to_vertexmap(function_space)
dolfinx_indices = dolfinx_mesh.geometry.input_global_indices

trial_function = ufl.TrialFunction(function_space)
test_function = ufl.TestFunction(function_space)
weak_form = ufl.inner(ufl.grad(trial_function), ufl.grad(test_function)) * ufl.dx
rhs_constant = dlx.fem.Constant(dolfinx_mesh, [0.0, 0.0])
right_hand_side = ufl.inner(rhs_constant, test_function) * ufl.dx

In [None]:
bc_function = dlx.fem.Function(function_space)
bc_function.x.array[:] = 0.0
boundary_conditions = []

for boundary_field in fields(uac_pv_boundaries):
    if "inner" in boundary_field.name:
        continue
    boundary = getattr(uac_pv_boundaries, boundary_field.name)
    boundary_inds = get_dolfinx_indices_from_original_indices(boundary.inds)
    boundary_dofs = vertex_to_dof_map[boundary_inds]
    shaped_array = np.zeros((vertices.shape[0], 2))
    shaped_array[boundary_dofs, 0] = boundary.alpha
    shaped_array[boundary_dofs, 1] = boundary.beta
    bc_function.x.array[:] += shaped_array.flatten()
    dirichlet_bc = dlx.fem.dirichletbc(bc_function, boundary_dofs)
    boundary_conditions.append(dirichlet_bc)

for boundary in [uac_mv_boundary]:
    boundary_inds = get_dolfinx_indices_from_original_indices(boundary.inds)
    boundary_dofs = vertex_to_dof_map[boundary_inds]
    shaped_array = np.zeros((vertices.shape[0], 2))
    shaped_array[boundary_dofs, 0] = boundary.alpha
    shaped_array[boundary_dofs, 1] = boundary.beta
    bc_function.x.array[:] += shaped_array.flatten()
    dirichlet_bc = dlx.fem.dirichletbc(bc_function, boundary_dofs)
    boundary_conditions.append(dirichlet_bc)

In [None]:
problem = dlx.fem.petsc.LinearProblem(
    weak_form,
    right_hand_side,
    boundary_conditions,
    petsc_options={"ksp_type": "preonly", "pc_type": "lu"},
)
solution = problem.solve()
solution_array = solution.x.array.reshape((-1, 2))
solution_alpha_dolfinx = solution_array[:, 0][dof_to_vertex_map]
solution_beta_dolfinx = solution_array[:, 1][dof_to_vertex_map]
solution_alpha_original = solution_alpha_dolfinx[dolfinx_index_permutation]
solution_beta_original = solution_beta_dolfinx[dolfinx_index_permutation]

In [None]:
triangular_mesh.point_data["alpha"] = solution_alpha_original
triangular_mesh.point_data["beta"] = solution_beta_original

plotter = pv.Plotter(window_size=[600, 600])
plotter.add_mesh(
    triangular_mesh,
    scalars="beta",
    show_edges=True,
)
plotter.show()

In [None]:
feature_to_visualize = uac_pv_boundaries.RSPV_outer
feature_mesh = pv.PolyData(triangular_mesh.points[feature_to_visualize.inds])
feature_mesh.point_data["alpha"] = solution_alpha_original[feature_to_visualize.inds]
feature_mesh.point_data["beta"] = solution_beta_original[feature_to_visualize.inds]

plotter = pv.Plotter(window_size=[600, 600])
plotter.add_mesh(
    triangular_mesh,
    style="wireframe",
    color="lightgrey",
)
plotter.add_points(
    feature_mesh,
    scalars="beta",
    point_size=10,
    render_points_as_spheres=True,
)
plotter.show()

In [None]:
reference_coords_x = solution_alpha_original
reference_coords_y = solution_beta_original
reference_coords_z = np.zeros(reference_coords_x.size)
reference_coords = np.vstack([reference_coords_x, reference_coords_y, reference_coords_z]).T
reference_mesh = pv.PolyData.from_regular_faces(reference_coords, simplices)

plotter = pv.Plotter(window_size=[900, 900])
plotter.add_mesh(
    reference_mesh,
    show_edges=True,
    style="wireframe",
    color="lightgray",
    edge_opacity=1,
)
plotter.view_xy()
plotter.show()