In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, os.path 
import numpy 
import pickle
from glob import glob
from tqdm.notebook import tqdm, trange

"""Change to the data folder"""
new_path = "./new_train/new_train"

cuda_status = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# number of sequences in each dataset
# train:205942  val:3200 test: 36272 
# sequences sampled at 10HZ rate

### Create a dataset class 

In [2]:
class ArgoverseDataset(Dataset):
    """Dataset class for Argoverse"""
    def __init__(self, data_path: str, transform=None):
        super(ArgoverseDataset, self).__init__()
        self.data_path = data_path
        self.transform = transform

        self.pkl_list = glob(os.path.join(self.data_path, '*'))
        self.pkl_list.sort()
        
    def __len__(self):
        return len(self.pkl_list)

    def __getitem__(self, idx):

        pkl_path = self.pkl_list[idx]
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)
            
        if self.transform:
            data = self.transform(data)

        return data


# intialize a dataset
val_dataset  = ArgoverseDataset(data_path=new_path)

### Create a loader to enable batch processing

In [3]:
batch_sz = 64

def my_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    out = [numpy.dstack([scene['p_out'], scene['v_out']]) for scene in batch]
    scene_ids = [scene['scene_idx'] for scene in batch]
    track_ids = [scene['track_id'] for scene in batch]
    agent_ids = [scene['agent_id'] for scene in batch]
    inp = torch.LongTensor(inp)
    out = torch.LongTensor(out)
    scene_ids = torch.LongTensor(scene_ids)
    return [inp, out, scene_ids, track_ids, agent_ids]

def test_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    scene_ids = [scene['scene_idx'] for scene in batch]
    track_ids = [scene['track_id'] for scene in batch]
    agent_ids = [scene['agent_id'] for scene in batch]
    inp = torch.LongTensor(inp)
    scene_ids = torch.LongTensor(scene_ids)
    return [inp, scene_ids, track_ids, agent_ids]

val_loader = DataLoader(val_dataset,batch_size=batch_sz, shuffle = False, collate_fn=my_collate, num_workers=0)

In [None]:
model = torch.nn.Sequential(
    torch.nn.Linear(76, 98),
    torch.nn.ReLU(),
    torch.nn.Linear(98, 98),
    torch.nn.ReLU(),
    torch.nn.Linear(98, 98),
    torch.nn.ReLU(),
    torch.nn.Linear(98, 120)
)
model.to(device)
if cuda_status:
    model = model.cuda()

In [4]:
model = torch.nn.Sequential(
    torch.nn.Linear(76, 98),
    torch.nn.ReLU(),
    torch.nn.Linear(98, 98),
    torch.nn.ReLU(),
    torch.nn.Linear(98, 120)
)
model.to(device)
if cuda_status:
    model = model.cuda()

In [None]:
model = torch.nn.Sequential(
    torch.nn.Conv2d(4, 8, 5),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2, 2),
    #torch.nn.Conv2d(8, 24, 5),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2, 2),
    #torch.nn.Linear(24 * 5 * 5, 512),
    #torch.nn.ReLU(),
    #torch.nn.Linear(512, 256),
    #torch.nn.ReLU(),
    #torch.nn.Linear(256, 120)
)
model.to(device)
if cuda_status:
    model = model.cuda()

In [None]:
class LSTM(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = torch.nn.LSTM(76, hidden_size=98, batch_first=True)
        self.linear = torch.nn.Linear(19*98, 7200)
        self.linear2 = torch.nn.Linear(5880, 6500)
        self.relu = torch.nn.ReLU()
        self.linear3 = torch.nn.Linear(6500, 6500)
        self.linear4 = torch.nn.Linear(6500, 7200)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = torch.flatten(out, start_dim=1)
        #predict = self.linear(out)
        #predict = torch.reshape(predict, torch.Size([64, 60, 30, 4]))
        out = self.relu(self.linear2(out))
        out = self.relu(self.linear3(out))
        out = self.relu(self.linear4(out))
        return out

model = LSTM()
model.to(device)
if cuda_status:
    model = model.cuda()

In [None]:
model = torch.load('./models/1epochAdam.pt')
model.train()
model.to(device)
if cuda_status:
    model = model.cuda()

### Visualize the batch of sequences

In [5]:
import matplotlib.pyplot as plt
import random
from tqdm.notebook import tqdm, trange

agent_id = 0
epoch = 3

def show_sample_batch(sample_batch, agent_id):
    """visualize the trajectory for a batch of samples with a randon agent"""
    inp, out, scene_ids, track_ids, agent_ids = sample_batch
    batch_sz = inp.size(0)
    agent_sz = inp.size(1)
    
    fig, axs = plt.subplots(1,batch_sz, figsize=(15, 3), facecolor='w', edgecolor='k')
    fig.subplots_adjust(hspace = .5, wspace=.001)
    axs = axs.ravel()   
    for i in range(batch_sz):
        axs[i].xaxis.set_ticks([])
        axs[i].yaxis.set_ticks([])
        
        # first two feature dimensions are (x,y) positions
        axs[i].scatter(inp[i, agent_id,:,0], inp[i, agent_id,:,1])
        axs[i].scatter(out[i, agent_id,:,0], out[i, agent_id,:,1])
        
# Use the nn package to define our loss function
loss_fn=torch.nn.MSELoss()

# Use the optim package to define an Optimizer

learning_rate =1e-3
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
#optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for i in trange(epoch):
    
    for i_batch, sample_batch in enumerate(tqdm(val_loader)):
        inp, out, scene_ids, track_ids, agent_ids = sample_batch
        """TODO:
          Deep learning model
          training routine
        """
        x = torch.flatten(inp, start_dim=2)

        x = x.float()
        y = out.float()

        if cuda_status:
            #model = model.cuda()
            #x = inp.cuda()
            #y = out.cuda()
            x.to(device)
            y.to(device)
            x = x.cuda()
            y = y.cuda()

        y_pred = None

        # Forward pass: predict y by passing x to the model.    
        y_pred = model(x)
        y_pred = torch.reshape(y_pred, torch.Size([batch_sz, 60, 30, 4]))

        # Compute the loss.
        loss = loss_fn(y_pred, y)
        print(torch.sqrt(loss).item())

        # Before backward pass, zero outgradients to clear buffers  
        optimizer.zero_grad()

        # Backward pass: compute gradient w.r.t modelparameters
        loss.backward()

        # makes an gradient descent step to update its parameters
        optimizer.step()

        if i_batch == 3216:
            #show_sample_batch(sample_batch, agent_id)
            #show_sample_batch([inp, y_pred.cpu().detach(), scene_ids, track_ids, agent_ids], agent_id)
            break


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=3218.0), HTML(value='')))

591.3389282226562
510.1405029296875
364.1410217285156
380.54803466796875
339.64373779296875
263.7250671386719
207.22512817382812
241.1396942138672
235.02902221679688
194.53053283691406
187.10562133789062
138.29075622558594
128.87527465820312
100.16696166992188
78.4795150756836
70.75957489013672
59.284732818603516
51.26582336425781
44.05048751831055
36.15816879272461
37.134830474853516
36.5481071472168
45.3492317199707
60.45429229736328
58.32893371582031
46.97327423095703
34.26049041748047
30.284862518310547
24.438486099243164
22.49806022644043
19.69540023803711
17.0848388671875
21.54589080810547
15.746970176696777
12.987625122070312
15.638594627380371
13.63847827911377
13.88553237915039
12.062273979187012
11.95132064819336
9.97717571258545
11.85484790802002
11.121954917907715
14.438239097595215
17.14938735961914
27.699996948242188
46.73644256591797
65.38426971435547
64.41777801513672
67.64765167236328
69.47669219970703
44.88310241699219
37.89505386352539
28.550357818603516
30.567119598

40.277278900146484
27.749359130859375
30.800582885742188
21.518003463745117
27.037233352661133
26.401626586914062
28.056642532348633
33.45091247558594
36.870933532714844
24.609128952026367
19.214614868164062
20.347108840942383
23.356718063354492
27.54403305053711
34.26687240600586
31.09455680847168
30.72806739807129
30.829275131225586
31.178190231323242
24.42462921142578
18.43550682067871
15.041125297546387
14.350192070007324
14.99195384979248
15.479307174682617
8.985339164733887
8.442371368408203
11.08674430847168
12.351001739501953
12.341403007507324
14.423948287963867
17.949186325073242
20.324678421020508
24.524898529052734
28.666540145874023
20.24176597595215
20.388288497924805
30.323518753051758
43.213924407958984
30.29634666442871
25.77224349975586
17.252595901489258
13.523077011108398
13.354144096374512
12.433534622192383
11.758606910705566
14.339805603027344
15.029251098632812
17.496644973754883
18.291648864746094
18.687841415405273
16.43010711669922
18.775609970092773
15.15016

10.408207893371582
9.089292526245117
8.538784980773926
8.158997535705566
10.65070915222168
11.2975435256958
9.674598693847656
9.7103910446167
9.632647514343262
8.344672203063965
6.838675498962402
7.163899898529053
8.982710838317871
8.531081199645996
8.89802360534668
9.780816078186035
12.039338111877441
14.532703399658203
21.356258392333984
27.156709671020508
29.43937110900879
39.369102478027344
37.61640930175781
54.718082427978516
61.9666633605957
65.371826171875
57.04351043701172
52.80165481567383
45.74650192260742
55.15934753417969
46.99909591674805
39.22880935668945
14.730813026428223
15.514636993408203
17.443748474121094
20.899688720703125
17.119129180908203
21.702529907226562
18.81441307067871
18.85500144958496
15.27920913696289
18.440088272094727
22.716720581054688
25.94973373413086
26.005311965942383
35.0128059387207
24.058696746826172
25.163787841796875
30.366600036621094
40.39186096191406
47.4726448059082
61.045658111572266
45.36009216308594
41.1576042175293
17.754594802856445

9.127861022949219
9.400681495666504
7.6171112060546875
8.897871971130371
7.973166465759277
7.981858730316162
7.037312030792236
6.969513893127441
7.376958847045898
9.49114990234375
11.639394760131836
13.139945983886719
10.910870552062988
12.08969497680664
16.36188316345215
21.23627281188965
22.418888092041016
25.258010864257812
26.998205184936523
33.44584274291992
31.313156127929688
32.77125930786133
32.354331970214844
41.68947982788086
35.78799057006836
35.20857620239258
46.77165222167969
40.329429626464844
31.817428588867188
28.44019889831543
24.049943923950195
28.726715087890625
31.239992141723633
46.687923431396484
44.247276306152344
53.857078552246094
32.42512512207031
21.8680477142334
16.763275146484375
12.694793701171875
12.213449478149414
9.969342231750488
8.2186861038208
10.038126945495605
10.6347017288208
9.463020324707031
10.931924819946289
10.635299682617188
10.733054161071777
9.016732215881348
9.246521949768066
9.539690017700195
11.039238929748535
14.154894828796387
18.0598

25.930177688598633
27.800445556640625
29.77199935913086
34.379783630371094
28.19690704345703
32.32735061645508
28.541473388671875
23.7292423248291
20.39584732055664
19.268510818481445
16.392345428466797
17.78312873840332
19.778854370117188
17.6463623046875
16.63947296142578
18.688770294189453
12.528841972351074
8.026556015014648
6.45755672454834
6.221462726593018
5.2821149826049805
5.149555683135986
5.32208251953125
5.4966535568237305
6.381940841674805
7.362220764160156
6.752787113189697
7.212629795074463
9.163078308105469
11.687586784362793
13.932896614074707
15.69753646850586
17.56745147705078
23.725494384765625
31.94611930847168
51.56427764892578
47.54702377319336
51.26913070678711
54.58168029785156
63.57542419433594
22.911273956298828
16.4333553314209
14.583452224731445
15.183259010314941
13.993948936462402
12.285335540771484
12.299912452697754
12.776561737060547
14.515886306762695
18.17057228088379
25.510801315307617
30.12038230895996
29.759187698364258
27.69938850402832
29.958984

17.40021324157715
21.7824764251709
23.13424301147461
23.382930755615234
24.321685791015625
19.211170196533203
13.192806243896484
15.246378898620605
13.150519371032715
9.508370399475098
8.868868827819824
8.65328598022461
6.102212429046631
5.999510288238525
6.3631391525268555
5.962695598602295
5.586803436279297
5.925004959106445
5.591587066650391
5.85904598236084
6.297406196594238
6.620095252990723
6.527953624725342
7.097530841827393
7.675814151763916
10.016035079956055
14.733682632446289
18.053335189819336
21.720809936523438
24.893871307373047
22.998594284057617
25.368078231811523
24.552824020385742
28.341957092285156
26.09598159790039
30.952714920043945
33.585304260253906
37.64487838745117
26.460023880004883
22.69746971130371
21.77487564086914
25.028322219848633
29.661069869995117
38.957035064697266
34.16048812866211
39.85025405883789
50.051883697509766
71.36760711669922
54.21951675415039
56.73912048339844
34.81419372558594
30.795228958129883
20.9081974029541
14.449402809143066
12.6110

11.96838092803955
10.332971572875977
11.863348960876465
12.281006813049316
12.508631706237793
9.60970687866211
7.748051166534424
7.126616477966309
6.260611057281494
5.908865451812744
5.748982906341553
6.398118495941162
9.008164405822754
9.587454795837402
12.611726760864258
13.773879051208496
17.543498992919922
15.707306861877441
18.43694496154785
22.195350646972656
27.275209426879883
33.32527160644531
34.806766510009766
31.291446685791016
25.840505599975586
18.73984718322754
13.527649879455566
12.38463020324707
12.342999458312988
10.63126277923584
12.103124618530273
12.148630142211914
11.814029693603516
10.592889785766602
13.364513397216797
18.377962112426758
34.628604888916016
53.17545700073242
55.63437271118164
38.01451110839844
24.998374938964844
15.551360130310059
18.61500358581543
21.18741798400879
23.25872802734375
17.38092613220215
17.445423126220703
17.15882682800293
18.360761642456055
19.569856643676758
22.442611694335938
25.939834594726562
29.559398651123047
17.02237510681152

20.62387466430664
14.67235279083252
7.429820537567139
5.877523422241211
5.054591655731201
4.806180477142334
4.518045902252197
4.289010047912598
4.056962490081787
3.8052115440368652
3.820965051651001
3.3748440742492676
4.068356037139893
3.5964267253875732
3.4612810611724854
3.3926124572753906
3.57348895072937
3.9120473861694336
3.9827699661254883
4.420044422149658
5.723140716552734
5.523085594177246
6.105018138885498
8.249370574951172
6.6849565505981445
5.355803966522217
4.962578773498535
4.76914644241333
4.531096935272217
4.77574348449707
5.593735694885254
6.812426567077637
8.2686767578125
10.742202758789062
12.794757843017578
14.23768138885498
15.521028518676758
15.109413146972656
15.4542818069458
16.53513526916504
14.58430290222168
13.86801815032959
17.027896881103516
17.619857788085938
19.17190933227539
20.659818649291992
20.568992614746094
16.345989227294922
10.683022499084473
9.23552417755127
9.724886894226074
7.617445468902588
7.579308032989502
8.441028594970703
7.644414901733398

HBox(children=(FloatProgress(value=0.0, max=3218.0), HTML(value='')))

4.256950855255127
4.568853378295898
5.686650276184082
8.319928169250488
8.928044319152832
9.740351676940918
10.674898147583008
8.743149757385254
8.810323715209961
11.286203384399414
12.163317680358887
14.872488975524902
15.96142292022705
18.456056594848633
18.93609046936035
22.455188751220703
21.5587158203125
19.2357177734375
21.425186157226562
19.021486282348633
19.642297744750977
17.462261199951172
21.07330894470215
22.897443771362305
19.815744400024414
14.122958183288574
13.378373146057129
12.654179573059082
13.910462379455566
14.19821548461914
10.578369140625
8.83142375946045
8.954848289489746
7.0447516441345215
5.486727237701416
4.8497314453125
4.911346912384033
5.017590045928955
5.028367042541504
3.9208006858825684
4.967020034790039
6.584144592285156
8.208046913146973
11.078889846801758
14.516828536987305
19.348478317260742
29.85614776611328
34.47998809814453
37.818084716796875
47.700172424316406
63.985679626464844
54.384273529052734
60.74955749511719
36.30118942260742
40.1697387

26.714513778686523
26.015968322753906
17.859222412109375
18.50107765197754
12.190608024597168
14.593918800354004
13.664634704589844
12.929301261901855
14.664032936096191
13.921414375305176
9.319050788879395
6.596216201782227
6.441091060638428
7.337599754333496
7.214077472686768
7.242478370666504
7.012293338775635
7.423924922943115
7.985632419586182
8.89588451385498
8.984444618225098
7.542693138122559
6.426973342895508
6.049750804901123
6.311272144317627
6.290185928344727
4.4441633224487305
4.975798606872559
6.477880477905273
7.690883636474609
8.488384246826172
9.212584495544434
12.05755615234375
14.445916175842285
18.445220947265625
19.975759506225586
14.689980506896973
13.464231491088867
17.169376373291016
19.546875
15.60470199584961
14.448609352111816
12.78168773651123
12.213057518005371
11.104083061218262
9.797457695007324
9.890789985656738
11.754375457763672
13.570672035217285
17.239660263061523
19.539306640625
22.753849029541016
21.439916610717773
25.739105224609375
22.60456466674

10.771292686462402
12.860306739807129
14.30185317993164
15.124369621276855
15.244300842285156
18.63607406616211
19.607484817504883
12.577625274658203
11.33521556854248
11.137441635131836
8.531825065612793
5.469926357269287
5.042148113250732
4.301712512969971
4.226119518280029
3.6726887226104736
4.844563961029053
4.969043731689453
6.121121406555176
7.323793888092041
7.873185157775879
7.43520450592041
9.248152732849121
9.046096801757812
11.98748779296875
14.365787506103516
15.38088607788086
18.13577651977539
21.38429832458496
25.869855880737305
38.81737518310547
45.45397186279297
42.11333465576172
16.05039405822754
16.719127655029297
15.379414558410645
16.535337448120117
14.176824569702148
15.838120460510254
12.824507713317871
11.290237426757812
8.42378044128418
9.95920467376709
11.346969604492188
11.5867919921875
11.599096298217773
13.88050651550293
10.476460456848145
10.287776947021484
12.077873229980469
15.765707015991211
19.895797729492188
25.590368270874023
26.79322624206543
26.1949

16.486732482910156
15.259641647338867
10.622729301452637
8.43995189666748
8.787663459777832
9.696231842041016
8.42161750793457
9.198504447937012
8.131570816040039
7.640812397003174
6.9253034591674805
7.261985778808594
8.114770889282227
11.480457305908203
15.026651382446289
17.18475914001465
12.722787857055664
14.415987014770508
20.61707878112793
28.819116592407227
27.710235595703125
29.661712646484375
30.20350456237793
34.17771530151367
28.422483444213867
25.623323440551758
22.682979583740234
26.6966495513916
22.216812133789062
18.19512367248535
21.938720703125
16.40213966369629
13.679515838623047
11.19056224822998
9.912842750549316
11.643575668334961
13.449188232421875
19.770456314086914
21.364246368408203
26.067527770996094
18.50848960876465
12.29839038848877
10.270291328430176
7.568815231323242
7.399605751037598
6.214550495147705
5.965364456176758
7.532680034637451
7.5608086585998535
6.783332347869873
7.841564655303955
7.5163984298706055
8.069391250610352
6.925225734710693
6.8917388

26.40743637084961
30.062753677368164
35.560733795166016
40.55399703979492
39.88180160522461
19.586978912353516
14.033012390136719
14.80583381652832
14.830936431884766
15.479536056518555
17.142946243286133
15.557888984680176
18.935754776000977
20.785991668701172
17.50668716430664
16.29090690612793
15.855382919311523
15.999554634094238
18.96646499633789
23.27322006225586
21.373369216918945
19.106582641601562
19.677146911621094
11.37339973449707
5.981594562530518
4.63235330581665
4.5316267013549805
4.213138103485107
3.710597038269043
3.2674269676208496
2.992906093597412
2.986356496810913
3.0923779010772705
3.0024662017822266
3.1041858196258545
3.3091495037078857
3.459299087524414
3.653092622756958
3.697758436203003
4.045141220092773
4.4610819816589355
5.634225845336914
8.09813404083252
9.115087509155273
11.268110275268555
16.812877655029297
25.21274185180664
20.91781234741211
18.254179000854492
18.432645797729492
20.78536605834961
19.35699462890625
18.234853744506836
18.270748138427734
15

16.753116607666016
14.747690200805664
11.069710731506348
11.433501243591309
11.791104316711426
11.228340148925781
11.85381031036377
10.112528800964355
10.255729675292969
12.383766174316406
13.591147422790527
14.005839347839355
15.87663459777832
13.576176643371582
10.008755683898926
11.906209945678711
10.950560569763184
7.319946765899658
7.8623175621032715
7.375671863555908
4.991100788116455
5.098104953765869
5.5148091316223145
5.383939743041992
4.850991249084473
5.161860942840576
4.590724945068359
4.775842189788818
5.054262638092041
4.934108734130859
5.079324722290039
5.654652118682861
5.806509017944336
7.135455131530762
9.974848747253418
11.573762893676758
14.468629837036133
18.732851028442383
17.046585083007812
19.20940399169922
21.385099411010742
23.84463119506836
20.320354461669922
22.84315299987793
26.12949562072754
29.04448699951172
23.224449157714844
19.68722915649414
19.39864730834961
21.03299903869629
23.953868865966797
28.512516021728516
23.52815055847168
24.567276000976562
2

10.084667205810547
11.000215530395508
9.471179008483887
8.001459121704102
8.977375030517578
9.408467292785645
8.532244682312012
10.296330451965332
11.266347885131836
10.51679515838623
11.503689765930176
12.499582290649414
11.970398902893066
10.148561477661133
6.234185695648193
5.664183616638184
4.748720645904541
4.574054718017578
4.2178568840026855
4.841770172119141
6.19489049911499
6.049548625946045
7.230568885803223
7.967469692230225
9.537789344787598
8.697345733642578
8.96982192993164
10.8386812210083
13.418632507324219
19.0360164642334
19.24655532836914
21.247751235961914
17.021223068237305
13.060543060302734
10.776803970336914
10.98160171508789
12.299957275390625
12.292529106140137
11.988865852355957
11.960505485534668
9.79423999786377
8.058786392211914
8.899642944335938
12.027178764343262
20.1369686126709
31.34880256652832
26.84440803527832
27.470796585083008
16.84206199645996
17.686569213867188
24.509668350219727
36.18259811401367
34.08308029174805
27.443767547607422
26.27446174

40.096473693847656
35.18111801147461
33.7390022277832
20.417865753173828
14.092979431152344
6.800227165222168
5.767316818237305
4.751881122589111
4.2329630851745605
4.060980796813965
3.931966781616211
3.4376044273376465
3.4387359619140625
3.522001266479492
3.113568067550659
3.420217752456665
3.008488893508911
2.9592294692993164
2.9640636444091797
3.254204034805298
3.260631561279297
3.359527349472046
3.655768394470215
4.586330413818359
3.8725640773773193
3.864823818206787
4.783745288848877
3.6098973751068115
3.438399314880371
2.7773587703704834
3.1455814838409424
2.845371961593628
2.9422249794006348
3.14701509475708
3.324448585510254
3.653212785720825
4.387912273406982
4.5284104347229
4.727606296539307
5.310740947723389
5.23524284362793
5.413225173950195
5.733494281768799
5.70336389541626
6.186033725738525
8.583574295043945
10.327194213867188
12.388672828674316
15.066584587097168
16.37308692932129
14.992396354675293
10.159421920776367
9.454489707946777
11.114041328430176
9.5152444839477

HBox(children=(FloatProgress(value=0.0, max=3218.0), HTML(value='')))

3.6124088764190674
4.182312965393066
5.309777736663818
8.33752727508545
9.274818420410156
10.131546974182129
11.90077018737793
9.448267936706543
10.457038879394531
13.443161010742188
15.402971267700195
19.534210205078125
22.705432891845703
26.637916564941406
33.40229797363281
41.955421447753906
46.33366775512695
35.85482406616211
33.26353073120117
23.338701248168945
21.18864631652832
17.8084659576416
23.01042366027832
23.72405242919922
19.313154220581055
11.047490119934082
7.432075500488281
6.780575752258301
5.118831634521484
5.0878143310546875
4.330990791320801
4.295603275299072
4.234121799468994
3.6107451915740967
3.228891134262085
3.2418110370635986
3.394822359085083
3.4758551120758057
3.303734064102173
2.5760605335235596
2.9415600299835205
3.337789535522461
3.270686626434326
3.9924089908599854
3.8455066680908203
4.253268718719482
4.943889617919922
4.790688991546631
4.173239707946777
4.608926296234131
5.223701000213623
5.87853479385376
6.4758100509643555
6.354328155517578
7.61003780

20.783111572265625
18.674774169921875
15.867166519165039
15.952437400817871
12.842817306518555
14.862325668334961
14.43010139465332
12.505880355834961
13.959077835083008
12.687952995300293
9.417306900024414
5.860565662384033
5.719303131103516
6.159188747406006
5.9979658126831055
5.989514350891113
5.7694292068481445
5.991286754608154
6.1998209953308105
8.776719093322754
10.022191047668457
6.954673767089844
5.736851692199707
4.991382598876953
4.4247541427612305
4.131148815155029
2.9999258518218994
3.0897626876831055
3.7205581665039062
4.145037651062012
4.501914978027344
4.932329177856445
5.880410194396973
6.590512752532959
8.310775756835938
9.434121131896973
7.931723117828369
7.9986186027526855
12.714773178100586
18.427202224731445
17.557693481445312
15.3070650100708
11.945591926574707
10.695001602172852
11.88986873626709
10.600831031799316
11.987334251403809
15.148794174194336
18.6455020904541
22.215290069580078
25.72641372680664
24.480857849121094
24.20867347717285
23.878345489501953
2

4.760653495788574
6.5757975578308105
7.677708148956299
8.64977741241455
10.011034965515137
15.162554740905762
20.128196716308594
15.99852180480957
16.006813049316406
15.652493476867676
14.132963180541992
9.715309143066406
9.636402130126953
8.855705261230469
9.64314079284668
9.389754295349121
11.25063419342041
14.604860305786133
16.36667823791504
23.268932342529297
25.37441635131836
23.900482177734375
29.753589630126953
23.80986785888672
28.1857967376709
29.114925384521484
28.825321197509766
27.671329498291016
26.81338119506836
28.60328483581543
39.29084014892578
36.92564010620117
28.041921615600586
6.951506614685059
6.082265853881836
5.283878326416016
4.5643534660339355
4.415277481079102
4.127160549163818
3.8270466327667236
3.8257901668548584
2.9111146926879883
3.7639355659484863
3.767357110977173
3.9664323329925537
4.299001693725586
4.7091064453125
3.8440630435943604
3.7477691173553467
4.049989700317383
5.054727077484131
5.882488250732422
7.582890510559082
8.711369514465332
8.39141178

8.787331581115723
10.37529182434082
10.060213088989258
9.116289138793945
8.180543899536133
7.607625484466553
6.597752094268799
6.665823936462402
7.825440883636475
10.229144096374512
12.156359672546387
13.392577171325684
12.11819076538086
15.158720016479492
25.389846801757812
35.80489730834961
35.81804656982422
26.8629150390625
25.28717803955078
20.203439712524414
17.04166603088379
13.096699714660645
13.07351303100586
14.960426330566406
14.501669883728027
11.370450019836426
15.605413436889648
11.855192184448242
12.210790634155273
9.81502628326416
10.081003189086914
11.516083717346191
13.993552207946777
20.362863540649414
24.850337982177734
25.619155883789062
21.27440643310547
10.854726791381836
9.84807014465332
6.67665433883667
6.739532947540283
5.292797565460205
5.177453994750977
6.257515907287598
6.1293134689331055
4.837601661682129
5.522220134735107
5.093448162078857
5.4171462059021
4.349102020263672
4.440375804901123
4.171017169952393
4.818692684173584
5.668050765991211
7.0319747924

17.104093551635742
16.55742073059082
20.245201110839844
21.569677352905273
20.882234573364258
22.44004249572754
21.758028030395508
13.916069030761719
12.699115753173828
10.807236671447754
10.259577751159668
10.507627487182617
12.419785499572754
9.844120025634766
9.5592041015625
9.8739013671875
6.5074286460876465
4.023645877838135
3.3700809478759766
3.3430025577545166
3.540045738220215
3.5321097373962402
3.1786997318267822
3.4845879077911377
3.5671591758728027
3.9594058990478516
3.622685670852661
3.8746628761291504
4.879148960113525
5.560553073883057
5.862729072570801
6.466144561767578
7.612485885620117
8.983522415161133
12.492259979248047
19.979291915893555
22.991737365722656
27.969301223754883
40.42268753051758
46.10238265991211
23.00987434387207
13.662919044494629
11.185213088989258
12.220603942871094
10.382627487182617
9.665374755859375
8.63598918914795
7.874218940734863
8.558116912841797
9.979403495788574
12.434175491333008
13.880834579467773
12.177008628845215
11.528597831726074
1

16.048372268676758
18.528011322021484
19.596250534057617
20.222570419311523
15.790986061096191
15.161026000976562
8.390791893005371
5.992912769317627
5.484631538391113
4.982629776000977
3.5937230587005615
4.1878342628479
3.7882978916168213
2.9314019680023193
3.236027956008911
3.002960443496704
2.9987120628356934
2.7418031692504883
2.8343842029571533
2.560626983642578
2.8732950687408447
2.6609184741973877
2.7591488361358643
2.794355630874634
2.978844404220581
3.235793352127075
3.697795867919922
4.538057327270508
5.277252197265625
7.010167121887207
9.4821195602417
10.627083778381348
12.385139465332031
16.43798065185547
18.089550018310547
19.124910354614258
23.868701934814453
30.832651138305664
30.723215103149414
28.140634536743164
17.94417381286621
18.32526206970215
17.453458786010742
20.861225128173828
21.387054443359375
19.30052947998047
17.044668197631836
19.31476593017578
23.55022430419922
27.299072265625
26.112911224365234
25.41141128540039
19.97442054748535
17.450284957885742
10.25

4.5115966796875
5.535925388336182
6.068522930145264
5.438405990600586
6.052835941314697
6.247211456298828
6.189121246337891
4.9854559898376465
3.981598138809204
3.86750864982605
3.5561423301696777
3.592986822128296
3.4808273315429688
4.101999282836914
5.334102630615234
5.229530334472656
6.312582015991211
6.768925666809082
8.229239463806152
7.1633992195129395
8.04489517211914
9.736791610717773
12.539167404174805
16.99000358581543
19.243675231933594
19.43582534790039
17.342262268066406
12.845117568969727
10.480371475219727
10.701438903808594
11.835412979125977
10.469596862792969
11.588345527648926
11.081424713134766
9.560017585754395
7.273994445800781
8.529391288757324
11.469315528869629
18.605932235717773
27.48828125
28.85066795349121
24.81787872314453
16.946956634521484
14.7811861038208
20.227760314941406
25.204113006591797
25.462112426757812
13.648008346557617
11.40013313293457
10.720355987548828
10.341290473937988
10.66771411895752
11.333823204040527
12.485039710998535
12.61864566802

4.457620620727539
3.3518447875976562
3.1144213676452637
3.0036158561706543
2.907374143600464
2.793710470199585
2.6858911514282227
2.7636377811431885
2.440716505050659
2.7072036266326904
2.3723785877227783
2.3679049015045166
2.351680040359497
2.4520182609558105
2.4475722312927246
2.59736704826355
2.925673723220825
3.735053777694702
3.3003058433532715
3.5014798641204834
4.707547187805176
3.7761287689208984
3.3302295207977295
3.0076522827148438
3.1286914348602295
2.8093109130859375
2.9436535835266113
3.1895384788513184
3.550863265991211
3.9929683208465576
4.777846336364746
5.12544584274292
4.849237442016602
5.16641092300415
4.516904830932617
4.376368045806885
4.138889789581299
3.847177505493164
3.7894604206085205
4.992290019989014
5.847424507141113
7.47797966003418
9.69159984588623
11.145092010498047
11.056868553161621
8.267319679260254
8.493212699890137
11.247624397277832
10.33993911743164
11.996743202209473
15.587486267089844
14.649280548095703
16.319162368774414
25.482967376708984
41.0

In [None]:
import matplotlib.pyplot as plt
import random
from tqdm.notebook import tqdm, trange

agent_id = 0
epoch = 10
        
# Use the nn package to define our loss function
loss_fn=torch.nn.MSELoss()

# Use the optim package to define an Optimizer

learning_rate =1e-3
#learning_rate =0.01
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
#optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for i in trange(epoch):
    
    for i_batch, sample_batch in enumerate(tqdm(val_loader)):
        inp, out, scene_ids, track_ids, agent_ids = sample_batch
        """TODO:
          Deep learning model
          training routine
        """
        #x = torch.flatten(inp, start_dim=2)
        x = inp.permute(0,2,1,3).float()
        x = torch.flatten(x, start_dim=2)
        x = x.float()
        y = out.float()

        if cuda_status:
            #model = model.cuda()
            #x = inp.cuda()
            #y = out.cuda()
            x.to(device)
            y.to(device)
            x = x.cuda()
            y = y.cuda()

        y_pred = None
        # Forward pass: predict y by passing x to the model.    
        y_pred = model(x)
        #y_pred = torch.reshape(y_pred, torch.Size([batch_sz, 60, 30, 4]))

        # Compute the loss.
        loss = loss_fn(y_pred, y)
        print(torch.sqrt(loss))

        # Before backward pass, zero outgradients to clear buffers  
        optimizer.zero_grad()

        # Backward pass: compute gradient w.r.t modelparameters
        loss.backward()

        # makes an gradient descent step to update its parameters
        optimizer.step()

        if i_batch == 3216:
            #show_sample_batch(sample_batch, agent_id)
            #show_sample_batch([inp, y_pred.cpu().detach(), scene_ids, track_ids, agent_ids], agent_id)
            break


In [None]:
torch.save(model, './models/12InitialWorkspace.pt')

In [None]:
model = torch.load('./models/6epochmodel.pt')
model.eval()
model.to(device)
if cuda_status:
    model = model.cuda()

In [None]:
import pandas as pd

# Submission output
writeCSV = True
val_path = "./new_val_in/new_val_in"

if writeCSV:
    
    dataset = ArgoverseDataset(data_path=val_path)
    test_loader = DataLoader(dataset,batch_size=batch_sz, shuffle = False, collate_fn=test_collate, num_workers=0)
    
    data = []
    
    with torch.no_grad():
        for i_batch, sample_batch in enumerate(tqdm(test_loader)):
            inp, scene_ids, track_ids, agent_ids = sample_batch
            inp = torch.flatten(inp, start_dim=2)

            if cuda_status:
                model = model.cuda()
                x = inp.cuda()
            else:
                x = inp

            y_pred = None

            # Forward pass: predict y by passing x to the model.    
            y_pred = model(x.float())
            y_pred = torch.reshape(y_pred, torch.Size([batch_sz, 60, 30, 4]))
            
            for i in range(batch_sz):
                row = []
                row.append(scene_ids[i].item())
                curr = y_pred[i]
                
                agent_id = agent_ids[i]
                
                for j in range(30):
                    vehicle_index = 0
                    found = False
                    while not found:
                        if track_ids[i][vehicle_index][j][0] == agent_id:
                            found = True
                        else:
                            vehicle_index += 1

                    row.append(str(curr[vehicle_index][j][0].item()))
                    row.append(str(curr[vehicle_index][j][1].item()))
                    
                data.append(row)

    df = pd.DataFrame(data, columns = ['ID','v1','v2','v3','v4','v5','v6','v7','v8','v9','v10','v11','v12','v13','v14','v15','v16','v17','v18','v19','v20','v21','v22','v23','v24','v25','v26','v27','v28','v29','v30','v31','v32','v33','v34','v35','v36','v37','v38','v39','v40','v41','v42','v43','v44','v45','v46','v47','v48','v49','v50','v51','v52','v53','v54','v55','v56','v57','v58','v59','v60'])
    print(df)
    df.to_csv('submission.csv', index=False)
                
                
                