In [None]:
import numpy as np
import sys
import trimesh
import tetgen 
import meshio 
import math

import matplotlib.pyplot as plt
import pymeshfix as pfix
import pyvista as pv
import iso2mesh as i2m
import pygalmesh as pygm
import pymeshlab as ml

from skimage.filters import threshold_otsu
from tqdm.notebook import tqdm
from skimage.transform import resize
from typing import Any, Dict, Optional, Tuple, List

sys.path.append('./src')  # to import alveoRVE from parent directory

from alveoRVE.plot.mpl import show_four_panel_volume
from alveoRVE.plot.pv import view_surface

%load_ext autoreload
%autoreload 2

In [None]:
def matlab_to_python_conv(no, fc): 
    no_out = no[:, :3].copy()
    fc_out = (np.atleast_2d(fc).astype(np.int64)[:, :3] - 1).astype(np.int32)
    return no_out, fc_out

In [None]:
def _triangle_quality(V, F):
    # 2*sqrt(3)*A / sum(l^2)  -> 1 for equilateral; 0 for degenerate
    p0, p1, p2 = V[F[:,0]], V[F[:,1]], V[F[:,2]]
    e0 = np.linalg.norm(p1 - p0, axis=1)
    e1 = np.linalg.norm(p2 - p1, axis=1)
    e2 = np.linalg.norm(p0 - p2, axis=1)
    A  = 0.5*np.linalg.norm(np.cross(p1 - p0, p2 - p0), axis=1)
    denom = e0**2 + e1**2 + e2**2
    with np.errstate(divide='ignore', invalid='ignore'):
        q = (2.0*math.sqrt(3.0))*A/denom
        q[~np.isfinite(q)] = 0.0
    return np.clip(q, 0.0, 1.0)


def mesh_report(mesh: trimesh.Trimesh, name="mesh", plot=True, tol_dup=1e-12) -> dict:
    """
    Print thorough stats and return a dict. Does NOT mutate the mesh.
    """
    m = mesh.copy()

    print(f"\n=== {name}: geometric/graph checks START ===")
    print(f"trimesh metrics:")
    print(f"mesh is_winding_consistent? {m.is_winding_consistent}")
    print(f"mesh is_watertight? {m.is_watertight}")
    print(f"verts={len(m.vertices):,}, faces={len(m.faces):,}")
    bbox = m.bounds
    L = bbox[1] - bbox[0]
    print(f"bbox min={bbox[0]}, max={bbox[1]}, extents={L}")

    # basic areas/volumes
    A = float(m.area)
    Vvol = float(m.volume) if m.is_watertight else None
    print(f"surface area = {A:.6g}")
    if Vvol is not None:
        print(f"enclosed volume (watertight) = {Vvol:.6g}")
    else:
        print("enclosed volume = N/A (mesh not watertight)")

    # edge-length stats (proxy for 'h')
    edges = m.edges_unique
    elen  = np.linalg.norm(m.vertices[edges[:,0]] - m.vertices[edges[:,1]], axis=1)
    print(f"h (edge length): min={elen.min():.4g}, mean={elen.mean():.4g}, max={elen.max():.4g}")

    # duplicates (faces & vertices)
    F_sorted = np.sort(m.faces, axis=1)
    _, idx = np.unique(F_sorted, axis=0, return_index=True)
    dup_faces = len(m.faces) - len(idx)
    Vkey = np.round(m.vertices / tol_dup).astype(np.int64)
    _, vidx = np.unique(Vkey, axis=0, return_index=True)
    dup_verts = len(m.vertices) - len(vidx)
    print(f"duplicate faces: {dup_faces:,}, duplicate vertices (<= {tol_dup:g}): {dup_verts:,}")

    V, F = m.vertices.copy(), m.faces.copy()

    print(f"== pymeshlab metrics:")

    ms = ml.MeshSet()
    ms.add_mesh(ml.Mesh(V, F))

    geo_measures = ms.get_geometric_measures()
    topo_measures = ms.get_topological_measures()
    print("Geometric measures:")
    for k, v in geo_measures.items():
        if k in ['surface_area', 'avg_edge_length', 'volume']:
            v = float(v)
            print(f"  {k}: {v}")
    print("Topological measures:")
    for k, v in topo_measures.items():
        if k in ['non_two_manifold_edges', 'boundary_edges', 'non_two_manifold_vertices', 'genus', 'faces_number', 'vertices_number', 'edges_number', 'connected_components_number']:
            v = int(v)
            print(f"  {k}: {v}")
    print(f"== custom metrics:")
    # triangle quality
    q = _triangle_quality(m.vertices, m.faces)
    q_stats = dict(min=float(q.min()), p5=float(np.percentile(q,5)),
                   mean=float(q.mean()), p95=float(np.percentile(q,95)),
                   max=float(q.max()))
    print(f"triangle quality q in [0,1] (equilateral=1): "
          f"min={q_stats['min']:.3f}, p5={q_stats['p5']:.3f}, "
          f"mean={q_stats['mean']:.3f}, p95={q_stats['p95']:.3f}, max={q_stats['max']:.3f}")


    if plot:
        plt.figure(figsize=(5,3))
        plt.hist(q, bins=40, range=(0,1), alpha=0.8)
        plt.xlabel("triangle quality q"); plt.ylabel("count"); plt.title(f"Quality histogram: {name}")
        plt.tight_layout(); plt.show()

    return dict(
        verts=len(m.vertices), faces=len(m.faces),
        area=A, volume=Vvol, bbox=bbox, h_stats=(float(elen.min()), float(elen.mean()), float(elen.max())),
        watertight=bool(m.is_watertight),
        dup_faces=int(dup_faces), dup_verts=int(dup_verts),
        tri_quality=q_stats
    )

def quick_mesh_report(ms: ml.MeshSet | trimesh.Trimesh, i: int = 0):
    # print(f"\n == pymeshlab quick metrics:")
    # number of vertices and faces

    flag = False
    if isinstance(ms, trimesh.Trimesh):
        flag = True
        V, F = ms.vertices, ms.faces
        ms = ml.MeshSet()
        ms.add_mesh(ml.Mesh(V, F))
    else: 
        V = ms.current_mesh().vertex_matrix()
        F = ms.current_mesh().face_matrix()

    n_verts = ms.current_mesh().vertex_number()
    n_faces = ms.current_mesh().face_number()
    geo_measures = ms.get_geometric_measures()
    topo_measures = ms.get_topological_measures()
    h = geo_measures['avg_edge_length']

    connected_components = topo_measures['connected_components_number']

    # trimesh watertightness
    trimesh_mesh = trimesh.Trimesh(
        vertices=np.asarray(ms.current_mesh().vertex_matrix(), float),
        faces=np.asarray(ms.current_mesh().face_matrix(), int),
        process=False
    )
    is_watertight = trimesh_mesh.is_watertight
    is_winding_consistent = trimesh_mesh.is_winding_consistent

    # trimesh volume
    vol = trimesh_mesh.volume if is_watertight else None
    area = trimesh_mesh.area

    # pymeshlab nonmanifold edges/faces
    n_nonmanifold_edges = int(topo_measures['non_two_manifold_edges'])
    n_nonmanifold_vertices = int(topo_measures['non_two_manifold_vertices'])

    print(f"[quick {i} 1/3] {n_verts} verts, {n_faces} faces, watertight: {is_watertight}, genus: {topo_measures['genus']}, wind-consistent: {is_winding_consistent}, h = {np.round(h, 3)}, components: {connected_components}\n[quick {i} 2/3] vol = {vol}, area = {area}, non-manifold edges: {n_nonmanifold_edges}/ vertices: {n_nonmanifold_vertices}\n[quick {i} 3/3] bbox: {V.min(axis=0)} to {V.max(axis=0)}")

    if flag: 
        del ms

In [None]:
def heal(ms: ml.MeshSet, manifold_method = 0):
    i = 0
    print("[REMOVING DUPLICATE VERTICES]")
    ms.meshing_remove_duplicate_vertices()
    i+1; quick_mesh_report(ms, i)
    print("[REMOVING DUPLICATE FACES]")
    ms.meshing_remove_duplicate_faces()
    i+1; quick_mesh_report(ms, i)
    print("[REMOVING NULL FACES]")
    ms.meshing_remove_null_faces()
    i+1; quick_mesh_report(ms, i)
    print("[REMOVING UNREFERENCED VERTICES]")
    ms.meshing_remove_unreferenced_vertices()
    i+1; quick_mesh_report(ms, i)
    print(f"[REPAIRING NON-MANIFOLD EDGES with manifold_method = {manifold_method}]")
    ms.meshing_repair_non_manifold_edges(method=manifold_method)
    i+1; quick_mesh_report(ms, i)
    print("[REPAIRING NON-MANIFOLD VERTICES]")
    ms.meshing_repair_non_manifold_vertices()
    i+1; quick_mesh_report(ms, i)
    print("[REMOVING NULL FACES AGAIN]")
    ms.meshing_remove_null_faces()
    i+1; quick_mesh_report(ms, i)
    print("[REMOVING UNREFERENCED VERTICES AGAIN]")
    ms.meshing_remove_unreferenced_vertices()
    i+1; quick_mesh_report(ms, i)
    print(f"[REPAIRING NON-MANIFOLD EDGES AGAINx2 with manifold_method = {1-manifold_method}]")
    ms.meshing_repair_non_manifold_edges(method=1-manifold_method)
    i+1; quick_mesh_report(ms, i)
    print("[REPAIRING NON-MANIFOLD VERTICES AGAINx2]")
    ms.meshing_repair_non_manifold_vertices()
    i+1; quick_mesh_report(ms, i)
    print("[REMOVING NULL FACES AGAINx2]")
    ms.meshing_remove_null_faces()
    i+1; quick_mesh_report(ms, i)
    print("[REMOVING UNREFERENCED VERTICES AGAINx2]")
    ms.meshing_remove_unreferenced_vertices()
    i+1; quick_mesh_report(ms, i)
    geo_measures = ms.get_geometric_measures()
    topo_measures = ms.get_topological_measures()

    if trimesh.Trimesh(
        vertices=np.asarray(ms.current_mesh().vertex_matrix(), float),
        faces=np.asarray(ms.current_mesh().face_matrix(), int),
        process=False
    ).volume < 0:
        ms.meshing_invert_face_orientation()

    
    # print("\nGeometric measures:")
    # for k, v in geo_measures.items():
    #     print(f"  {k}: {v}")
    # print("\nTopological measures:")
    # for k, v in topo_measures.items():
    #     print(f"  {k}: {v}")

def remove_close_verts(ms: ml.MeshSet, tol=1e-5):
    print(f"Removing close vertices with tol={tol}")
    initial_v = ms.current_mesh().vertex_number()
    ms.meshing_merge_close_vertices(threshold = ml.PercentageValue(tol*100))
    final_v = ms.current_mesh().vertex_number()
    print(f" - removed {initial_v - final_v} vertices")

In [None]:
def print_python_or_matlab_indexing(fc, name=""): 
    if np.min(fc) == 0:
        print(f" - {name} (python)")
    elif np.min(fc) == 1:
        print(f" - {name} (matlab)")
    else:
        print(f" - {name} (unknown indexing)")

def normalize_vertices_inplace(no):
    no[:, :3] = (no[:, :3] - no[:, :3].min(axis=0)) / (no[:, :3].max(axis=0) - no[:, :3].min(axis=0))

def normalize_vertices(no): 
    no_out = no.copy()
    no_out[:, :3] = (no_out[:, :3] - no_out[:, :3].min(axis=0)) / (no_out[:, :3].max(axis=0) - no_out[:, :3].min(axis=0))
    return no_out

def view_wireframe(V: np.ndarray, F: np.ndarray, title="surface"):
    faces = np.hstack([np.full((len(F),1),3), F]).ravel()
    mesh = pv.PolyData(V, faces)
    p = pv.Plotter()
    p.add_mesh(mesh, show_edges=True, color='black', style='wireframe')
    p.add_axes(); p.show(title=title)

def view_cropped(V, F, title, point_on_plane=None, normal=None, mode="wireframe"):
    # default is half the model
    if point_on_plane is None:
        point_on_plane = V.mean(axis=1)
    if normal is None:
        normal = np.array([0,0,1])
    faces = np.hstack([np.full((len(F),1),3), F]).ravel()
    mesh = pv.PolyData(V, faces)
    clipped = mesh.clip(normal=normal, origin=point_on_plane, invert=False)
    p = pv.Plotter()
    p.add_mesh(clipped, show_edges=True, style="wireframe" if mode=="wireframe" else "surface")
    p.add_axes(); p.show(title=title)

In [None]:
mesh = meshio.read('./output_after_dedup.stl')
nodes = mesh.points
cells_dict = mesh.cells_dict
print(f"cells_dict: {cells_dict}")
elems = mesh.cells_dict['triangle']
print(nodes.shape)
print(elems.shape)

# quick report
m = trimesh.Trimesh(nodes, elems)
ms = ml.MeshSet()
ms.add_mesh(ml.Mesh(nodes, elems))
print("[ORIGINAL]")
quick_mesh_report(ms, i=0)
print_python_or_matlab_indexing(elems, "elems")

# ppymeshfix cleaning
print("[PYMESHFIX CLEANING]")
clean_nodes, clean_elems = pfix.clean_from_arrays(nodes, elems)
m = trimesh.Trimesh(clean_nodes, clean_elems)
quick_mesh_report(m, i=1)
print_python_or_matlab_indexing(clean_elems, "clean_elems")

# reorient and heal (see if needed)
ms = ml.MeshSet()
ms.add_mesh(ml.Mesh(clean_nodes, clean_elems))
ms.meshing_invert_face_orientation()
quick_mesh_report(ms, i=0)
heal(ms)

# smoothing
print(f"[SMOOTHING]")
trimesh_mesh = trimesh.Trimesh(ms.current_mesh().vertex_matrix(), ms.current_mesh().face_matrix())
smoothed_mesh = trimesh.smoothing.filter_mut_dif_laplacian(trimesh_mesh, lamb=0.5, iterations=10)
quick_mesh_report(smoothed_mesh, i=0)

# pymeshfix cleaning again
print(f"[PYMESHFIX CLEANING AGAIN]")
clean_nodes2, clean_elems2 = pfix.clean_from_arrays(smoothed_mesh.vertices, smoothed_mesh.faces)
ms0 = ml.MeshSet()
ms0.add_mesh(ml.Mesh(clean_nodes2, clean_elems2))
quick_mesh_report(ms0, i=0)
del ms0
print_python_or_matlab_indexing(clean_elems2, "clean_elems2")

# normalize nodes 
print(f"[NORMALIZING]")
normalize_vertices_inplace(clean_nodes2)
ms0 = ml.MeshSet()
ms0.add_mesh(ml.Mesh(clean_nodes2, clean_elems2))
quick_mesh_report(ms0, i=0)

In [None]:
d = 0.01

x_plane_min_normal = (1,0,0)
x_plane_min_point = (d,0,0)
x_plane_max_normal = (-1,0,0)
x_plane_max_point = (1-d,0,0)

y_plane_min_normal = (0,1,0)
y_plane_min_point = (0,d,0)
y_plane_max_normal = (0,-1,0)
y_plane_max_point = (0,1-d,0)   

z_plane_min_normal = (0,0,1)
z_plane_min_point = (0,0,d)
z_plane_max_normal = (0,0,-1)
z_plane_max_point = (0,0,1-d)

normalized_mesh = trimesh.Trimesh(clean_nodes2, clean_elems2)

engine = "triangle" # "triangle" # "manifold", "earcut"

# consecutive slicing and capping
normalized_mesh_sliced_x_min = normalized_mesh.slice_plane(x_plane_min_point, x_plane_min_normal, cap=True, engine=engine)
print(f"new vertices after x min slice: {normalized_mesh_sliced_x_min.vertices.shape}")
normalized_mesh_sliced_x_max = normalized_mesh_sliced_x_min.slice_plane(x_plane_max_point, x_plane_max_normal, cap=True, engine=engine)
print(f"new vertices after x max slice: {normalized_mesh_sliced_x_max.vertices.shape}")
normalized_mesh_sliced_y_min = normalized_mesh_sliced_x_max.slice_plane(y_plane_min_point, y_plane_min_normal, cap=True, engine=engine)
print(f"new vertices after y min slice: {normalized_mesh_sliced_y_min.vertices.shape}")
normalized_mesh_sliced_y_max = normalized_mesh_sliced_y_min.slice_plane(y_plane_max_point, y_plane_max_normal, cap=True, engine=engine)
print(f"new vertices after y max slice: {normalized_mesh_sliced_y_max.vertices.shape}")
normalized_mesh_sliced_z_min = normalized_mesh_sliced_y_max.slice_plane(z_plane_min_point, z_plane_min_normal, cap=True, engine=engine)
print(f"new vertices after z min slice: {normalized_mesh_sliced_z_min.vertices.shape}")
normalized_mesh_sliced_all = normalized_mesh_sliced_z_min.slice_plane(z_plane_max_point, z_plane_max_normal, cap=True, engine=engine)
print(f"new vertices after z max slice: {normalized_mesh_sliced_all.vertices.shape}")

# normalize again
print(f"[NORMALIZING]")
normalize_vertices_inplace(normalized_mesh_sliced_all.vertices)

# mesh_report(normalized_mesh_sliced_all, name="after slicing and capping", plot=True)
ms0 = ml.MeshSet()
ms0.add_mesh(ml.Mesh(normalized_mesh_sliced_all.vertices, normalized_mesh_sliced_all.faces))
quick_mesh_report(ms0, i=0)
del ms0

# view_surface(normalized_mesh_sliced_all.vertices, normalized_mesh_sliced_all.faces, title="normalized")

# isotropic remeshing
h_cap = 0.01
iterations = 6  
ms = ml.MeshSet()
ms.add_mesh(ml.Mesh(normalized_mesh_sliced_all.vertices, normalized_mesh_sliced_all.faces))
print(f"[ISOTROPIC REMESHING to target h={h_cap}]")
ms.meshing_isotropic_explicit_remeshing(targetlen=ml.PercentageValue(h_cap*100),
                                        iterations=iterations,
                                        adaptive=False,
                                        reprojectflag=True
                                        )

heal(ms)

trimesh_remeshed_full = trimesh.Trimesh(np.asarray(ms.current_mesh().vertex_matrix()), np.asarray(ms.current_mesh().face_matrix()))

# view_surface(np.asarray(ms.current_mesh().vertex_matrix()), np.asarray(ms.current_mesh().face_matrix()), title="after remeshing")
# stats1 = mesh_report(trimesh_remeshed_full, name="after remeshing", plot=True)

In [None]:
from dataclasses import dataclass

@dataclass
class InterfaceMatchResult:
    axis: str
    plane_value: float
    tol: float
    quant: float
    nA_plane: int
    nB_plane: int
    n_common: int
    n_only_A: int
    n_only_B: int
    max_pos_error: float
    mean_pos_error: float
    mapping_B_to_A: dict   # key -> (iA, iB)
    unmatched_keys_A: list
    unmatched_keys_B: list

def compare_interface_nodes(
    VA: np.ndarray,
    VB: np.ndarray,
    axis: str = 'x',
    plane_value: float = 1.0,
    tol: float = 1e-9,
    quant: float | None = None,
    verbose: bool = True,
) -> InterfaceMatchResult:
    """
    Compare interface (planar) nodes of two meshes to see if they coincide.
    VA, VB : (N,3) float arrays
    axis   : interface normal axis: 'x', 'y', or 'z'
    plane_value : coordinate value of the interface plane
    tol    : absolute tolerance to decide if a vertex lies on the plane
    quant  : tangential quantization step; if None choose from min edge span heuristic
    Returns InterfaceMatchResult with stats and mapping pairs.
    """

    ax = {'x':0,'y':1,'z':2}[axis]
    oth = [i for i in range(3) if i != ax]

    # Select vertices on the plane
    IA = np.where(np.isclose(VA[:,ax], plane_value, atol=tol))[0]
    IB = np.where(np.isclose(VB[:,ax], plane_value, atol=tol))[0]

    if quant is None:
        # Heuristic quant: half of min non-zero spread of tangential coords
        spreads = []
        for o in oth:
            rngA = np.ptp(VA[IA,o]) if IA.size else 1.0
            rngB = np.ptp(VB[IB,o]) if IB.size else 1.0
            spreads.append(min(rngA, rngB))
        base = min(spreads) if spreads else 1.0
        quant = max(base * 1e-6, 1e-12)

    def key(p):
        return tuple(np.round(p[oth]/quant).astype(np.int64))

    # Build dictionaries keyed by quantized tangential coordinates
    dictA = {}
    for i in IA:
        dictA.setdefault(key(VA[i]), []).append(i)

    dictB = {}
    for j in IB:
        dictB.setdefault(key(VB[j]), []).append(j)

    keysA = set(dictA.keys())
    keysB = set(dictB.keys())
    common = keysA & keysB
    onlyA = sorted(keysA - keysB)
    onlyB = sorted(keysB - keysA)

    # Build 1-1 representative mapping using first index from each bucket
    mapping = {}
    errs = []
    for k in common:
        iA = dictA[k][0]
        iB = dictB[k][0]
        # positional error in full 3D
        err = np.linalg.norm(VA[iA] - VB[iB])
        errs.append(err)
        mapping[k] = (iA, iB)

    max_err = float(np.max(errs)) if errs else 0.0
    mean_err = float(np.mean(errs)) if errs else 0.0

    res = InterfaceMatchResult(
        axis=axis,
        plane_value=plane_value,
        tol=tol,
        quant=quant,
        nA_plane=len(IA),
        nB_plane=len(IB),
        n_common=len(common),
        n_only_A=len(onlyA),
        n_only_B=len(onlyB),
        max_pos_error=max_err,
        mean_pos_error=mean_err,
        mapping_B_to_A=mapping,
        unmatched_keys_A=onlyA,
        unmatched_keys_B=onlyB
    )

    if verbose:
        print(f"[interface check] axis={axis} plane={plane_value} tol={tol} quant={quant}")
        print(f"  on-plane counts: A={len(IA)} B={len(IB)}")
        print(f"  common keys: {len(common)}  onlyA={len(onlyA)}  onlyB={len(onlyB)}")
        print(f"  position error (A vs B matched reps): max={max_err:.3e} mean={mean_err:.3e}")
        if onlyA:
            print(f"  sample only-in-A key: {onlyA[0]}")
        if onlyB:
            print(f"  sample only-in-B key: {onlyB[0]}")
        if len(common) and max_err > 10*tol:
            print("  WARNING: large mismatch vs tol. Consider snapping.")

    return res

In [None]:
# reflect, and check interfaces

normalized_mesh_reflected = trimesh_remeshed_full.copy()
normalized_mesh_reflected.vertices[:,0] = 2 - trimesh_remeshed_full.vertices[:,0]
ms_reflected = ml.MeshSet()
ms_reflected.add_mesh(ml.Mesh(normalized_mesh_reflected.vertices, normalized_mesh_reflected
.faces))
heal(ms_reflected)
quick_mesh_report(ms_reflected, i=0)
normalized_mesh_reflected = trimesh.Trimesh(np.asarray(ms_reflected.current_mesh().vertex_matrix()), np.asarray(ms_reflected.current_mesh().face_matrix()), process=False)

In [None]:
# view_cropped(trimesh_remeshed_full.vertices, trimesh_remeshed_full.faces, title="original", point_on_plane=(0.5,0.5,0.5), normal=(1,0,0), mode="surface")
# view_cropped(normalized_mesh_reflected.vertices, normalized_mesh_reflected.faces, title="reflected", point_on_plane=(0.5,0.5,0.5), normal=(1,0,0), mode="surface")

In [None]:
import numpy as np
import pyvista as pv
import trimesh

def _to_pv_polydata(tm: trimesh.Trimesh) -> pv.PolyData:
    """Trimesh -> PyVista PolyData (faces in VTK cell array format)."""
    pts = np.asarray(tm.vertices, dtype=float)
    tri = np.asarray(tm.faces,    dtype=np.int64)
    if tri.ndim != 2 or tri.shape[1] != 3:
        raise ValueError("Expected triangular faces Nx3")
    cells = np.empty((tri.shape[0], 4), dtype=np.int64)
    cells[:, 0] = 3
    cells[:, 1:] = tri
    cells = cells.ravel()
    pd = pv.PolyData(pts, cells)
    return pd

def _from_pv_polydata(pd: pv.PolyData) -> trimesh.Trimesh:
    """PyVista PolyData -> Trimesh (triangles)."""
    # ensure triangles
    if not pd.is_all_triangles:
        pd = pd.triangulate()
    faces_vtk = np.asarray(pd.faces, dtype=np.int64)
    if faces_vtk.size == 0:
        return trimesh.Trimesh(vertices=np.asarray(pd.points, float), faces=np.empty((0,3), int), process=False)
    faces = faces_vtk.reshape(-1, 4)[:, 1:].astype(np.int32, copy=False)
    verts = np.asarray(pd.points, dtype=float)
    return trimesh.Trimesh(vertices=verts, faces=faces, process=False)

def vtk_weld_two(triA: trimesh.Trimesh, triB: trimesh.Trimesh, tol: float = 5e-9) -> trimesh.Trimesh:
    """Append two touching shells and weld coincident seam vertices with absolute tolerance."""
    pa = _to_pv_polydata(triA)
    pb = _to_pv_polydata(triB)

    print(f"[pv] A: V={pa.n_points:,} F={pa.n_cells:,} | B: V={pb.n_points:,} F={pb.n_cells:,}")
    p = pa + pb  # vtkAppendPolyData
    print(f"[pv] appended: V={p.n_points:,} F={p.n_cells:,}")

    # clean: absolute tolerance; merge coincident points; drop duplicate cells
    p = p.clean(tolerance=tol, absolute=True, point_merging=True)
    print(f"[pv] after clean: V={p.n_points:,} F={p.n_cells:,}")

    # ensure triangles and drop duplicates again just in case
    p = p.triangulate()
    print(f"[pv] after triangulate+dedup: V={p.n_points:,} F={p.n_cells:,}")

    tm = _from_pv_polydata(p)
    print(f"[pv] back to trimesh: V={len(tm.vertices):,} F={len(tm.faces):,} | watertight={tm.is_watertight}")
    return tm


In [None]:
tri_welded = vtk_weld_two(trimesh_remeshed_full, normalized_mesh_reflected, tol=1e-8)
quick_mesh_report(tri_welded, i=0)  # your reporter
# view_surface(tri_welded.vertices, tri_welded.faces, title="VTK welded")  # if you want to inspect

# pymeshfix fix 
print(f"[PYMESHFIX CLEANING WELDED]")
clean_nodes3, clean_elems3 = pfix.clean_from_arrays(tri_welded.vertices, tri_welded.faces)
m = trimesh.Trimesh(clean_nodes3, clean_elems3)
quick_mesh_report(m, i=1)

In [None]:
view_surface(m.vertices, m.faces, title="after final pymeshfix")

In [None]:
view_wireframe(m.vertices, m.faces, title="after final pymeshfix wireframe")

In [None]:
view_cropped(m.vertices, m.faces, title="after final pymeshfix cropped", point_on_plane=(0.5,0.5,0.5), normal=(0,1,0), mode="surface")

In [None]:
quick_mesh_report(m, i=2)

In [None]:
# reflect now in y and repeat

normalized_mesh_reflected_y = m.copy()
normalized_mesh_reflected_y.vertices[:,1] = 2 - m.vertices[:,1]
ms_reflected_y = ml.MeshSet()
ms_reflected_y.add_mesh(ml.Mesh(normalized_mesh_reflected_y.vertices, normalized_mesh_reflected_y.faces))
heal(ms_reflected_y)
normalized_mesh_reflected_y = trimesh.Trimesh(np.asarray(ms_reflected_y.current_mesh().vertex_matrix()), np.asarray(ms_reflected_y.current_mesh().face_matrix()), process=False)
print(f"[REFLECTED IN Y]")
quick_mesh_report(ms_reflected_y, i=0)

In [None]:
tri_welded_y = vtk_weld_two(m, normalized_mesh_reflected_y, tol=1e-8)
quick_mesh_report(tri_welded_y, i=0)  # your reporter

#pymeshfix fix
print(f"[PYMESHFIX CLEANING WELDED Y]")
clean_nodes4, clean_elems4 = pfix.clean_from_arrays(tri_welded_y.vertices, tri_welded_y.faces)
m = trimesh.Trimesh(clean_nodes4, clean_elems4)
quick_mesh_report(m, i=1)

print(f"[PYMESHFIX CLEANING WELDED Y AGAIN]")
clean_nodes4, clean_elems4 = pfix.clean_from_arrays(m.vertices, m.faces)
m = trimesh.Trimesh(clean_nodes4, clean_elems4)
quick_mesh_report(m, i=1)

ms_y = ml.MeshSet()
ms_y.add_mesh(ml.Mesh(m.vertices, m.faces))
heal(ms_y)
tri_welded_y = trimesh.Trimesh(np.asarray(ms_y.current_mesh().vertex_matrix()), np.asarray(ms_y.current_mesh().face_matrix()), process=False)
quick_mesh_report(tri_welded_y, i=1)
view_surface(tri_welded_y.vertices, tri_welded_y.faces, title="VTK welded y")  # if you want to inspect

In [None]:
# reflect now in z and repeat

normalized_mesh_reflected_z = m.copy()
normalized_mesh_reflected_z.vertices[:,2] = 2 - m.vertices[:,2]
ms_reflected_z = ml.MeshSet()
ms_reflected_z.add_mesh(ml.Mesh(normalized_mesh_reflected_z.vertices, normalized_mesh_reflected_z.faces))
heal(ms_reflected_z)
normalized_mesh_reflected_z = trimesh.Trimesh(np.asarray(ms_reflected_z.current_mesh().vertex_matrix()), np.asarray(ms_reflected_z.current_mesh().face_matrix()), process=False)
print(f"[REFLECTED IN Z]")
quick_mesh_report(ms_reflected_z, i=0)

In [None]:
tri_welded_z = vtk_weld_two(m, normalized_mesh_reflected_z, tol=1e-8)
quick_mesh_report(tri_welded_z, i=0)  # your reporter

#pymeshfix fix
print(f"[PYMESHFIX CLEANING WELDED Z]")
clean_nodes4, clean_elems4 = pfix.clean_from_arrays(tri_welded_z.vertices, tri_welded_z.faces)
m = trimesh.Trimesh(clean_nodes4, clean_elems4)
quick_mesh_report(m, i=1)

print(f"[PYMESHFIX CLEANING WELDED Z AGAIN]")
clean_nodes4, clean_elems4 = pfix.clean_from_arrays(m.vertices, m.faces)
m = trimesh.Trimesh(clean_nodes4, clean_elems4)
quick_mesh_report(m, i=1)

ms_z = ml.MeshSet()
ms_z.add_mesh(ml.Mesh(m.vertices, m.faces))
heal(ms_z)
tri_welded_z = trimesh.Trimesh(np.asarray(ms_z.current_mesh().vertex_matrix()), np.asarray(ms_z.current_mesh().face_matrix()), process=False)
quick_mesh_report(tri_welded_z, i=1)
view_surface(tri_welded_z.vertices, tri_welded_z.faces, title="VTK welded z")  # if you want to inspect

In [None]:
# normalize 

print(f"[NORMALIZING FINAL]")
normalize_vertices_inplace(tri_welded_z.vertices)
ms_final = ml.MeshSet()
ms_final.add_mesh(ml.Mesh(tri_welded_z.vertices, tri_welded_z.faces))
quick_mesh_report(ms_final, i=0)

VVV = ms_final.current_mesh().vertex_matrix()
FFF = ms_final.current_mesh().face_matrix()

In [None]:
# --- periodic seam pairing preview (for future MPC) ---
import trimesh, triangle as tr
from scipy.spatial import cKDTree
import pyvista as pv

AXIS_ID = {'x':0,'y':1,'z':2}

def _axis_pair(ax:str):
    a = AXIS_ID[ax]
    t = [0,1,2]; t.remove(a)
    return a, t[0], t[1]

def preview_periodic_pairs(V:np.ndarray, axis:str, tol:float=1e-6):
    a,t1,t2 = _axis_pair(axis)
    Smin = np.where(np.isclose(V[:,a], 0.0, atol=tol))[0]
    Smax = np.where(np.isclose(V[:,a], 1.0, atol=tol))[0]
    kd = cKDTree(V[Smax][:,[t1,t2]])
    d,j = kd.query(V[Smin][:,[t1,t2]], distance_upper_bound=max(tol, 10*tol))
    ok = np.isfinite(d)
    n_ok = int(np.sum(ok)); n_tot = len(Smin)
    print(f"[pairs {axis}] matched {n_ok}/{n_tot} within tol={tol:g}")
    return Smin[ok], Smax[j[ok]]

In [None]:
preview_periodic_pairs(VVV, 'x', tol=1e-6)
preview_periodic_pairs(VVV, 'y', tol=1e-6)
preview_periodic_pairs(VVV, 'z', tol=1e-6)

In [None]:
# calculate area of the flat faces

def area_of_flat_faces(V: np.ndarray, F: np.ndarray, axis: str, plane_value: float, tol: float = 1e-9) -> float:
    ax = AXIS_ID[axis]
    oth = [i for i in range(3) if i != ax]

    # Select faces that have all three vertices on the plane
    face_mask = np.all(np.isclose(V[F][:, :, ax], plane_value, atol=tol), axis=1)
    F_plane = F[face_mask]

    if len(F_plane) == 0:
        print(f"No faces found on the {axis}={plane_value} plane within tol={tol}")
        return 0.0

    # Calculate area of these faces
    p0, p1, p2 = V[F_plane][:, 0], V[F_plane][:, 1], V[F_plane][:, 2]
    A = 0.5 * np.linalg.norm(np.cross(p1 - p0, p2 - p0), axis=1)
    total_area = float(np.sum(A))

    print(f"Total area of faces on the {axis}={plane_value} plane: {total_area:.6g} (from {len(F_plane)} faces)")
    return total_area

In [None]:
area_x = area_of_flat_faces(VVV, FFF, 'x', 0.0, tol=1e-9)
area_y = area_of_flat_faces(VVV, FFF, 'y', 0.0, tol=1e-9)
area_z = area_of_flat_faces(VVV, FFF, 'z', 0.0, tol=1e-9)

: 