In [13]:
from dotmap import DotMap

In [14]:
import torch
import numpy

In [None]:
from nn import beye, bsm

In [16]:
def cross_product_matrix(v):
    nb, dim = v.shape
    assert dim==3
    v1, v2, v3 = v[:, 0:1], v[:, 1:2], v[:, 2:3]
    v_hat = torch.cat((torch.zeros(nb, 1), -v3, v2, 
                       v3, torch.zeros(nb, 1), -v1, 
                       -v2, v1, torch.zeros(nb, 1)), dim=1)
    v_hat = v_hat.reshape(nb, 3, 3)
    return v_hat

In [25]:
def rodrigues_formula(w, th):
    if w.ndimension() == 1:
        w = w.unsqueeze(0)
    if th.ndimension() == 1:
        th = th.unsqueeze(0)
        th = th.reshape(-1, 1)
    nb, dim = w.shape
    assert dim==3
    w_hat = cross_product_matrix(w)
    w_norm = torch.norm(w, dim=1, keepdim=True)
    exp_wth = beye(nb, 3, 3)
    exp_wth += bsm(torch.sin(w_norm * th) / w_norm, w_hat)
    w_hat_squared = torch.bmm(w_hat, w_hat)
    exp_wth += bsm((1-torch.cos(w_norm * th)) / (w_norm ** 2), w_hat_squared)
    return exp_wth

In [21]:
def Rx(th):
    th = th.reshape(-1,1)
    nb, _ = th.shape
    axis = torch.Tensor([1,0,0]).repeat(nb, 1)
    return rodrigues_formula(axis, th)
def Ry(th):
    th.reshape(-1,1)
    nb, _ = th.shape
    axis = torch.Tensor([0,1,0]).repeat(nb, 1)
    return rodrigues_formula(axis, th)
def Rz(th):
    th.reshape(-1,1)
    nb, _ = th.shape
    axis = torch.Tensor([0,0,1]).repeat(nb, 1)
    return rodrigues_formula(axis, th)

In [31]:
def rpy_to_rotation_matrix(rpy):
    roll, pitch, yaw = rpy[:, 0:1], rpy[:, 1:2], rpy[:, 2:3]
    R = torch.bmm(Rz(yaw), torch.bmm(Ry(pitch), Rx(roll)))
    return R

In [347]:
class ForwardKinematics(nn.Module):
    
    def __init__(self, nq, xyz, rpy, axis, xyz_links):
        super().__init__()
        self.nq, self.axis = nq, axis
        self.joints = self.joint_frame_chain(xyz, rpy)
        self.links = self.link_frame_chain(xyz_links)
        
    def forward(self, q, idx):
        nb, nq = q.shape[0:2]
        q = q.reshape(-1,1)
        axis = self.axis.repeat(nb, 1)
        Tq = joint_rotation(axis, q)
        T = torch.bmm(self.joints.repeat(nb, 1, 1), Tq)
        T = T.reshape(nb, -1, 4, 4)
        T_world = beye(nb, 4, 4)
        for j in range(nq):
            T_world = torch.bmm(T_world, T[:, j, :, :])
        T_local = self.links[idx, :, :]
        T_local = T_local.unsqueeze(0).repeat(nb, 1, 1)
        T = torch.bmm(T_world, T_local)
        return T
        
    def joint_frame_chain(self, xyz, rpy):
        if xyz.ndimension() == 1:
            xyz = xyz.unsqueeze(0)
        if rpy.ndimension() == 1:
            rpy = rpy.unsqueeze(0)
        R = rpy_to_rotation_matrix(rpy)
        xyz = xyz.unsqueeze(-1)
        rpy = rpy.unsqueeze(-1)
        T = torch.cat((R, xyz), dim=-1)
        T = flatten_non_batch(T)
        nb = T.shape[0]
        last_row = torch.Tensor([0,0,0,1]).repeat(nb, 1)
        T = torch.cat((T, last_row), dim=-1).reshape(-1, 4, 4)
        return T
        
    def link_frame_chain(self, xyz, rpy=None):
        if rpy is None:
            rpy = torch.zeros_like(xyz)

        if xyz.ndimension() == 1:
            xyz = xyz.unsqueeze(0)
        if rpy.ndimension() == 1:
            rpy = rpy.unsqueeze(0)

        R = rpy_to_rotation_matrix(rpy)
        xyz = xyz.unsqueeze(-1)
        rpy = rpy.unsqueeze(-1)
        T = torch.cat((R, xyz), dim=-1)
        T = flatten_non_batch(T)
        nb = T.shape[0]
        last_row = torch.Tensor([0,0,0,1]).repeat(nb, 1)
        T = torch.cat((T, last_row), dim=-1).reshape(-1, 4, 4)
        return T
    
    def joint_rotation(self, axis, q):
        if axis.ndimension() == 1:
            axis = axis.unsqueeze(0)
        T = rodrigues_formula(axis, q)
        nb = T.shape[0]
        last_col = torch.Tensor([0,0,0]).reshape(-1,1).unsqueeze(0).repeat(nb, 1, 1)
        T = torch.cat((T, last_col), dim=-1)
        T = flatten_non_batch(T)
        last_row = torch.Tensor([0,0,0,1]).repeat(nb, 1)
        T = torch.cat((T, last_row), dim=-1).reshape(-1, 4, 4)
        return T    

In [348]:
# ## two-link 
# xyz = torch.Tensor([[0, 0, 0.075], 
#                     [1.05, 0, 0]])
# rpy = torch.Tensor([[0, -pi/2, 0], 
#                     [0, 0, 0]])
# axis = torch.Tensor([[0, 1, 0], 
#                      [0, 1, 0]])
# xyz_links = torch.Tensor([[0.5, 0, 0],
#                          [0.5 , 0, 0]])

# twolink_fwk = ForwardKinematics(2, xyz, rpy, axis, xyz_links)
# q = torch.Tensor([[0.25, 0.75], [0.25*pi, 0.75*pi]])
# link_world = twolink_fwk(q, 1)

In [355]:
## kuka

xyz = torch.Tensor([[0, 0, 0.1575],
                    [0, 0, 0.2025],
                    [0.2045, 0, 0],
                    [0, 0, 0.2155],
                    [0, 0.1845, 0],
                    [0, -0.06070, 0.21550],
                    [0.08100, 0.00000, 0.06070]])
rpy = torch.Tensor([[0, 0, 0],
                    [2.3561944901923457, -1.5707962635746238, 2.3561944901923457],
                    [1.5707963267948948, -4.371139000186238E-8, 1.5707963705062866],
                    [1.5707963705062866, 0, 0],
                    [-1.5707963705062866, 0, 0],
                    [2.3561944901923457, -1.5707962635746238, 2.3561944901923457],
                    [1.5707963267948948, -4.371139000186238E-8, 1.5707963705062866]])

axis = torch.Tensor([0, 0, 1]).repeat(7,1)

xyz_links = torch.Tensor([-0.05000, 0.00400, 0.12700]).repeat(7, 1)

In [356]:
kuka_fwk = ForwardKinematics(7, xyz, rpy, axis, xyz_links)

In [361]:
q = torch.stack((torch.zeros(7), torch.ones(7)))

In [364]:
link_world = kuka_fwk(q, 6)