# Visualize Solution

In [20]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
import numpy as np
import trimesh
import tempfile
import cadquery as cq
from typing import Optional
import importlib.util
from pathlib import Path
import open3d as o3d

## Solution

In [22]:
solution = trimesh.load_mesh("./baseline.stl")

## Predicted

In [23]:
def cq_to_trimesh(workplane: cq.Workplane) -> trimesh.Trimesh:
    with tempfile.NamedTemporaryFile(suffix=".stl", delete=False) as temp_file:
        workplane.export(temp_file.name)
        mesh = trimesh.load_mesh(temp_file.name)
    return mesh


def load_from_checkpoint(
    checkpoint: int,
    folder_path: Path = Path(
        "./openevolve_output_20250725_204139/checkpoints",
    ),
) -> trimesh.Trimesh:
    path = folder_path / f"checkpoint_{checkpoint}" / "best_program.py"
    spec = importlib.util.spec_from_file_location("program", path)
    assert spec is not None and spec.loader is not None, "Failed to load program module"
    program = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(program)

    if not hasattr(program, "build_3d_figure"):
        raise RuntimeError(
            "Error: program does not have the main 'build_3d_figure' function"
        )
    return cq_to_trimesh(program.build_3d_figure())


# predicted = [load_from_checkpoint(4), load_from_checkpoint(5)]

## Plot

In [24]:
def tri_to_o(trimesh_mesh: trimesh.Trimesh) -> o3d.geometry.TriangleMesh:
    vertices = np.asarray(trimesh_mesh.vertices)
    triangles = np.asarray(trimesh_mesh.faces)

    o3d_mesh = o3d.geometry.TriangleMesh()
    o3d_mesh.vertices = o3d.utility.Vector3dVector(vertices)
    o3d_mesh.triangles = o3d.utility.Vector3iVector(triangles)

    return o3d_mesh


def o_to_tri(o3d_mesh):
    vertices = np.asarray(o3d_mesh.vertices)
    faces = np.asarray(o3d_mesh.triangles)

    return trimesh.Trimesh(vertices=vertices, faces=faces)


def preprocess_point_cloud(pcd, voxel_size):
    pcd_down = pcd.voxel_down_sample(voxel_size)

    radius_normal = voxel_size * 2
    pcd_down.estimate_normals(
        o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, max_nn=30)
    )

    radius_feature = voxel_size * 5
    pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature(
        pcd_down,
        o3d.geometry.KDTreeSearchParamHybrid(radius=radius_feature, max_nn=100),
    )
    return pcd_down, pcd_fpfh


def execute_global_registration(
    source_down, target_down, source_fpfh, target_fpfh, voxel_size
):
    distance_threshold = voxel_size * 1.5
    result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching(
        source_down,
        target_down,
        source_fpfh,
        target_fpfh,
        True,
        distance_threshold,
        o3d.pipelines.registration.TransformationEstimationPointToPoint(False),
        3,
        [
            o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9),
            o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(
                distance_threshold
            ),
        ],
        o3d.pipelines.registration.RANSACConvergenceCriteria(100000, 0.999),
    )
    return result


def align_rot(
    source_mesh: trimesh.Trimesh,
    target_mesh: trimesh.Trimesh,
    n_points: int = 10000,
    voxel_size: float = 0.1,
) -> trimesh.Trimesh:
    # Convert to Open3D triangle meshes and sample point clouds
    target_pcd = tri_to_o(target_mesh).sample_points_uniformly(n_points)
    source_pcd = tri_to_o(source_mesh).sample_points_uniformly(n_points)
    # Preprocess point clouds
    source_down, source_fpfh = preprocess_point_cloud(source_pcd, voxel_size)
    target_down, target_fpfh = preprocess_point_cloud(target_pcd, voxel_size)

    # Register point clouds
    result_ransac = execute_global_registration(
        source_down, target_down, source_fpfh, target_fpfh, voxel_size
    )

    # Transform original Open3D mesh and convert back to trimesh
    source_o3d = tri_to_o(source_mesh)
    source_o3d.transform(result_ransac.transformation)

    return o_to_tri(source_o3d)


def center_mesh(mesh: trimesh.Trimesh) -> trimesh.Trimesh:
    # Get the centroid of the mesh
    centroid = mesh.centroid

    # Create a translation matrix
    T = np.eye(4)
    T[:3, 3] = -centroid  # translate by negative centroid

    # Apply transformation
    centered = mesh.copy()
    centered.apply_transform(T)
    return centered


def transform(obj: trimesh.Trimesh) -> trimesh.Trimesh:
    """Normalizes a mesh to be centered and fit within a unit cube."""
    center = obj.bounds.mean(axis=0)
    obj.apply_translation(-center)
    scale = obj.extents.max()
    if scale > 1e-7:
        # if scale > 1:
        obj.apply_scale(1.0 / scale)
    return center_mesh(obj)


def plot_mesh_comparison_scene(
    meshes: list[Optional[trimesh.Trimesh]],
    colors: Optional[list[Optional[np.ndarray]]] = None,
    align: Optional[bool] = False,
):
    """
    Plot valid meshes side by side using trimesh's Scene system.

    Parameters:
        meshes: List of trimesh.Trimesh objects (some may be None or invalid).
        colors: Optional list of face colors (same length as meshes).
        align: Whether to align meshes to the first valid one (default: False).
    """
    scene = trimesh.Scene()
    valid_meshes = []
    valid_colors = []

    # Filter out None or empty/broken meshes
    for mesh, color in zip(meshes, colors or [None] * len(meshes)):
        if mesh is None or not isinstance(mesh, trimesh.Trimesh) or mesh.is_empty:
            print("Warning: Skipping empty or invalid mesh.")
            continue
        valid_meshes.append(mesh)
        valid_colors.append(color)

    if not valid_meshes:
        raise ValueError("No valid meshes to display.")

    valid_meshes = [transform(m) for m in valid_meshes]
    # Compute offset based on valid meshes only
    offset = max(m.extents[0] for m in valid_meshes) * 1.2

    # Center the baseline mesh (first one)
    baseline = valid_meshes[0]

    for idx, (mesh, color) in enumerate(zip(valid_meshes, valid_colors)):
        if idx > 0 and align:
            # mesh = align_mesh(mesh, baseline)
            mesh = transform(align_rot(mesh, baseline))
        mesh.visual.face_colors = None if color is None else color
        scene.add_geometry(
            mesh,
            transform=trimesh.transformations.translation_matrix([offset * idx, 0, 0]),
        )

    scene.export("./scene.stl")
    return scene.show()


# plot_mesh_comparison_scene(
#     [solution, solution.copy().apply_transform(trimesh.transformations.translation_matrix([1,2,3]) @ trimesh.transformations.rotation_matrix(3.14 / 4, [1,0,0], [0,0,0]))],
#     colors=[
#         np.array([0, 255, 0]),
#         np.array([255, 0, 0]),
#     ],
#     align=True,
# )

In [30]:
start = 0
end = 19

# indices = range(start, end + 1)
indices = [0, 1, 4, 37]

plot_mesh_comparison_scene(
    [solution, *[load_from_checkpoint(idx) for idx in indices]],
    # [*[load_from_checkpoint(idx) for idx in range(start, end + 1)]],
    colors=[
        np.array([0, 255, 0]),
        *[None for _ in range(len(indices) - 1)],
        np.array([255, 0, 0]),
    ],
    align=True,
)



In [26]:
aaa = load_from_checkpoint(2)
print(aaa.vertices)

[[ 5.27028923e+01 -1.29084857e-14 -8.49847336e+01]
 [ 4.84122925e+01 -1.18575914e-14 -8.75000000e+01]
 [ 4.83521118e+01  2.41315365e+00 -8.75000000e+01]
 ...
 [-2.72541103e+01 -1.49271941e+00  8.76059189e+01]
 [-3.24070091e+01 -5.28335047e+00  9.04191284e+01]
 [-3.21300354e+01 -2.64812589e+00  9.04191284e+01]]


In [27]:
solution.vertices
print(solution.vertices)

[[-0.26168037 -0.00253304 -0.17307182]
 [-0.26731267 -0.0059901  -0.16173609]
 [-0.27098333  0.00549365 -0.16809163]
 ...
 [ 0.10048498 -0.24538569  0.04858939]
 [ 0.09846355 -0.24087664  0.04561778]
 [ 0.09839722 -0.24556119  0.05424616]]
