In [1]:
from xml.etree import ElementTree as element_tree

In [64]:
import pdb

In [2]:
import numpy as jnp

In [3]:
urdf_path = 'data/kuka_iiwa.urdf'

In [4]:
with open(urdf_path, 'r') as f:
    urdf_string = f.read()

urdf_root =  element_tree.fromstring(urdf_string)

In [5]:
def get_link_map(joint_nodes, link_reference):
    link_map = {}
    for joint_node in joint_nodes:
        link_name = joint_node.find(link_reference).get('link')
        if link_name in link_map:
            link_map[link_name].append(joint_node)
        else:
            link_map[link_name] = [joint_node]
    return link_map

def get_joint_chains(parent_link_map, child_link, list_of_chains=None):
    old_chains = list_of_chains if list_of_chains else [[]]
    new_chains = []
    if child_link in parent_link_map:
        for joint_node in parent_link_map[child_link]:
            new_chains.extend(get_joint_chains(parent_link_map,
                                             joint_node.find('child').get('link'),
                                             [oc + [joint_node] for oc in old_chains]))
        return new_chains
    else:
        return old_chains
            
def get_all_chains(urdf_root):
    joint_nodes = urdf_root.findall('joint')

    child_link_map = get_link_map(joint_nodes, 'child')
    parent_link_map = get_link_map(joint_nodes, 'parent')

    parent_link_set = set([v for v in parent_link_map])
    child_link_set = set([v for v in child_link_map])

    root_links = list(parent_link_set.difference(child_link_set))

    return [jc for root_link in root_links for jc in get_joint_chains(parent_link_map, root_link)]

def parse_string_to_numeric_list(vec_string):
    return [float(x) for x in vec_string.split(' ')]

In [6]:
def make_pose(translation, rotation, axis):
    cq = jnp.cos(0.5 * rotation)
    sq = jnp.sin(0.5 * rotation)
    return jnp.array([translation[0], translation[1], translation[2], cq, sq * axis[0], sq * axis[1], sq * axis[2]])

def multiply(p_left, p_right):
    ltx, lty, ltz, lqw, lqx, lqy, lqz = p_left
    rtx, rty, rtz, rqw, rqx, rqy, rqz = p_right
    tw = -lqx*rtx - lqy*rty - lqz*rtz
    tx = lqw*rtx + lqy*rtz - lqz*rty
    ty = lqw*rty - lqx*rtz + lqz*rtx
    tz = lqw*rtz + lqx*rty - lqy*rtx
    
    tx =-tw*lqx + tx*lqw - ty*lqz + tz*lqy + ltx
    ty =-tw*lqy + tx*lqz + ty*lqw - tz*lqx + lty
    tz =-tw*lqz - tx*lqy + ty*lqx + tz*lqw + ltz
    
    qw = lqw*rqw - lqx*rqx - lqy*rqy - lqz*rqz
    qx = lqw*rqx + lqx*rqw + lqy*rqz - lqz*rqy
    qy = lqw*rqy - lqx*rqz + lqy*rqw + lqz*rqx
    qz = lqw*rqz + lqx*rqy - lqy*rqx + lqz*rqw
    return jnp.array([tx, ty, tz, qw, qx, qy, qz])

def make_rpy_xyz_pose(rpy, xyz):
    yaw = make_pose(xyz, rpy[2], [0.0, 0.0, 1.0])
    pitch = make_pose([0., 0., 0.], rpy[1], [0.0, 1.0, 0.0])
    roll = make_pose([0., 0., 0.], rpy[0], [1.0, 0.0, 0.0])
    return multiply(multiply(yaw, pitch), roll)

In [7]:
def extract_origin_pose(joint_node):
    origin = joint_node.find('origin')
    if origin:
        xyz = parse_string_to_numeric_list(origin.get('xyz'))
        rpy = parse_string_to_numeric_list(origin.get('rpy'))
        return make_rpy_xyz_pose(rpy, xyz)
    else:
        return jnp.zeros(7)

def extract_axis_function(joint_node):
    xyz = parse_string_to_numeric_list(joint_node.find('axis').get('xyz'))
    return lambda rotation: make_pose([0., 0., 0.], rotation, xyz)

def CreateKinematicChainFunction(joint_chain):
    

In [8]:
joint_nodes = urdf_root.findall('joint')
link_nodes = urdf_root.findall('link')

joint_dict = {n.get('name'): n for n in joint_nodes}
link_dict = {n.get('name'): n for n in link_nodes}


In [9]:
# Get a joint chain.
joint_chains = get_all_chains(urdf_root)
joint_chain = joint_chains[0]

In [10]:
[j.get('name') for j in joint_chain]

['lbr_iiwa_joint_1',
 'lbr_iiwa_joint_2',
 'lbr_iiwa_joint_3',
 'lbr_iiwa_joint_4',
 'lbr_iiwa_joint_5',
 'lbr_iiwa_joint_6',
 'lbr_iiwa_joint_7']

In [11]:
node = joint_chain[0]
print(node.keys())
list(node)

['type', 'name']


[<Element 'parent' at 0x7fb9201122d0>,
 <Element 'child' at 0x7fb920112310>,
 <Element 'origin' at 0x7fb920112290>,
 <Element 'axis' at 0x7fb920112350>,
 <Element 'limit' at 0x7fb9201123d0>,
 <Element 'dynamics' at 0x7fb920112390>]

In [101]:
node.find('axis').keys()

['xyz']