In [1]:
import glob
import numpy as np
import os
import smplx
import trimesh
from PIL import Image

from tools.utils import parse_npz, prepare_params, params2torch, append2dict, np2torch, to_cpu
from tools.utils import makepath, euler, to_tensor
from tools.meshviewer import Mesh, MeshViewer, colors
from tools.objectmodel import ObjectModel

data_path = "dataset/grab/"
model_path = "smplx_models/"
all_seqs = glob.glob(data_path + "/*/*.npz")
seq = parse_npz(all_seqs[0])
print(seq)

bj_name = seq.obj_name
sbj_id   = seq.sbj_id
n_comps  = seq.n_comps # default: 24 from GRAB
gender   = seq.gender
seq_len = len(seq['contact']['object'])
fps = 120

mv = MeshViewer(width=1600, height=1200, offscreen=True)
camera_pose = np.eye(4)
camera_pose[:3, :3] = euler([80, -15, 0], 'xzx')
camera_pose[:3, 3] = np.array([-.5, -1.4, 1.5])
mv.update_camera_pose(camera_pose)

{'gender': 'female', 'sbj_id': 's5', 'framerate': 120.0, 'obj_name': 'cubemedium', 'body': {'params': {'transl': array([[0.08219146, 1.3606334 , 0.9163717 ],
       [0.08221214, 1.360683  , 0.91636825],
       [0.08223021, 1.3607304 , 0.91636986],
       ...,
       [0.10422771, 1.3602108 , 0.917547  ],
       [0.10421915, 1.3605403 , 0.91756487],
       [0.10421263, 1.3608717 , 0.9175787 ]], dtype=float32), 'global_orient': array([[ 1.5995356 ,  0.01279119, -0.03933163],
       [ 1.5998636 ,  0.01279056, -0.03941234],
       [ 1.6000395 ,  0.01281127, -0.0394329 ],
       ...,
       [ 1.5680282 ,  0.02041454, -0.0397973 ],
       [ 1.5677646 ,  0.01971703, -0.04063402],
       [ 1.5675946 ,  0.01902752, -0.04141111]], dtype=float32), 'body_pose': array([[ 0.00871489,  0.09809992,  0.07467163, ...,  0.24217927,
        -0.15459926,  0.08223481],
       [ 0.00827465,  0.09833907,  0.07473552, ...,  0.2363947 ,
        -0.1529555 ,  0.08589435],
       [ 0.00798909,  0.09846441,  0.0748

In [2]:
# Some random functions

def filter_contact_frames(seq_data):
  # find the first contact point
  frame_mask = (seq_data['contact']['object']>0).any(axis=1)
  start_frame = np.argmax(frame_mask)
  return start_frame

def points2sphere(points, radius = .001, vc = [0., 0., 1.], count = [5,5]):
    points = points.reshape(-1,3)
    n_points = points.shape[0]

    spheres = []
    for p in range(n_points):
        sphs = trimesh.creation.uv_sphere(radius=radius, count = count)
        sphs.apply_translation(points[p])
        sphs = Mesh(vertices=sphs.vertices, faces=sphs.faces, vc=vc)

        spheres.append(sphs)

    spheres = Mesh.concatenate_meshes(spheres)
    return spheres

In [3]:
frame_at_contact = filter_contact_frames(seq) # find the first contact point
assert frame_at_contact > fps
frame_mask = np.zeros((seq_len,)).astype(bool)
frame_mask[frame_at_contact-fps : frame_at_contact+fps//2] = 1 # example frame mask: -1s ~ +0.5s (120 fps)

n_comps = seq['n_comps']
T = len(np.where(frame_mask)[0])

rh_mesh = os.path.join(data_path, '..', seq.rhand.vtemp)
rh_vtemp = np.array(Mesh(filename=rh_mesh).vertices)
rh_m = smplx.create(model_path=model_path,
  model_type='mano',
  is_rhand=True,
  v_template=rh_vtemp,
  num_pca_comps=n_comps,
  flat_hand_mean=True,
  batch_size=T
)

rh_params  = prepare_params(seq.rhand.params, frame_mask)
obj_params = prepare_params(seq.object.params, frame_mask)
table_params = prepare_params(seq.table.params, frame_mask)

rh_parms = params2torch(rh_params)
verts_rh = to_cpu(rh_m(**rh_parms).vertices)
points_rh = to_cpu(rh_m(**rh_parms).joints) # --> Where the finger joints' global positions are saved

obj_mesh = os.path.join(data_path, '..', seq.object.object_mesh)
obj_mesh = Mesh(filename=obj_mesh)
obj_vtemp = np.array(obj_mesh.vertices)
obj_m = ObjectModel(v_template=obj_vtemp,
                    batch_size=T)
obj_parms = params2torch(obj_params)
verts_obj = to_cpu(obj_m(**obj_parms).vertices)

table_mesh = os.path.join(data_path, '..', seq.table.table_mesh)
table_mesh = Mesh(filename=table_mesh)
table_vtemp = np.array(table_mesh.vertices)
table_m = ObjectModel(v_template=table_vtemp,
                    batch_size=T)
table_parms = params2torch(table_params)
verts_table = to_cpu(table_m(**table_parms).vertices)

skip_frame = 4
frames = []
for frame in range(0, T, skip_frame):
  o_mesh = Mesh(vertices=verts_obj[frame], faces=obj_mesh.faces, vc=colors['yellow'])
  o_mesh.set_vertex_colors(vc=colors['red'], vertex_ids=seq['contact']['object'][frame] > 0)

  rh_mesh = Mesh(vertices=verts_rh[frame], faces=rh_m.faces, vc=colors['grey'], wireframe=True)
  t_mesh = Mesh(vertices=verts_table[frame], faces=table_mesh.faces, vc=colors['white'])

  rh_joint_mesh = points2sphere(points=points_rh[frame], radius = .01)

  mv.set_static_meshes([o_mesh, rh_mesh, t_mesh, rh_joint_mesh])
  color, depth = mv.viewer.render(mv.scene)
  img = Image.fromarray(color)
  frames.append(img)

frames[0].save(
  "rendered.gif",
  append_images=frames[1:],
  save_all=True,
  duration=100,
  loop=0
)

