In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, os.path 
import numpy 
import pickle
from glob import glob
from tqdm.notebook import tqdm, trange

"""Change to the data folder"""
new_path = "./new_train/new_train"

cuda_status = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# number of sequences in each dataset
# train:205942  val:3200 test: 36272 
# sequences sampled at 10HZ rate

### Create a dataset class 

In [2]:
class ArgoverseDataset(Dataset):
    """Dataset class for Argoverse"""
    def __init__(self, data_path: str, transform=None):
        super(ArgoverseDataset, self).__init__()
        self.data_path = data_path
        self.transform = transform

        self.pkl_list = glob(os.path.join(self.data_path, '*'))
        self.pkl_list.sort()
        
    def __len__(self):
        return len(self.pkl_list)

    def __getitem__(self, idx):

        pkl_path = self.pkl_list[idx]
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)
            
        if self.transform:
            data = self.transform(data)

        return data


# intialize a dataset
val_dataset  = ArgoverseDataset(data_path=new_path)

### Create a loader to enable batch processing

In [3]:
batch_sz = 64

def my_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    out = [numpy.dstack([scene['p_out'], scene['v_out']]) for scene in batch]
    scene_ids = [scene['scene_idx'] for scene in batch]
    track_ids = [scene['track_id'] for scene in batch]
    agent_ids = [scene['agent_id'] for scene in batch]
    inp = torch.LongTensor(inp)
    out = torch.LongTensor(out)
    scene_ids = torch.LongTensor(scene_ids)
    return [inp, out, scene_ids, track_ids, agent_ids]

def test_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    scene_ids = [scene['scene_idx'] for scene in batch]
    track_ids = [scene['track_id'] for scene in batch]
    agent_ids = [scene['agent_id'] for scene in batch]
    inp = torch.LongTensor(inp)
    scene_ids = torch.LongTensor(scene_ids)
    return [inp, scene_ids, track_ids, agent_ids]

val_loader = DataLoader(val_dataset,batch_size=batch_sz, shuffle = False, collate_fn=my_collate, num_workers=0)

In [22]:
class CNNRNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(4, 64, (3,3))
        self.relu = torch.nn.ReLU()
        self.maxPool = torch.nn.MaxPool2d((3,3), stride=2)
        self.conv2 = torch.nn.Conv2d(64, 128, (3,3))
        self.linear = torch.nn.Linear(4800, 7200)
        self.rnn = torch.nn.RNN(128, 200, nonlinearity='relu', batch_first=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.maxPool(x)
        x = self.relu(self.conv2(x))
        x = self.maxPool(x)
        x = torch.flatten(x, start_dim = 2)
        x = x.permute(0, 2, 1)
        x, _ = self.rnn(x)
        x = torch.flatten(x, start_dim = 1)
        x = self.linear(x)
        x = torch.reshape(x, torch.Size([64, 60, 30, 4]))
        return x

model = CNNRNN()
model.to(device)
if cuda_status:
    model = model.cuda()

In [None]:
model = torch.load('./models/1epochAdam.pt')
model.train()
model.to(device)
if cuda_status:
    model = model.cuda()

### Visualize the batch of sequences

In [None]:
import matplotlib.pyplot as plt
import random
from tqdm.notebook import tqdm

agent_id = 0
epoch = 3
        
# Use the nn package to define our loss function
loss_fn=torch.nn.MSELoss()

# Use the optim package to define an Optimizer

learning_rate =1e-3
#learning_rate =0.01
#optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
iterator = tqdm(val_loader, total=int(len(val_loader)))

for i in trange(epoch):
    
    for i_batch, sample_batch in enumerate(iterator):
        inp, out, scene_ids, track_ids, agent_ids = sample_batch
        """TODO:
          Deep learning model
          training routine
        """
        x = inp.permute(0,3,1,2).float()
        x = x.float()
        y = out.float()

        if cuda_status:
            x.to(device)
            y.to(device)
            x = x.cuda()
            y = y.cuda()

        y_pred = None
        # Forward pass: predict y by passing x to the model.    
        y_pred = model(x)
        #y_pred = torch.reshape(y_pred, torch.Size([batch_sz, 60, 30, 4]))

        # Compute the loss.
        loss = loss_fn(y_pred, y)

        # Before backward pass, zero outgradients to clear buffers  
        optimizer.zero_grad()

        # Backward pass: compute gradient w.r.t modelparameters
        loss.backward()

        # makes an gradient descent step to update its parameters
        optimizer.step()
        
        print(torch.sqrt(loss).item())
        iterator.set_postfix(loss=(torch.sqrt(loss).item()))

        if i_batch == 3216:
            break


HBox(children=(FloatProgress(value=0.0, max=3218.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

594.7294311523438
575.9655151367188
581.56201171875
482.83984375
418.925537109375
519.5551147460938
329.73663330078125
374.31402587890625
415.0807189941406
376.6632995605469
401.1958923339844
352.32855224609375
348.4886779785156
336.4626770019531
357.71197509765625
339.9645690917969
306.7245178222656
301.5850830078125
280.14373779296875
280.5069580078125
266.3636779785156
266.8194885253906
260.28204345703125
243.26776123046875
209.39060974121094
219.85113525390625
256.3262023925781
200.3320770263672
214.6450653076172
205.1403045654297
196.83396911621094
202.61184692382812
191.03883361816406
178.88827514648438
194.75416564941406
171.91204833984375
166.2494354248047
181.68203735351562
154.53826904296875
170.274169921875
174.53929138183594
158.52012634277344
155.66146850585938
163.94752502441406
159.83648681640625
166.2750701904297
151.91387939453125
138.37469482421875
150.5624237060547
140.54583740234375
157.25010681152344
138.6998291015625
137.51300048828125
138.6575927734375
137.505233

15.243630409240723
15.333113670349121
20.610031127929688
18.831235885620117
13.691628456115723
17.77817726135254
16.292255401611328
23.51111602783203
12.916882514953613
17.00955581665039
12.816092491149902
12.547394752502441
14.207308769226074
12.190974235534668
16.0218448638916
12.89465045928955
15.81289291381836
19.49411964416504
10.887709617614746
12.048445701599121
11.0712890625
10.533095359802246
11.087474822998047
10.899065017700195
12.334254264831543
25.594942092895508
14.69421100616455
14.360763549804688
11.210768699645996
16.19485855102539
27.864587783813477
16.192527770996094
13.77562427520752
10.493185043334961
11.377900123596191
17.915088653564453
12.456106185913086
10.578797340393066
11.30079174041748
10.595511436462402
12.543694496154785
10.459762573242188
9.982020378112793
11.255632400512695
10.48917293548584
10.498064041137695
10.700643539428711
10.159385681152344
29.733747482299805
16.332408905029297
14.17668628692627
11.738049507141113
10.852721214294434
10.9926462173

30.119569778442383
17.40519905090332
21.285478591918945
13.61147403717041
14.510533332824707
11.326326370239258
19.962223052978516
12.142400741577148
11.969976425170898
10.609624862670898
11.773473739624023
13.874035835266113
10.961104393005371
10.690508842468262
11.137889862060547
11.451205253601074
18.08829689025879
10.053861618041992
10.187716484069824
17.155715942382812
15.188077926635742
11.773284912109375
11.695616722106934
9.928312301635742
9.727876663208008
12.924023628234863
10.290773391723633
10.407732009887695
11.734395980834961
12.68596363067627
7.7805280685424805
12.153243064880371
9.814044952392578
25.07418441772461
67.83936309814453
15.923508644104004
14.570269584655762
15.573183059692383
15.39831829071045
23.177608489990234
15.264952659606934
14.586956977844238
16.672988891601562
44.95372009277344
20.318172454833984
21.774574279785156
20.84856605529785
29.22330665588379
21.465559005737305
19.678573608398438
19.953805923461914
20.117582321166992
16.929786682128906
18.362

21.958927154541016
22.680824279785156
24.704435348510742
29.27224349975586
22.273723602294922
22.482120513916016
38.14503479003906
26.289291381835938
19.458160400390625
25.36684226989746
19.934436798095703
20.347158432006836
17.458541870117188
20.98039436340332
17.99822425842285
17.301298141479492
27.5249080657959
22.40322494506836
18.32665252685547
82.25408172607422
36.44700241088867
41.21138381958008
38.39756393432617
30.86717414855957
31.012954711914062
32.28047180175781
34.40625762939453
22.2817440032959
31.638824462890625
21.354293823242188
25.589298248291016
22.03679847717285
22.236711502075195
23.8471736907959
18.68403434753418
43.09326171875
24.83765983581543
28.41200828552246
21.974946975708008
18.451923370361328
68.6041259765625
26.053064346313477
25.776966094970703
31.863327026367188
21.46206283569336
26.546173095703125
41.74263381958008
23.16876792907715
41.64002227783203
20.12685203552246
24.422273635864258
19.068811416625977
23.03491973876953
17.96945571899414
21.51858139

7.660918235778809
8.595407485961914
9.674986839294434
8.836640357971191
12.31118392944336
7.960610389709473
8.142959594726562
9.05069351196289
8.487266540527344
8.92641544342041
8.059305191040039
9.33009147644043
9.51667308807373
9.9293212890625
11.443182945251465
7.117691516876221
7.877203464508057
9.833468437194824
19.75649642944336
8.744176864624023
8.245122909545898
7.461423873901367
9.08963394165039
7.842501640319824
8.944602012634277
7.427175045013428
8.011818885803223
8.672486305236816
15.435098648071289
8.9089937210083
8.51931381225586
8.322671890258789
7.933389663696289
9.07852554321289
15.278083801269531
8.998819351196289
8.250596046447754
9.895263671875
10.996227264404297
7.593086242675781
10.736411094665527
8.39509391784668
9.187813758850098
8.14037799835205
8.92597484588623
11.394649505615234
8.92151927947998
8.438835144042969
9.404266357421875
9.987244606018066
8.21226692199707
7.761626243591309
7.7328643798828125
10.416475296020508
7.524131774902344
8.968767166137695
8.2

17.735275268554688
11.46117877960205
13.356889724731445
12.617570877075195
13.166984558105469
17.986722946166992
11.764846801757812
17.662824630737305
14.69351863861084
12.21744155883789
12.262085914611816
12.005695343017578
12.605237007141113
10.793499946594238
12.066510200500488
11.989123344421387
17.367124557495117
10.368706703186035
12.495145797729492
13.006308555603027
11.912640571594238
12.171361923217773
11.857437133789062
12.073734283447266
12.390270233154297
10.601548194885254
11.823184967041016
18.473957061767578
11.324109077453613
12.57424259185791
9.55239200592041
13.181439399719238
10.21240234375
10.590703964233398
11.127732276916504
11.579955101013184
10.300776481628418
14.027910232543945
13.033432006835938
14.48148250579834
12.613571166992188
12.803606033325195
11.667356491088867
13.168434143066406
15.869915008544922
9.84764575958252
13.154995918273926
9.86141300201416
13.831295013427734
9.29454517364502
13.121302604675293
11.12038516998291
12.777693748474121
11.31748485

8.582626342773438
7.768979072570801
9.077269554138184
8.08152961730957
8.456355094909668
10.308656692504883
8.278218269348145
10.014095306396484
8.675564765930176
9.296465873718262
9.086162567138672
8.040620803833008
8.587908744812012
47.159629821777344
25.647274017333984
15.286344528198242
17.357786178588867
16.03846549987793
13.454063415527344
13.923691749572754
15.564959526062012
18.534997940063477
12.460803985595703
21.548429489135742
13.022139549255371
11.056380271911621
10.883183479309082
13.785370826721191
14.101127624511719
12.426668167114258
18.141386032104492
11.673300743103027
13.953142166137695
13.294540405273438
11.9475679397583
10.807204246520996
13.415942192077637
14.742558479309082
10.130636215209961
13.328744888305664
12.429306030273438
10.571825981140137
10.682947158813477
10.76710319519043
9.759588241577148
12.495565414428711
8.490911483764648
9.784432411193848
10.129024505615234
10.542411804199219
8.797638893127441
9.780421257019043
10.531987190246582
8.516985893249

7.460062503814697
7.458831310272217
7.618724822998047
7.334355354309082
7.094327449798584
7.9253973960876465
7.119197368621826
6.912120342254639
7.904712200164795
7.025936126708984
8.070719718933105
7.525928497314453
7.252161026000977
7.617103576660156
7.327266216278076
7.746463298797607
44.376747131347656
10.93591594696045
29.605207443237305
12.792924880981445
18.33574676513672
12.00269889831543
11.845202445983887
15.266902923583984
12.651336669921875
11.050979614257812
13.587920188903809
11.919463157653809
12.652820587158203
12.272672653198242
9.445661544799805
11.958849906921387
12.010958671569824
10.292852401733398
10.394519805908203
95.36680603027344
19.272247314453125
28.04538917541504
25.88650894165039
30.839107513427734
36.84658432006836
20.157733917236328
23.474817276000977
23.69671630859375
27.14269256591797
21.300745010375977
20.326322555541992
21.735546112060547
23.030677795410156
22.641298294067383
28.935548782348633
43.33173370361328
23.53460693359375
32.30253982543945
22

13.912593841552734
11.617936134338379
19.82314682006836
13.29040241241455
11.911526679992676
10.448795318603516
15.085775375366211
11.271620750427246
27.42650604248047
14.743020057678223
11.488797187805176
11.439901351928711
11.08912181854248
13.092913627624512
13.276636123657227
13.698864936828613
11.817030906677246
19.261930465698242
13.902609825134277
10.273612022399902
12.003880500793457
16.892810821533203
10.045317649841309
11.677440643310547
10.919276237487793
10.670299530029297
11.809388160705566
8.904473304748535
12.901647567749023
10.743103981018066
12.96519947052002
11.580881118774414
9.762470245361328
9.664901733398438
10.507121086120605
10.195351600646973
9.427093505859375
12.76950454711914
11.971635818481445
10.333524703979492
9.197564125061035
11.099163055419922
40.718963623046875
14.591039657592773
16.153676986694336
16.121849060058594
14.35827350616455
13.964056015014648
15.385919570922852
24.842060089111328
14.915217399597168
13.560962677001953
27.753217697143555
15.09

28.160999298095703
22.516834259033203
23.568828582763672
23.782032012939453
24.965885162353516
15.329285621643066
25.41643524169922
19.357295989990234
23.092744827270508
15.745077133178711
20.489299774169922
18.467960357666016
25.031774520874023
17.334897994995117
16.414331436157227
15.445107460021973
25.832584381103516
13.149041175842285
24.87025260925293
17.638370513916016
17.218538284301758
17.014001846313477
17.18682289123535
15.582051277160645
12.293243408203125
19.97165870666504
14.021663665771484
14.851066589355469
15.416180610656738
21.821876525878906
16.16913604736328
15.900100708007812
13.517029762268066
14.34799861907959
13.934881210327148
13.269356727600098
10.96173095703125
10.494911193847656
12.013872146606445
13.753880500793457
15.136186599731445
14.0999755859375
10.206740379333496
12.55682373046875
10.048526763916016
12.433149337768555
10.312613487243652
9.115998268127441
10.866625785827637
9.394083023071289
11.294976234436035
9.842961311340332
9.536972045898438
8.70211

8.910845756530762
9.042332649230957
9.809041976928711
10.046895980834961
8.337801933288574
9.542539596557617
8.514914512634277
9.075095176696777
10.221796035766602
9.925095558166504
8.768895149230957
8.587600708007812
8.196124076843262
8.064218521118164
8.746851921081543
9.142465591430664
10.565985679626465
10.337421417236328
7.988521575927734
9.308000564575195
8.463133811950684
9.093055725097656
7.917481422424316
7.3170247077941895
8.884920120239258
7.218078136444092
8.581652641296387
7.73235559463501
8.418998718261719
8.12998104095459
7.8432464599609375
9.154781341552734
7.492988109588623
8.372509956359863
13.188016891479492
9.285905838012695
7.420551776885986
8.466790199279785
8.700949668884277
8.726256370544434
9.216141700744629
8.984498977661133
32.82200622558594
11.263876914978027
9.020772933959961
13.699190139770508
11.718893051147461
12.41557502746582
11.025496482849121
10.509464263916016
9.492877960205078
10.659645080566406
10.254228591918945
9.872990608215332
8.79545879364013

15.414608001708984
12.995285034179688
12.075014114379883
10.772849082946777
12.011895179748535
9.494464874267578
11.896889686584473
9.67486572265625
9.554027557373047
12.704066276550293
9.96422004699707
11.236322402954102
8.56611442565918
9.070515632629395
9.93618106842041
8.655802726745605
6.965576171875
9.10813045501709
8.196104049682617
9.495698928833008
8.04482650756836
11.263607025146484
10.1419095993042
8.504281997680664
9.867961883544922
8.55693244934082
8.276273727416992
8.501702308654785
7.804042339324951
7.7006707191467285
7.6263275146484375
8.821186065673828
8.712732315063477
7.264102935791016
9.958571434020996
7.716718673706055
9.710930824279785
11.64848518371582
9.0738525390625
7.700290679931641
9.227025032043457
10.092373847961426
7.869169235229492
9.199626922607422
7.691209316253662
7.943917274475098
9.62044620513916
8.249975204467773
9.654722213745117
8.141192436218262
7.086182594299316
8.546838760375977
7.862727642059326
7.771894454956055
7.446408271789551
7.6118402481

6.48036527633667
6.791483402252197
7.190149307250977
7.251340389251709
35.116233825683594
11.528045654296875
18.192995071411133
15.635436058044434
17.728620529174805
14.943014144897461
11.501189231872559
15.719639778137207
12.26141357421875
12.710999488830566
13.76074504852295
14.476768493652344
13.155263900756836
14.737285614013672
12.225241661071777
12.57503604888916
11.776371002197266
12.368950843811035
13.526071548461914
11.942717552185059
13.804927825927734
11.980624198913574
11.115288734436035
11.93875503540039
9.831661224365234
10.601862907409668
10.98769760131836
10.111457824707031
9.890198707580566
10.63974666595459
10.688438415527344
9.903958320617676
9.264299392700195
10.855118751525879
10.116421699523926
10.575804710388184
10.047157287597656
27.093957901000977
16.626808166503906
10.709076881408691
13.503253936767578
10.838305473327637
15.293994903564453
10.662237167358398
11.791962623596191
10.796707153320312
14.04698371887207
11.17740249633789
10.59008502960205
12.01918506

8.837276458740234
8.841404914855957
15.404407501220703
9.604965209960938
7.98579216003418
10.84703254699707
8.656888008117676
9.152779579162598
12.497796058654785
10.172080993652344
9.312104225158691
12.517719268798828
9.7360258102417
11.286155700683594
11.018900871276855
9.096280097961426
10.633426666259766
9.0617094039917
9.83651351928711
7.443286895751953
9.157661437988281
9.150335311889648
8.590048789978027
9.114670753479004
8.772379875183105
8.317770004272461
8.234222412109375
8.40880012512207
7.8775153160095215
8.408560752868652
8.255929946899414
9.04178524017334
9.446549415588379
9.10487174987793
8.058707237243652
7.498234272003174
7.830175399780273
8.372512817382812
10.164206504821777
8.394947052001953
8.474780082702637
8.553683280944824
7.306196689605713
7.097385406494141
8.607359886169434
8.578483581542969
7.927817344665527
11.71138858795166
9.340668678283691
8.86469841003418
8.390554428100586
7.561664581298828
7.630246162414551
9.469332695007324
8.268446922302246
11.33404541

12.476045608520508
15.550742149353027
9.536615371704102
13.154569625854492
16.616737365722656
11.416186332702637
12.842191696166992
13.589820861816406
13.445051193237305
12.681970596313477
9.110712051391602
9.346240043640137
10.261639595031738
9.980412483215332
10.311429977416992
10.182414054870605
10.703571319580078
11.14535140991211
8.687989234924316
9.702679634094238
8.806044578552246
8.580419540405273
10.340639114379883
9.346427917480469
8.830355644226074
9.275806427001953
9.119366645812988
11.855806350708008
10.387596130371094
9.660174369812012
8.819836616516113
8.833429336547852
8.851300239562988
10.160737991333008
9.835186004638672
7.998276233673096
9.961441040039062
9.30618667602539
10.205910682678223
8.708843231201172
8.232224464416504
10.145345687866211
8.555548667907715
7.458498477935791
7.863020896911621
8.08400821685791
7.596018314361572
7.456208229064941
9.736612319946289
8.647027969360352
9.16195297241211
8.746532440185547
8.216816902160645
8.225833892822266
9.0167493820

10.324644088745117
10.692886352539062
10.842305183410645
10.489540100097656
10.5965576171875
10.587282180786133
8.588488578796387
10.115431785583496
10.940690994262695
17.833728790283203
11.902771949768066
9.622040748596191
10.179718017578125
9.852415084838867
8.871844291687012
9.634642601013184
10.394725799560547
10.127933502197266
10.407044410705566
8.614588737487793
9.932208061218262
9.18262004852295
8.335481643676758
16.379650115966797
8.975564002990723
15.900111198425293
10.090209007263184
12.917316436767578
9.176913261413574
10.3281888961792
11.009967803955078
8.904423713684082
8.92564582824707
9.894001007080078
33.37468338012695
26.565156936645508
14.288630485534668
11.461930274963379
12.815619468688965
18.754732131958008
11.779644966125488
13.271387100219727
16.04144859313965
11.593571662902832
13.978689193725586
26.19463348388672
13.407222747802734
14.66883659362793
17.191425323486328
11.437093734741211
16.717199325561523
11.534978866577148
13.627811431884766
11.65138244628906

7.498052597045898
7.801900386810303
7.02216911315918
7.292203426361084
12.241825103759766
7.844273090362549
7.76335334777832
9.211490631103516
7.4706878662109375
7.1710357666015625
8.18510913848877
8.969977378845215
8.62057113647461
6.940980434417725
7.80289363861084
8.17719841003418
7.743444919586182
7.528343677520752
6.939727306365967
8.977143287658691
6.724529266357422
7.500753879547119
6.641271114349365
6.98681640625
6.783323287963867
6.838065147399902
7.011281490325928
7.1987481117248535
7.840015888214111
7.1858601570129395
7.875670433044434
8.04955768585205
7.693275451660156
7.137086391448975
7.4967145919799805
8.60524845123291
7.397894382476807
9.489933967590332
8.057745933532715
6.743692874908447
6.999166488647461
8.41417121887207
8.26951789855957
7.135112762451172
7.059249401092529
7.481542587280273
6.567617893218994
8.672398567199707
8.045868873596191
8.090086936950684
8.055587768554688
7.738253593444824
10.146122932434082
12.206879615783691
10.55795955657959
11.0118513107299

9.22711181640625
9.277081489562988
7.460833549499512
9.476103782653809
9.553438186645508
9.752928733825684
8.524442672729492
9.996675491333008
10.748296737670898
9.438680648803711
13.533977508544922
10.53748893737793
8.680867195129395
11.257095336914062
9.398992538452148
10.764253616333008
8.100381851196289
11.296296119689941
10.043725967407227
9.02004337310791
10.608321189880371
9.579338073730469
8.522499084472656
8.467400550842285
7.991243839263916
9.155688285827637
9.614164352416992
8.48126220703125
9.107695579528809
48.325340270996094
11.159247398376465
15.108429908752441
11.715049743652344
12.366527557373047
10.201427459716797
10.44519329071045
14.90476131439209
19.370376586914062
11.512521743774414
11.105737686157227
9.780184745788574
10.641773223876953
10.101161003112793
37.55167770385742
21.147567749023438
18.53959083557129
17.319190979003906
16.353591918945312
13.084301948547363
13.994503021240234
14.7406644821167
15.484294891357422
12.441043853759766
15.416182518005371
11.867

9.506477355957031
9.649038314819336
11.69465446472168
8.81785774230957
9.581585884094238
8.282940864562988
9.760848999023438
10.873865127563477
10.187711715698242
9.185844421386719
8.072103500366211
11.368229866027832
11.205147743225098
8.806939125061035
9.831287384033203
9.700753211975098
7.741852760314941
9.30950927734375
9.709616661071777
15.149245262145996
8.490006446838379
15.18242073059082
17.709407806396484
11.092486381530762
15.128565788269043
15.567118644714355
10.89529037475586
17.37526512145996
14.957648277282715
17.223373413085938
19.710628509521484
14.317806243896484
14.228703498840332
15.503886222839355
9.978690147399902
16.09454345703125
10.252867698669434
13.996560096740723
12.519330024719238
13.223160743713379
17.706941604614258
14.582226753234863
13.324845314025879
11.492321014404297
12.011975288391113
11.1897611618042
12.036266326904297
11.931264877319336
10.528386116027832
15.16398811340332
9.454675674438477
13.296274185180664
10.762520790100098
11.184191703796387
1

7.374809741973877
6.755706787109375
7.081999778747559
7.443924427032471
7.208517074584961
7.881009578704834
6.933271884918213
6.698119640350342
6.704709053039551
6.8552961349487305
7.2039337158203125
6.979747295379639
7.599488735198975
7.170906066894531
8.24282169342041
6.257002830505371
7.06118631362915
6.763549327850342
7.123348712921143
7.498103141784668
7.0487823486328125
6.679437160491943
6.650183200836182
6.984043121337891
6.719775676727295
6.801955223083496
6.985523700714111
6.7932891845703125
6.844898700714111
6.30672025680542
6.578513145446777
6.515974521636963
6.479632377624512
6.501984119415283
6.696750640869141
7.763308048248291
7.422365188598633
5.94155740737915
7.555491924285889
7.070761203765869
7.48053503036499
7.3703293800354
6.986769676208496
6.836428642272949
6.541720867156982
6.4423370361328125
6.939096450805664
7.069042205810547
6.412127494812012
6.600634574890137
7.249037742614746
6.755077362060547
7.508634567260742
6.586719989776611
6.734256744384766
6.7768926620

9.864013671875
7.484788417816162
7.539833068847656
9.825994491577148
8.244465827941895
8.918496131896973
7.6591105461120605
9.524226188659668
17.768095016479492
9.994976997375488
9.493535041809082
9.221126556396484
8.994174003601074
13.127188682556152
7.499195575714111
9.071470260620117
9.63753890991211
8.939693450927734
9.682388305664062
9.594602584838867
9.268794059753418
8.830257415771484
15.202508926391602
9.07169246673584
8.843047142028809
11.237918853759766
9.334935188293457
9.099499702453613
9.06544017791748
13.203266143798828
10.198596000671387
9.688087463378906
10.2483491897583
13.988946914672852
11.180054664611816
9.363032341003418
8.023948669433594
9.51634407043457
9.313786506652832
9.499195098876953
8.86860466003418
9.27186107635498
8.865400314331055
9.64098834991455
8.583378791809082
9.146130561828613
8.868063926696777
10.293728828430176
9.271628379821777
8.720540046691895
8.259936332702637
8.116036415100098
9.28732967376709
9.17952823638916
9.074334144592285
7.81438398361

8.489103317260742
9.522379875183105
10.142930030822754
8.308022499084473
10.171451568603516
10.28135871887207
11.493746757507324
8.13460922241211
10.005019187927246
16.785865783691406
11.175625801086426
14.035133361816406
9.97645092010498
9.761173248291016
10.771228790283203
8.516413688659668
10.647113800048828
8.357850074768066
9.599406242370605
15.463603019714355
9.719164848327637
10.092331886291504
10.322280883789062
9.520381927490234
9.236308097839355
8.227405548095703
8.39622688293457
8.550291061401367
8.44284725189209
9.54448127746582
8.881814956665039
9.300704956054688
9.182685852050781
9.499702453613281
9.336150169372559
8.932205200195312
12.708673477172852
10.366949081420898
8.135384559631348
10.529436111450195
9.151999473571777
9.690240859985352
9.842415809631348
9.392627716064453
10.953994750976562
11.539877891540527
13.229338645935059
9.99111270904541
9.557202339172363
12.26036262512207
9.773870468139648
13.431772232055664
11.571300506591797
10.039437294006348
15.4310693740

In [None]:
torch.save(model, './models/3epoch-CNN-RNN.pt')

In [None]:
model = torch.load('./models/6epochmodel.pt')
model.eval()
model.to(device)
if cuda_status:
    model = model.cuda()

In [None]:
import pandas as pd

# Submission output
writeCSV = True
val_path = "./new_val_in/new_val_in"

if writeCSV:
    
    dataset = ArgoverseDataset(data_path=val_path)
    test_loader = DataLoader(dataset,batch_size=batch_sz, shuffle = False, collate_fn=test_collate, num_workers=0)
    
    data = []
    
    with torch.no_grad():
        for i_batch, sample_batch in enumerate(tqdm(test_loader)):
            inp, scene_ids, track_ids, agent_ids = sample_batch

            if cuda_status:
                model = model.cuda()
                x = inp.permute(0,3,1,2).float().cuda()
            else:
                x = inp.permute(0,3,1,2).float()

            y_pred = None

            # Forward pass: predict y by passing x to the model.    
            y_pred = model(x)
            y_pred = torch.reshape(y_pred, torch.Size([batch_sz, 60, 30, 4]))
            
            for i in range(batch_sz):
                row = []
                row.append(scene_ids[i].item())
                curr = y_pred[i]
                
                agent_id = agent_ids[i]
                
                for j in range(30):
                    vehicle_index = 0
                    found = False
                    while not found:
                        if track_ids[i][vehicle_index][j][0] == agent_id:
                            found = True
                        else:
                            vehicle_index += 1

                    row.append(str(curr[vehicle_index][j][0].item()))
                    row.append(str(curr[vehicle_index][j][1].item()))
                    
                data.append(row)

    df = pd.DataFrame(data, columns = ['ID','v1','v2','v3','v4','v5','v6','v7','v8','v9','v10','v11','v12','v13','v14','v15','v16','v17','v18','v19','v20','v21','v22','v23','v24','v25','v26','v27','v28','v29','v30','v31','v32','v33','v34','v35','v36','v37','v38','v39','v40','v41','v42','v43','v44','v45','v46','v47','v48','v49','v50','v51','v52','v53','v54','v55','v56','v57','v58','v59','v60'])
    print(df)
    df.to_csv('submission.csv', index=False)
                
                
                