In [120]:
import jax
import jax.numpy as jnp
import jaxlie

import numpy as onp

from jaxmp import JaxKinTree

from jaxmp.extras.urdf_loader import load_urdf

In [4]:
urdf = load_urdf("panda")

In [10]:
kin = JaxKinTree.from_urdf(urdf)



In [19]:
from jac_visualizer import jac_position_cost, get_idx_applied_to_target

In [17]:
target_joint_idx = urdf.joint_names.index(urdf.joint_names[-1])

In [20]:
idx_applied_to_target = get_idx_applied_to_target(kin, target_joint_idx)

In [21]:
print(idx_applied_to_target)

[ 0  1  2  3  4  5  6 -1 -1 -1 -1  7]


In [22]:
kin.num_actuated_joints

8

In [173]:
# cfg = jnp.zeros((kin.num_actuated_joints,))
import numpy as np

cfg = np.random.default_rng(0).normal(size=(kin.num_actuated_joints,))
print(cfg)

[ 0.12573022 -0.13210486  0.64042265  0.10490012 -0.53566937  0.36159505
  1.30400005  0.94708096]


In [174]:
jac = jac_position_cost(kin, cfg, target_joint_idx, idx_applied_to_target)


def position_cost(cfg):
    ee_wxyz_xyz = kin.forward_kinematics(cfg)[target_joint_idx]
    assert ee_wxyz_xyz.shape == (7,)
    return ee_wxyz_xyz[4:7]


jac_autodiff = jax.jacfwd(position_cost)(cfg)

print("Manual Jacobian\n", onp.array(jac).round(3))
print("Autodiff Jacobian\n", onp.array(jac_autodiff).round(3))
print()
print("Difference\n", onp.array(jac_autodiff - jac).round(5))

Manual Jacobian
 [[ 0.878  0.297  0.866  0.09   0.852  0.336 -0.782 -0.199]
 [-0.148  0.038 -0.108  0.003 -0.106  0.091  0.31  -0.931]
 [ 0.     0.257  0.112 -0.759  0.169 -0.227 -0.434 -0.307]]
Autodiff Jacobian
 [[ 0.878  0.297  0.866  0.09   0.852  0.336 -0.782 -0.199]
 [-0.148  0.038 -0.108  0.003 -0.106  0.091  0.31  -0.931]
 [-0.     0.257  0.112 -0.759  0.169 -0.227 -0.434 -0.307]]

Difference
 [[-0.  0. -0. -0. -0. -0.  0.  0.]
 [ 0. -0.  0. -0.  0.  0.  0.  0.]
 [-0. -0. -0.  0. -0.  0.  0.  0.]]


In [175]:
jac = jac_position_cost(kin, cfg, target_joint_idx, idx_applied_to_target)


def position_cost(cfg):
    ee_wxyz_xyz = kin.forward_kinematics(cfg)[target_joint_idx]
    assert ee_wxyz_xyz.shape == (7,)
    return ee_wxyz_xyz[4:7]


jac_autodiff = jax.jacfwd(position_cost)(cfg)

print("Manual Jacobian\n", onp.array(jac).round(3))
print("Autodiff Jacobian\n", onp.array(jac_autodiff).round(3))
print()
print("Difference\n", onp.array(jac_autodiff - jac).round(5))

Manual Jacobian
 [[ 0.878  0.297  0.866  0.09   0.852  0.336 -0.782 -0.199]
 [-0.148  0.038 -0.108  0.003 -0.106  0.091  0.31  -0.931]
 [ 0.     0.257  0.112 -0.759  0.169 -0.227 -0.434 -0.307]]
Autodiff Jacobian
 [[ 0.878  0.297  0.866  0.09   0.852  0.336 -0.782 -0.199]
 [-0.148  0.038 -0.108  0.003 -0.106  0.091  0.31  -0.931]
 [-0.     0.257  0.112 -0.759  0.169 -0.227 -0.434 -0.307]]

Difference
 [[-0.  0. -0. -0. -0. -0.  0.  0.]
 [ 0. -0.  0. -0.  0.  0.  0.  0.]
 [-0. -0. -0.  0. -0.  0.  0.  0.]]


In [176]:
def _skew(omega: jax.Array) -> jax.Array:
    """Returns the skew-symmetric form of a length-3 vector."""

    wx, wy, wz = jnp.moveaxis(omega, -1, 0)
    zeros = jnp.zeros_like(wx)
    return jnp.stack(
        [zeros, -wz, wy, wz, zeros, -wx, -wy, wx, zeros],
        axis=-1,
    ).reshape((*omega.shape[:-1], 3, 3))


def V_inv(theta: jax.Array) -> jax.Array:
    theta_squared = jnp.sum(jnp.square(theta), axis=-1)
    use_taylor = theta_squared < 1e-5

    # Shim to avoid NaNs in jnp.where branches, which cause failures for
    # reverse-mode AD.
    theta_squared_safe = jnp.where(
        use_taylor,
        jnp.ones_like(theta_squared),  # Any non-zero value should do here.
        theta_squared,
    )
    del theta_squared
    theta_safe = jnp.sqrt(theta_squared_safe)
    half_theta_safe = theta_safe / 2.0

    skew_omega = _skew(theta)
    V_inv = jnp.where(
        use_taylor[..., None, None],
        jnp.eye(3)
        - 0.5 * skew_omega
        + jnp.einsum("...ij,...jk->...ik", skew_omega, skew_omega) / 12.0,
        (
            jnp.eye(3)
            - 0.5 * skew_omega
            + (
                (
                    1.0
                    - theta_safe
                    * jnp.cos(half_theta_safe)
                    / (2.0 * jnp.sin(half_theta_safe))
                )
                / theta_squared_safe
            )[..., None, None]
            * jnp.einsum("...ij,...jk->...ik", skew_omega, skew_omega)
        ),
    )
    return V_inv


target = jaxlie.SE3(kin.forward_kinematics(cfg)[-1]) @ jaxlie.SE3.sample_uniform(
    jax.random.PRNGKey(5)
)


def jac_position_and_orientation_cost(
    kin: JaxKinTree,
    cfg: jnp.ndarray,
    target_joint_idx: int,
    idx_applied_to_target: jax.Array,
) -> jnp.ndarray:
    """Jacobian for doing basic IK."""
    Ts_world_joint = kin.forward_kinematics(cfg)
    T_world_target = jaxlie.SE3(Ts_world_joint[target_joint_idx])

    # Get the kinematic chain
    Ts_world_act_joint = jaxlie.SE3(Ts_world_joint)
    joint_twists = kin.joint_twists[kin.idx_actuated_joint]
    vel = joint_twists[..., :3]
    omega = joint_twists[..., 3:]

    ### Get the translational component.
    # Get the kinematic chain
    parent_translation = T_world_target.translation() - Ts_world_act_joint.translation()
    omega = Ts_world_act_joint.rotation() @ omega
    vel = Ts_world_act_joint.rotation() @ vel

    linear_part = jnp.cross(omega, b=parent_translation).squeeze() + vel.squeeze()
    jac_translation = jnp.zeros((3, kin.num_actuated_joints))
    jac_translation = jac_translation.at[:, idx_applied_to_target].add(
        jnp.where(
            (idx_applied_to_target != -1)[None],
            linear_part.T,
            jnp.zeros((3, 1)),
        )
    )
    assert jac_translation.shape == (3, kin.num_actuated_joints)

    omega_wrt_ee = T_world_target.rotation().inverse() @ omega
    ori_error = T_world_target.rotation().inverse() @ target.rotation()
    Jlog3 = V_inv(ori_error.log())

    jac_orientation = jnp.zeros((3, kin.num_actuated_joints))
    jac_orientation = jac_orientation.at[:, idx_applied_to_target].add(
        jnp.where(
            (idx_applied_to_target != -1)[None],
            Jlog3 @ omega_wrt_ee.T,
            jnp.zeros((3, 1)),
        )
    )
    # need negative because inverse is on target term, I think
    jac_orientation = -jac_orientation

    return jnp.concatenate(
        [
            jac_orientation,
            jac_translation,
        ]
    )


jac = jac_position_and_orientation_cost(
    kin, cfg, target_joint_idx, idx_applied_to_target
)


def position_and_orientation_cost(cfg):
    T_world_target = kin.forward_kinematics(cfg)[target_joint_idx]
    assert T_world_target.shape == (7,)
    return jnp.concatenate(
        [
            # We're going to compute:
            #
            #    log(R_ee_target)
            #
            #    f = log( Roterror ( input ) )
            #
            # (R_world_ee)^{-1} @ R_world_target
            (jaxlie.SO3(T_world_target[:4]).inverse() @ target.rotation()).log(),
            T_world_target[4:7] - target.wxyz_xyz[4:7],
        ]
    )


jac_autodiff = jax.jacfwd(position_and_orientation_cost)(cfg)

print("Manual Jacobian\n", onp.array(jac).round(3))
print("Autodiff Jacobian\n", onp.array(jac_autodiff).round(3))
print()
print("Difference\n", onp.array(jac_autodiff - jac).round(5))

Manual Jacobian
 [[ 0.633  0.36   0.754 -0.811  0.801 -0.491 -1.019 -0.   ]
 [-0.242  0.935 -0.297 -0.512 -0.387 -0.87   0.525 -0.   ]
 [ 1.059 -0.03   0.961  0.509  0.889  0.065 -0.518 -0.   ]
 [ 0.878  0.297  0.866  0.09   0.852  0.336 -0.782 -0.199]
 [-0.148  0.038 -0.108  0.003 -0.106  0.091  0.31  -0.931]
 [ 0.     0.257  0.112 -0.759  0.169 -0.227 -0.434 -0.307]]
Autodiff Jacobian
 [[ 0.633  0.36   0.754 -0.811  0.801 -0.491 -1.019  0.   ]
 [-0.242  0.935 -0.297 -0.512 -0.387 -0.87   0.525  0.   ]
 [ 1.059 -0.03   0.961  0.509  0.889  0.065 -0.518 -0.   ]
 [ 0.878  0.297  0.866  0.09   0.852  0.336 -0.782 -0.199]
 [-0.148  0.038 -0.108  0.003 -0.106  0.091  0.31  -0.931]
 [-0.     0.257  0.112 -0.759  0.169 -0.227 -0.434 -0.307]]

Difference
 [[-0. -0. -0.  0. -0.  0.  0.  0.]
 [ 0. -0.  0.  0.  0.  0. -0.  0.]
 [-0.  0. -0. -0. -0.  0.  0.  0.]
 [-0.  0. -0. -0. -0. -0.  0.  0.]
 [ 0. -0.  0. -0.  0.  0.  0.  0.]
 [-0. -0. -0.  0. -0.  0.  0.  0.]]


In [177]:
print(np.linalg.norm(jac[:3, :], axis=0))
print(np.linalg.norm(jac_autodiff[:3, :], axis=0))

[1.257433  1.0019157 1.2572322 1.0860132 1.2580618 1.0008051 1.257996
 0.       ]
[1.2574327 1.0019152 1.2572321 1.086013  1.258061  1.0008049 1.2579957
 0.       ]


In [179]:
_V_inv(position_and_orientation_cost(cfg)[:3])

Array([[ 0.6162893 ,  0.2136889 ,  1.0189798 ],
       [ 0.17484526,  0.9014716 , -0.5252336 ],
       [-1.0263586 ,  0.5106646 ,  0.51803756]], dtype=float32)