In [2]:
from gym.wrappers import TimeLimit
from robot.envs.hyrule.rl_env import ArmReachWithXYZ
import numpy as np
from robot.model.arm.extra import lagrangian_v2 as lg
from torch import nn
import torch
from robot.model.arm.dataset import *

In [5]:
make_dataset('arm')
dataset=Dataset('./dataset/acrobat2', device='cuda:3')

MAX ACTION [0.99999994 0.99999893]
MAX Q [6.28318357e+00 9.72666779e+01 9.13639145e+01 3.85837158e+03
 6.67188330e+03 9.99996364e-01 1.19208849e-07]
MAX DQ []
num train 40000
num valid 10000


In [6]:
def get_info(data, ndim):
    q=data[0][:, 1, 0:ndim]
    dq=data[0][:,1, ndim:2*ndim]
    ddq=data[0][:,1,2*ndim:3*ndim]
    tau=data[1].squeeze()
    return q, dq, ddq, tau*50

class LagrangianNetwork(nn.Module):
    def __init__(self, ndim):
        super(LagrangianNetwork, self).__init__()
        self.feat=nn.Sequential(
            nn.Linear(2*ndim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
        )
        self.diag=nn.Sequential(
            nn.Linear(256, ndim),
            nn.ReLU(),
        )
        self.tril=nn.Linear(256, int(ndim*(ndim-1)/2))
        self.g=nn.Linear(256, ndim)
        
    def get_batch_jacobian(self, q, noutputs):
        n = q.size()[0]
        q.requires_grad_(True)
        q = q.unsqueeze(1)
        q = q.repeat(1, noutputs, 1)
        x = torch.cat((torch.cos(q), torch.sin(q)), dim=-1)
        feature = self.feat(x)
        y = torch.cat([self.tril(feature), self.diag(feature)], dim=-1)
        g = self.g(feature)
        input_val = torch.eye(noutputs).reshape(1,noutputs, noutputs).repeat(n, 1, 1).cuda()
    #     x.retain_grad()
        jac=torch.autograd.grad(y,q,input_val,create_graph=True)
        return y[:,0,:], g[:,0,:], jac[0]


    def forward(self, q, dq, ddq, return_all=False):
        nbatch=q.shape[0]
        ndim=q.shape[1]
        dqr=dq.unsqueeze(1).repeat(1,ndim,1)

        l, g, l_jac=self.get_batch_jacobian(q, int(ndim*(ndim+1)/2))
        L = torch.zeros((nbatch, ndim, ndim)).cuda()
        tril_indices = torch.tril_indices(row=ndim, col=ndim, offset=-1)

        L[:,tril_indices[0], tril_indices[1]] = l[:,:-ndim]
        L[:,torch.arange(ndim),torch.arange(ndim)]=l[:,-ndim:]

        dLdq=torch.zeros((nbatch, ndim, ndim, ndim)).cuda()
        dLdq[:,tril_indices[0], tril_indices[1],:]=l_jac[:,:-ndim]
        dLdq[:,torch.arange(ndim),torch.arange(ndim)]=l_jac[:,-ndim:]   
        dLdt=(dLdq@dqr.unsqueeze(3)).squeeze()

        H=L@L.transpose(1,2)
        dHdt=L@dLdt.transpose(1,2)
        dHdt=dHdt+dHdt.transpose(1,2)
        dHdq=dLdq.permute(0,3,1,2)@(L.unsqueeze(1).repeat(1,ndim,1,1).transpose(2,3))
        dHdq=dHdq+dHdq.transpose(2,3)

        quad=(dqr.unsqueeze(2)@dHdq@dqr.unsqueeze(3)).squeeze()
        tau=(H@ddq.unsqueeze(2)).squeeze()+(dHdt@dq.unsqueeze(2)).squeeze()-quad/2+g
        if return_all:
            return L,dLdq,dLdt,H,dHdq,dHdt,quad,tau
        else:
            return tau
        

## Raw

In [11]:
torch.cuda.set_device('cuda:3')
# ndim=7 for arm, ndim=2 for acrobat
ndim=2

model_lag=LagrangianNetwork(ndim)
model_lag=model_lag.cuda()
        
model_naive = nn.Sequential(
    nn.Linear(3*ndim, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, ndim),
    nn.Tanh()
)

model_naive=model_naive.cuda()
crit = nn.MSELoss()

train_loss_lag=[]
train_loss_naive=[]
val_loss_lag=[]
val_loss_naive=[]

optimizer_lag=torch.optim.Adam(model_lag.parameters(), lr=1e-4)
optimizer_naive=torch.optim.Adam(model_naive.parameters(), lr=1e-4)

for t in range(5):
    for i in range(dataset.num_train):
        data=dataset.sample()
        q,dq,ddq,tau_target=get_info(data,ndim)
#         tau_target=tau_target/50
        
        # Lagrangian network
        tau=model_lag(q,dq,ddq)
        optimizer_lag.zero_grad()
        loss_lag=crit(tau,tau_target)
        loss_lag_rel=torch.mean((tau-tau_target)**2/(tau_target)**2)
        loss_lag.backward()
        optimizer_lag.step()
        
        # Naive network
        tau=model_naive(torch.cat([q,dq,ddq], axis=-1))
        optimizer_naive.zero_grad()
        loss_naive=crit(tau,tau_target)
        loss_naive_rel=torch.mean((tau-tau_target)**2/(tau_target)**2)
        loss_naive.backward()
        optimizer_naive.step()
        
        print(i, 'MSE:',loss_lag.data.item(), loss_naive.data.item(), 'Relative MSE:', loss_lag_rel.data.item(), loss_naive_rel.data.item(),)
        
        if i%1000==0:
            vdata=dataset.sample('valid')
            q,dq,ddq,tau_target=get_info(vdata,ndim)
            tau_target=vdata[1]*50

            # Lagrangian network
            L_lag,dLdq_lag,dLdt_lag,H_lag,dHdq_lag,dHdt_lag,quad_lag,tau=model_lag(q,dq,ddq,True)
            
            loss_lag=crit(tau,tau_target)
            loss_lag_rel=torch.mean((tau-tau_target)**2/(tau_target)**2)
#             train_loss_lag.append(loss_lag.data.item())
#             val_loss_lag.append(loss_val_lag.item())
            
#             # Naive network
            tau=model_naive(torch.cat([q,dq,ddq], axis=1))
            loss_naive=crit(tau,tau_target)
            loss_naive_rel=torch.mean((tau-tau_target)**2/(tau_target)**2)
#             train_loss_naive.append(loss_naive.data.item())
#             val_loss_naive.append(loss_val_naive.item())   
            
            print('VAL',i, 'MSE:',loss_lag.data.item(), loss_naive.data.item(), 'Relative MSE:', loss_lag_rel.data.item(), loss_naive_rel.data.item(),)


0 MSE: 852.0191650390625 844.3885498046875 Relative MSE: 57.580989837646484 9.897561073303223
VAL 0 MSE: 918.3187255859375 918.2095947265625 Relative MSE: 15.110260009765625 7.532215118408203
1 MSE: 817.190673828125 814.0492553710938 Relative MSE: 1.1197062730789185 1.0760786533355713
2 MSE: 799.38427734375 799.0001220703125 Relative MSE: 6.170267105102539 4.666138172149658
3 MSE: 891.3093872070312 888.7840576171875 Relative MSE: 1.0078439712524414 1.2643418312072754
4 MSE: 814.2274169921875 812.1486206054688 Relative MSE: 1.2386200428009033 10.739459037780762
5 MSE: 872.4290771484375 864.23779296875 Relative MSE: 1.038805603981018 1.231734275817871
6 MSE: 863.4141845703125 860.2381591796875 Relative MSE: 1.0078520774841309 1.2164288759231567
7 MSE: 814.3147583007812 814.57666015625 Relative MSE: 1.0452024936676025 2.8082189559936523
8 MSE: 750.580322265625 748.5826416015625 Relative MSE: 0.9942533373832703 1.0500636100769043
9 MSE: 829.364013671875 828.17822265625 Relative MSE: 1.0170

87 MSE: 776.470703125 776.345947265625 Relative MSE: 1.31136953830719 93.53118133544922
88 MSE: 830.7044067382812 833.1065063476562 Relative MSE: 2.0186567306518555 3.1859660148620605
89 MSE: 880.0549926757812 872.670654296875 Relative MSE: 1.018210530281067 1.0863991975784302
90 MSE: 808.390625 804.9230346679688 Relative MSE: 2.6452760696411133 1.9177677631378174
91 MSE: 779.67919921875 775.350830078125 Relative MSE: 1.4395359754562378 273.4100036621094
92 MSE: 825.3421630859375 825.7613525390625 Relative MSE: 1.3782322406768799 1.0393195152282715
93 MSE: 778.6986083984375 774.930908203125 Relative MSE: 1.0332856178283691 1.6903170347213745
94 MSE: 827.880615234375 826.5011596679688 Relative MSE: 1.099531650543213 1.191253423690796
95 MSE: 870.2650146484375 867.2070922851562 Relative MSE: 6.078710556030273 356.0982971191406
96 MSE: 860.310546875 862.538330078125 Relative MSE: 1.1665065288543701 7.8314290046691895
97 MSE: 845.5975341796875 842.6585693359375 Relative MSE: 1.207426071166

180 MSE: 824.629638671875 827.5089111328125 Relative MSE: 1.1749467849731445 1.2054953575134277
181 MSE: 865.1215209960938 871.9886474609375 Relative MSE: 1.1519479751586914 29.460247039794922
182 MSE: 795.21630859375 780.5487060546875 Relative MSE: 1.0214903354644775 1.4940333366394043
183 MSE: 788.5094604492188 766.803955078125 Relative MSE: 1.0373406410217285 1.3245186805725098
184 MSE: 814.0508422851562 803.2933349609375 Relative MSE: 1.7070741653442383 1.177677035331726
185 MSE: 815.93896484375 805.0596313476562 Relative MSE: 1.0708320140838623 1.1864805221557617
186 MSE: 764.64306640625 760.1243896484375 Relative MSE: 1.8805768489837646 3.3210999965667725
187 MSE: 827.7399291992188 823.3675537109375 Relative MSE: 3.479011297225952 2.1535251140594482
188 MSE: 822.8195190429688 823.294677734375 Relative MSE: 1.4054250717163086 3.0574264526367188
189 MSE: 832.8167724609375 832.1537475585938 Relative MSE: 1.2655926942825317 1.1145466566085815
190 MSE: 902.0321044921875 900.7075195312

277 MSE: 787.032958984375 784.9328002929688 Relative MSE: 1.2495988607406616 2.1566872596740723
278 MSE: 866.3477783203125 853.2864990234375 Relative MSE: 11.047178268432617 17.646244049072266
279 MSE: 879.5625 880.6090087890625 Relative MSE: 1.388964056968689 1.0386872291564941
280 MSE: 825.18115234375 825.7881469726562 Relative MSE: 2.6282472610473633 22.972070693969727
281 MSE: 863.8677368164062 856.1498413085938 Relative MSE: 1.222861647605896 52.124420166015625
282 MSE: 796.78125 783.6173706054688 Relative MSE: 2.4132604598999023 445.5380554199219
283 MSE: 803.41162109375 783.8045043945312 Relative MSE: 2.0898523330688477 4.971192359924316
284 MSE: 826.9012451171875 819.7841796875 Relative MSE: 2.371284008026123 4.4602766036987305
285 MSE: 874.5355834960938 866.1397094726562 Relative MSE: 1.115734338760376 1.3828275203704834
286 MSE: 778.4962158203125 773.7742919921875 Relative MSE: 1.560330867767334 15.561113357543945
287 MSE: 832.6157836914062 832.9091796875 Relative MSE: 1.0100

370 MSE: 789.422607421875 777.4564208984375 Relative MSE: 1.0720820426940918 1.083479642868042
371 MSE: 810.6583862304688 810.9993896484375 Relative MSE: 2.193463087081909 4.240206241607666
372 MSE: 919.935546875 914.935302734375 Relative MSE: 1.4442905187606812 2.2045624256134033
373 MSE: 817.5079956054688 809.46240234375 Relative MSE: 1.140305519104004 2.372974395751953
374 MSE: 863.5283203125 869.886962890625 Relative MSE: 1.2845206260681152 1.2947393655776978
375 MSE: 820.0206298828125 817.8326416015625 Relative MSE: 11.087573051452637 10.007547378540039
376 MSE: 859.1009521484375 858.716796875 Relative MSE: 1.0196447372436523 1.2364997863769531
377 MSE: 826.7193603515625 823.5880737304688 Relative MSE: 1.6061160564422607 15.146063804626465
378 MSE: 794.1676025390625 792.6334228515625 Relative MSE: 1.0072059631347656 0.9882897138595581
379 MSE: 813.731689453125 823.0098266601562 Relative MSE: 1.0357542037963867 4.4783430099487305
380 MSE: 837.615966796875 838.5923461914062 Relative

458 MSE: 831.4261474609375 826.8740234375 Relative MSE: 1.0254848003387451 2.3669745922088623
459 MSE: 836.5975341796875 845.9367065429688 Relative MSE: 1.9878888130187988 4.630314826965332
460 MSE: 806.7344970703125 810.1138916015625 Relative MSE: 1.1091631650924683 1.2062134742736816
461 MSE: 810.6364135742188 803.9908447265625 Relative MSE: 1.0227370262145996 1.1573827266693115
462 MSE: 808.29931640625 799.3534545898438 Relative MSE: 1.0249626636505127 1.1638317108154297
463 MSE: 845.0201416015625 851.392578125 Relative MSE: 1.4382374286651611 8.244803428649902
464 MSE: 805.226806640625 798.0255737304688 Relative MSE: 1.0882447957992554 1.034376621246338
465 MSE: 782.8013916015625 794.8734130859375 Relative MSE: 1.1436243057250977 1.59291410446167
466 MSE: 812.1394653320312 796.752685546875 Relative MSE: 4.800998687744141 2.0857863426208496
467 MSE: 840.45947265625 850.640625 Relative MSE: 1.0476436614990234 16.7473201751709
468 MSE: 743.09912109375 747.976318359375 Relative MSE: 1.

552 MSE: 826.3919677734375 827.2984619140625 Relative MSE: 48.514949798583984 8.615781784057617
553 MSE: 822.2638549804688 815.7161865234375 Relative MSE: 1.1332058906555176 1.1673433780670166
554 MSE: 796.551513671875 807.2734375 Relative MSE: 1.0641074180603027 4.916255950927734
555 MSE: 863.3843994140625 850.9515991210938 Relative MSE: 1.8252277374267578 1.1528778076171875
556 MSE: 826.0960693359375 823.044677734375 Relative MSE: 19.885601043701172 2.545224666595459
557 MSE: 776.6357421875 768.7928466796875 Relative MSE: 1.7792683839797974 104.84452056884766
558 MSE: 848.7535400390625 834.370361328125 Relative MSE: 1.0378146171569824 1.1690285205841064
559 MSE: 779.9188232421875 782.08935546875 Relative MSE: 1.0938198566436768 1.1507041454315186
560 MSE: 832.37255859375 832.9786376953125 Relative MSE: 1.3257267475128174 1.2952511310577393
561 MSE: 854.273193359375 844.174560546875 Relative MSE: 1.830413579940796 1.7372074127197266
562 MSE: 781.587646484375 780.1527099609375 Relative

642 MSE: 780.0408935546875 777.875732421875 Relative MSE: 24.582778930664062 4.302990436553955
643 MSE: 761.8982543945312 761.1275634765625 Relative MSE: 1.4893479347229004 1.1159734725952148
644 MSE: 790.7965087890625 793.7257080078125 Relative MSE: 1.1153076887130737 1.0680674314498901
645 MSE: 833.9318237304688 822.6927490234375 Relative MSE: 27.01076889038086 129.38946533203125
646 MSE: 842.0968017578125 821.7747802734375 Relative MSE: 1.5182888507843018 1.9139822721481323
647 MSE: 887.9116821289062 883.867431640625 Relative MSE: 1.9180631637573242 1.633152961730957
648 MSE: 827.86572265625 827.2266845703125 Relative MSE: 3.182490825653076 1.9526658058166504
649 MSE: 855.9693603515625 826.5618286132812 Relative MSE: 1.7227671146392822 1.7664783000946045
650 MSE: 846.29345703125 844.998291015625 Relative MSE: 2.6879751682281494 1.2770404815673828
651 MSE: 816.6732177734375 818.4654541015625 Relative MSE: 1.011568546295166 1.0486576557159424
652 MSE: 852.88623046875 841.9483642578125

735 MSE: 796.7550659179688 786.7081909179688 Relative MSE: 21.627351760864258 1.4996352195739746
736 MSE: 808.948974609375 814.9227294921875 Relative MSE: 1.0327073335647583 1.235966444015503
737 MSE: 815.1057739257812 823.861572265625 Relative MSE: 3.4706857204437256 15.575447082519531
738 MSE: 815.1919555664062 818.258056640625 Relative MSE: 1.739884376525879 8.722142219543457
739 MSE: 780.5440063476562 783.6204223632812 Relative MSE: 3.1433937549591064 6.109692573547363
740 MSE: 843.482421875 838.9803466796875 Relative MSE: 1.1143782138824463 2.1847028732299805
741 MSE: 818.3759765625 813.3072509765625 Relative MSE: 1.3430676460266113 4.823183059692383
742 MSE: 845.01611328125 840.5529174804688 Relative MSE: 63.66810989379883 3.3240978717803955
743 MSE: 857.0828857421875 836.35693359375 Relative MSE: 1.4874584674835205 1.673000693321228
744 MSE: 769.345947265625 781.1716918945312 Relative MSE: 2.679445505142212 1.3789321184158325
745 MSE: 869.8299560546875 863.9034423828125 Relative

834 MSE: 797.3084106445312 790.3358154296875 Relative MSE: 1.9594950675964355 2.2551748752593994
835 MSE: 868.3626708984375 867.528076171875 Relative MSE: 1.3052945137023926 1.7507038116455078
836 MSE: 793.4129638671875 795.3492431640625 Relative MSE: 1.0548620223999023 1.0788681507110596
837 MSE: 810.775146484375 816.268798828125 Relative MSE: 1.0855998992919922 1.217963457107544
838 MSE: 771.460205078125 755.4153442382812 Relative MSE: 9.387484550476074 1.4267232418060303
839 MSE: 881.80078125 887.2014770507812 Relative MSE: 1.1052887439727783 1.1025292873382568
840 MSE: 807.8568115234375 809.659912109375 Relative MSE: 1.002617597579956 1.1347713470458984
841 MSE: 845.6338500976562 849.2418212890625 Relative MSE: 1.0548174381256104 1.680189847946167
842 MSE: 807.8462524414062 798.5208740234375 Relative MSE: 2.7081360816955566 2.0110936164855957
843 MSE: 802.0678100585938 801.54150390625 Relative MSE: 1.088315725326538 1.2290349006652832
844 MSE: 822.0087890625 812.705810546875 Relati

KeyboardInterrupt: 

## Fit Random

In [39]:
torch.cuda.set_device('cuda:3')
ndim=2

model_lag=LagrangianNetwork(ndim).cuda()
model_gen=LagrangianNetwork(ndim).cuda()
model_naive = nn.Sequential(
    nn.Linear(3*ndim, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, ndim),
    nn.Tanh()
).cuda()
crit = nn.MSELoss()

train_loss_lag=[]
train_loss_naive=[]
val_loss_lag=[]
val_loss_naive=[]

optimizer_lag=torch.optim.Adam(model_lag.parameters(), lr=1e-4)
optimizer_naive=torch.optim.Adam(model_naive.parameters(), lr=1e-4)

for t in range(5):
    for i in range(dataset.num_train):
        data=dataset.sample()
        q,dq,ddq,tau_target=get_info(data,ndim)
        L_target,dLdq_target,dLdt_target,H_target,dHdq_target,dHdt_target,quad_target,tau_target=model_gen(q,dq,ddq,True)
        tau_target=tau_target.detach()
        
        # Lagrangian network
        L,dLdq,dLdt,H,dHdq,dHdt,quad,tau=model_lag(q,dq,ddq,True)
        optimizer_lag.zero_grad()
        loss_lag=crit(tau,tau_target)
        loss_lag_rel=torch.mean((tau-tau_target)**2/(tau_target)**2)
        
        
        # Different targets
#         loss_lag=crit(H, H_target)
#         loss_lag=crit(quad, quad_target)
#         loss_lag=crit(dHdt, dHdt_target)
#         loss_lag=crit((dHdt@dq.unsqueeze(2)).squeeze(), (dHdt_target@dq.unsqueeze(2)).squeeze())
#         loss_lag=crit((H@ddq.unsqueeze(2)).squeeze(), (H_target@ddq.unsqueeze(2)).squeeze())
#         loss_lag.backward(retain_graph=True)
        loss_lag.backward()
        optimizer_lag.step()
        
        # Naive network
        tau=model_naive(torch.cat([q,dq,ddq], axis=-1))
        optimizer_naive.zero_grad()
        loss_naive=crit(tau,tau_target)
        loss_naive_rel=torch.mean((tau-tau_target)**2/(tau_target)**2)
        loss_naive.backward()
        optimizer_naive.step()
        
        print('\n',i, '\nMSE:',loss_lag.data.item(), loss_naive.data.item(), '\nRelative MSE:', loss_lag_rel.data.item(), loss_naive_rel.data.item(),)
        
        if i%1000==0:
            vdata=dataset.sample('valid')
            q,dq,ddq,tau_target=get_info(vdata,ndim)
            L_target,dLdq_target,dLdt_target,H_target,dHdq_target,dHdt_target,quad_target,tau_target=model_gen(q,dq,ddq,True)

            # Lagrangian network
            L,dLdq,dLdt,H,dHdq,dHdt,quad,tau=model_lag(q,dq,ddq,True)
            
            loss_lag=crit(tau,tau_target)
            loss_lag_rel=torch.mean(((tau-tau_target)/tau_target)**2)
#             train_loss_lag.append(loss_lag.data.item())
#             val_loss_lag.append(loss_val_lag.item())
            
#             # Naive network
            tau=model_naive(torch.cat([q,dq,ddq], axis=1))
            loss_naive=crit(tau,tau_target)
            loss_naive_rel=torch.mean((tau-tau_target)**2/(tau_target)**2)
#             train_loss_naive.append(loss_naive.data.item())
#             val_loss_naive.append(loss_val_naive.item())   
            
            print('VAL',i, 'MSE:',loss_lag.data.item(), loss_naive.data.item(), 'Relative MSE:', loss_lag_rel.data.item(), loss_naive_rel.data.item(),)



 0 
MSE: 228.6424102783203 30.68705940246582 
Relative MSE: 249986.734375 204.29380798339844
VAL 0 MSE: 115.89969635009766 37.62029266357422 Relative MSE: 1768.573974609375 1197.067138671875

 1 
MSE: 92.08088684082031 28.111549377441406 
Relative MSE: 3563.269287109375 1396.2322998046875

 2 
MSE: 54.26406478881836 24.892688751220703 
Relative MSE: 479526.375 1848.933837890625

 3 
MSE: 39.247169494628906 28.125152587890625 
Relative MSE: 58.299537658691406 826.893798828125

 4 
MSE: 24.240196228027344 26.166358947753906 
Relative MSE: 206.3905487060547 1196.3570556640625

 5 
MSE: 26.359588623046875 25.556167602539062 
Relative MSE: 130.83169555664062 107.47006225585938

 6 
MSE: 22.36331558227539 27.27581214904785 
Relative MSE: 11.693815231323242 147.8741912841797

 7 
MSE: 22.307525634765625 26.967527389526367 
Relative MSE: 57.354461669921875 2007.476318359375

 8 
MSE: 22.01972198486328 25.449037551879883 
Relative MSE: 38.0400390625 5469.47265625

 9 
MSE: 23.423046112060547 2

 85 
MSE: 4.581513404846191 27.550418853759766 
Relative MSE: 4.274971008300781 393.8302001953125

 86 
MSE: 4.034107208251953 23.937938690185547 
Relative MSE: 855.560791015625 10452.8173828125

 87 
MSE: 3.688624382019043 25.924448013305664 
Relative MSE: 47.36824035644531 876.8925170898438

 88 
MSE: 4.125500679016113 28.96678924560547 
Relative MSE: 32.97437286376953 206.92787170410156

 89 
MSE: 2.9921646118164062 16.174158096313477 
Relative MSE: 83.15711212158203 2267.47314453125

 90 
MSE: 4.0791473388671875 25.309444427490234 
Relative MSE: 13.63748836517334 184.15576171875

 91 
MSE: 3.136442184448242 23.565940856933594 
Relative MSE: 19.655200958251953 256.9980773925781

 92 
MSE: 5.290219306945801 29.92436981201172 
Relative MSE: 476.6665954589844 212.15904235839844

 93 
MSE: 3.0398683547973633 20.769939422607422 
Relative MSE: 10143.546875 7486.5439453125

 94 
MSE: 3.523197650909424 21.26596450805664 
Relative MSE: 99.29069519042969 111.94456481933594

 95 
MSE: 3.506632


 170 
MSE: 3.2087795734405518 18.768238067626953 
Relative MSE: 67.6711196899414 752.4625244140625

 171 
MSE: 3.1104071140289307 22.377452850341797 
Relative MSE: 7.898682594299316 175.76962280273438

 172 
MSE: 2.671372413635254 26.154590606689453 
Relative MSE: 2374.29931640625 25560.240234375

 173 
MSE: 2.8646860122680664 25.154102325439453 
Relative MSE: 120.57400512695312 3794.251953125

 174 
MSE: 2.6984355449676514 30.665653228759766 
Relative MSE: 92.46875 25402.625

 175 
MSE: 2.5353195667266846 22.877620697021484 
Relative MSE: 49.2346305847168 362.0782775878906

 176 
MSE: 2.550034284591675 29.42755699157715 
Relative MSE: 16149.0166015625 385778.78125

 177 
MSE: 2.6674342155456543 25.78644561767578 
Relative MSE: 24.104511260986328 634.8907470703125

 178 
MSE: 3.0183300971984863 25.55221176147461 
Relative MSE: 48.75276565551758 941.5240478515625

 179 
MSE: 2.68931245803833 25.13388442993164 
Relative MSE: 718.9691772460938 2878.91357421875

 180 
MSE: 2.2233586311340

KeyboardInterrupt: 

In [35]:
L_target,dLdq_target,dLdt_target,H_target,dHdq_target,dHdt_target,quad_target,tau_target=model_gen(q,dq,ddq,True)
L,dLdq,dLdt,H,dHdq,dHdt,quad,tau=model_lag(q,dq,ddq,True)
i=5
print(tau_target[i])
print(tau[i])
print(H[i])
print(H_target[i])
print(dHdt[i])
print(dHdt_target[i])
print(L[i])
print(L_target[i])
print(quad[i])
print(quad_target[i])
print((H@ddq.unsqueeze(2)).squeeze()[i], (H_target@ddq.unsqueeze(2)).squeeze()[i])

tensor([-0.3473, -7.1716], device='cuda:3', grad_fn=<SelectBackward>)
tensor([-0.1597, -6.8257], device='cuda:3', grad_fn=<SelectBackward>)
tensor([[0.0000, 0.0000],
        [0.0000, 0.0025]], device='cuda:3', grad_fn=<SelectBackward>)
tensor([[0.0000, 0.0000],
        [0.0000, 0.0025]], device='cuda:3', grad_fn=<SelectBackward>)
tensor([[ 0.0000,  0.0000],
        [ 0.0000, -0.0558]], device='cuda:3', grad_fn=<SelectBackward>)
tensor([[ 0.0000,  0.0000],
        [ 0.0000, -0.0567]], device='cuda:3', grad_fn=<SelectBackward>)
tensor([[ 0.0000,  0.0000],
        [-0.0502,  0.0000]], device='cuda:3', grad_fn=<SelectBackward>)
tensor([[0.0000, 0.0000],
        [0.0345, 0.0368]], device='cuda:3', grad_fn=<SelectBackward>)
tensor([0.2176, 4.4811], device='cuda:3', grad_fn=<SelectBackward>)
tensor([0.7592, 5.2025], device='cuda:3', grad_fn=<SelectBackward>)
tensor([ 0.0000, -8.7261], device='cuda:3', grad_fn=<SelectBackward>) tensor([ 0.0000, -8.8383], device='cuda:3', grad_fn=<SelectBackwar

In [33]:
model_lag(q,dq,ddq)[1]

tensor([0.3969, 1.0341], device='cuda:3', grad_fn=<SelectBackward>)

In [21]:
tau[0]

tensor([-1.0000, -0.9989], device='cuda:3', grad_fn=<SelectBackward>)

In [22]:
tau_target[0]

tensor([ 0.0307, -0.1165], device='cuda:3', grad_fn=<SelectBackward>)