In [1]:
import dolfinx
from mpi4py import MPI
from petsc4py import PETSc
from dolfinx import mesh as dmesh
from dolfinx import fem
from dolfinx import io
from dolfinx import plot
import ufl  

In [2]:
import numpy as np
import sys
import trimesh
import tetgen 
import meshio 
import math

import matplotlib.pyplot as plt
import pymeshfix as pfix
import pyvista as pv
import iso2mesh as i2m
import pygalmesh as pygm
import pymeshlab as ml

from skimage.filters import threshold_otsu
from tqdm.notebook import tqdm
from skimage.transform import resize
from typing import Any, Dict, Optional, Tuple, List

sys.path.append('./src')  # to import alveoRVE from parent directory

from alveoRVE.plot.mpl import show_four_panel_volume
from alveoRVE.plot.pv import view_surface

%load_ext autoreload
%autoreload 2

In [3]:
def _triangle_quality(V, F):
    # 2*sqrt(3)*A / sum(l^2)  -> 1 for equilateral; 0 for degenerate
    p0, p1, p2 = V[F[:,0]], V[F[:,1]], V[F[:,2]]
    e0 = np.linalg.norm(p1 - p0, axis=1)
    e1 = np.linalg.norm(p2 - p1, axis=1)
    e2 = np.linalg.norm(p0 - p2, axis=1)
    A  = 0.5*np.linalg.norm(np.cross(p1 - p0, p2 - p0), axis=1)
    denom = e0**2 + e1**2 + e2**2
    with np.errstate(divide='ignore', invalid='ignore'):
        q = (2.0*math.sqrt(3.0))*A/denom
        q[~np.isfinite(q)] = 0.0
    return np.clip(q, 0.0, 1.0)


def mesh_report(mesh: trimesh.Trimesh, name="mesh", plot=True, tol_dup=1e-12) -> dict:
    """
    Print thorough stats and return a dict. Does NOT mutate the mesh.
    """
    m = mesh.copy()

    print(f"\n=== {name}: geometric/graph checks START ===")
    print(f"trimesh metrics:")
    print(f"mesh is_winding_consistent? {m.is_winding_consistent}")
    print(f"mesh is_watertight? {m.is_watertight}")
    print(f"verts={len(m.vertices):,}, faces={len(m.faces):,}")
    bbox = m.bounds
    L = bbox[1] - bbox[0]
    print(f"bbox min={bbox[0]}, max={bbox[1]}, extents={L}")

    # basic areas/volumes
    A = float(m.area)
    Vvol = float(m.volume) if m.is_watertight else None
    print(f"surface area = {A:.6g}")
    if Vvol is not None:
        print(f"enclosed volume (watertight) = {Vvol:.6g}")
    else:
        print("enclosed volume = N/A (mesh not watertight)")

    # edge-length stats (proxy for 'h')
    edges = m.edges_unique
    elen  = np.linalg.norm(m.vertices[edges[:,0]] - m.vertices[edges[:,1]], axis=1)
    print(f"h (edge length): min={elen.min():.4g}, mean={elen.mean():.4g}, max={elen.max():.4g}")

    # duplicates (faces & vertices)
    F_sorted = np.sort(m.faces, axis=1)
    _, idx = np.unique(F_sorted, axis=0, return_index=True)
    dup_faces = len(m.faces) - len(idx)
    Vkey = np.round(m.vertices / tol_dup).astype(np.int64)
    _, vidx = np.unique(Vkey, axis=0, return_index=True)
    dup_verts = len(m.vertices) - len(vidx)
    print(f"duplicate faces: {dup_faces:,}, duplicate vertices (<= {tol_dup:g}): {dup_verts:,}")

    V, F = m.vertices.copy(), m.faces.copy()

    print(f"== pymeshlab metrics:")

    ms = ml.MeshSet()
    ms.add_mesh(ml.Mesh(V, F))

    geo_measures = ms.get_geometric_measures()
    topo_measures = ms.get_topological_measures()
    print("Geometric measures:")
    for k, v in geo_measures.items():
        if k in ['surface_area', 'avg_edge_length', 'volume']:
            v = float(v)
            print(f"  {k}: {v}")
    print("Topological measures:")
    for k, v in topo_measures.items():
        if k in ['non_two_manifold_edges', 'boundary_edges', 'non_two_manifold_vertices', 'genus', 'faces_number', 'vertices_number', 'edges_number', 'connected_components_number']:
            v = int(v)
            print(f"  {k}: {v}")
    print(f"== custom metrics:")
    # triangle quality
    q = _triangle_quality(m.vertices, m.faces)
    q_stats = dict(min=float(q.min()), p5=float(np.percentile(q,5)),
                   mean=float(q.mean()), p95=float(np.percentile(q,95)),
                   max=float(q.max()))
    print(f"triangle quality q in [0,1] (equilateral=1): "
          f"min={q_stats['min']:.3f}, p5={q_stats['p5']:.3f}, "
          f"mean={q_stats['mean']:.3f}, p95={q_stats['p95']:.3f}, max={q_stats['max']:.3f}")


    if plot:
        plt.figure(figsize=(5,3))
        plt.hist(q, bins=40, range=(0,1), alpha=0.8)
        plt.xlabel("triangle quality q"); plt.ylabel("count"); plt.title(f"Quality histogram: {name}")
        plt.tight_layout(); plt.show()

    return dict(
        verts=len(m.vertices), faces=len(m.faces),
        area=A, volume=Vvol, bbox=bbox, h_stats=(float(elen.min()), float(elen.mean()), float(elen.max())),
        watertight=bool(m.is_watertight),
        dup_faces=int(dup_faces), dup_verts=int(dup_verts),
        tri_quality=q_stats
    )

def quick_mesh_report(ms: ml.MeshSet | trimesh.Trimesh, i: int = 0):
    # print(f"\n == pymeshlab quick metrics:")
    # number of vertices and faces

    flag = False
    if isinstance(ms, trimesh.Trimesh):
        flag = True
        V, F = ms.vertices, ms.faces
        ms = ml.MeshSet()
        ms.add_mesh(ml.Mesh(V, F))
    else: 
        V = ms.current_mesh().vertex_matrix()
        F = ms.current_mesh().face_matrix()

    n_verts = ms.current_mesh().vertex_number()
    n_faces = ms.current_mesh().face_number()
    geo_measures = ms.get_geometric_measures()
    topo_measures = ms.get_topological_measures()
    h = geo_measures['avg_edge_length']

    connected_components = topo_measures['connected_components_number']

    # trimesh watertightness
    trimesh_mesh = trimesh.Trimesh(
        vertices=np.asarray(ms.current_mesh().vertex_matrix(), float),
        faces=np.asarray(ms.current_mesh().face_matrix(), int),
        process=False
    )
    is_watertight = trimesh_mesh.is_watertight
    is_winding_consistent = trimesh_mesh.is_winding_consistent

    # trimesh volume
    vol = trimesh_mesh.volume if is_watertight else None
    area = trimesh_mesh.area

    # pymeshlab nonmanifold edges/faces
    n_nonmanifold_edges = int(topo_measures['non_two_manifold_edges'])
    n_nonmanifold_vertices = int(topo_measures['non_two_manifold_vertices'])

    print(f"[quick {i} 1/3] {n_verts} verts, {n_faces} faces, watertight: {is_watertight}, genus: {topo_measures['genus']}, wind-consistent: {is_winding_consistent}, h = {np.round(h, 3)}, components: {connected_components}\n[quick {i} 2/3] vol = {vol}, area = {area}, non-manifold edges: {n_nonmanifold_edges}/ vertices: {n_nonmanifold_vertices}\n[quick {i} 3/3] bbox: {V.min(axis=0)} to {V.max(axis=0)}")

    if flag: 
        del ms

In [4]:
def print_python_or_matlab_indexing(fc, name=""): 
    if np.min(fc) == 0:
        print(f" - {name} (python)")
    elif np.min(fc) == 1:
        print(f" - {name} (matlab)")
    else:
        print(f" - {name} (unknown indexing)")

def normalize_vertices_inplace(no):
    no[:, :3] = (no[:, :3] - no[:, :3].min(axis=0)) / (no[:, :3].max(axis=0) - no[:, :3].min(axis=0))

def normalize_vertices(no): 
    no_out = no.copy()
    no_out[:, :3] = (no_out[:, :3] - no_out[:, :3].min(axis=0)) / (no_out[:, :3].max(axis=0) - no_out[:, :3].min(axis=0))
    return no_out

def view_wireframe(V: np.ndarray, F: np.ndarray, title="surface"):
    faces = np.hstack([np.full((len(F),1),3), F]).ravel()
    mesh = pv.PolyData(V, faces)
    p = pv.Plotter()
    p.add_mesh(mesh, show_edges=True, color='black', style='wireframe')
    p.add_axes(); p.show(title=title)

def view_cropped(V, F, title, point_on_plane=None, normal=None, mode="wireframe"):
    # default is half the model
    if point_on_plane is None:
        point_on_plane = V.mean(axis=1)
    if normal is None:
        normal = np.array([0,0,1])
    faces = np.hstack([np.full((len(F),1),3), F]).ravel()
    mesh = pv.PolyData(V, faces)
    clipped = mesh.clip(normal=normal, origin=point_on_plane, invert=False)
    p = pv.Plotter()
    p.add_mesh(clipped, show_edges=True, style="wireframe" if mode=="wireframe" else "surface")
    p.add_axes(); p.show(title=title)

In [16]:
mesh = meshio.read('./final_stitched.stl')
nodes = mesh.points
cells_dict = mesh.cells_dict
print(f"cells_dict: {cells_dict}")
elems = mesh.cells_dict['triangle']
print(nodes.shape)
print(elems.shape)

# quick report
m = trimesh.Trimesh(nodes, elems)
ms = ml.MeshSet()
ms.add_mesh(ml.Mesh(nodes, elems))
print("[ORIGINAL]")
quick_mesh_report(ms, i=0)

print("[NORMALIZING NODES]")
normalize_vertices_inplace(nodes)
m = trimesh.Trimesh(nodes, elems)
quick_mesh_report(m, i=1)

cells_dict: {'triangle': array([[     0,      1,      2],
       [     3,      2,      4],
       [     5,      6,      7],
       ...,
       [200203, 200213, 200206],
       [199634, 199643, 200212],
       [199643, 199641, 199642]], shape=(401532, 3))}
(200214, 3)
(401532, 3)
[ORIGINAL]
[quick 0 1/3] 200214 verts, 401532 faces, watertight: True, genus: 277, wind-consistent: True, h = 0.017, components: 1
[quick 0 2/3] vol = 1.8257260034360687, area = 50.91224601954718, non-manifold edges: 0/ vertices: 0
[quick 0 3/3] bbox: [0. 0. 0.] to [2.         2.         1.99899995]
[NORMALIZING NODES]
[quick 1 1/3] 200214 verts, 401532 faces, watertight: True, genus: 277, wind-consistent: True, h = 0.009, components: 1
[quick 1 2/3] vol = 0.22832992072480432, area = 12.732050087799307, non-manifold edges: 0/ vertices: 0
[quick 1 3/3] bbox: [0. 0. 0.] to [1. 1. 1.]


In [6]:
# tetgen tetrahedralization

print("[TETGEN TETRAHEDRALIZATION]")
tg = tetgen.TetGen(clean_nodes, clean_elems)
tg.tetrahedralize(order=1, minratio=1.1)
nodes_tet = tg.node
elems_tet = tg.elem
print(f"tetgen output: {nodes_tet.shape}, {elems_tet.shape}")

[TETGEN TETRAHEDRALIZATION]
tetgen output: (300215, 3), (916112, 4)


In [7]:
# meshio write as xdmf

tg.write('./final_stitched_tetra.vtu')
vtumsh = meshio.read('./final_stitched_tetra.vtu')
print(vtumsh.points.shape)
print(vtumsh.cells_dict)
vtumsh.write('./final_stitched_tetra.xdmf')

(300215, 3)
{'tetra': array([[213661, 213638, 213672, 257105],
       [ 29223,  29260,  29203,  29259],
       [  6708,   5415,   6707,   5470],
       ...,
       [124327, 124339, 124340, 206156],
       [157824, 157682, 157832, 217395],
       [260840, 271149, 272358, 276223]], shape=(916112, 4))}


In [8]:
grid = tg.grid
grid.plot(show_edges=True)
cells = grid.cells.reshape(-1, 5)[:, 1:]
cell_center = grid.points[cells].mean(1)

mask = cell_center[:, 1] < 0.5
cell_ind = mask.nonzero()[0]
subgrid = grid.extract_cells(cell_ind)

# advanced plotting
plotter = pv.Plotter()
plotter.add_mesh(subgrid, 'lightgrey', lighting=True, show_edges=True)
plotter.add_mesh(tg.mesh, 'red', 'wireframe', opacity=0.2) # style = 'wireframe'
plotter.add_legend([[' Input Mesh ', 'r'],
                    [' tetrahedralize mesh ', 'black']])
plotter.show()

Widget(value='<iframe src="http://localhost:37507/index.html?ui=P_0x7f5662f9a610_0&reconnect=auto" class="pyvi…

Widget(value='<iframe src="http://localhost:37507/index.html?ui=P_0x7f56510afad0_1&reconnect=auto" class="pyvi…

In [9]:
# def create_mesh(mesh, cell_type):
#     points = mesh.points
#     cells = mesh.get_cells_type(cell_type)
#     # This line needs to get cell data from what ever the markers have been called in the vtk file.
#     # Expore this by printing mesh.cell_data
#     # cell_data = mesh.get_cell_data("gmsh:physical", cell_type)
#     out_mesh = meshio.Mesh(points=points, cells={cell_type: cells})
#     #, cell_data={"name_to_read":[cell_data]})
#     return out_mesh

# in_mesh = meshio.read('./meshes/lung-cut/lung_cut_right.msh')
# print(in_mesh.cells_dict)
# print(in_mesh.points)
# print(in_mesh.points.shape)

# out_mesh = create_mesh(in_mesh, "tetra")
# meshio.write('./meshes/lung-cut/lung_cut_right.xdmf', out_mesh)

In [12]:
with io.XDMFFile(MPI.COMM_WORLD,  './final_stitched_tetra.xdmf', "r") as xdmf:
    xdmfmesh = xdmf.read_mesh(name="Grid")
    dim = xdmfmesh.topology.dim
fdim = dim - 1
    
# Scale mesh to [0,1]³ for periodic boundary conditions
x = xdmfmesh.geometry.x
min_coords = np.min(x, axis=0)
max_coords = np.max(x, axis=0)

# Scale to [0,1]³
for i in range(3):
    x[:, i] = (x[:, i] - min_coords[i]) / (max_coords[i] - min_coords[i])

In [15]:
# dolfinx integrate volume to compare
# tet_mesh = dmesh.create_mesh(MPI.COMM_WORLD, elems_tet, nodes_tet, dmesh.CellType.tetrahedron)

V = fem.functionspace(xdmfmesh, ("Lagrange", 1))
one = fem.Constant(xdmfmesh, PETSc.ScalarType(1.0))
# u = fem.Function(V)
# u.interpolate(one)
volume = fem.assemble_scalar(fem.form(one*ufl.dx))
print(f"tet mesh volume: {volume}")

tet mesh volume: 0.22832954898885044


In [None]:
# solve the poisson equation