In [2]:
import jax.numpy as jnp
import jax
import plotly.graph_objects as go
import json
from jaxlie import SO3
from scipy.spatial.transform import Rotation

In [3]:
class KinematicChain:
    def __init__(self, bones, root):
        self.bones = bones
        self.root = root
        self.rest_angles = [[0., 0., 0., 0.] for i in range(len(bones))]
        for _, bone in bones.items():
            bone_idx = bone["idx"]
            self.rest_angles[bone_idx] = bone["quat"]
        self.rest_angles = jnp.array(self.rest_angles)
    
    def forward(self, angles=None):
        if angles is None:
            # angles = jnp.array([[1., 0., 0., 0.] for i in range(len(bones))])
            angles = self.rest_angles
            
        joints = jnp.zeros((len(self.bones)+1, 3), dtype=float)
        stk = [(self.root, jnp.array([0, 0, 0]), jnp.array([0, 0, 1]))]
        while len(stk):
            bone_name, loc, vec = stk.pop()
            bone_idx = self.bones[bone_name]["idx"]
            r = SO3(wxyz=angles[bone_idx])
            # r_rest = SO3(wxyz=self.rest_angles[bone_idx])
            # dir_vec = r_rest.apply(vec)
            dir_vec = r.apply(vec)
            dir_vec = dir_vec / jnp.linalg.norm(dir_vec)
            joints = joints.at[bone_idx+1].set(dir_vec*self.bones[bone_name]["len"] + loc)
            for child in self.bones[bone_name]["child"]:
                stk.append((child, joints[bone_idx+1], dir_vec))
        
        return joints
    
    def IK(self, target, max_iter, mse_threshold, to_use=None, init=None):
        num_bones = len(self.bones)
        num_joints = num_bones+1
        if init is None:
            init = self.rest_angles

        u = 1e-3
        v = 1.5
        last_update = 0
        last_mse = 0
        params = init
        if to_use is None:
            to_use = jnp.ones(num_joints, dtype=bool)
            
        for i in range(max_iter):
            residual = (self.forward(params)[to_use] - target[to_use]).reshape(-1, 1)
            j = jax.jacrev(self.forward)(params)[to_use].reshape(-1, num_bones, 4).reshape(-1, 4*num_bones)

            mse = jnp.mean(jnp.square(residual))
            
            if abs(mse - last_mse) < mse_threshold:
                return params
            
            jtj = jnp.matmul(j.T, j)
            jtj = jtj + u * jnp.eye(jtj.shape[0])
            
            update = last_mse - mse
            delta = jnp.matmul(
                jnp.matmul(jnp.linalg.inv(jtj), j.T), residual
            ).ravel()
            params -= delta.reshape(num_bones, 4)

            if update > last_update and update > 0:
                u /= v
            else:
                u *= v

            last_update = update
            last_mse = mse

            print(f"Iteration {i}: {mse}")

        return params

In [3]:
def plot(predicted, target):
    bounds = jnp.maximum(predicted, target)
    min_ = bounds.min()
    max_ = bounds.max()
    fig= go.Figure(
        [
            go.Scatter3d(
                x=predicted[:,0],
                y=predicted[:,1],
                z=predicted[:,2],
                mode='markers'
            ),
            go.Scatter3d(
                x=target[:,0],
                y=target[:,1],
                z=target[:,2],
                mode='markers'
            )
        ]
    )
    fig.update_layout(scene_aspectmode='cube')
    fig.update_layout(scene=dict(
        xaxis = dict(range=[-4, 4]),
        yaxis = dict(range=[-4, 4]),
        zaxis = dict(range=[-4, 4])
    ))
    fig.write_html("keypoints.html")

In [4]:
with open("./init.json", "r") as f:
    bones = json.load(f)
    
chain = KinematicChain(bones["bones"], bones["root"])
joints = chain.forward()

for i in range(20):
    print(f"Optimizing for frame {i}")
    with open(f"./hands_maximo/keypoints/keypoints_{i}.json") as f:
        keypoints = jnp.array(json.load(f))
        target = jnp.vstack([jnp.hstack(joints[0]), keypoints[:,:3]])
        to_use = jnp.hstack([True, ~jnp.isclose(keypoints[:,3], 0)])
    target = target + joints[1] - target[1]
    target = target.at[0].set(0)
    angles = chain.IK(target, max_iter=100, mse_threshold=1e-6, to_use=to_use)
    pred = chain.forward(angles)
    pred = jnp.hstack([pred, jnp.ones((pred.shape[0], 1))])
    
    with open(f"./hands_maximo/angles/angles_{i}.json", "w") as f:
        json.dump(angles.tolist(), f)
        
    with open(f"./hands_maximo/keypoints_ik/keypoints_{i}.json", "w") as f:
        json.dump(pred.tolist(), f)



Optimizing for frame 0
Optimizing for frame 1
Iteration 0: 0.00013755576219409704
Iteration 1: 1.902964640976279e-06
Optimizing for frame 2
Iteration 0: 0.0008946952293626964
Iteration 1: 2.209007834608201e-05
Iteration 2: 2.0395498722791672e-05
Optimizing for frame 3
Iteration 0: 0.0031319214031100273
Iteration 1: 2.700101504160557e-05
Iteration 2: 1.6281454009003937e-05
Optimizing for frame 4
Iteration 0: 0.006385145243257284
Iteration 1: 6.939285958651453e-05
Iteration 2: 2.7715801479644142e-05
Optimizing for frame 5
Iteration 0: 0.007666280027478933
Iteration 1: 7.226689922390506e-05
Iteration 2: 1.9656594304251485e-05
Optimizing for frame 6
Iteration 0: 0.00821548979729414
Iteration 1: 6.539761670865119e-05
Iteration 2: 1.9582348613766953e-05
Optimizing for frame 7
Iteration 0: 0.008760804310441017
Iteration 1: 5.1788014388876036e-05
Iteration 2: 2.1253097656881437e-05
Optimizing for frame 8
Iteration 0: 0.01039836835116148
Iteration 1: 4.3282750993967056e-05
Iteration 2: 9.288845

In [7]:
with open("./init.json", "r") as f:
    bones = json.load(f)
    
chain = KinematicChain(bones["bones"], bones["root"])
joints = chain.forward()

for i in range(20):
    print(f"Optimizing for frame {i:05}")
    with open(f"./hands_maximo/keypoints_3d/{i:05}.json") as f:
        keypoints = jnp.array(json.load(f))
        target = jnp.vstack([jnp.hstack(joints[0]), keypoints[:,:3]])
        to_use = jnp.hstack([True, keypoints[:,3] < 10])
    target = target + joints[1] - target[1]
    target = target.at[0].set(0)
    angles = chain.IK(target, max_iter=100, mse_threshold=1e-6, to_use=to_use)
    pred = chain.forward(angles)
    pred = jnp.hstack([pred, jnp.ones((pred.shape[0], 1))])
    
    with open(f"./hands_maximo/angles/angles_{i}.json", "w") as f:
        json.dump(angles.tolist(), f)
        
    with open(f"./hands_maximo/keypoints_ik/keypoints_{i}.json", "w") as f:
        json.dump(pred.tolist(), f)

Optimizing for frame 00000
Iteration 0: 0.00010934772581094876
Iteration 1: 9.424074050912168e-06
Optimizing for frame 00001
Iteration 0: 0.0005309314001351595
Iteration 1: 4.0294577047461644e-05
Iteration 2: 3.8071499147918075e-05
Iteration 3: 5.9814585256390274e-05
Iteration 4: 3.98204501834698e-05
Optimizing for frame 00002
Iteration 0: 0.0013195665087550879
Iteration 1: 4.4614775106310844e-05
Iteration 2: 4.071483272127807e-05
Optimizing for frame 00003
Iteration 0: 0.003631086088716984
Iteration 1: 0.00010552824096521363
Iteration 2: 8.440624515060335e-05
Iteration 3: 8.30395074444823e-05
Optimizing for frame 00004
Iteration 0: 0.00882847048342228
Iteration 1: 0.00015286130656022578
Iteration 2: 8.045370486797765e-05
Iteration 3: 7.92860082583502e-05
Optimizing for frame 00005
Iteration 0: 0.009454334154725075
Iteration 1: 0.00014018654474057257
Iteration 2: 5.954165681032464e-05
Optimizing for frame 00006
Iteration 0: 0.009484287351369858
Iteration 1: 0.00013828850933350623
Itera