### Constructing Per-Pixel Binary Attention Masks from Joint and Bone Information

The RigNet paper finds that pre-training the attention modules weights with a cross-entropy loss function with a per-vertex attention mask can improve performance. In this notebook, we explore how to construct these masks.

The main idea is that vertices orthogonal to bones at joint locations should have higher attention values as compared to other verticies. 

The paper doesn't explicitly say how it constructs these masks. I have two main ideas.

**Ideas**
- For each joint, pick one connected bone and find the orthogonal plane.
    - Idea 1: Pick all mesh vertices within some radius r that lie on this plane
    - Idea 2: Find the vertex p_min closest to joint j that lies on the plane (d = || j - p_min ||). Pick all mesh vertices 
    that lie at a distance of d + eps fom the joint, where eps is some "slack" threshold. 

The main issue is that this introduces yet another hyperparameter that may need to be tuned. 

**RigNet Implementation: Ray-casting**

Found their actual impl. 
- Cast K rays from the joint in the direction of the plane orthogonal to the bone centered at the joint (K=14)
- Perform triangle-ray intersection and gather vertices closest to intersection points.
    - If <6 vertices found, just to kNN centered at the joint with k=6
    - If triangle-ray intersection fails, do a "nearby triangle" search 
        - We need this because there ARE rare instances where we rays won't intersect with triangles. This ensures that there at least some training signal retained for each joint
- Find the 20th percentile distance of vertices. Multiply this distance by 2. This distance is threshold for retaining points.

Notes: 
- RigNet decimates meshes to 3k verts before doing this. Why? If computation time is available, what's the need for this? The max number of verts in the dataset is already only 5k. 

In [1]:
import os
import glob
import numpy as np
from collections import defaultdict
from tqdm import tqdm
import open3d as o3d
import trimesh

In [2]:
data_root = "../data/ModelResource_RigNetv1_preproccessed"
obj_folder = f'{data_root}/obj'
rig_folder = f'{data_root}/rig_info'

In [3]:
obj_files = glob.glob(os.path.join(obj_folder, "*.obj"))
len(obj_files)

2703

In [4]:
def get_obj_path_from_idx(idx: int):
    return os.path.join(obj_folder, f"{idx}.obj")

def get_rig_path_from_idx(idx: int):
    return os.path.join(rig_folder, f"{idx}.txt")

In [5]:
mesh_idx = 13
obj_path = get_obj_path_from_idx(mesh_idx)
rig_path = get_rig_path_from_idx(mesh_idx)
obj_path, rig_path

('../data/ModelResource_RigNetv1_preproccessed/obj/13.obj',
 '../data/ModelResource_RigNetv1_preproccessed/rig_info/13.txt')

### Visualization Code from `visualization.ipynb`

In [6]:
def get_mesh_joint_pair(obj_index, obj_folder, rig_folder):
    
    obj_path = f'{obj_folder}/{obj_index}.obj'
    rig_path = f'{rig_folder}/{obj_index}.txt'

    # Load and pre-process mesh
    mesh = o3d.io.read_triangle_mesh(obj_path, enable_post_processing=True)

    verts = np.asarray(mesh.vertices)
    centroid = verts.mean(axis=0)
    mesh.translate(-centroid)

    # extraxt join locations
    # disregard bone info. Only want joint locations
    joints = []
    with open(rig_path, "r") as f:
        tokens = f.readline().split()
        while(tokens[0] == "joints"):
            joints.append(list(map(float, tokens[2:])))
            tokens = f.readline().split()

    # Return joints as point cloud
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(np.asarray(joints))
    pcd.translate(-centroid)

    return mesh, pcd

In [7]:
def process_and_visualize(mesh: o3d.geometry.TriangleMesh, 
                          joints: o3d.geometry.PointCloud,
                          attn_mask: np.ndarray = None,
                          coord_frame: bool = True):

    # Get verices and centroid
    verts = np.asarray(mesh.vertices)
    centroid = verts.mean(axis=0)   

    geometries = []

    # Mesh frame
    mesh_frame = o3d.geometry.LineSet.create_from_triangle_mesh(mesh)
    mesh_frame.paint_uniform_color([0.7, 0.7, 0.7])
    geometries.append(mesh_frame)

    # Draw joints
    spheres = []
    for (x, y, z) in np.asarray(joints.points):
        sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.0075)
        sphere.translate((x, y, z))
        sphere.paint_uniform_color([0, 1, 1])
        spheres.append(sphere)

    if attn_mask is not None:
        for x, y, z in attn_mask:
            sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.003)
            sphere.translate((x, y, z))
            sphere.paint_uniform_color([1, 0, 0])
            spheres.append(sphere)

    geometries.extend(spheres)

    # Compute AABB and its longest side
    aabb = mesh.get_axis_aligned_bounding_box()
    min_bound = aabb.min_bound  # [x_min, y_min, z_min]
    max_bound = aabb.max_bound  # [x_max, y_max, z_max]
    lengths = max_bound - min_bound  # [Lx, Ly, Lz]
    longest = lengths.max()

    # Create a box for visualization
    box = o3d.geometry.LineSet.create_from_axis_aligned_bounding_box(aabb)
    box.translate(-centroid) 
    box.colors = o3d.utility.Vector3dVector([[1, 0, 0] for _ in box.lines]) 
    geometries.append(box)

    # Axes
    if coord_frame:
        axes = o3d.geometry.TriangleMesh.create_coordinate_frame(
            size=longest * 0.5,
            origin=[0, 0, 0]
        )
        geometries.append(axes)

    o3d.visualization.draw_geometries(geometries,
                                      window_name="Mesh + AABB",
                                      width=800, height=600)

### Visualization

In [35]:
o3d_mesh, o3d_joints = get_mesh_joint_pair(mesh_idx, obj_folder, rig_folder)

In [None]:
process_and_visualize(o3d_mesh, o3d_joints)

### Ray Casting

In [8]:
mesh = trimesh.load_mesh(obj_path)
mesh

<trimesh.Trimesh(vertices.shape=(4307, 3), faces.shape=(8627, 3))>

In [9]:
# center mesh
centroid = mesh.centroid
mesh.apply_translation(-centroid)
mesh.centroid, centroid

(array([1.30718479e-17, 1.82078577e-15, 4.37174042e-17]),
 array([0.00662568, 0.55715229, 0.00377134]))

In [10]:
def parse_rig(rig_path: str, centroid):
    # extraxt join locations
    # disregard bone info. Only want joint locations

    # file structure:
    # joint name x y z
    # root name
    # skin name1 weight1 name2 weight2 ...
    # hier name1 name2

    # Will assume this rigid order

    jointname2idx = {}
    joints = []
    bones = []
    root_idx = ""
    with open(rig_path, "r") as f:
        tokens = f.readline().split()
        while(tokens[0] == "joints"):
            name = tokens[1]
            loc = list(map(float, tokens[2:]))

            # name to index mapping
            next_idx = len(joints)
            jointname2idx[name] = next_idx
            joints.append(loc)
            tokens = f.readline().split()
        
        # root
        root_idx = jointname2idx[tokens[1]]

        # Skip skin info
        while tokens[0] != "hier":
            tokens = f.readline().split()
        
        # Hier info
        while tokens:
            b1 = jointname2idx[tokens[1]]
            b2 = jointname2idx[tokens[2]]
            bones.append([b1, b2])
            tokens = f.readline().split()

    joints = np.array(joints) - centroid
    
    return joints, bones, root_idx

In [11]:
joints, bones, root_idx = parse_rig(rig_path, centroid)
len(joints), bones, root_idx

(26,
 [[0, 1],
  [0, 2],
  [0, 3],
  [1, 4],
  [2, 5],
  [3, 6],
  [4, 7],
  [5, 8],
  [6, 9],
  [6, 10],
  [6, 11],
  [7, 12],
  [8, 13],
  [9, 14],
  [10, 15],
  [11, 16],
  [14, 17],
  [15, 18],
  [16, 19],
  [18, 20],
  [19, 21],
  [20, 22],
  [20, 23],
  [21, 24],
  [21, 25]],
 0)

In [12]:
mesh, joints.shape

(<trimesh.Trimesh(vertices.shape=(4307, 3), faces.shape=(8627, 3))>, (26, 3))

In [39]:
# Quick sanity check
# Plot the mesh and joints created with trimesh and parserig

_mesh_o3d = o3d.geometry.TriangleMesh()
_mesh_o3d.vertices = o3d.utility.Vector3dVector(np.asarray(mesh.vertices))
_mesh_o3d.triangles = o3d.utility.Vector3iVector(np.asarray(mesh.faces))


_pcd_o3d = o3d.geometry.PointCloud()
_pcd_o3d.points = o3d.utility.Vector3dVector(np.asarray(joints))



In [None]:
# Looks right
process_and_visualize(_mesh_o3d, _pcd_o3d)

### Form Rays

In [13]:
# Method for finding orthonormal plane to v

def pick_arbitrary(v):
    # pick axis least aligned with v
    abs_v = np.abs(v)

    # If the x-component is the smallest, pick the x-axis
    if abs_v[0] <= abs_v[1] and abs_v[0] <= abs_v[2]:
        return np.array([1.0, 0.0, 0.0])
    # x-component is not the smallest. Check if y is smaller than z. 
    elif abs_v[1] <= abs_v[2]:
        return np.array([0.0, 1.0, 0.0])
    # z-component is the smallest.
    else:
        return np.array([0.0, 0.0, 1.0])

def get_orthonormal_plane(v):
    # assume v is already unit length
    a = pick_arbitrary(v)
    # project a on-to v, then remove component along v
    u0 = a - np.dot(a, v) * v
    # normalize
    u0 /= np.linalg.norm(u0)
    # second orthonormal vector
    w = np.cross(v, u0)
    w /= np.linalg.norm(w)
    return u0, w

In [14]:
p, c = 0, 1

p_pos = np.array(joints[p])
c_pos = np.array(joints[c])
p_pos, c_pos

(array([-0.00662568, -0.09567229,  0.00913566]),
 array([ 0.03874052, -0.11656959, -0.00147604]))

In [15]:
# bone direction
v = c_pos - p_pos
v = v / (np.linalg.norm(v) + 1e-10)
v

array([ 0.8884414 , -0.40924799, -0.20781713])

In [16]:
# orthonormal plane basis
u0, w = get_orthonormal_plane(v)
u0, w

(array([ 0.18875428, -0.086947  ,  0.9781677 ]),
 array([-0.41838224, -0.90827105,  0.        ]))

In [17]:
K = 3

# sample K angles around the circle
# np.linspace lets you specify the number of points you want in an interval
# np.arange lets you specify the step size when you don't know your sample size
thetas = np.linspace(0, 2*np.pi, K, endpoint=False)

# [K, 3]
circle_points = [np.cos(t) * u0 + np.sin(t) * w for t in thetas]
dirs_k = np.stack(circle_points, axis=0)

# Normalize each ray
dirs_k /= (np.linalg.norm(dirs_k, axis=1, keepdims=True) + 1e-10)

dirs_k.shape, dirs_k

((3, 3),
 array([[ 0.18875428, -0.086947  ,  0.9781677 ],
        [-0.45670679, -0.7431123 , -0.48908385],
        [ 0.26795251,  0.8300593 , -0.48908385]]))

In [18]:
# We are shooting rays from both ends of the bone at once. 
# create 2 origins (parent & child), each repeated K times
bone_origins = np.vstack([p_pos, c_pos]) # shape [2,3]
bone_origins.shape, bone_origins

((2, 3),
 array([[-0.00662568, -0.09567229,  0.00913566],
        [ 0.03874052, -0.11656959, -0.00147604]]))

In [19]:
# np.repeat repeats entries along an existing dimension
# For each item in bone_origins, it will repeat it K times before moving on to the next one
origin_2K = np.repeat(bone_origins, K, axis=0) # [2*K,3]
origin_2K.shape, origin_2K

((6, 3),
 array([[-0.00662568, -0.09567229,  0.00913566],
        [-0.00662568, -0.09567229,  0.00913566],
        [-0.00662568, -0.09567229,  0.00913566],
        [ 0.03874052, -0.11656959, -0.00147604],
        [ 0.03874052, -0.11656959, -0.00147604],
        [ 0.03874052, -0.11656959, -0.00147604]]))

In [20]:
# tile dirs twice to match origins
# in this case tile repeats the "block" represented by dirs_k twice
# reps takes the repetitions for each dimension.
# reps=(2, 1): repeat dim_0 twice, repeat dim_1 once
# reps=(1, 2): repeat dim_0 once, repeat dim_1 twice
# reps=2: repeat dim_1 twice 

dirs_2K = np.tile(dirs_k, (2, 1)) # [2*K,3]
dirs_2K.shape, dirs_2K

((6, 3),
 array([[ 0.18875428, -0.086947  ,  0.9781677 ],
        [-0.45670679, -0.7431123 , -0.48908385],
        [ 0.26795251,  0.8300593 , -0.48908385],
        [ 0.18875428, -0.086947  ,  0.9781677 ],
        [-0.45670679, -0.7431123 , -0.48908385],
        [ 0.26795251,  0.8300593 , -0.48908385]]))

In [21]:
# joint indices 
joint_idxs = np.vstack([p, c])
joints_2K = np.repeat(joint_idxs, K, axis=0)
joints_2K

array([[0],
       [0],
       [0],
       [1],
       [1],
       [1]])

In [22]:
def form_rays(joints, bones, K=14):
    """
    joints: list of (x,y,z) arrays, shape [J,3]
    bones: list of (parent_idx, child_idx) pairs
    K: number of rays per bone-end
    
    Returns:
      origins: np.ndarray [2*K*len(bones), 3]
      dirs: np.ndarray [2*K*len(bones), 3]
    """
    origins_list = []
    dirs_list = []
    joint_idx_list = []
    
    for p_idx, c_idx in bones:
        p_pos = np.array(joints[p_idx])
        c_pos = np.array(joints[c_idx])
        
        # bone direction
        v = c_pos - p_pos
        v = v / (np.linalg.norm(v) + 1e-10)
        
        # orthonormal plane basis
        u0, w = get_orthonormal_plane(v)
        
        # sample K angles around the circle
        thetas = np.linspace(0, 2*np.pi, K, endpoint=False)

        # [K, 3]
        dirs_k = np.stack([np.cos(t) * u0 + np.sin(t) * w for t in thetas], axis=0)
        dirs_k /= (np.linalg.norm(dirs_k, axis=1, keepdims=True) + 1e-10)
        
        # create 2 origins (parent & child), each repeated K times
        bone_origins = np.vstack([p_pos, c_pos]) # [2,3]
        origin_2K = np.repeat(bone_origins, K, axis=0) # [2*K,3]

        # joint indices 
        joint_idxs = np.vstack([p_idx, c_idx])
        joints_2K = np.repeat(joint_idxs, K, axis=0)
        
        # tile dirs twice to match origins
        dirs_2K = np.tile(dirs_k, (2, 1)) # [2*K,3]
        
        origins_list.append(origin_2K)
        dirs_list.append(dirs_2K)
        joint_idx_list.append(joints_2K)
    
    # concatenate all bones so we can shoot all rays from every joint together
    origins = np.concatenate(origins_list, axis=0)
    dirs = np.concatenate(dirs_list, axis=0)
    joint_idxs = np.concatenate(joint_idx_list, axis=0)
    
    return origins, dirs, joint_idxs

In [23]:
origins, dirs, ray_joint_idxs = form_rays(joints, bones, K=14)

# origins and dirs should be (2 x K x #Bones, 3)
len(joints), len(bones), origins.shape, dirs.shape, ray_joint_idxs.shape

(26, 25, (700, 3), (700, 3), (700, 1))

### Shoot Rays

In [24]:
# current variables
mesh, joints, origins, dirs, ray_joint_idxs;

In [25]:
intersector = trimesh.ray.ray_triangle.RayMeshIntersector(mesh)
locs, ray_ids, tri_ids = intersector.intersects_location(origins, dirs)

# locs: coords of every intersection [M, 3]
# ray_ids: which ray produced each hit [M]
# tri_ids: which triangle was hit [M]

# M=32 with K=2, and M=171 with K=14

locs.shape, ray_ids.shape, tri_ids.shape

((1138, 3), (1138,), (1138,))

In [26]:
hits_by_ray = defaultdict(list)
tris_by_ray = defaultdict(list)

# iterate over every intersection
for pt, r, t in zip(locs, ray_ids, tri_ids):
    hits_by_ray[r].append(pt)
    tris_by_ray[r].append(t)

# We want to pick the closest of each ray's hits 

hits_by_ray.keys(), list(map(len, hits_by_ray.values()))

(dict_keys([np.int64(108), np.int64(398), np.int64(655), np.int64(178), np.int64(580), np.int64(349), np.int64(385), np.int64(455), np.int64(297), np.int64(646), np.int64(23), np.int64(486), np.int64(584), np.int64(317), np.int64(329), np.int64(684), np.int64(347), np.int64(675), np.int64(524), np.int64(418), np.int64(358), np.int64(374), np.int64(248), np.int64(611), np.int64(687), np.int64(423), np.int64(49), np.int64(459), np.int64(367), np.int64(241), np.int64(554), np.int64(331), np.int64(488), np.int64(637), np.int64(525), np.int64(366), np.int64(240), np.int64(285), np.int64(432), np.int64(473), np.int64(579), np.int64(576), np.int64(639), np.int64(375), np.int64(249), np.int64(591), np.int64(508), np.int64(401), np.int64(107), np.int64(177), np.int64(295), np.int64(650), np.int64(575), np.int64(599), np.int64(93), np.int64(533), np.int64(480), np.int64(593), np.int64(421), np.int64(328), np.int64(589), np.int64(204), np.int64(105), np.int64(175), np.int64(302), np.int64(672), n

For each of 171 rays, we have
- hits_by_ray[r]: list of points where ray r intersected the mesh
- tris: corresponding triangle idxs
- origins: origin joints of ray r

For each ray: 
- compute euc. dist. from each hit point to origin[r] and pick the hit with the min dist.
- If there are no hits (empty list, since hits_by_ray is a defaultdict), use "nearby_faces" fallback

In [27]:
origins.shape

(700, 3)

In [28]:
r_idx = 0

hits = hits_by_ray.get(r, [])
tris = tris_by_ray.get(r, [])

# Stack list of lists into array of shape N x 3
pts = np.stack(hits, axis=0)
pts.shape

(6, 3)

In [29]:
selected_hits = []
# list of tuple[ray_idx, point, triangle, distance from origin]

for r in range(origins.shape[0]):
    hits = hits_by_ray.get(r, [])
    tris = tris_by_ray.get(r, [])

    if hits:
        pts = np.stack(hits, axis=0) # [m, 3]
        dists = np.linalg.norm(pts - origins[r], axis=1) # origins[r] broadcasted
        k = np.argmin(dists) # get closest point
       
        selected_hits.append((r, pts[k], tris[k], dists[k]))
    else:

        # Fallback: get faces near the origin
        # Index origins with newaxis indexing 
        close_tris = trimesh.proximity.nearby_faces(mesh, origins[r][None, :])

        # get vertices from triangle indices
        vs = mesh.faces[close_tris].flatten()

        # Use np.asarray to index mesh.vertices with an array (vs)
        fallback_pts = np.asarray(mesh.vertices)[vs]
        
        # record all fallback_pts 
        for pt in fallback_pts:
            selected_hits.append((r, pt, None, None))

len(selected_hits)

1484

In [30]:
# Need to collect hits by joint

hits_by_joint = defaultdict(list)

for r, pt, *_ in selected_hits:
    # which joint generated ray r
    j = ray_joint_idxs[r][0] 
    hits_by_joint[j].append(pt)

# inspect how many hits each joint got
for j, pts in hits_by_joint.items():
    print(f"Joint {j} has {len(pts)} hit points")

Joint 0 has 42 hit points
Joint 1 has 28 hit points
Joint 2 has 28 hit points
Joint 3 has 28 hit points
Joint 4 has 28 hit points
Joint 5 has 28 hit points
Joint 6 has 56 hit points
Joint 7 has 28 hit points
Joint 8 has 28 hit points
Joint 9 has 28 hit points
Joint 10 has 28 hit points
Joint 11 has 28 hit points
Joint 12 has 14 hit points
Joint 13 has 14 hit points
Joint 14 has 28 hit points
Joint 15 has 28 hit points
Joint 16 has 28 hit points
Joint 17 has 14 hit points
Joint 18 has 28 hit points
Joint 19 has 28 hit points
Joint 20 has 42 hit points
Joint 21 has 42 hit points
Joint 22 has 382 hit points
Joint 23 has 14 hit points
Joint 24 has 430 hit points
Joint 25 has 14 hit points


In [31]:
len(selected_hits)

1484

#### Last step
- Compute euc. dists for each joint and its hit point 
- Compute 20th percentile per joint
- keep only points with distances < 20th percentile
- Remaining points are the attention masks

In [36]:
filtered_pts   = []
filtered_jidxs = []

for j, pts in hits_by_joint.items():
    if len(pts) == 0:
        continue
    
    # Stack to (M_j, 3)
    pts_arr = np.stack(pts, axis=0)
    
    # Joint position:
    joint_pos = np.array(joints[j])[None, :] # shape (1,3)
    
    # distances of each hit pt to joint j
    dists = np.linalg.norm(pts_arr - joint_pos, axis=1)  # shape (M_j,)
    
    # 20th percentile
    p20 = np.percentile(dists, 20)
    
    # threshold = 2 * p20
    keep_mask = (dists < 2 * p20)
    
    # collect filtered points
    kept_pts = pts_arr[keep_mask]
    filtered_pts.append(kept_pts)
    
    # record joint index for each kept point
    filtered_jidxs.append(np.full(len(kept_pts), j, dtype=int))

# Flatten lists into arrays
if filtered_pts:
    hit_pts = np.concatenate(filtered_pts, axis=0) # (P,3)
    hit_joints = np.concatenate(filtered_jidxs, axis=0) # (P,)
else:
    # In case filtered pts is empty
    hit_pts = np.zeros((0,3), dtype=float)
    hit_joints = np.zeros((0,), dtype=int)

print(f"After filtering: {hit_pts.shape[0]} total hit-points across {len(hits_by_joint)} joints")


After filtering: 1173 total hit-points across 26 joints


In [37]:
hit_pts

array([[-0.00275148, -0.09745689,  0.02921266],
       [-0.00686151, -0.1053877 ,  0.02725971],
       [-0.01188748, -0.11487658,  0.02445932],
       ...,
       [-0.41661804,  0.03090796,  0.00891139],
       [-0.41657761,  0.03077854,  0.00997678],
       [-0.41651017,  0.03055876,  0.01152709]], shape=(1173, 3))

In [44]:
# Looks right
process_and_visualize(_mesh_o3d, _pcd_o3d, hit_pts, coord_frame=False)

### Build Attention Mask
- Create KD tree on vertices
- for each hitpoint, find all verts within radius (0.02)
- mark those verts as true

In [41]:
from scipy.spatial import cKDTree

def build_attention_mask(vtx_ori, hit_pts, radius=0.02):
    """
    vtx_ori : (V,3) array of original mesh vertices
    hit_pts : (P,3) array of filtered surface hits
    radius   : float radius threshold in mesh units
    
    Returns:
      attn_mask : (V,) boolean array
    """
    V = vtx_ori.shape[0]
    attn_mask = np.zeros(V, dtype=bool)
    
    # build KD-tree on vertices for faster indexing
    tree = cKDTree(vtx_ori)
    
    # for each hit point, find all vertices within radius
    # this returns a list of lists; we can flatten it
    neighbors = tree.query_ball_point(hit_pts, r=radius)
    neighbors = np.unique(np.concatenate(neighbors))
    attn_mask[neighbors] = True
    
    return attn_mask, vtx_ori[neighbors]


In [42]:
attn, marked_verts = build_attention_mask(np.asarray(mesh.vertices), 
                            hit_pts, 
                            radius=0.02)
print(f"{attn.sum()} / {len(attn)} vertices marked as attention.")

2436 / 4307 vertices marked as attention.


In [45]:
process_and_visualize(_mesh_o3d, _pcd_o3d, marked_verts)

In [54]:
for b in attn:
    print(b)

False
False
False
False
True
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
True
True
False
False
True
False
False
False
False
True
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True

In [47]:

# Now you can save:
np.savetxt(f"../data/ModelResource_RigNetv1_preproccessed/attn_masks/{mesh_idx}.txt", attn.astype(int), fmt='%d')