In [1]:
import sys
# caution: path[0] is reserved for script path (or '' in REPL)
sys.path.insert(1, '../utils')

In [2]:
import os
import open3d as o3d
import numpy as np
import glob
from tqdm import tqdm
import trimesh

In [3]:
data_root = "../data/ModelResource_RigNetv1_preproccessed"
obj_folder = f'{data_root}/obj'
rig_folder = f'{data_root}/rig_info'

objs_files = glob.glob

In [4]:
def get_mesh_joint_pair(obj_index, obj_folder, rig_folder):
    
    obj_path = f'{obj_folder}/{obj_index}.obj'
    rig_path = f'{rig_folder}/{obj_index}.txt'

    # Load and pre-process mesh
    mesh = trimesh.load_mesh(obj_path)
    centroid = mesh.centroid
    mesh.apply_translation(-centroid)

    # extraxt join locations
    # disregard bone info. Only want joint locations
    joints = []
    with open(rig_path, "r") as f:
        tokens = f.readline().split()
        while(tokens[0] == "joints"):
            joints.append(list(map(float, tokens[2:])))
            tokens = f.readline().split()

    joints = np.array(joints) - centroid

    return mesh, joints

In [5]:
mesh, joints = get_mesh_joint_pair(341, obj_folder, rig_folder)
mesh, joints

(<trimesh.Trimesh(vertices.shape=(1623, 3), faces.shape=(3172, 3))>,
 array([[-4.80618542e-06, -7.20965129e-02, -9.80135224e-02],
        [ 3.53591938e-02, -8.28225129e-02, -1.18517522e-01],
        [-4.78618542e-06, -6.49445129e-02, -7.98179224e-02],
        [-3.53688062e-02, -8.28225129e-02, -1.18517522e-01],
        [-4.82618542e-06, -6.12175129e-02, -1.34724022e-01],
        [ 6.55807938e-02, -1.60356913e-01, -8.75791224e-02],
        [-4.79618542e-06, -6.00325129e-02, -3.48797224e-02],
        [-6.55904062e-02, -1.60356813e-01, -8.75793224e-02],
        [-4.78618542e-06, -6.06325129e-02, -2.31827522e-01],
        [ 6.57489938e-02, -2.09598813e-01, -1.19307122e-01],
        [-4.77618542e-06, -5.54125129e-02,  1.72098776e-02],
        [-6.57586062e-02, -2.09598713e-01, -1.19307122e-01],
        [-4.75618542e-06, -6.00465129e-02, -3.28931522e-01],
        [ 1.89700938e-02, -3.96565129e-02,  5.59514776e-02],
        [ 7.90632381e-03, -5.52555129e-02,  5.40774776e-02],
        [-4.7461

### Data Analysis

In [6]:
mesh_idx_files = glob.glob(os.path.join(data_root, '*_final.txt'))
mesh_idx_files
mesh_idxs = []
for fpath in mesh_idx_files:
    with open(fpath, 'r+') as f:
        mesh_idxs.extend(list(map(int, f.read().splitlines())))

len(mesh_idxs)

2703

In [9]:
def compute_mesh_stats(mesh_idxs, obj_folder, rig_folder):
    """
    For each mesh index in mesh_idxs, loads the mesh via get_mesh_joint_pair,
    then computes:
      - number of vertices
      - number of faces (triangles)
    Returns a dict with means, mins & maxes.
    """
    vert_counts = []
    face_counts = []
    below1k = 0
    above5k = 0

    for idx in tqdm(mesh_idxs):
        mesh, _ = get_mesh_joint_pair(idx, obj_folder, rig_folder)
        verts = np.asarray(mesh.vertices)
        tris  = np.asarray(mesh.triangles)

        if len(verts) < 1000:
            below1k += 1
        
        if len(verts) > 5000:
            above5k += 1

        vert_counts.append(len(verts))
        face_counts.append(len(tris))

    vert_counts = np.array(vert_counts)
    face_counts = np.array(face_counts)

    stats = {
        'vertices': {
            'mean': vert_counts.mean(),
            'min':   vert_counts.min(),
            'max':   vert_counts.max(),
            'std':   vert_counts.std(),
        },
        'faces': {
            'mean': face_counts.mean(),
            'min':   face_counts.min(),
            'max':   face_counts.max(),
            'std':   face_counts.std(),
        },
        'other': {
            'below1k': below1k,
            'above5k': above5k
        }
    }

    return stats

In [8]:
stats = compute_mesh_stats(mesh_idxs, obj_folder, rig_folder)
print("Vertex count: mean={mean:.1f}, min={min}, max={max}, std={std:.1f}".format(**stats['vertices']))
print("Face   count: mean={mean:.1f}, min={min}, max={max}, std={std:.1f}".format(**stats['faces']))
print("Other     : below_5k={below1k}, above_5k={above5k}".format(**stats['other']))

100%|██████████| 2703/2703 [00:13<00:00, 202.09it/s]

Vertex count: mean=1308.3, min=56, max=13858, std=1194.2
Face   count: mean=2499.8, min=102, max=33520, std=2384.8
Other     : below_5k=1549, above_5k=53





We need to decimate meshes to have ~5000 triangles (which results in <= 5000 vertices)

Also, lots of meshes have fewer than 1000 verts so we can perform triangle subdivision before decimation

In [10]:
def subdivide_to_min_verts(mesh: trimesh.Trimesh, min_verts: int = 1000) -> trimesh.Trimesh:
    """
    Repeatedly apply Loop subdivision until mesh has at least min_verts vertices.
    Stops early if subdivision no longer increases the vertex count.
    """
    current = len(mesh.vertices)
    # early exit
    if current >= min_verts:
        return mesh

    while current < min_verts:
        # apply one iteration of Loop subdivision
        mesh_sub = mesh.subdivide()  # same as mesh.subdivide_loop()
        new_count = len(mesh_sub.vertices)

        # if no growth, bail out
        if new_count <= current:
            break

        mesh = mesh_sub
        current = new_count

    return mesh

In [11]:
def decimate_to_range(mesh: trimesh.Trimesh, min_tris=1000, max_tris=8000, shrink_factor=0.8):
    """
    Quadric-decimates mesh so that its triangle count ends up in [min_tris, max_tris],
    or stops early if it can't get any smaller.
    """
    current = len(mesh.triangles)
    
    # If we’re already within the target band, do nothing
    if min_tris <= current <= max_tris:
        return mesh
    
    # Only decimate when above max_tris
    while current > max_tris:
        # pick a new target strictly between min_tris and current
        target = int(current * shrink_factor)
        # clamp to the lower bound so we don't go below min_tris
        target = max(target, min_tris)
        
        # if target is not strictly less, we can’t make progress
        if target >= current:
            break
        
        mesh_dec = mesh.simplify_quadric_decimation(face_count=target)
        new_count = len(mesh_dec.triangles)
        
        # if no triangles were lost, bail
        if new_count >= current:
            break
        
        mesh = mesh_dec
        current = new_count
    
    mesh.remove_unreferenced_vertices()
    return mesh

In [12]:
def compute_mesh_stats(mesh_idxs, obj_folder, rig_folder):
    """
    For each mesh index in mesh_idxs, loads the mesh via get_mesh_joint_pair,
    then computes:
      - number of vertices
      - number of faces (triangles)
    Returns a dict with means, mins & maxes.
    """
    vert_counts = []
    face_counts = []
    below1k = 0
    above5k = 0
    for idx in tqdm(mesh_idxs):
        mesh, _ = get_mesh_joint_pair(idx, obj_folder, rig_folder)

        mesh = subdivide_to_min_verts(mesh, 1000)
        mesh = decimate_to_range(mesh, 1000, 8000)

        verts = np.asarray(mesh.vertices)
        tris = np.asarray(mesh.triangles)

        if len(verts) < 1000:
            print(idx)
            below1k += 1
        
        if len(verts) > 5000:
            above5k += 1

        vert_counts.append(len(verts))
        face_counts.append(len(tris))

    vert_counts = np.array(vert_counts)
    face_counts = np.array(face_counts)

    stats = {
        'vertices': {
            'mean': vert_counts.mean(),
            'min':   vert_counts.min(),
            'max':   vert_counts.max(),
            'std':   vert_counts.std(),
        },
        'faces': {
            'mean': face_counts.mean(),
            'min':   face_counts.min(),
            'max':   face_counts.max(),
            'std':   face_counts.std(),
        },
        'other': {
            'below1k': below1k,
            'above5k': above5k
        }
    }

    return stats

In [13]:
stats = compute_mesh_stats(mesh_idxs, obj_folder, rig_folder)
print("Vertex count: mean={mean:.1f}, min={min}, max={max}, std={std:.1f}".format(**stats['vertices']))
print("Face   count: mean={mean:.1f}, min={min}, max={max}, std={std:.1f}".format(**stats['faces']))
print("Other     : below_1k={below1k}, above_5k={above5k}".format(**stats['other']))


100%|██████████| 2703/2703 [00:14<00:00, 191.37it/s]

Vertex count: mean=2356.8, min=1000, max=4494, std=846.3
Face   count: mean=4573.9, min=1500, max=7999, std=1694.3
Other     : below_1k=0, above_5k=0





### Load and Preprocess Mesh Subroutine

In [14]:
def load_and_preprocess_mesh(obj_path,
                             min_verts=1000,
                             min_tris=1000,
                             max_tris=8000):
    
    # These param numbers obtained just be recomputing stats until 
    # there were no meshes below 1k verts, no meshes aboce 5k verts, 
    # and pushing maximum num verts to be as close to 5k as possible
    """
    1) Load + repair raw mesh
    2) Center at origin
    3) Subdivide up to >= min_verts
    4) Decimate down into [min_tris, max_tris]
    5) Final clean + normals
    Returns (mesh, centroid), so you can apply the same centering to your rig/joints.
    """

    # --- 1) Load & initial repair ---
    mesh = trimesh.load_mesh(obj_path, process=False)

    # drop any zero‐area or duplicate bits
    mesh.update_faces(mesh.nondegenerate_faces())
    mesh.update_faces(mesh.unique_faces())
    mesh.remove_unreferenced_vertices()
    mesh.fill_holes() # closes small cracks that might break subdivision
    
    # --- 2) Center at origin ---
    centroid = mesh.centroid
    mesh.apply_translation(-centroid)
    
    # --- 3) Grow small meshes up to min_verts ---
    mesh = subdivide_to_min_verts(mesh, min_verts=min_verts)
    
    # --- 4) Shrink big meshes into [min_tris, max_tris] ---
    mesh = decimate_to_range(mesh,
                             min_tris=min_tris,
                             max_tris=max_tris)
    
    # --- 5) Final cleanup 
    mesh.update_faces(mesh.nondegenerate_faces())
    mesh.remove_unreferenced_vertices()
    mesh.fill_holes()

    return mesh, centroid

In [15]:
obj_path = f'{obj_folder}/13.obj'
mesh, centroid = load_and_preprocess_mesh(obj_path)

In [16]:
mesh

<trimesh.Trimesh(vertices.shape=(3377, 3), faces.shape=(6862, 3))>

In [17]:
import numpy as np
import open3d as o3d
import trimesh

def visualize_trimesh(mesh_tm: trimesh.Trimesh,
                      mesh_color = [0.7, 0.7, 0.7],
                      joints: np.ndarray = None,
                      joint_color = [1.0, 0.0, 0.0]):
    """
    Visualize a Trimesh in Open3D as a wireframe (LineSet), with optional joints.
    
    Args:
        mesh_tm: trimesh.Trimesh instance
        joints:  optional (J,3) numpy array of joint positions
    """
    # --- Convert to LineSet wireframe ---
    verts = np.asarray(mesh_tm.vertices)
    edges = mesh_tm.edges_unique  # (E,2) array of [u, v] index pairs

    lines = o3d.geometry.LineSet()
    lines.points = o3d.utility.Vector3dVector(verts)
    lines.lines  = o3d.utility.Vector2iVector(edges)
    # color each edge light gray
    colors = np.tile(mesh_color, (len(edges), 1))
    lines.colors = o3d.utility.Vector3dVector(colors)

    # --- Prepare geometries for rendering ---
    geometries = [lines]

    if joints is not None and len(joints) > 0:
        # render each joint as a small red sphere
        for j in joints:
            sph = o3d.geometry.TriangleMesh.create_sphere(radius=0.01)
            sph.translate(j)
            sph.paint_uniform_color(joint_color)
            sph.compute_vertex_normals()
            geometries.append(sph)

    # --- Draw all ---
    o3d.visualization.draw_geometries(geometries)


In [18]:
visualize_trimesh(mesh)

### Stats analysis after preprocessing

In [19]:
def compute_mesh_stats(mesh_idxs, obj_folder):
    """
    For each mesh index in mesh_idxs, loads the mesh via get_mesh_joint_pair,
    then computes:
      - number of vertices
      - number of faces (triangles)
    Returns a dict with means, mins & maxes.
    """
    vert_counts = []
    face_counts = []
    below1k = 0
    above5k = 0
    for idx in tqdm(mesh_idxs):
        mesh, centroid = load_and_preprocess_mesh(f'{obj_folder}/{idx}.obj')

        verts = np.asarray(mesh.vertices)
        tris = np.asarray(mesh.triangles)

        if len(verts) < 1000:
            print(idx)
            below1k += 1
        
        if len(verts) > 5000:
            above5k += 1

        vert_counts.append(len(verts))
        face_counts.append(len(tris))

    vert_counts = np.array(vert_counts)
    face_counts = np.array(face_counts)

    stats = {
        'vertices': {
            'mean': vert_counts.mean(),
            'min':   vert_counts.min(),
            'max':   vert_counts.max(),
            'std':   vert_counts.std(),
        },
        'faces': {
            'mean': face_counts.mean(),
            'min':   face_counts.min(),
            'max':   face_counts.max(),
            'std':   face_counts.std(),
        },
        'other': {
            'below1k': below1k,
            'above5k': above5k
        }
    }

    return stats

In [20]:
stats = compute_mesh_stats(mesh_idxs, obj_folder)
print("Vertex count: mean={mean:.1f}, min={min}, max={max}, std={std:.1f}".format(**stats['vertices']))
print("Face   count: mean={mean:.1f}, min={min}, max={max}, std={std:.1f}".format(**stats['faces']))
print("Other     : below_1k={below1k}, above_5k={above5k}".format(**stats['other']))

100%|██████████| 2703/2703 [00:20<00:00, 130.80it/s]

Vertex count: mean=2359.8, min=1000, max=4505, std=849.8
Face   count: mean=4573.4, min=1519, max=7992, std=1694.7
Other     : below_1k=0, above_5k=0



