In [1]:
GIT_ROOT_LINES = !git rev-parse --show-toplevel
WORK_DIR = GIT_ROOT_LINES[0]
%cd $WORK_DIR
%load_ext autoreload
%autoreload 2

/scratch/xz653/code/shape-of-motion


In [2]:
from icecream import ic
import pyvista as pv
from copy import deepcopy
from pyvista import examples
from motionblender.lib.animate import *
from tqdm.auto import tqdm, trange

pv.set_jupyter_backend('trame')
pv.set_plot_theme('paraview')

In [3]:
def vis_link(joints, connections, prefix=""):
    colors = ['green', 'blue', 'yellow', 'purple', 'orange', 'white', 'black']
    for i, joint in enumerate(joints):
        plotter['instance'].add_mesh(pv.Sphere(center=joint, radius=0.1), color=colors[i], name=f'{prefix}joint-{i}')
    for i, connection in enumerate(connections):
        joint1 = joints[connection[0]]
        joint2 = joints[connection[1]]
        plotter['instance'].add_mesh(pv.Tube(joint1, joint2, radius=0.05), color='red', name=f'{prefix}tube-{i}')
        
def clear_plotter():
    try:
        plotter['instance'].Startpos = {'curve': 12, 'mask': 12}
        plotter['instance'].clear()
        plotter['instance'].clear_slider_widgets()
        plotter['instance'].clear_button_widgets()
    except NameError:
        pass

def close_plotter(): # close it before debugging!
    plotter['cam_pos'] = plotter['instance'].camera.position
    plotter['roll'] = plotter['instance'].camera.roll
    plotter['azimuth'] = plotter['instance'].camera.azimuth
    plotter['elevation'] = plotter['instance'].camera.elevation
    plotter['instance'].close()
    plotter['instance'] = None
    import gc
    gc.collect()

def open_plotter():
    plotter['instance'] = new_plotter()
    if 'cam_pos' in plotter:
        plotter['instance'].camera.position = plotter['cam_pos']
        plotter['instance'].camera.roll = plotter['roll']
        plotter['instance'].camera.azimuth = plotter['azimuth']
        plotter['instance'].camera.elevation = plotter['elevation']


try:
    v = plotter['instance'].show(return_viewer=True)
    print(v.value.split("src=\"")[1].split("\"")[0])
except:
    if 'plotter' not in globals():
        plotter = {'instance': pv.Plotter(notebook=True)}
    else:
        open_plotter()
    v = plotter['instance'].show(return_viewer=True)
    print(v.value.split("src=\"")[1].split("\"")[0])

http://localhost:40690/index.html?ui=P_0x2ada0be4c7f0_0&reconnect=auto


wslink is not expecting text message:
> 
wslink is not expecting text message:
> 
wslink is not expecting text message:
> 
wslink is not expecting text message:
> 
wslink is not expecting text message:
> 
wslink is not expecting text message:
> 


# Basics

The following code generates simple deformable graph and kinematic tree, and verifies some basic APIs from `animate.py`.

## Deformable Graph

In [4]:
clear_plotter()

N = 10000
joints = torch.as_tensor([
    [0, 0, 0], 
    [0, 0, 1],
    [0, 0, 2] 
]).float()

new_joints = torch.randn_like(joints)

triangles = torch.as_tensor([
    [[1, 0], [1, 2]]
])
connections = triangles.reshape(-1, 2)

xyz = torch.from_numpy(np.concatenate([np.random.normal(0, 0.25, size=(N, 2)), np.random.uniform(-0.5, 2.5, size=(N, 1))], axis=1)).float()

xyz_weights = weight_inpaint(xyz, joints, connections, gamma=1.0, temperature=0.1)
_, _, falloff = compute_distance_from_link(xyz, joints[connections[:, 0]], joints[connections[:, 1]]) # projection 

def rpt(x):
    return repeat(x, 'a b -> (p a) b', p=N)

normals = rpt(compute_normals(new_joints, triangles.long()))

deform_mats = find_T_between_poses(
                    find_link_ctrl_pt_pose(rpt(joints[connections[:, 0]]), rpt(joints[connections[:, 1]]), normals, falloff.flatten()),
                    find_link_ctrl_pt_pose(rpt(new_joints[connections[:, 0]]), rpt(new_joints[connections[:, 1]]), normals, falloff.flatten()))
deform_mats = rearrange(deform_mats, '(p m) a b -> p m a b', m=len(connections))

final_xyz = apply_mat4(skinning(xyz_weights, deform_mats, blend_mode='dq'), xyz) # apply deformation from graph links to points

vis_link(new_joints, connections)
plotter['instance'].add_mesh(pv.PolyData(final_xyz.numpy()), point_size=4, render_points_as_spheres=False, opacity=1.0,
                             name="pts", scalars=xyz_weights.numpy()[:, 1], cmap='viridis');


## Kinematic Chain

In [5]:
clear_plotter()

N = 10000
joints = torch.as_tensor([
    [0, 0, 0], 
    [0, 0, 1],
    [0, 0, 2],
    [0, 0, 3]
]).float()
connections = [(0, 1), (1, 2), (2, 3)]
new_joints = joints.clone()
chain = inverse_kinematic(joints, connections)

chain['chain'][0]['rot6d'] = rmat_to_cont_6d(roma.euler_to_rotmat('xyz', [30, 70, 70], degrees=True))
chain['chain'][0]['chain'][0]['rot6d'] =  rmat_to_cont_6d(roma.euler_to_rotmat('xyz', [45, 65, 30], degrees=True))
chain['chain'][0]['chain'][0]['chain'][0]['rot6d'] =  rmat_to_cont_6d(roma.euler_to_rotmat('xyz', [45, 60, 0], degrees=True))

new_joints[1:] = forward_kinematic(chain)[:, :3, 3]

new_chain = inverse_kinematic(new_joints, connections) # reconstruct kinematic tree from just joints and links 
new_joints_recover = new_joints.clone()
new_joints_recover[1:] = forward_kinematic(new_chain)[:, :3, 3]
print('============= original =============')
print(new_joints)
print('============= recover =============')
print(new_joints_recover)
assert torch.allclose(new_joints, new_joints_recover)

vis_link(new_joints_recover, connections)

tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.7482,  0.5937,  0.2962],
        [ 1.2687,  0.8478, -0.5190],
        [ 1.2352,  0.1177, -1.2016]])
tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.7482,  0.5937,  0.2962],
        [ 1.2687,  0.8478, -0.5190],
        [ 1.2352,  0.1177, -1.2016]])


# Learning

The following code generate some random points, and then fit a graph structure to align with the points. Here we assume the weight (soft connections) between the points and graphs are predefined. 

## Kinematic Chain

In [6]:
clear_plotter()
N = 10000
joints = torch.as_tensor([
    [0, 0, 0], 
    [0, 0, 1],
    [0, 0, 2],
    [0, 0, 3]
]).float()
connections = [(0, 1), (1, 2), (2, 3)]
xyz = torch.from_numpy(np.concatenate([np.random.normal(0, 0.25, size=(N, 2)), 
                                       np.random.uniform(joints.min().item()-0.5, joints.max().item()+0.5, size=(N, 1))], axis=1)).float()


xyz_weights = weight_inpaint(xyz, joints, torch.as_tensor(connections), gamma=1.0, temperature=0.1)
new_joints = joints.clone()
chain = inverse_kinematic(joints, connections)
init_chain = deepcopy(chain)
old_link_poses = forward_kinematic(chain)

chain['chain'][0]['rot6d'] = rmat_to_cont_6d(roma.euler_to_rotmat('xyz', [30, 70, 70], degrees=True))
chain['chain'][0]['chain'][0]['rot6d'] =  rmat_to_cont_6d(roma.euler_to_rotmat('xyz', [45, 65, 30], degrees=True))
chain['chain'][0]['chain'][0]['chain'][0]['rot6d'] =  rmat_to_cont_6d(roma.euler_to_rotmat('xyz', [45, 0, 0], degrees=True))

new_link_poses = forward_kinematic(chain)
new_joints[1:] = new_link_poses[:, :3, 3]

mat4 = find_T_between_poses(old_link_poses, new_link_poses)
final_xyz = apply_mat4(skinning(xyz_weights, mat4, blend_mode='dq'), xyz)

# --- VISUALIZATION --- #
# the plotter now visualizes displaced points and graph
# before
vis_link(joints, connections, prefix="original-") 
# plotter['instance'].add_mesh(pv.PolyData(xyz.numpy()), point_size=4, render_points_as_spheres=False, opacity=0.5, 
#                              name="pts-before", scalars=xyz_weights.numpy()[:, -1], cmap='viridis');

# after
# vis_link(new_joints, connections, prefix="new-")
plotter['instance'].add_mesh(pv.PolyData(final_xyz.numpy()), point_size=4, render_points_as_spheres=False, opacity=0.5, 
                             name="pts-new", scalars=xyz_weights.numpy()[:, -1], cmap='viridis');

In [7]:
def to_parameter_chain(chain: KinematicLink, 
            length_inv_activation: Callable=lambda x: x, 
            rot_inv_activation: Callable=lambda x: x) \
            -> KinematicLink:
    """ return the same chain, but replace dict with ParameterDict, and 
    apply the inv activation"""
    new_chain = nn.ParameterDict()
    
    def walk(from_c, to_c):
        to_c['id'] = from_c['id']
        if from_c.get('length', None) is not None:
            to_c['length'] = nn.Parameter(length_inv_activation(from_c['length']), requires_grad=True)
            to_c['rot6d'] = nn.Parameter(rot_inv_activation(from_c['rot6d']), requires_grad=True)
        
        to_c['chain'] = nn.ParameterList()
        for child in from_c.get('chain', []):
            to_c['chain'].append(nn.ParameterDict())
            walk(child, to_c['chain'][-1])

    walk(chain, new_chain)
    return new_chain

In [8]:
pchain = to_parameter_chain(init_chain)
opt = optim.Adam(pchain.parameters(), lr=1e-3)

def train_step():
    """ 
    depends on pre-generated: old_link_poses, xyz_weights, final_xyz
    """
    new_link_poses = forward_kinematic(pchain)
    mat4 = find_T_between_poses(old_link_poses, new_link_poses)
    xyz_after_skinning = apply_mat4(skinning(xyz_weights, mat4, blend_mode='dq'), xyz)
    return F.l1_loss(xyz_after_skinning, final_xyz)

for i in trange(1000):
    opt.zero_grad()
    loss = train_step()
    loss.backward()
    opt.step()
    if i % 50 == 0:
        print(loss.item())

  0%|          | 0/1000 [00:00<?, ?it/s]

1.1545584201812744
1.0416420698165894
0.9306448101997375
0.7828100323677063
0.5491570234298706
0.29944494366645813
0.25183573365211487
0.24004189670085907
0.23307541012763977
0.22679348289966583
0.23553216457366943
0.2253325879573822
0.21686141192913055
0.2082044780254364
0.1998181790113449
0.19207166135311127
0.18476508557796478
0.17809826135635376
0.14708375930786133
0.13309122622013092


In [9]:
clear_plotter()
plotter['instance'].add_mesh(pv.PolyData(final_xyz.numpy()), point_size=4, render_points_as_spheres=False, opacity=0.5, 
                             name="pts-new", scalars=xyz_weights.numpy()[:, -1], cmap='viridis')
with torch.no_grad():
    learned_link_poss = forward_kinematic(pchain)
    learned_joints = joints.clone()
    learned_joints[1:] = learned_link_poss[:, :3, 3]
    vis_link(learned_joints, connections, prefix="learned-")
    

## Deformable Grid

In [18]:
clear_plotter()

N = 10000
joints = torch.as_tensor([
    [0, 0, 0], 
    [0, 0, 1],
    [0, 0, 2] 
]).float()

new_joints = torch.randn_like(joints)

triangles = torch.as_tensor([
    [[1, 0], [1, 2]]
])
connections = triangles.reshape(-1, 2)

xyz = torch.from_numpy(np.concatenate([np.random.normal(0, 0.25, size=(N, 2)), np.random.uniform(-0.5, 2.5, size=(N, 1))], axis=1)).float()




xyz_weights = weight_inpaint(xyz, joints, connections, gamma=1.0, temperature=0.1)
_, _, falloff = compute_distance_from_link(xyz, joints[connections[:, 0]], joints[connections[:, 1]]) 

def rpt(x): return repeat(x, 'a b -> (p a) b', p=N)

origin_normals = rpt(compute_normals(joints, triangles.long()))
normals = rpt(compute_normals(new_joints, triangles.long()))

deform_mats = find_T_between_poses(
                    find_link_ctrl_pt_pose(rpt(joints[connections[:, 0]]), rpt(joints[connections[:, 1]]), normals, falloff.flatten()),
                    find_link_ctrl_pt_pose(rpt(new_joints[connections[:, 0]]), rpt(new_joints[connections[:, 1]]), normals, falloff.flatten()))
deform_mats = rearrange(deform_mats, '(p m) a b -> p m a b', m=len(connections))
final_xyz = apply_mat4(skinning(xyz_weights, deform_mats, blend_mode='dq'), xyz)

vis_link(joints, connections) # before
# plotter['instance'].add_mesh(pv.PolyData(xyz.numpy()), point_size=4, render_points_as_spheres=False, opacity=1.0,
#                              name="pts", scalars=xyz_weights.numpy()[:, 1], cmap='viridis'); # after

# vis_link(new_joints, connections) # after
plotter['instance'].add_mesh(pv.PolyData(final_xyz.numpy()), point_size=4, render_points_as_spheres=False, opacity=1.0,
                             name="pts", scalars=xyz_weights.numpy()[:, 1], cmap='viridis'); # after

In [19]:
pjoints = nn.Parameter(joints.clone(), requires_grad=True)  # after avoiding the use of arun, much better

opt = optim.Adam([pjoints], lr=1e-2)

def train_step():
    """ 
    depends on pre-generated: xyz, xyz_weights, final_xyz, falloff, connections, triangles
    """
    def rpt(x): return repeat(x, 'a b -> (p a) b', p=len(xyz))
    pred_normals = compute_normals(pjoints, triangles.long())
    # illegal_normals = torch.all(pred_normals == 0, dim=1)
    # pred_normals[illegal_normals] = origin_normals[illegal_normals]
    pred_normals = rpt(pred_normals)
    deform_mats = find_T_between_poses(
                    find_link_ctrl_pt_pose(rpt(joints[connections[:, 0]]), rpt(joints[connections[:, 1]]), origin_normals, falloff.flatten()),
                    find_link_ctrl_pt_pose(rpt(pjoints[connections[:, 0]]), rpt(pjoints[connections[:, 1]]), pred_normals, falloff.flatten()))
    deform_mats = rearrange(deform_mats, '(p m) a b -> p m a b', m=len(connections))

    xyz_after_skinning = apply_mat4(skinning(xyz_weights, deform_mats, blend_mode='dq'), xyz)
    return F.l1_loss(xyz_after_skinning, final_xyz)

for i in trange(1000):
    opt.zero_grad()
    loss = train_step()
    loss.backward()
    opt.step()
    if i % 50 == 0:
        print(loss.item())

  0%|          | 0/1000 [00:00<?, ?it/s]

0.863994836807251
0.6059651970863342
0.44048967957496643
0.3776163160800934
0.34170520305633545
0.3229825794696808
0.316812127828598
0.31532153487205505
0.31496870517730713
0.31491872668266296
0.3149138391017914
0.31491294503211975
0.31491291522979736
0.314912885427475
0.31491294503211975
0.3149128556251526
0.31491291522979736
0.31491294503211975
0.31491294503211975
0.31491294503211975


In [20]:
clear_plotter()
plotter['instance'].add_mesh(pv.PolyData(final_xyz.numpy()), point_size=4, render_points_as_spheres=False, opacity=0.5, 
                             name="pts-new", scalars=xyz_weights.numpy()[:, -1], cmap='viridis')
with torch.no_grad():
    learned_joints = pjoints.detach().clone()
    vis_link(learned_joints, connections, prefix="learned-")
    

: 