In [8]:
import robot
from torch import nn
import torch

In [218]:
# net_L: x(b x dim)->l(b x ndim*(ndim+1)/2)
# net_g: x(b x dim)->g(b x ndim)
# q(b x ndim)--net-->l(b x ndim*(ndim+1)/2)--reshape-->L(b x ndim x ndim,lower triangle), dLdq(b x ndim x ndim x ndim, lower triangle)
# dLdt(b x ndim x ndim)
# dHdt(b x ndim x ndim)
# dHdq(b x ndim x ndim x ndim)

def get_batch_jacobian(net, x, noutputs):
    x = x.unsqueeze(1)
    n = x.size()[0]
    x = x.repeat(1, noutputs, 1)
    x.requires_grad_(True)
    y = net(x)
    input_val = torch.eye(noutputs).reshape(1,noutputs, noutputs).repeat(n, 1, 1)
    x.retain_grad()
    y.backward(input_val)
    return y[:,0,:], x.grad.data


def get_LdLdq(net, q):
    nbatch=q.shape[0]
    ndim=q.shape[1]
    l, l_jac=get_batch_jacobian(net, q, int(ndim*(ndim+1)/2))
    L = torch.zeros((nbatch, ndim, ndim))
    tril_indices = torch.tril_indices(row=ndim, col=ndim, offset=0)
    L[:,tril_indices[0], tril_indices[1]] = l
    dLdq=torch.zeros((nbatch, ndim, ndim, ndim))
    dLdq[:,tril_indices[0], tril_indices[1],:]=l_jac
    return L, dLdq

def inverse_model(net_L, net_g, q, dq, ddq):
    nbatch=q.shape[0]
    ndim=q.shape[1]
    l, l_jac=get_batch_jacobian(net_L, q, int(ndim*(ndim+1)/2))
    L = torch.zeros((nbatch, ndim, ndim))
    tril_indices = torch.tril_indices(row=ndim, col=ndim, offset=0)
    L[:,tril_indices[0], tril_indices[1]] = l
    dLdq=torch.zeros((nbatch, ndim, ndim, ndim))
    dLdq[:,tril_indices[0], tril_indices[1],:]=l_jac
    dLdt=(dLdq@dq.unsqueeze(2)).squeeze()
    H=L@L.transpose(1,2)
    dHdt=L@dLdt.transpose(1,2)+dLdt@L.transpose(1,2)
    dHdq=dLdq.permute(0,3,1,2)@L.transpose(1,2)+L@dLdq.permute(0,3,2,1)
    quad=((dq.unsqueeze(1)@dHdq)@dq.unsqueeze(2)).squeeze() # d(dqHdq)dq
    tau=(H@ddq.unsqueeze(2)).squeeze()+(dHdt@dq.unsqueeze(2)).squeeze()-0.5*quad+net_g(q)
    return tau

In [184]:
class TestModule(nn.Module):
    # Test module only for ndim=3
    def __init__(self):
        super(TestModule, self).__init__()
        
    def forward(self, x):
        return torch.cat([torch.exp(x),torch.exp(2*x)],-1)

In [216]:
model=TestModule()

q=torch.stack([torch.log(torch.arange(1,4).float()),torch.log(torch.arange(2,5).float()),torch.log(torch.arange(5,8).float())])
dq=torch.ones(q.shape)
ddq=torch.ones(q.shape)


L,dLdq=get_LdLdq(model,q)
dLdt=(dLdq@dq.unsqueeze(2)).squeeze()
H=L@L.transpose(1,2)
dHdt=L@dLdt.transpose(1,2)+dLdt@L.transpose(1,2)
dHdq=dLdq.permute(0,3,1,2)@L.transpose(1,2)+L@dLdq.permute(0,3,2,1)
quad=((dq.unsqueeze(1)@dHdq)@dq.unsqueeze(2)).squeeze() # d(dqHdq)dq

tau=(H@ddq.unsqueeze(2)).squeeze()+(dHdt@dq.unsqueeze(2)).squeeze()-0.5*quad+model_g(q)

In [219]:
model_L=TestModule()
model_g=nn.Linear(3,3)
q=torch.stack([torch.log(torch.arange(1,4).float()),torch.log(torch.arange(2,5).float()),torch.log(torch.arange(5,8).float())])
dq=torch.ones(q.shape)
ddq=torch.ones(q.shape)

tau=inverse_model(model_L, model_g, q,dq,ddq)

In [112]:
model = nn.Sequential(
    nn.Linear(3, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Sequential(
        nn.Softplus(),
        nn.Linear(256, 256),
        nn.Linear(256, 256),
        nn.ReLU(),
    ),
    nn.Linear(256, 256),
    nn.Softplus(),
    nn.Linear(256, 6),
)

x = torch.rand((1, 3)).requires_grad_(True)

y = model(x)

#### RobotArm test

In [1]:
from gym.wrappers import TimeLimit
from robot.envs.hyrule.rl_env import ArmReachWithXYZ
import numpy as np
from robot.model.arm.extra import lagrangian_v2 as lg
from torch import nn
import torch
from robot.model.arm.dataset import *

Using default glsl path /home/derek/anaconda3/lib/python3.7/site-packages/sapien/glsl_shader/130
USE sapien core


In [2]:
make_dataset('arm')

Using default glsl path /home/derek/anaconda3/lib/python3.7/site-packages/sapien/glsl_shader/130
USE sapien core


  delta = np.linalg.lstsq(jac[:3], goal-achieved)[0] * 10 # desired_velocity
  delta = np.linalg.lstsq(jac[:3], goal-achieved)[0] * 10 # desired_velocity
  delta = np.linalg.lstsq(jac[:3], goal-achieved)[0] * 10 # desired_velocity
  delta = np.linalg.lstsq(jac[:3], goal-achieved)[0] * 10 # desired_velocity
  delta = np.linalg.lstsq(jac[:3], goal-achieved)[0] * 10 # desired_velocity
  delta = np.linalg.lstsq(jac[:3], goal-achieved)[0] * 10 # desired_velocity
  delta = np.linalg.lstsq(jac[:3], goal-achieved)[0] * 10 # desired_velocity
  delta = np.linalg.lstsq(jac[:3], goal-achieved)[0] * 10 # desired_velocity
  delta = np.linalg.lstsq(jac[:3], goal-achieved)[0] * 10 # desired_velocity
  delta = np.linalg.lstsq(jac[:3], goal-achieved)[0] * 10 # desired_velocity
  delta = np.linalg.lstsq(jac[:3], goal-achieved)[0] * 10 # desired_velocity
  delta = np.linalg.lstsq(jac[:3], goal-achieved)[0] * 10 # desired_velocity


  delta = np.linalg.lstsq(jac[:3], goal-achieved)[0] * 10 # desired_velocity
  delta = np.linalg.lstsq(jac[:3], goal-achieved)[0] * 10 # desired_velocity
  delta = np.linalg.lstsq(jac[:3], goal-achieved)[0] * 10 # desired_velocity
 76%|███████▋  | 3824/5000 [44:15<12:38,  1.55it/s] 

saving...  ./dataset/arm/11.pkl


 77%|███████▋  | 3828/5000 [44:17<10:50,  1.80it/s]

saving...  ./dataset/arm/13.pkl
saving...  ./dataset/arm/10.pkl


 77%|███████▋  | 3832/5000 [44:19<12:29,  1.56it/s]

saving...  ./dataset/arm/14.pkl


 77%|███████▋  | 3835/5000 [44:21<10:51,  1.79it/s]

saving...  ./dataset/arm/12.pkl


 79%|███████▉  | 3945/5000 [45:36<11:35,  1.52it/s]

saving...  ./dataset/arm/9.pkl


 81%|████████▏ | 4066/5000 [46:57<11:25,  1.36it/s]

saving...  ./dataset/arm/8.pkl


 84%|████████▎ | 4179/5000 [48:14<09:17,  1.47it/s]

saving...  ./dataset/arm/7.pkl


 86%|████████▌ | 4293/5000 [49:27<07:32,  1.56it/s]

saving...  ./dataset/arm/6.pkl


 87%|████████▋ | 4359/5000 [50:10<06:57,  1.54it/s]

saving...  ./dataset/arm/5.pkl


 88%|████████▊ | 4421/5000 [50:45<06:13,  1.55it/s]

saving...  ./dataset/arm/4.pkl


 90%|█████████ | 4517/5000 [51:44<04:51,  1.66it/s]

saving...  ./dataset/arm/3.pkl


 92%|█████████▏| 4596/5000 [52:32<04:23,  1.53it/s]

saving...  ./dataset/arm/2.pkl


 94%|█████████▍| 4694/5000 [53:30<02:52,  1.77it/s]

saving...  ./dataset/arm/1.pkl


 96%|█████████▌| 4792/5000 [54:27<02:17,  1.51it/s]

saving...  ./dataset/arm/0.pkl


 98%|█████████▊| 4915/5000 [55:33<00:30,  2.76it/s]

saving...  ./dataset/arm/18.pkl


 98%|█████████▊| 4925/5000 [55:38<00:40,  1.87it/s]

saving...  ./dataset/arm/17.pkl


 99%|█████████▉| 4938/5000 [55:44<00:26,  2.36it/s]

saving...  ./dataset/arm/15.pkl


 99%|█████████▉| 4948/5000 [55:49<00:22,  2.33it/s]

saving...  ./dataset/arm/16.pkl


100%|██████████| 5000/5000 [56:16<00:00,  1.48it/s]


saving...  ./dataset/arm/19.pkl


In [59]:
env=TimeLimit(ArmReachWithXYZ(), 50)
obs=env.reset()
print(obs['observation'])

[ 0.0000000e+00 -1.3825793e+00 -2.9427279e-02  5.7432674e-02
 -9.7876757e-01  3.7991413e-01  6.0942268e-01  2.4935515e+00
  8.0001003e-01  8.0001003e-01  8.0000019e-01  1.1211396e-13
  3.5619316e-12 -1.8626451e-09 -6.7977220e-02 -1.1120176e+00
  2.8683582e-01 -1.0252906e+00 -2.1472794e-01  7.6243766e-02
  2.1023731e-01  1.1284603e-05  1.0477379e-09 -1.4394755e-06
  4.8441251e-12  0.0000000e+00 -1.2245178e-03 -1.9696816e+00
 -2.1866995e+01  6.0620375e+00 -1.8629780e+01 -4.7180300e+00
  2.4600406e+00  5.1464810e+00  1.8125772e-03 -5.9604645e-06
  1.0341406e-05  3.6397523e-08  5.2817632e-07  8.8724220e-01
  3.2058075e-02  1.2223468e+00]




In [None]:
for i in env.agent.get_joints():
    print(i.name)
    
actuator=['right_shoulder_pan_joint',
                          'right_shoulder_lift_joint',
                          'right_arm_half_joint',
                          'right_elbow_joint',
                          'right_wrist_spherical_1_joint',
                          'right_wrist_spherical_2_joint',
                          'right_wrist_3_joint',
                          ]

In [29]:
dataset=Dataset('./dataset/arm', device='cuda:3')

def get_info(data):
    q=data[0][:,1,1:8]
    dq=data[0][:,1,14:21]
    ddq=data[0][:,1,27:34]
    return q, dq, ddq

MAX ACTION [1. 1. 1. 1. 1. 1. 1.]
MAX Q [6.28318214 2.47949553 6.28318214 6.13466597 6.28318501 2.6306982
 6.28317213]
MAX DQ [ 494.97387695  102.62372589  481.76843262  259.27737427 1409.93457031
  554.1619873  1217.09179688]
num train 80000
num valid 20000


In [136]:
torch.cuda.set_device('cuda:3')

# model_l = nn.Sequential(
#     nn.Linear(7, 256),
#     nn.ReLU(),
#     nn.Linear(256, 256),
#     nn.ReLU(),
#     nn.Linear(256, 256),
#     nn.ReLU(),
#     nn.Linear(256, 256),
#     nn.ReLU(),
#     nn.Linear(256, 28),
# )

# model_g = nn.Sequential(
#     nn.Linear(7, 256),
#     nn.ReLU(),
#     nn.Linear(256, 256),
#     nn.ReLU(),
#     nn.Linear(256, 256),
#     nn.ReLU(),
#     nn.Linear(256, 256),
#     nn.ReLU(),
#     nn.Linear(256, 7),
# )

class LagModel(nn.Module):
    
    def __init__(self, ndim):
        super(LagModel, self).__init__()
        self.feat=nn.Sequential(
            nn.Linear(ndim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
        )
        self.diag=nn.Sequential(
            nn.Linear(256, ndim),
#             nn.ReLU(),
        )
        self.tril=nn.Linear(256, int(ndim*(ndim-1)/2))
        self.gravity=nn.Linear(256, ndim)
        
    def forward(self, q):
        feature=self.feat(q)
        out=torch.cat([self.tril(feature), self.diag(feature)], dim=-1)
        return out, self.gravity(feature)
        
model_lag=LagModel(7)      
model_naive = nn.Sequential(
    nn.Linear(21, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 7),

)

# model_l=model_l.cuda()
# model_g=model_g.cuda()
model_lag=model_lag.cuda()
model_naive=model_naive.cuda()
crit = nn.MSELoss()


train_loss_lag=[]
train_loss_naive=[]
val_loss_lag=[]
val_loss_naive=[]

In [None]:
from importlib import reload
reload(lg)
optimizer_lag=torch.optim.AdamW(model_lag.parameters(), lr=5e-4, weight_decay=1e-6)
optimizer_naive=torch.optim.Adam(model_naive.parameters(), lr=5e-4)
for t in range(1):
    for i in range(dataset.num_train):
        data=dataset.sample()
        q,dq,ddq=get_info(data)
        
        # Lagrangian network
        tau=lg.inverse_model_v2(model_lag,q,dq,ddq)
        optimizer_lag.zero_grad()
        loss_lag=crit(tau,data[1].squeeze()*50)
        loss_lag.backward()
        
        optimizer_lag.step()
        
        # Naive network
        tau=model_naive(torch.cat([q,dq,ddq], axis=1))
        optimizer_naive.zero_grad()
        loss_naive=crit(tau,data[1].squeeze()*50)
        loss_naive.backward()
        optimizer_naive.step()
        
        print(t, i, loss_lag.data.item(), loss_naive.data.item())
#         if i%1000==0:
#             data=dataset.sample('valid')
            
#             q,dq,ddq=get_info(data)
#             tau=lg.inverse_model(model_lag,q,dq,ddq)
#             loss_val_lag=crit(tau,data[1].squeeze()*50)
# #             train_loss_lag.append(loss_lag.data.item())
# #             val_loss_lag.append(loss_val_lag.item())
            
#             # Naive network
#             tau=model_naive(torch.cat([q,dq,ddq], axis=1))
#             loss_val_naive=crit(tau,data[1].squeeze()*50)
# #             train_loss_naive.append(loss_naive.data.item())
# #             val_loss_naive.append(loss_val_naive.item())   
            
#             print('Val:', t, i, loss_val_lag.data.item() ,loss_val_naive.data.item())


0 0 587.8056640625 540.7299194335938
0 1 628.6410522460938 806.023193359375
0 2 622.036865234375 612.7535400390625
0 3 569.6346435546875 559.8724365234375
0 4 546.64453125 664.2753295898438
0 5 537.3187255859375 587.8859252929688
0 6 668.9659423828125 647.7855834960938
0 7 606.235107421875 572.4244995117188
0 8 659.0186767578125 676.05615234375
0 9 626.4241943359375 743.9520263671875
0 10 620.9656372070312 650.6654663085938
0 11 615.0822143554688 592.386474609375
0 12 664.6119384765625 751.105224609375
0 13 527.2022094726562 544.55859375
0 14 722.9551391601562 901.4373779296875
0 15 678.4617919921875 680.3306884765625
0 16 603.312255859375 658.2486572265625
0 17 656.2230224609375 689.4296264648438
0 18 569.5133056640625 585.6622924804688
0 19 599.6051025390625 595.4862670898438
0 20 572.4472045898438 554.93017578125
0 21 547.12890625 573.6093139648438
0 22 552.9281616210938 565.3452758789062
0 23 618.9705810546875 611.6369018554688
0 24 584.397216796875 591.0956420898438
0 25 669.34484

0 206 540.6292724609375 506.8505859375
0 207 597.1697998046875 599.067626953125
0 208 571.6842651367188 511.65606689453125
0 209 611.6302490234375 571.9480590820312
0 210 617.6044921875 601.7863159179688
0 211 524.1285400390625 484.4566955566406
0 212 622.6765747070312 607.9529418945312
0 213 573.5654296875 531.9588012695312
0 214 591.2113647460938 535.0735473632812
0 215 567.2802124023438 511.37677001953125
0 216 759.1730346679688 598.968994140625
0 217 643.3486328125 597.5198974609375
0 218 571.9302368164062 517.39208984375
0 219 592.0443725585938 567.3264770507812
0 220 592.833984375 552.6376953125
0 221 599.1173706054688 593.7313232421875
0 222 558.7228393554688 549.8895874023438
0 223 587.2657470703125 541.5393676757812
0 224 532.67529296875 507.51092529296875
0 225 594.9597778320312 600.3469848632812
0 226 630.9389038085938 641.0387573242188
0 227 585.9430541992188 574.219970703125
0 228 624.8361206054688 639.0629272460938
0 229 647.2478637695312 603.41943359375
0 230 658.5124511

0 408 592.4925537109375 548.9274291992188
0 409 633.0385131835938 583.7510986328125
0 410 675.3596801757812 727.0318603515625
0 411 630.4176635742188 634.2998046875
0 412 622.981201171875 547.900634765625
0 413 565.01123046875 520.771240234375
0 414 600.004150390625 562.3606567382812
0 415 595.70703125 537.1122436523438
0 416 569.9954833984375 530.125244140625
0 417 631.9882202148438 587.0762939453125
0 418 636.4598388671875 577.5580444335938
0 419 579.8860473632812 555.366943359375
0 420 578.38525390625 540.401123046875
0 421 586.6476440429688 540.8370971679688
0 422 503.9578857421875 450.0553283691406
0 423 611.3875732421875 579.810791015625
0 424 625.1076049804688 583.6784057617188
0 425 587.9523315429688 551.65185546875
0 426 583.610595703125 479.9137268066406
0 427 634.7781372070312 552.072509765625
0 428 594.777099609375 531.58544921875
0 429 554.8361206054688 504.4642333984375
0 430 649.597900390625 648.4436645507812
0 431 547.2799072265625 509.10150146484375
0 432 588.145141601

0 613 602.4093017578125 566.5960083007812
0 614 605.7681884765625 549.3026733398438
0 615 583.1143798828125 521.3280029296875
0 616 561.9100341796875 520.5199584960938
0 617 583.2505493164062 548.8411865234375
0 618 605.5113525390625 580.3529663085938
0 619 569.5308227539062 530.4879760742188
0 620 632.5184326171875 592.4310913085938
0 621 613.6659545898438 597.3270263671875
0 622 588.2557983398438 518.4569702148438
0 623 583.6057739257812 534.3577880859375
0 624 574.4822998046875 524.3930053710938
0 625 608.1764526367188 549.7669677734375
0 626 597.6439819335938 566.4361572265625
0 627 557.2603759765625 514.042724609375
0 628 705.6441650390625 633.0676879882812
0 629 797.8514404296875 670.9027709960938
0 630 626.7675170898438 571.5738525390625
0 631 689.2020263671875 558.232666015625
0 632 635.3078002929688 589.0037841796875
0 633 628.5949096679688 573.1763916015625
0 634 598.7346801757812 539.1296997070312
0 635 601.5501098632812 544.0578002929688
0 636 648.4567260742188 563.24737548

0 817 584.3302001953125 517.0408935546875
0 818 681.0071411132812 566.6397705078125
0 819 633.240966796875 582.5360717773438
0 820 595.8133544921875 541.6658325195312
0 821 592.5655517578125 534.4789428710938
0 822 516.220703125 463.51971435546875
0 823 646.894287109375 583.098876953125
0 824 527.7635498046875 474.1202087402344
0 825 683.2171630859375 556.3270263671875
0 826 589.0138549804688 515.173095703125
0 827 691.1856079101562 623.771484375
0 828 533.6968994140625 477.6539001464844
0 829 539.1232299804688 494.83380126953125
0 830 584.1528930664062 511.3109436035156
0 831 596.490966796875 517.7747802734375
0 832 597.0392456054688 537.983642578125
0 833 616.611572265625 548.1220703125
0 834 587.154052734375 541.2533569335938
0 835 587.22998046875 542.900634765625
0 836 597.501220703125 526.320556640625
0 837 581.0277709960938 512.621337890625
0 838 642.5203857421875 572.6798706054688
0 839 595.0560302734375 555.0368041992188
0 840 574.0382690429688 492.60107421875
0 841 572.5585937

0 1018 605.0701904296875 524.7799682617188
0 1019 606.7208251953125 546.05322265625
0 1020 685.2824096679688 606.2532348632812
0 1021 602.9535522460938 574.880859375
0 1022 654.7977294921875 563.4201049804688
0 1023 653.34765625 570.7408447265625
0 1024 577.5402221679688 522.878173828125
0 1025 530.3646850585938 463.257080078125
0 1026 590.8369140625 518.470458984375
0 1027 628.1809692382812 535.318359375
0 1028 560.2815551757812 485.9876708984375
0 1029 571.6157836914062 511.109619140625
0 1030 638.7293701171875 550.4301147460938
0 1031 613.5220947265625 535.4511108398438
0 1032 579.5169067382812 504.5515441894531
0 1033 606.4596557617188 543.145263671875
0 1034 658.9039306640625 582.5916137695312
0 1035 555.8758544921875 514.7695922851562
0 1036 589.6810913085938 505.9252014160156
0 1037 586.4093017578125 501.1935119628906
0 1038 597.7955322265625 540.7326049804688
0 1039 612.853515625 595.720703125
0 1040 585.0634765625 514.2837524414062
0 1041 610.7396850585938 535.2041015625
0 104

0 1215 561.502197265625 507.281005859375
0 1216 566.4124755859375 507.33734130859375
0 1217 571.2066650390625 533.09814453125
0 1218 608.025146484375 545.7240600585938
0 1219 581.8450927734375 489.671630859375


In [73]:
for name, module in model_lag.named_modules():
    print(name)


feat
feat.0
feat.1
feat.2
feat.3
diag
diag.0
tril
gravity


In [78]:
for i,j in model_lag.feat.named_modules():
    print(i,j)

 Sequential(
  (0): Linear(in_features=7, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=256, bias=True)
  (3): ReLU()
)
0 Linear(in_features=7, out_features=256, bias=True)
1 ReLU()
2 Linear(in_features=256, out_features=256, bias=True)
3 ReLU()


In [96]:
model_lag.diag[0].weight.grad

tensor([[ 0.0000e+00, -7.7214e+01,  1.3162e+01,  ...,  5.9595e+01,
         -6.8283e-02, -2.4130e+02],
        [ 0.0000e+00, -2.2030e+01,  4.2801e+01,  ...,  8.9456e+00,
         -2.9239e-02, -1.8897e+01],
        [ 0.0000e+00, -3.2994e+01, -5.4616e+01,  ..., -2.1125e+01,
          1.3532e-01,  1.5908e+02],
        ...,
        [ 0.0000e+00, -7.7974e+02, -7.8320e+02,  ..., -1.6180e+02,
         -5.2150e+00, -7.1043e+02],
        [ 0.0000e+00, -3.9131e+01, -5.8152e+02,  ..., -5.6040e+01,
         -5.8999e-03, -6.0449e+01],
        [ 0.0000e+00, -1.1678e+02, -8.9093e+01,  ..., -3.7642e+01,
         -1.3136e-04,  3.2812e+01]], device='cuda:3')

In [28]:
    
class TestModule(nn.Module):
    # Test module only for ndim=3
    def __init__(self):
        super(TestModule, self).__init__()
        self.g=nn.Linear(3,3)
        
    def forward(self, x):
        y=torch.cat([torch.exp(x),torch.exp(2*x)],-1)
        return y,self.g(x)
reload(lg)

torch.cuda.set_device('cuda:3')
    
model_L=TestModule().cuda()
model_g=nn.Linear(3,3).cuda()
q=torch.stack([torch.log(torch.arange(1,4).float()),torch.log(torch.arange(2,5).float()),torch.log(torch.arange(5,8).float()),torch.log(torch.arange(7,10).float())]).cuda()
dq=torch.ones(q.shape).cuda()
ddq=torch.ones(q.shape).cuda()

L,dLdq,dLdt,H,dHdq,dHdt,quad,tau=lg.inverse_model(model_L.cuda(), q,dq,ddq, True)
print(L[0])
print(dLdq[0])
print(dLdt[0])
print(H[0])
print(dHdq[0])
print(dHdt[0])
print(quad[0])
print(tau[0])

tensor([[1., 0., 0.],
        [1., 4., 0.],
        [2., 3., 9.]], device='cuda:3', grad_fn=<SelectBackward>)
tensor([[[ 2.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.]],

        [[ 1.,  0.,  0.],
         [ 0.,  8.,  0.],
         [ 0.,  0.,  0.]],

        [[ 0.,  2.,  0.],
         [ 0.,  0.,  3.],
         [ 0.,  0., 18.]]], device='cuda:3', grad_fn=<SelectBackward>)
tensor([[ 2.,  0.,  0.],
        [ 1.,  8.,  0.],
        [ 2.,  3., 18.]], device='cuda:3', grad_fn=<SelectBackward>)
tensor([[ 1.,  1.,  2.],
        [ 1., 17., 14.],
        [ 2., 14., 94.]], device='cuda:3', grad_fn=<SelectBackward>)
tensor([[[  4.,   3.,   4.],
         [  3.,   2.,   2.],
         [  4.,   2.,   0.]],

        [[  0.,   0.,   2.],
         [  0.,  64.,  26.],
         [  2.,  26.,   8.]],

        [[  0.,   0.,   0.],
         [  0.,   0.,  12.],
         [  0.,  12., 342.]]], device='cuda:3', grad_fn=<SelectBackward>)
tensor([[  4.,   3.,   6.],
        [  3.,  66.,  40.],
   