# Triangle Meshes

This notebook demonstrates the application of `circle_bundles` to a synthetic dataset of 3D densities constructed from ...


In [None]:
# ============================================================
# Core scientific stack
# ============================================================
import numpy as np
import matplotlib.pyplot as plt


# ============================================================
# circle_bundles core API
# ============================================================
import circle_bundles as cb
import circle_bundles.synthetic as sy
import circle_bundles.viz as vz


In [None]:
from __future__ import annotations

def mesh_vertex_normals(
    X: np.ndarray,
    *,
    n_vertices: int | None = None,
    vertex_dim: int = 3,
    idx: tuple[int, int, int] = (0, 1, 2),
    eps: float = 1e-12,
) -> np.ndarray:
    """
    Compute the oriented unit normal determined by three vertices
    from flattened mesh-vertex data.

    Raises a ValueError if any triple is colinear or degenerate.

    Parameters
    ----------
    X:
        Shape (D,) for one mesh or (N, D) for batch.
    n_vertices:
        Optional expected vertex count.
    vertex_dim:
        Usually 3.
    idx:
        (i, j, k) vertex indices used.
        Orientation follows cross(vj-vi, vk-vi).
    eps:
        Tolerance for detecting degeneracy.

    Returns
    -------
    normals:
        Shape (3,) or (N,3) of unit normals.
    """
    X = np.asarray(X)
    single = (X.ndim == 1)

    if single:
        Xb = X[None, :]
    elif X.ndim == 2:
        Xb = X
    else:
        raise ValueError(f"X must be 1D or 2D. Got shape {X.shape}.")

    N, D = Xb.shape

    if D % vertex_dim != 0:
        raise ValueError(f"D={D} not divisible by vertex_dim={vertex_dim}.")

    nv = D // vertex_dim
    if n_vertices is not None and nv != int(n_vertices):
        raise ValueError(f"Expected {n_vertices} vertices, got {nv}.")

    i, j, k = map(int, idx)
    if not (0 <= i < nv and 0 <= j < nv and 0 <= k < nv):
        raise ValueError(f"indices {idx} out of range for nv={nv}")

    V = Xb.reshape(N, nv, vertex_dim)

    a = V[:, i]
    b = V[:, j]
    c = V[:, k]

    u = b - a
    v = c - a

    n = np.cross(u, v)
    norm = np.linalg.norm(n, axis=1)

    # Strict check
    bad = norm <= eps
    if np.any(bad):
        inds = np.where(bad)[0]
        raise ValueError(
            f"Colinear or degenerate vertex triples encountered at indices: {inds[:10]}"
            + (" ..." if len(inds) > 10 else "")
        )

    normals = n / norm[:, None]

    if single:
        return normals[0]
    return normals


In [None]:
from __future__ import annotations
import os
import numpy as np
import trimesh

def load_obj_as_trimesh(
    obj_path: str,
    *,
    process: bool = False,
    merge_scene: bool = True,
) -> trimesh.Trimesh:
    """
    Load an .obj (and its .mtl if referenced) into a single Trimesh.

    Notes
    -----
    - If the OBJ has multiple objects/material groups, trimesh may load a Scene.
    - If merge_scene=True, we concatenate geometry into one Trimesh (keeping transforms).
    """
    obj_path = os.fspath(obj_path)
    loaded = trimesh.load(obj_path, process=process)

    if isinstance(loaded, trimesh.Trimesh):
        return loaded

    if isinstance(loaded, trimesh.Scene):
        if not merge_scene:
            raise ValueError("OBJ loaded as a Scene. Set merge_scene=True or handle the Scene yourself.")

        meshes = []
        for name, geom in loaded.geometry.items():
            T = np.eye(4)
            # apply the scene graph transform for this geom if present
            try:
                T = loaded.graph.get(name)[0]
            except Exception:
                pass
            g = geom.copy()
            g.apply_transform(T)
            meshes.append(g)

        if not meshes:
            raise ValueError(f"No geometry found in OBJ scene: {obj_path}")

        return trimesh.util.concatenate(meshes)

    raise TypeError(f"Unsupported trimesh load type: {type(loaded)}")

from __future__ import annotations
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from typing import Callable, Tuple, List

def _stable_face_colors_from_mesh(mesh) -> np.ndarray:
    """
    Try to pull stable face colors from trimesh, else make a stable palette
    based on template face normals (pose-dependent but stable across rotations).
    Returns RGBA in [0,1], shape (F,4).
    """
    F = mesh.faces.shape[0]

    # Try trimesh-provided colors (from MTL, vertex colors, etc.)
    fc = getattr(getattr(mesh, "visual", None), "face_colors", None)
    if fc is not None and len(fc) == F:
        fc = np.asarray(fc)
        # trimesh usually stores uint8 RGBA
        if fc.dtype == np.uint8:
            return (fc / 255.0).astype(float)
        return fc.astype(float)

    # Fallback: cluster face normals from template mesh
    normals = np.asarray(mesh.face_normals, dtype=float)  # (F,3)

    # Quantize normals to bins to create stable groups
    # (coarser bins = fewer groups; adjust if you want more/less segmentation)
    q = np.round(normals * 8.0) / 8.0
    # map each unique quantized normal to an index
    uniq, inv = np.unique(q, axis=0, return_inverse=True)
    k = uniq.shape[0]

    cmap = plt.get_cmap("tab20") if k <= 20 else plt.get_cmap("hsv")
    colors = np.zeros((F, 4), dtype=float)
    for gi in range(k):
        colors[inv == gi] = cmap(gi / max(1, k - 1))
    return colors

def make_obj_mesh_visualizer(
    template_mesh,
    *,
    edge_color: str = "gray",
    alpha: float = 1.0,
    figsize: Tuple[float, float] = (5.0, 5.0),
    dpi: int = 150,
    elev: float = 10.0,
    azim: float = 20.0,
) -> Callable[[np.ndarray], Figure]:
    """
    Returns vis_func(flat_vertices) that renders the OBJ mesh using a stable
    face-color assignment computed from the template mesh (so colors don't shuffle
    when vertices are rotated).

    flat_vertices must be shape (3*V,) where V is the vertex count of template_mesh.
    """
    faces = np.asarray(template_mesh.faces, dtype=int)
    V = int(np.asarray(template_mesh.vertices).shape[0])
    face_colors = _stable_face_colors_from_mesh(template_mesh)  # (F,4)

    def vis_func(flat_vertices: np.ndarray) -> Figure:
        flat_vertices = np.asarray(flat_vertices, dtype=float)
        if flat_vertices.shape != (3 * V,):
            raise ValueError(f"Expected flat_vertices shape {(3*V,)}, got {flat_vertices.shape}")
        verts = flat_vertices.reshape(V, 3)
        tris = verts[faces]

        fig = plt.figure(figsize=figsize, dpi=dpi, facecolor="none")
        ax = fig.add_subplot(111, projection="3d", facecolor="none")
        ax.set_axis_off()

        poly = Poly3DCollection(
            tris,
            facecolors=face_colors,
            edgecolor=edge_color,
            alpha=float(alpha),
            linewidths=0.2,
        )
        ax.add_collection3d(poly)

        # Equal-ish scaling
        max_range = float(np.ptp(verts, axis=0).max() + 1e-12)
        mid = verts.mean(axis=0)
        lims = [(float(m - max_range / 2), float(m + max_range / 2)) for m in mid]
        ax.set_xlim(*lims[0])
        ax.set_ylim(*lims[1])
        ax.set_zlim(*lims[2])
        ax.set_box_aspect([1, 1, 1])

        ax.view_init(elev=float(elev), azim=float(azim))
        return fig

    return vis_func 


def simplify_mesh_to_target(
    mesh: trimesh.Trimesh,
    *,
    target_vertices: int = 1000,
    prefer_target_faces: bool = True,
) -> trimesh.Trimesh:
    """
    Simplify mesh geometry to roughly target_vertices.
    Uses quadratic decimation if available.

    Notes:
    - Most decimators take target FACE count, not vertex count.
    - We pick a face target that typically yields ~target_vertices.
    """
    m = mesh.copy()

    if prefer_target_faces:
        # heuristic: vertices ~ faces/2 for many triangle meshes
        target_faces = int(max(200, 2 * target_vertices))
    else:
        target_faces = int(max(200, 2 * target_vertices))

    # Try trimesh's quadratic decimation if available
    if hasattr(m, "simplify_quadratic_decimation"):
        m2 = m.simplify_quadratic_decimation(target_faces)
        # keep consistent dtype/processing minimal
        m2.process(validate=False)
        return m2

    raise RuntimeError(
        "No simplifier available via trimesh.simplify_quadratic_decimation on this install. "
        "If you want, I can give you an Open3D or PyMeshLab fallback."
    )
    

In [None]:
file_path = '/Users/bradturow/Desktop/Meshes/A_crazy_but_friendly__0729193044_texture_obj/A_crazy_but_friendly__0729193044_texture.obj'
mesh = load_obj_as_trimesh(file_path, process=False, merge_scene=True)

x0 = np.asarray(mesh.vertices, dtype=float).reshape(-1)  # (3V,)
print(x0.shape)
#vis = make_obj_mesh_visualizer(mesh, alpha=1.0)

#fig = vis(x0)
#fig.show()
print(mesh.visual.vertex_colors)

First, generate a dataset of triangle meshes, stored as vectors of length $3\times 6 = 18$:

In [None]:
#Create the template triangle mesh
mesh = sy.make_star_pyramid(n_points = 5, height = 1)

#Create a visualization function
vis_mesh = sy.make_star_pyramid_visualizer(mesh)

In [None]:
n_samples = 5000
rng = np.random.default_rng(0)
R = sy.sample_so3(n_samples, rng=rng)[0]

data = sy.get_mesh_sample(mesh, R)

View a small sample of the dataset:

In [None]:
fig = vz.show_data_vis(
    data, 
    vis_mesh, 
    max_samples=8, 
    n_cols=8, 
    sampling_method="first", 
    pad_frac=0.3)
plt.show()


Compute the base projections to $\mathbb{S}^{2}$:

In [None]:
base_points = mesh_vertex_normals(
    data,
    vertex_dim = 3,
    idx = (0, 1, 2), #choose vertices to use for projection
)

Construct an open cover of $\mathbb{S}^{2}$ using a collection of nearly equidistant landmark points (see reference section):

In [None]:
n_landmarks = 60
cover = cb.make_s2_fibonacci_star_cover(base_points, n_vertices = n_landmarks)

summ = cover.summarize(plot = True)

Optionally run the cell below to view a Plotly visualization of the nerve of the open cover:

In [None]:
fig = cover.show_nerve()

Compute a persistence diagram for the data in each set $\pi^{-1}(U_{j})$

In [None]:
fiber_ids, dense_idx_list, rips_list = cb.get_local_rips(
    data,
    cover.U,
    to_view = [4,25,56], #Choose a few diagrams to compute 
                       #(or compute all by setting to None)
    maxdim=1,
    n_perm=500,
    random_state=None,
)

fig, axes = cb.plot_local_rips(
    fiber_ids,
    rips_list,
    n_cols=3,
    titles='default',
    font_size=20,
)

Optionally run the cell below to show a visualization of an interactive visualization of the projection map:

In [None]:
app = vz.show_bundle_vis(base_points = base_points, data = data)

Compute local circular coordinates, approximate transition matrices and characteristic clases

In [None]:
bundle = cb.build_bundle(
    data,
    cover,
    show=True,                          
)


Now, restrict the bundle to the equator $\mathbb{S}^{1}\subset \mathbb{S}^{2}$:

In [None]:
eps = 0.15  # thickness of equatorial band (in the chosen S^2 embedding/coords)

# Points near the equator: last coordinate close to 0
eq_mask = np.abs(base_points[:, -1]) < eps

eq_data = bundle.data[eq_mask]

# Parametrize the equator by an angle in S^1 
eq_base_angles = np.arctan2(base_points[eq_mask, 1], base_points[eq_mask, 0]) % (2*np.pi)

print(f"Equator band: {eq_data.shape[0]} / {bundle.data.shape[0]} samples (eps={eps}).")

Constuct an open cover of $\mathbb{S}^{1}$ by metric balls around equally-spaced landmark points:

In [None]:
n_landmarks = 12
landmarks = np.linspace(0, 2*np.pi, n_landmarks, endpoint=False)

overlap = 1.5
radius = overlap * np.pi / n_landmarks

eq_cover = cb.MetricBallCover(
    eq_base_angles,
    landmarks,
    radius,
    metric=cb.S1AngleMetric(),
)
eq_cover_data = eq_cover.build()

#Show a summary of the construction
eq_summ = eq_cover.summarize(plot = True)

Compute characteristic classes for the restricted bundle:

In [None]:
eq_bundle = cb.build_bundle(
    eq_data,
    eq_cover,
    show=True,
)


Observe that the restricted bundle is orientable, hence trivial, as expected. Construct a global toroidal coordinate system by synchronizing local circular coordinates:

In [None]:
eq_triv_result = eq_bundle.get_global_trivialization()

Finally, show a visualization of the coordinatized meshes:

In [None]:
coords = np.column_stack([eq_base_angles, eq_triv_result.F])

fig = vz.lattice_vis(
    eq_data,
    coords,
    vis_mesh,
    per_row=7,
    per_col=7,
    figsize=10,
    thumb_px=100,
    dpi=200,
)

plt.show()


Base projection angle varies from $0$ to $2\pi$ along the $x$-direction and fiber angle varies from $0$ to $2\pi$ along the $y$-direction.  Notice that base projection roughly corresponds to axis of symmetry, as expected.  The coordinatized meshes in each column approximately traverse a full rotation about the axis of symmetry. Meshes on opposite edges of the diagram roughly correspond, reflecting the toroidal topology of the restricted dataset. 