In [1]:
import jax.numpy as jnp
import jax
import plotly.graph_objects as go
import json
from jaxlie import SO3
from scipy.spatial.transform import Rotation

In [53]:
class KinematicChain:
    def __init__(self, bones, root):
        self.bones = bones
        self.root = root
        self.rest_angles = [[0., 0., 0., 0.] for i in range(len(bones))]
        for _, bone in bones.items():
            bone_idx = bone["idx"]
            self.rest_angles[bone_idx] = bone["quat"]
        self.rest_angles = jnp.array(self.rest_angles)
    
    def forward(self, angles=None):
        if angles is None:
            # angles = jnp.array([[1., 0., 0., 0.] for i in range(len(bones))])
            angles = self.rest_angles
            
        joints = jnp.zeros((len(self.bones)+1, 3), dtype=float)
        stk = [(self.root, jnp.array([0, 0, 0]), jnp.array([0, 0, 1]))]
        while len(stk):
            bone_name, loc, vec = stk.pop()
            bone_idx = self.bones[bone_name]["idx"]
            r = SO3(wxyz=angles[bone_idx])
            # r_rest = SO3(wxyz=self.rest_angles[bone_idx])
            # dir_vec = r_rest.apply(vec)
            dir_vec = r.apply(vec)
            dir_vec = dir_vec / jnp.linalg.norm(dir_vec)
            joints = joints.at[bone_idx+1].set(dir_vec*self.bones[bone_name]["len"] + loc)
            for child in self.bones[bone_name]["child"]:
                stk.append((child, joints[bone_idx+1], dir_vec))
        
        return joints
    
    def IK(self, target, max_iter, mse_threshold, init=None):
        num_bones = len(self.bones)
        num_joints = num_bones+1
        if init is None:
            init = self.rest_angles

        u = 1e-3
        v = 1.5
        last_update = 0
        last_mse = 0
        params = init
        for i in range(max_iter):
            residual = (self.forward(params) - target).reshape(num_joints*3, 1)
            j = jax.jacrev(self.forward)(params).reshape(num_joints*3, num_bones, 4).reshape(num_joints*3, -1)
            mse = jnp.mean(jnp.square(residual))
            
            if abs(mse - last_mse) < mse_threshold:
                return params
            
            jtj = jnp.matmul(j.T, j)
            jtj = jtj + u * jnp.eye(jtj.shape[0])
            
            update = last_mse - mse
            delta = jnp.matmul(
                jnp.matmul(jnp.linalg.inv(jtj), j.T), residual
            ).ravel()
            params -= delta.reshape(num_bones, 4)

            if update > last_update and update > 0:
                u /= v
            else:
                u *= v

            last_update = update
            last_mse = mse

            print(f"Iteration {i}: {mse}")

        return params

In [3]:
def plot(predicted, target):
    bounds = jnp.maximum(predicted, target)
    min_ = bounds.min()
    max_ = bounds.max()
    fig= go.Figure(
        [
            go.Scatter3d(
                x=predicted[:,0],
                y=predicted[:,1],
                z=predicted[:,2],
                mode='markers'
            ),
            go.Scatter3d(
                x=target[:,0],
                y=target[:,1],
                z=target[:,2],
                mode='markers'
            )
        ]
    )
    fig.update_layout(scene_aspectmode='cube')
    fig.update_layout(scene=dict(
        xaxis = dict(range=[-4, 4]),
        yaxis = dict(range=[-4, 4]),
        zaxis = dict(range=[-4, 4])
    ))
    fig.write_html("keypoints.html")

In [4]:
with open("./rest_hand.json", "r") as f:
    bones = json.load(f)

chain = KinematicChain(bones["bones"], bones["root"])



In [5]:
joints = chain.forward()

In [6]:
with open("target.json", "r") as f:
    target = jnp.array(json.load(f))

angles = chain.IK(target, max_iter=100, mse_threshold=1e-9)
pred = chain.forward(angles)
plot(pred, joints)

Iteration 0: 0.0015547601506114006
Iteration 1: 0.0003543452185112983
Iteration 2: 8.447992149740458e-05
Iteration 3: 2.9204022212070413e-05
Iteration 4: 1.7335587472189218e-05
Iteration 5: 1.2771195542882197e-05
Iteration 6: 9.997445886256173e-06
Iteration 7: 8.385212822759058e-06
Iteration 8: 7.448848009516951e-06
Iteration 9: 6.889071755722398e-06
Iteration 10: 6.544075858982978e-06
Iteration 11: 6.326258699118625e-06
Iteration 12: 6.186321570567088e-06
Iteration 13: 6.095340722822584e-06
Iteration 14: 6.035690603312105e-06
Iteration 15: 5.9963722378597595e-06
Iteration 16: 5.970351139694685e-06
Iteration 17: 5.953091203991789e-06
Iteration 18: 5.941621566307731e-06
Iteration 19: 5.933995453233365e-06
Iteration 20: 5.9289186538080685e-06
Iteration 21: 5.925538516748929e-06
Iteration 22: 5.923285243625287e-06
Iteration 23: 5.921786851104116e-06
Iteration 24: 5.920781404711306e-06


In [59]:
with open("./init.json", "r") as f:
    bones = json.load(f)
    
chain = KinematicChain(bones["bones"], bones["root"])
joints = chain.forward()

for i in range(20):
    print(f"Optimizing for frame {i}")
    with open(f"./hands_maximo/keypoints_lerp/keypoints_{i}.json") as f:
        keypoints = jnp.array(json.load(f))
        target = keypoints[:,:3]
        target = jnp.vstack([joints[0], target])

    target = target + joints[1] - target[1]
    target = target.at[0].set(0)
    angles = chain.IK(target, max_iter=100, mse_threshold=1e-6)
    pred = chain.forward(angles)
    pred = jnp.hstack([pred, jnp.ones((pred.shape[0], 1))])
    
    with open(f"./hands_maximo/angles/angles_{i}.json", "w") as f:
        json.dump(angles.tolist(), f)
        
    with open(f"./hands_maximo/keypoints_ik/keypoints_{i}.json", "w") as f:
        json.dump(pred.tolist(), f)

Optimizing for frame 0
Optimizing for frame 1
Iteration 0: 0.00013755576219409704
Iteration 1: 1.902964640976279e-06
Optimizing for frame 2
Iteration 0: 0.000987681676633656
Iteration 1: 2.22903963731369e-05
Iteration 2: 2.012632648984436e-05
Optimizing for frame 3
Iteration 0: 0.003250508802011609
Iteration 1: 4.435678056324832e-05
Iteration 2: 2.6659878130885772e-05
Optimizing for frame 4
Iteration 0: 0.006138625554740429
Iteration 1: 7.325123442569748e-05
Iteration 2: 2.935641896328889e-05
Optimizing for frame 5
Iteration 0: 0.007666280027478933
Iteration 1: 7.226689922390506e-05
Iteration 2: 1.9656594304251485e-05
Optimizing for frame 6
Iteration 0: 0.00821548979729414
Iteration 1: 6.539761670865119e-05
Iteration 2: 1.9582348613766953e-05
Optimizing for frame 7
Iteration 0: 0.008760804310441017
Iteration 1: 5.1788014388876036e-05
Iteration 2: 2.1253097656881437e-05
Optimizing for frame 8
Iteration 0: 0.01039836835116148
Iteration 1: 4.3282750993967056e-05
Iteration 2: 9.28884583117

In [61]:
plot(pred[:,:3], target)