### Constructing Per-Pixel Binary Attention Masks from Joint and Bone Information

The RigNet paper finds that pre-training the attention modules weights with a cross-entropy loss function with a per-vertex attention mask can improve performance. In this notebook, we explore how to construct these masks.

The main idea is that vertices orthogonal to bones at joint locations should have higher attention values as compared to other verticies. 

The paper doesn't explicitly say how it constructs these masks. I have two main ideas.

**Ideas**
- For each joint, pick one connected bone and find the orthogonal plane.
    - Idea 1: Pick all mesh vertices within some radius r that lie on this plane
    - Idea 2: Find the vertex p_min closest to joint j that lies on the plane (d = || j - p_min ||). Pick all mesh vertices 
    that lie at a distance of d + eps fom the joint, where eps is some "slack" threshold. 

The main issue is that this introduces yet another hyperparameter that may need to be tuned. 

**RigNet Implementation: Ray-casting**

Found their actual impl. 
- Cast K rays from the joint in the direction of the plane orthogonal to the bone centered at the joint (K=14)
- Perform triangle-ray intersection and gather vertices closest to intersection points.
    - If <6 vertices found, just to kNN centered at the joint with k=6
    - If triangle-ray intersection fails, do a "nearby triangle" search 
        - We need this because there ARE rare instances where we rays won't intersect with triangles. This ensures that there at least some training signal retained for each joint
- Find the 20th percentile distance of vertices. Multiply this distance by 2. This distance is threshold for retaining points.

Notes: 
- RigNet decimates meshes to 3k verts before doing this. Why? If computation time is available, what's the need for this? The max number of verts in the dataset is already only 5k. 

In [1]:
import os
import glob
import numpy as np
from collections import defaultdict
from tqdm import tqdm
import open3d as o3d
import trimesh

In [2]:
data_root = "../data/ModelResource_RigNetv1_preproccessed"
obj_folder = f'{data_root}/obj'
rig_folder = f'{data_root}/rig_info'

In [3]:
obj_files = glob.glob(os.path.join(obj_folder, "*.obj"))
len(obj_files)

2703

In [4]:
def get_obj_path_from_idx(idx: int):
    return os.path.join(obj_folder, f"{idx}.obj")

def get_rig_path_from_idx(idx: int):
    return os.path.join(rig_folder, f"{idx}.txt")

In [5]:
mesh_idx = 13
obj_path = get_obj_path_from_idx(mesh_idx)
rig_path = get_rig_path_from_idx(mesh_idx)
obj_path, rig_path

('../data/ModelResource_RigNetv1_preproccessed/obj/13.obj',
 '../data/ModelResource_RigNetv1_preproccessed/rig_info/13.txt')

### Visualization Code from `visualization.ipynb`

In [30]:
def get_mesh_joint_pair(obj_index, obj_folder, rig_folder):
    
    obj_path = f'{obj_folder}/{obj_index}.obj'
    rig_path = f'{rig_folder}/{obj_index}.txt'

    # Load and pre-process mesh
    mesh = o3d.io.read_triangle_mesh(obj_path, enable_post_processing=True)

    verts = np.asarray(mesh.vertices)
    centroid = verts.mean(axis=0)
    mesh.translate(-centroid)

    # extraxt join locations
    # disregard bone info. Only want joint locations
    joints = []
    with open(rig_path, "r") as f:
        tokens = f.readline().split()
        while(tokens[0] == "joints"):
            joints.append(list(map(float, tokens[2:])))
            tokens = f.readline().split()

    # Return joints as point cloud
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(np.asarray(joints))
    pcd.translate(-centroid)

    return mesh, pcd

In [31]:
def process_and_visualize(mesh: o3d.geometry.TriangleMesh, joints: o3d.geometry.PointCloud):

    # Get verices and centroid
    verts = np.asarray(mesh.vertices)
    centroid = verts.mean(axis=0)   

    # Mesh frame
    mesh_frame = o3d.geometry.LineSet.create_from_triangle_mesh(mesh)

    # Draw joints
    spheres = []
    for (x, y, z) in np.asarray(joints.points):
        sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.01)
        sphere.translate((x, y, z))
        sphere.paint_uniform_color([0, 1, 1])
        spheres.append(sphere)

    # Compute AABB and its longest side
    aabb = mesh.get_axis_aligned_bounding_box()
    min_bound = aabb.min_bound  # [x_min, y_min, z_min]
    max_bound = aabb.max_bound  # [x_max, y_max, z_max]
    lengths = max_bound - min_bound  # [Lx, Ly, Lz]
    longest = lengths.max()

    # Create a box for visualization
    box = o3d.geometry.LineSet.create_from_axis_aligned_bounding_box(aabb)
    box.translate(-centroid) 
    box.colors = o3d.utility.Vector3dVector([[1, 0, 0] for _ in box.lines]) 

    # Axes
    axes = o3d.geometry.TriangleMesh.create_coordinate_frame(
        size=longest * 0.5,
        origin=[0, 0, 0]
    )

    o3d.visualization.draw_geometries([mesh_frame, *spheres, box, axes],
                                      window_name="Mesh + AABB",
                                      width=800, height=600)

### Visualization

In [33]:
mesh, pcd = get_mesh_joint_pair(mesh_idx, obj_folder, rig_folder)
process_and_visualize(mesh, pcd)

2025-04-24 06:16:15.299 python3[42968:1348644] +[IMKClient subclass]: chose IMKClient_Modern
2025-04-24 06:16:15.299 python3[42968:1348644] +[IMKInputSession subclass]: chose IMKInputSession_Modern


### Ray Casting

In [37]:
mesh = trimesh.load_mesh(obj_path)
mesh

<trimesh.Trimesh(vertices.shape=(4307, 3), faces.shape=(8627, 3))>

In [38]:
# center mesh
mesh.apply_translation(-mesh.centroid)
mesh.centroid

array([1.30718479e-17, 1.82078577e-15, 4.37174042e-17])

In [39]:
def parse_rig(rig_path: str, centroid):
    # extraxt join locations
    # disregard bone info. Only want joint locations

    # file structure:
    # joint name x y z
    # root name
    # skin name1 weight1 name2 weight2 ...
    # hier name1 name2

    # Will assume this rigid order

    jointname2idx = {}
    joints = []
    bones = []
    root_idx = ""
    with open(rig_path, "r") as f:
        tokens = f.readline().split()
        while(tokens[0] == "joints"):
            name = tokens[1]
            loc = list(map(float, tokens[2:]))

            # name to index mapping
            next_idx = len(joints)
            jointname2idx[name] = next_idx

            joints.append(loc)

            tokens = f.readline().split()
        
        # root
        root_idx = jointname2idx[tokens[1]]

        # Skip skin info
        while tokens[0] != "hier":
            tokens = f.readline().split()
        
        # Hier info
        while tokens:
            b1 = jointname2idx[tokens[1]]
            b2 = jointname2idx[tokens[2]]
            bones.append([b1, b2])
            tokens = f.readline().split()

    joints = np.array(joints) - centroid
    
    return joints, bones, root_idx

In [40]:
joints, bones, root_idx = parse_rig(rig_path, mesh.centroid)
len(joints), bones, root_idx

(26,
 [[0, 1],
  [0, 2],
  [0, 3],
  [1, 4],
  [2, 5],
  [3, 6],
  [4, 7],
  [5, 8],
  [6, 9],
  [6, 10],
  [6, 11],
  [7, 12],
  [8, 13],
  [9, 14],
  [10, 15],
  [11, 16],
  [14, 17],
  [15, 18],
  [16, 19],
  [18, 20],
  [19, 21],
  [20, 22],
  [20, 23],
  [21, 24],
  [21, 25]],
 0)

### Form Rays

In [41]:
# Method for finding orthonormal plane to v

def pick_arbitrary(v):
    # pick axis least aligned with v
    abs_v = np.abs(v)

    # If the x-component is the smallest, pick the x-axis
    if abs_v[0] <= abs_v[1] and abs_v[0] <= abs_v[2]:
        return np.array([1.0, 0.0, 0.0])
    # x-component is not the smallest. Check if y is smaller than z. 
    elif abs_v[1] <= abs_v[2]:
        return np.array([0.0, 1.0, 0.0])
    # z-component is the smallest.
    else:
        return np.array([0.0, 0.0, 1.0])

def get_orthonormal_plane(v):
    # assume v is already unit length
    a = pick_arbitrary(v)
    # project a on-to v, then remove component along v
    u0 = a - np.dot(a, v) * v
    # normalize
    u0 /= np.linalg.norm(u0)
    # second orthonormal vector
    w = np.cross(v, u0)
    w /= np.linalg.norm(w)
    return u0, w

In [42]:
p, c = 0, 1

p_pos = np.array(joints[p])
c_pos = np.array(joints[c])
p_pos, c_pos

(array([-1.30718479e-17,  4.61480000e-01,  1.29070000e-02]),
 array([0.0453662, 0.4405827, 0.0022953]))

In [43]:
# bone direction
v = c_pos - p_pos
v = v / (np.linalg.norm(v) + 1e-10)
v

array([ 0.8884414 , -0.40924799, -0.20781713])

In [44]:
# orthonormal plane basis
u0, w = get_orthonormal_plane(v)
u0, w

(array([ 0.18875428, -0.086947  ,  0.9781677 ]),
 array([-0.41838224, -0.90827105,  0.        ]))

In [45]:
K = 3

# sample K angles around the circle
# np.linspace lets you specify the number of points you want in an interval
# np.arange lets you specify the step size when you don't know your sample size
thetas = np.linspace(0, 2*np.pi, K, endpoint=False)

# [K, 3]
circle_points = [np.cos(t) * u0 + np.sin(t) * w for t in thetas]
dirs_k = np.stack(circle_points, axis=0)

# Normalize each ray
dirs_k /= (np.linalg.norm(dirs_k, axis=1, keepdims=True) + 1e-10)

dirs_k.shape, dirs_k

((3, 3),
 array([[ 0.18875428, -0.086947  ,  0.9781677 ],
        [-0.45670679, -0.7431123 , -0.48908385],
        [ 0.26795251,  0.8300593 , -0.48908385]]))

In [46]:
# We are shooting rays from both ends of the bone at once. 
# create 2 origins (parent & child), each repeated K times
bone_origins = np.vstack([p_pos, c_pos]) # shape [2,3]
bone_origins.shape, bone_origins

((2, 3),
 array([[-1.30718479e-17,  4.61480000e-01,  1.29070000e-02],
        [ 4.53662000e-02,  4.40582700e-01,  2.29530000e-03]]))

In [47]:
# np.repeat repeats entries along an existing dimension
# For each item in bone_origins, it will repeat it K times before moving on to the next one
origin_2K = np.repeat(bone_origins, K, axis=0) # [2*K,3]
origin_2K.shape, origin_2K

((6, 3),
 array([[-1.30718479e-17,  4.61480000e-01,  1.29070000e-02],
        [-1.30718479e-17,  4.61480000e-01,  1.29070000e-02],
        [-1.30718479e-17,  4.61480000e-01,  1.29070000e-02],
        [ 4.53662000e-02,  4.40582700e-01,  2.29530000e-03],
        [ 4.53662000e-02,  4.40582700e-01,  2.29530000e-03],
        [ 4.53662000e-02,  4.40582700e-01,  2.29530000e-03]]))

In [48]:
# tile dirs twice to match origins
# in this case tile repeats the "block" represented by dirs_k twice
# reps takes the repetitions for each dimension.
# reps=(2, 1): repeat dim_0 twice, repeat dim_1 once
# reps=(1, 2): repeat dim_0 once, repeat dim_1 twice
# reps=2: repeat dim_1 twice 

dirs_2K = np.tile(dirs_k, (2, 1)) # [2*K,3]
dirs_2K.shape, dirs_2K

((6, 3),
 array([[ 0.18875428, -0.086947  ,  0.9781677 ],
        [-0.45670679, -0.7431123 , -0.48908385],
        [ 0.26795251,  0.8300593 , -0.48908385],
        [ 0.18875428, -0.086947  ,  0.9781677 ],
        [-0.45670679, -0.7431123 , -0.48908385],
        [ 0.26795251,  0.8300593 , -0.48908385]]))

In [49]:
# joint indices 
joint_idxs = np.vstack([p, c])
joints_2K = np.repeat(joint_idxs, K, axis=0)
joints_2K

array([[0],
       [0],
       [0],
       [1],
       [1],
       [1]])

In [50]:
def form_rays(joints, bones, K=14):
    """
    joints: list of (x,y,z) arrays, shape [J,3]
    bones: list of (parent_idx, child_idx) pairs
    K: number of rays per bone-end
    
    Returns:
      origins: np.ndarray [2*K*len(bones), 3]
      dirs: np.ndarray [2*K*len(bones), 3]
    """
    origins_list = []
    dirs_list = []
    joint_idx_list = []
    
    for p_idx, c_idx in bones:
        p_pos = np.array(joints[p_idx])
        c_pos = np.array(joints[c_idx])
        
        # bone direction
        v = c_pos - p_pos
        v = v / (np.linalg.norm(v) + 1e-10)
        
        # orthonormal plane basis
        u0, w = get_orthonormal_plane(v)
        
        # sample K angles around the circle
        thetas = np.linspace(0, 2*np.pi, K, endpoint=False)

        # [K, 3]
        dirs_k = np.stack([np.cos(t) * u0 + np.sin(t) * w for t in thetas], axis=0)
        dirs_k /= (np.linalg.norm(dirs_k, axis=1, keepdims=True) + 1e-10)
        
        # create 2 origins (parent & child), each repeated K times
        bone_origins = np.vstack([p_pos, c_pos]) # [2,3]
        origin_2K = np.repeat(bone_origins, K, axis=0) # [2*K,3]

        # joint indices 
        joint_idxs = np.vstack([p_idx, c_idx])
        joints_2K = np.repeat(joint_idxs, K, axis=0)
        
        # tile dirs twice to match origins
        dirs_2K = np.tile(dirs_k, (2, 1)) # [2*K,3]
        
        origins_list.append(origin_2K)
        dirs_list.append(dirs_2K)
        joint_idx_list.append(joints_2K)
    
    # concatenate all bones so we can shoot all rays from every joint together
    origins = np.concatenate(origins_list, axis=0)
    dirs = np.concatenate(dirs_list, axis=0)
    joint_idxs = np.concatenate(joint_idx_list, axis=0)
    
    return origins, dirs, joint_idxs

In [51]:
origins, dirs, joint_idxs = form_rays(joints, bones, K=14)

# origins and dirs should be (2 x K x #Bones, 3)
len(joints), len(bones), origins.shape, dirs.shape, joint_idxs.shape

(26, 25, (700, 3), (700, 3), (700, 1))

### Shoot Rays

In [52]:
# current variables
mesh, joints, origins, dirs, joint_idxs;

In [None]:
intersector = trimesh.ray.ray_triangle.RayMeshIntersector(mesh)
locs, ray_ids, tri_ids = intersector.intersects_location(origins, dirs)

# locs: coords of every intersection [M, 3]
# ray_ids: which ray produced each hit [M]
# tri_ids: which triangle was hit [M]

# M=32 with K=2, and M=171 with K=14

locs.shape, ray_ids.shape, tri_ids.shape

((171, 3), (171,), (171,))

In [63]:
hits_by_ray = defaultdict(list)
tris_by_ray = defaultdict(list)

# iterate over every intersection
for pt, r, t in zip(locs, ray_ids, tri_ids):
    hits_by_ray[r].append(pt)
    tris_by_ray[r].append(t)

# We want to pick the closest of each ray's hits 

len(hits_by_ray.keys()), list(map(len, hits_by_ray.values()))

(98,
 [2,
  2,
  3,
  1,
  1,
  1,
  5,
  3,
  3,
  1,
  3,
  2,
  2,
  2,
  1,
  1,
  1,
  6,
  1,
  1,
  1,
  1,
  2,
  2,
  1,
  3,
  3,
  1,
  3,
  3,
  5,
  1,
  4,
  2,
  5,
  2,
  3,
  1,
  1,
  3,
  3,
  2,
  1,
  1,
  1,
  1,
  3,
  3,
  1,
  1,
  1,
  3,
  1,
  1,
  1,
  1,
  1,
  3,
  1,
  1,
  3,
  1,
  1,
  1,
  1,
  1,
  1,
  3,
  1,
  2,
  2,
  2,
  1,
  1,
  1,
  1,
  1,
  1,
  3,
  1,
  2,
  2,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  3,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1])

For each of 171 rays, we have
- hits_by_ray[r]: list of points where ray r intersected the mesh
- tris: corresponding triangle idxs
- origins: origin joints of ray r

For each ray: 
- compute euc. dist. from each hit point to origin[r] and pick the hit with the min dist.
- If there are no hits (empty list, since hits_by_ray is a defaultdict), use "nearby_faces" fallback