In [1]:
## This is demo for kuka reaching a with mpc and diff_qp
## Author : Avadesh Meduri
## Date : 25/02/2022

import time
import numpy as np
import pinocchio as pin
from robot_properties_kuka.config import IiwaConfig

import meshcat
import meshcat.transformations as tf
import meshcat.geometry as g

from diff_pin_costs import DiffFK
from inverse_qp import IOC

import torch
from torch.autograd import Function
from torch.nn import functional as F

In [2]:
robot = IiwaConfig.buildRobotWrapper()
model, data = robot.model, robot.data
f_id = model.getFrameId("EE")

In [3]:
viz = pin.visualize.MeshcatVisualizer(robot.model, robot.collision_model, robot.visual_model)
viz.initViewer(open=False)
viz.loadViewerModel()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7000/static/


In [4]:
dfk = DiffFK.apply
def quadratic_loss(q_pred, x_des, nq):
    loss = torch.linalg.norm(dfk(q_pred[-2*nq:], model, data, f_id) - x_des)
    loss += 0.2*torch.linalg.norm(q_pred[-nq:])
    return loss

In [5]:
nq = model.nq
nv = model.nv

n_col = 10
u_max = model.nq*[5,]

lr = 1e-1
eps = 100

# q_des = np.array([-0.58904862, -0.58904862, -0.58904862,  0.58904862,  0.        ,
#         0.        ,  0.        ])
q_des = np.hstack(((np.pi/3)*(np.random.rand(5) - 0.5)*2, np.zeros(2)))

dq_des = np.zeros_like(q_des)
pin.forwardKinematics(model, data, q_des, dq_des, np.zeros(nv))
pin.updateFramePlacements(model, data)

x_des = torch.tensor(data.oMf[f_id].translation)
print(x_des)
x_init = np.zeros(2*nq)
x_init[0:nq] = (np.pi/3.0)*(np.random.rand(len(q_des)) - 0.5)*2
x_init[nq:] = 0.2*2*(np.random.rand(len(q_des)) - 0.5)



tensor([-0.6971,  0.3109,  0.9090], dtype=torch.float64)


In [6]:
x_in = x_init

viz.viewer["box"].set_object(g.Sphere(0.05), 
                         g.MeshLambertMaterial(
                             color=0xff22dd,
                             reflectivity=0.8))
viz.viewer["box"].set_transform(tf.translation_matrix(x_des.detach().numpy()))

for j in range(8):

    ioc = IOC(n_col, nq, u_max, 0.05, eps = 1.0, isvec=False)
    optimizer = torch.optim.Adam(ioc.parameters(), lr=lr)
    
    
    i = 0
    loss = 1000

    while loss > 0.03 and i < eps:
        t1 = time.time()
        x_pred = ioc(x_in) 
        t3 = time.time()
        print("total forward time:", t3 - t1)

        loss = quadratic_loss(x_pred, x_des, nq)
#         print("Index :" + str(i) + " loss is : " + str(loss.detach().numpy()), end = '\r', flush = True)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        i += 1
        t2 = time.time()
        print("total time:", t2 - t1)
        print("-------------------")
        
    x_pred = ioc(x_in).detach().numpy()

    for i in range(n_col+1):
        q = x_pred[3*nq*i:3*nq*i + nq]
        dq = x_pred[3*nq*i + nq:3*nq*i + 2*nq]

        pin.forwardKinematics(model, data, q, dq, np.zeros(nv))
        pin.updateFramePlacements(model, data)

        viz.display(q)
        time.sleep(0.05)
    
    x_in = torch.tensor(x_pred[-2*nq:])


solve time: 0.033745527267456055
total forward time: 0.0352632999420166
backward time: 0.07416868209838867
backward linalg time: 0.06104135513305664
kkt : 0.012810945510864258
rest : 0.0003101825714111328
total time: 0.11269259452819824
-------------------
solve time: 0.032685041427612305
total forward time: 0.03344845771789551
backward time: 0.06119894981384277
backward linalg time: 0.0021691322326660156
kkt : 0.008438348770141602
rest : 0.05058741569519043
total time: 0.1081078052520752
-------------------
solve time: 0.03520059585571289
total forward time: 0.036373138427734375
backward time: 0.035759687423706055
backward linalg time: 0.018456459045410156
kkt : 0.008980035781860352
rest : 0.008316993713378906
total time: 0.07445764541625977
-------------------
solve time: 0.0467069149017334
total forward time: 0.04754972457885742
backward time: 0.0597224235534668
backward linalg time: 0.0020475387573242188
kkt : 0.014354228973388672
rest : 0.04331636428833008
total time: 0.1221735477

total time: 0.07419180870056152
-------------------
solve time: 0.05639481544494629
total forward time: 0.07280778884887695
backward time: 0.026234149932861328
backward linalg time: 0.0024385452270507812
kkt : 0.01554107666015625
rest : 0.00825047492980957
total time: 0.1295609474182129
-------------------
solve time: 0.04686546325683594
total forward time: 0.052713871002197266
backward time: 0.07792401313781738
backward linalg time: 0.002111196517944336
kkt : 0.07555913925170898
rest : 0.00024819374084472656
total time: 0.13318681716918945
-------------------
solve time: 0.04031968116760254
total forward time: 0.04110383987426758
backward time: 0.07502484321594238
backward linalg time: 0.0019066333770751953
kkt : 0.07289814949035645
rest : 0.00021576881408691406
total time: 0.1180260181427002
-------------------
solve time: 0.03983306884765625
total forward time: 0.040471792221069336
backward time: 0.0753941535949707
backward linalg time: 0.002141237258911133
kkt : 0.07304215431213379

KeyboardInterrupt: 

In [None]:
a = torch.ones((5,5))
