In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, os.path 
import numpy 
import pickle
from glob import glob
from tqdm.notebook import tqdm, trange

"""Change to the data folder"""
new_path = "./new_train/new_train"

cuda_status = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# number of sequences in each dataset
# train:205942  val:3200 test: 36272 
# sequences sampled at 10HZ rate

### Create a dataset class 

In [2]:
class ArgoverseDataset(Dataset):
    """Dataset class for Argoverse"""
    def __init__(self, data_path: str, transform=None):
        super(ArgoverseDataset, self).__init__()
        self.data_path = data_path
        self.transform = transform

        self.pkl_list = glob(os.path.join(self.data_path, '*'))
        self.pkl_list.sort()
        
    def __len__(self):
        return len(self.pkl_list)

    def __getitem__(self, idx):

        pkl_path = self.pkl_list[idx]
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)
            
        if self.transform:
            data = self.transform(data)

        return data


# intialize a dataset
val_dataset  = ArgoverseDataset(data_path=new_path)

### Create a loader to enable batch processing

In [3]:
batch_sz = 64

def my_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    out = [numpy.dstack([scene['p_out'], scene['v_out']]) for scene in batch]
    scene_ids = [scene['scene_idx'] for scene in batch]
    track_ids = [scene['track_id'] for scene in batch]
    agent_ids = [scene['agent_id'] for scene in batch]
    inp = torch.LongTensor(inp)
    out = torch.LongTensor(out)
    scene_ids = torch.LongTensor(scene_ids)
    return [inp, out, scene_ids, track_ids, agent_ids]

def test_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    scene_ids = [scene['scene_idx'] for scene in batch]
    track_ids = [scene['track_id'] for scene in batch]
    agent_ids = [scene['agent_id'] for scene in batch]
    inp = torch.LongTensor(inp)
    scene_ids = torch.LongTensor(scene_ids)
    return [inp, scene_ids, track_ids, agent_ids]

val_loader = DataLoader(val_dataset,batch_size=batch_sz, shuffle = False, collate_fn=my_collate, num_workers=0)

In [None]:
model = torch.nn.Sequential(
    torch.nn.Linear(76, 98),
    torch.nn.ReLU(),
    torch.nn.Linear(98, 98),
    torch.nn.ReLU(),
    torch.nn.Linear(98, 98),
    torch.nn.ReLU(),
    torch.nn.Linear(98, 120)
)
model.to(device)
if cuda_status:
    model = model.cuda()

In [4]:
model = torch.nn.Sequential(
    torch.nn.Linear(76, 98),
    torch.nn.ReLU(),
    torch.nn.Linear(98, 98),
    torch.nn.ReLU(),
    torch.nn.Linear(98, 120)
)
model.to(device)
if cuda_status:
    model = model.cuda()

In [None]:
model = torch.nn.Sequential(
    torch.nn.Conv2d(4, 8, 5),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2, 2),
    #torch.nn.Conv2d(8, 24, 5),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2, 2),
    #torch.nn.Linear(24 * 5 * 5, 512),
    #torch.nn.ReLU(),
    #torch.nn.Linear(512, 256),
    #torch.nn.ReLU(),
    #torch.nn.Linear(256, 120)
)
model.to(device)
if cuda_status:
    model = model.cuda()

In [None]:
class LSTM(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = torch.nn.LSTM(76, hidden_size=98, batch_first=True)
        self.linear = torch.nn.Linear(19*98, 7200)
        self.linear2 = torch.nn.Linear(5880, 6500)
        self.relu = torch.nn.ReLU()
        self.linear3 = torch.nn.Linear(6500, 6500)
        self.linear4 = torch.nn.Linear(6500, 7200)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = torch.flatten(out, start_dim=1)
        #predict = self.linear(out)
        #predict = torch.reshape(predict, torch.Size([64, 60, 30, 4]))
        out = self.relu(self.linear2(out))
        out = self.relu(self.linear3(out))
        out = self.relu(self.linear4(out))
        return out

model = LSTM()
model.to(device)
if cuda_status:
    model = model.cuda()

In [None]:
model = torch.load('./models/1epochAdam.pt')
model.train()
model.to(device)
if cuda_status:
    model = model.cuda()

### Visualize the batch of sequences

In [6]:
import matplotlib.pyplot as plt
import random
from tqdm.notebook import tqdm, trange

agent_id = 0
epoch = 3

def show_sample_batch(sample_batch, agent_id):
    """visualize the trajectory for a batch of samples with a randon agent"""
    inp, out, scene_ids, track_ids, agent_ids = sample_batch
    batch_sz = inp.size(0)
    agent_sz = inp.size(1)
    
    fig, axs = plt.subplots(1,batch_sz, figsize=(15, 3), facecolor='w', edgecolor='k')
    fig.subplots_adjust(hspace = .5, wspace=.001)
    axs = axs.ravel()   
    for i in range(batch_sz):
        axs[i].xaxis.set_ticks([])
        axs[i].yaxis.set_ticks([])
        
        # first two feature dimensions are (x,y) positions
        axs[i].scatter(inp[i, agent_id,:,0], inp[i, agent_id,:,1])
        axs[i].scatter(out[i, agent_id,:,0], out[i, agent_id,:,1])
        
# Use the nn package to define our loss function
loss_fn=torch.nn.MSELoss()

# Use the optim package to define an Optimizer

learning_rate =1e-3
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
#optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for i in trange(epoch):
    
    for i_batch, sample_batch in enumerate(tqdm(val_loader)):
        inp, out, scene_ids, track_ids, agent_ids = sample_batch
        """TODO:
          Deep learning model
          training routine
        """
        x = torch.flatten(inp, start_dim=2)

        x = x.float()
        y = out.float()

        if cuda_status:
            #model = model.cuda()
            #x = inp.cuda()
            #y = out.cuda()
            x.to(device)
            y.to(device)
            x = x.cuda()
            y = y.cuda()

        y_pred = None

        # Forward pass: predict y by passing x to the model.    
        y_pred = model(x)
        y_pred = torch.reshape(y_pred, torch.Size([batch_sz, 60, 30, 4]))

        # Compute the loss.
        loss = loss_fn(y_pred, y)
        loss = torch.sqrt(loss)
        print(loss.item())

        # Before backward pass, zero outgradients to clear buffers  
        optimizer.zero_grad()

        # Backward pass: compute gradient w.r.t modelparameters
        loss.backward()

        # makes an gradient descent step to update its parameters
        optimizer.step()

        if i_batch == 3216:
            #show_sample_batch(sample_batch, agent_id)
            #show_sample_batch([inp, y_pred.cpu().detach(), scene_ids, track_ids, agent_ids], agent_id)
            break


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=3218.0), HTML(value='')))

595.5176391601562
518.7504272460938
350.3236999511719
471.8230285644531
358.3848571777344
260.1774597167969
195.37254333496094
214.919921875
183.5379180908203
169.55279541015625
173.09649658203125
214.51007080078125
68.60237884521484
128.9435272216797
162.61798095703125
75.94810485839844
175.1658935546875
30.801687240600586
130.76414489746094
104.53498840332031
69.88712310791016
129.19036865234375
38.41622543334961
128.32618713378906
27.962196350097656
88.03697967529297
45.86211013793945
109.01971435546875
23.50116539001465
57.83266067504883
78.68704986572266
78.75364685058594
63.00243377685547
60.8730354309082
41.0143928527832
85.49250030517578
32.722904205322266
93.15188598632812
36.73838424682617
72.16399383544922
58.88746643066406
90.59579467773438
23.912216186523438
72.8470458984375
37.87156295776367
80.13037109375
40.00065231323242
74.3509521484375
27.84984588623047
80.3686294555664
48.223270416259766
86.92333221435547
32.940757751464844
73.53670501708984
38.266963958740234
89.91

46.28207778930664
8.439879417419434
36.185123443603516
28.182039260864258
40.782020568847656
31.441625595092773
38.692710876464844
16.637990951538086
40.75779724121094
23.53154754638672
46.14322280883789
19.3338680267334
48.50819396972656
11.660279273986816
48.298709869384766
16.1064510345459
46.28834533691406
12.062610626220703
40.031768798828125
14.805492401123047
44.59483337402344
19.462018966674805
39.121826171875
7.300034999847412
24.01363182067871
43.76752853393555
34.525413513183594
27.743408203125
37.380027770996094
31.814687728881836
35.83980941772461
30.243083953857422
34.5100212097168
15.853996276855469
44.888893127441406
24.198623657226562
45.424259185791016
9.43684196472168
31.03265953063965
23.85213851928711
38.30720901489258
21.956518173217773
36.78919982910156
22.110610961914062
45.17974853515625
15.175603866577148
47.85655975341797
10.845549583435059
45.29816436767578
8.79875659942627
34.86021041870117
20.093631744384766
47.26080322265625
7.704500198364258
32.954574584

26.8977108001709
33.80451202392578
29.74704360961914
30.3781681060791
16.861112594604492
37.59554672241211
19.222291946411133
28.791818618774414
16.222827911376953
34.60624313354492
19.64168930053711
36.27064514160156
15.063058853149414
41.396732330322266
19.78878402709961
35.48930358886719
26.757442474365234
28.175373077392578
24.313020706176758
37.03563690185547
12.627825736999512
45.21857833862305
13.851129531860352
41.408958435058594
15.876930236816406
42.01236343383789
19.790382385253906
50.65245056152344
11.925741195678711
37.68379592895508
4.43537712097168
25.190019607543945
30.975366592407227
32.829917907714844
17.344572067260742
40.21299362182617
9.542991638183594
33.00434494018555
12.607256889343262
41.80349349975586
15.880282402038574
39.30500411987305
15.51285457611084
40.75192642211914
13.229883193969727
31.882287979125977
27.438199996948242
33.973567962646484
29.070894241333008
36.794151306152344
28.834184646606445
24.119869232177734
25.410261154174805
29.058839797973633


35.213253021240234
19.307268142700195
29.230796813964844
17.972858428955078
26.012039184570312
22.3624210357666
25.476634979248047
23.37856674194336
33.396278381347656
24.8433895111084
34.63470458984375
18.919574737548828
28.691280364990234
21.541948318481445
41.863739013671875
16.94753074645996
33.944339752197266
14.8948335647583
36.511539459228516
16.079200744628906
32.023929595947266
14.616643905639648
32.337345123291016
23.120147705078125
29.61158561706543
15.312162399291992
41.98564910888672
5.50853967666626
26.26788902282715
24.183944702148438
26.371023178100586
31.135372161865234
22.832170486450195
40.87346649169922
10.711368560791016
39.78963088989258
8.727266311645508
35.992000579833984
18.287145614624023
28.44399642944336
20.96584129333496
26.517826080322266
21.406251907348633
39.74195098876953
13.391550064086914
35.209754943847656
21.812211990356445
30.6662540435791
22.486417770385742
27.297819137573242
23.715227127075195
29.011791229248047
24.665061950683594
34.355705261230

25.60818862915039
13.394283294677734
26.867509841918945
22.90096092224121
24.738338470458984
25.997947692871094
27.721450805664062
20.896265029907227
34.585845947265625
17.643468856811523
25.717138290405273
21.80167579650879
26.812183380126953
20.102153778076172
32.64482116699219
20.532028198242188
26.101999282836914
20.446393966674805
30.59779930114746
9.494751930236816
26.054203033447266
16.09580421447754
30.03498649597168
12.540115356445312
31.465984344482422
7.405955791473389
28.975770950317383
19.2502498626709
30.31793975830078
11.192266464233398
31.905546188354492
19.87421226501465
29.142004013061523
17.692472457885742
30.40207862854004
20.22066307067871
25.438343048095703
27.84687614440918
29.786006927490234
19.731918334960938
33.00468444824219
27.061870574951172
28.765180587768555
12.371057510375977
30.954694747924805
18.133249282836914
32.58881759643555
12.838953018188477
33.211544036865234
22.80354881286621
20.786584854125977
31.006412506103516
22.95780372619629
32.7454872131

32.9620361328125
8.735406875610352
37.64059829711914
12.06283187866211
35.0748176574707
15.195510864257812
28.969655990600586
16.355897903442383
36.37114334106445
14.3162260055542
35.41444396972656
17.291118621826172
28.051660537719727
11.699934005737305
38.502708435058594
7.269145488739014
28.859819412231445
20.507488250732422
24.271440505981445
11.883821487426758
32.75225830078125
14.862527847290039
27.380483627319336
13.667075157165527
28.889169692993164
13.463048934936523
27.381174087524414
20.658727645874023
19.661190032958984
25.343326568603516
19.159666061401367
24.071226119995117
20.57901954650879
31.77225112915039
11.933157920837402
33.48356246948242
11.600646018981934
20.971166610717773
20.718372344970703
24.2818660736084
21.119077682495117
25.296560287475586
23.60056495666504
28.855213165283203
19.277374267578125
23.9481258392334
17.985107421875
30.44151496887207
15.504549026489258
33.940223693847656
13.688199996948242
30.98776626586914
18.262544631958008
26.423547744750977


31.22299575805664
19.21657371520996
28.82587242126465
12.700361251831055
26.831411361694336
21.535547256469727
21.58916473388672
22.97490119934082
24.16203498840332
21.551963806152344
17.33738899230957
27.880403518676758
15.441813468933105
26.905447006225586
11.294564247131348
24.138517379760742
15.670943260192871
23.355756759643555
15.762539863586426
25.561004638671875
19.1821346282959
34.310943603515625
9.165656089782715
35.23196792602539
7.095068454742432
26.67576789855957
14.232826232910156
34.09937286376953
12.9716796875
32.71326446533203
15.134004592895508
27.37445640563965
16.102569580078125
24.408157348632812
16.00892448425293
20.203527450561523
23.981409072875977
20.69270133972168
17.584678649902344
27.005189895629883
13.110247611999512
24.269447326660156
11.174549102783203
33.13493728637695
12.775618553161621
40.00304412841797
13.368343353271484
27.45669174194336
13.710354804992676
22.470579147338867
17.49921989440918
33.40135955810547
15.338321685791016
28.864953994750977
8.

14.636665344238281
25.945505142211914
11.015106201171875
29.57359504699707
12.422648429870605
28.164026260375977
12.442728042602539
26.807369232177734
16.95203399658203
23.32044219970703
15.176974296569824
22.830589294433594
19.439462661743164
23.112600326538086
19.23163414001465
25.16836166381836
18.097414016723633
23.36742401123047
21.823314666748047
19.906715393066406
21.45368194580078
19.93912124633789
19.818618774414062
21.056909561157227
18.408748626708984
26.701141357421875
14.50642204284668
27.367218017578125
12.643930435180664
30.03648567199707
10.161337852478027
32.198726654052734
18.251325607299805
17.472171783447266
16.022947311401367
25.146224975585938
12.253957748413086
22.970901489257812
16.042430877685547
26.53636932373047
16.515642166137695
26.250463485717773
17.511533737182617
26.812397003173828
12.117379188537598
26.14999008178711
10.843317031860352
28.369680404663086
10.761425018310547
23.58399200439453
12.376784324645996
31.26200294494629
7.182143688201904
31.99346

HBox(children=(FloatProgress(value=0.0, max=3218.0), HTML(value='')))

24.836565017700195
16.805187225341797
26.141965866088867
18.477657318115234
16.450511932373047
22.10698699951172
18.65924835205078
19.38802719116211
16.523311614990234
27.190353393554688
8.339810371398926
27.74309539794922
10.784688949584961
29.550203323364258
7.784900188446045
31.56035614013672
9.680991172790527
26.61287498474121
13.416044235229492
26.576154708862305
13.427664756774902
27.503639221191406
18.099180221557617
27.219993591308594
12.1777982711792
24.348796844482422
12.473944664001465
26.894691467285156
8.614736557006836
29.51480484008789
9.588005065917969
16.100116729736328
22.05018424987793
17.38741683959961
16.755250930786133
22.80368423461914
14.330825805664062
24.680002212524414
14.346274375915527
18.553733825683594
22.42852020263672
22.260395050048828
17.20258140563965
24.999805450439453
13.35335636138916
26.593971252441406
16.252038955688477
23.166645050048828
9.634811401367188
30.277299880981445
13.602702140808105
28.222909927368164
14.970658302307129
24.27966308593

27.694610595703125
8.431120872497559
28.426755905151367
13.937292098999023
23.607839584350586
19.68752670288086
19.703182220458984
18.62920570373535
18.143163681030273
21.310800552368164
16.173404693603516
21.26909065246582
18.919544219970703
20.112394332885742
16.323497772216797
16.297529220581055
24.429367065429688
15.987933158874512
23.225711822509766
14.013764381408691
23.13788604736328
17.449796676635742
21.388404846191406
17.228708267211914
21.661584854125977
9.765700340270996
22.58687400817871
10.936488151550293
24.24016571044922
16.25121307373047
11.742331504821777
27.611858367919922
14.095026016235352
27.522342681884766
12.623392105102539
29.044994354248047
13.133746147155762
27.639188766479492
14.964797973632812
26.565582275390625
6.25385046005249
30.456499099731445
15.02563190460205
30.264890670776367
5.752388000488281
17.205373764038086
16.859813690185547
21.135595321655273
17.107295989990234
19.51398468017578
18.367431640625
24.706377029418945
14.20701789855957
26.67716789

12.813860893249512
21.437122344970703
13.914056777954102
26.233516693115234
8.74560832977295
18.151012420654297
15.257621765136719
13.024914741516113
22.06554412841797
18.797849655151367
20.57143211364746
17.788776397705078
17.972713470458984
20.993724822998047
17.8901309967041
13.068757057189941
21.17333221435547
13.284870147705078
17.52284812927246
11.646271705627441
22.278654098510742
10.008273124694824
22.93828582763672
9.432046890258789
25.870182037353516
11.95421028137207
20.629507064819336
20.6020565032959
17.023738861083984
15.92526912689209
22.65732192993164
10.533695220947266
28.60507583618164
9.973860740661621
25.1538143157959
11.729519844055176
24.383359909057617
15.038129806518555
29.211519241333008
10.590171813964844
21.80232810974121
4.501715660095215
24.191709518432617
13.171781539916992
20.68865203857422
11.524410247802734
24.848522186279297
8.94143009185791
22.17547035217285
9.92910385131836
24.455463409423828
14.439859390258789
25.4162654876709
14.166460990905762
21.

25.107257843017578
13.02476978302002
23.814228057861328
13.426555633544922
22.018600463867188
9.621562004089355
24.27739715576172
10.509990692138672
18.745954513549805
11.17469596862793
21.494094848632812
8.426758766174316
22.741212844848633
11.582125663757324
24.288843154907227
7.006014823913574
25.87994384765625
10.696316719055176
17.941814422607422
10.96207046508789
22.58855438232422
12.68592643737793
24.043983459472656
11.425111770629883
20.70439338684082
9.739333152770996
25.875011444091797
14.873710632324219
25.95722198486328
7.52901029586792
23.109514236450195
12.538739204406738
21.78072738647461
12.022568702697754
19.443527221679688
16.10535430908203
23.146564483642578
11.65683364868164
21.701339721679688
16.14336395263672
13.62608814239502
18.20689582824707
13.83983325958252
19.921728134155273
17.455778121948242
19.096168518066406
25.43305778503418
11.583497047424316
27.327880859375
5.8094282150268555
23.296539306640625
11.458292007446289
18.45985984802246
17.27805519104004
16

22.021581649780273
10.23840618133545
19.944705963134766
7.9022722244262695
21.429534912109375
7.006595134735107
21.4638729095459
11.332845687866211
20.073453903198242
15.732933044433594
19.427400588989258
12.702439308166504
23.499879837036133
8.576591491699219
18.763601303100586
8.038894653320312
25.235595703125
6.284645080566406
26.04458999633789
8.5585355758667
24.842914581298828
13.618267059326172
24.59444808959961
5.896817684173584
17.15701675415039
14.722505569458008
21.12226676940918
13.433290481567383
24.51736831665039
5.793562889099121
24.83221435546875
9.857556343078613
15.874002456665039
9.448260307312012
21.159095764160156
11.672131538391113
21.43492317199707
11.170161247253418
14.889737129211426
17.615327835083008
14.126770973205566
17.868669509887695
9.673288345336914
22.205341339111328
15.25174331665039
16.71720314025879
14.94626522064209
17.299959182739258
16.20905876159668
14.196091651916504
20.844684600830078
17.648975372314453
14.508132934570312
19.372976303100586
19.

10.31032657623291
20.896461486816406
10.908148765563965
24.201635360717773
9.37113094329834
22.462648391723633
12.990336418151855
16.164535522460938
16.07382583618164
17.900238037109375
11.352293968200684
20.36665153503418
8.195204734802246
23.675230026245117
10.200064659118652
20.883272171020508
13.522682189941406
16.732999801635742
16.35198402404785
22.14134979248047
13.873178482055664
20.023727416992188
15.485321998596191
14.834156036376953
13.781258583068848
22.64751434326172
9.624441146850586
18.536407470703125
15.278471946716309
17.918760299682617
9.150777816772461
22.096811294555664
10.918869018554688
20.624286651611328
11.203707695007324
22.733856201171875
6.271303176879883
22.8926944732666
10.655479431152344
20.09326934814453
10.890064239501953
20.771011352539062
11.306878089904785
23.56867218017578
13.609954833984375
19.383211135864258
12.205049514770508
18.27875328063965
12.351655006408691
18.238079071044922
10.564252853393555
17.85304069519043
12.385335922241211
22.14360809

22.443565368652344
9.93227481842041
24.088804244995117
9.487238883972168
21.983491897583008
12.516487121582031
16.460786819458008
7.698635578155518
23.78962516784668
8.619976043701172
24.550390243530273
6.692249298095703
22.504587173461914
6.399683952331543
22.932785034179688
11.448294639587402
21.58592987060547
8.34808349609375
20.664569854736328
13.20978832244873
18.70848274230957
11.580086708068848
25.200180053710938
10.082751274108887
21.251798629760742
13.17595386505127
18.776094436645508
12.753300666809082
14.907437324523926
12.44222354888916
17.413982391357422
9.918265342712402
20.400833129882812
8.788846015930176
24.840694427490234
12.519564628601074
20.118589401245117
14.805526733398438
17.18147087097168
16.492267608642578
13.211219787597656
19.30712127685547
14.806723594665527
15.092455863952637
17.62207794189453
12.8688383102417
16.855850219726562
8.556178092956543
16.378841400146484
10.23288345336914
20.457658767700195
13.123276710510254
18.80208969116211
12.735213279724121

13.206759452819824
19.405073165893555
9.251437187194824
19.938560485839844
11.917820930480957
20.014934539794922
10.631767272949219
14.445219039916992
17.127525329589844
20.74734115600586
8.56065559387207
22.63654136657715
7.415245532989502
21.469207763671875
10.855742454528809
16.727781295776367
14.92848014831543
15.269217491149902
16.904298782348633
13.633353233337402
16.78805923461914
17.388160705566406
14.600064277648926
18.026473999023438
11.260335922241211
17.446706771850586
8.837871551513672
20.13101577758789
9.03662395477295
23.67515754699707
6.684001445770264
18.947072982788086
9.229747772216797
21.238862991333008
12.798891067504883
19.036884307861328
9.19093132019043
18.960113525390625
8.046998977661133
20.586833953857422
10.856011390686035
19.668701171875
10.962218284606934
21.42767906188965
13.595708847045898
14.410603523254395
14.306890487670898
19.45242691040039
5.743720531463623
18.162796020507812
11.530314445495605
16.753524780273438
13.727987289428711
17.43715858459472

HBox(children=(FloatProgress(value=0.0, max=3218.0), HTML(value='')))

15.66600513458252
15.460906982421875
16.91621208190918
20.310604095458984
8.903860092163086
20.676849365234375
10.552762031555176
17.33518409729004
12.584372520446777
22.109806060791016
6.706892967224121
23.630544662475586
6.0922393798828125
24.874547958374023
5.054913520812988
21.088754653930664
8.876914978027344
18.13497543334961
12.76789665222168
18.7529354095459
12.009268760681152
20.996549606323242
15.490334510803223
19.730833053588867
8.71025276184082
16.491369247436523
10.665071487426758
17.873403549194336
9.451172828674316
21.401315689086914
6.648909568786621
14.979330062866211
15.595287322998047
13.793951034545898
12.102607727050781
17.760055541992188
10.11772632598877
18.747467041015625
11.833638191223145
12.900761604309082
18.50606346130371
14.713733673095703
13.28826904296875
17.60373306274414
12.470890045166016
17.553159713745117
13.749968528747559
15.580854415893555
10.40047836303711
21.474992752075195
11.469688415527344
20.190155029296875
12.745454788208008
16.7138442993

17.897233963012695
15.959827423095703
18.958223342895508
10.189994812011719
19.07039451599121
8.653599739074707
22.429309844970703
3.9693400859832764
21.602293014526367
9.520722389221191
20.848690032958984
14.40493392944336
16.790441513061523
9.630166053771973
16.535938262939453
13.946614265441895
17.18999481201172
13.970685005187988
18.221813201904297
10.590503692626953
18.78296661376953
11.456206321716309
20.485637664794922
10.21025562286377
17.400150299072266
10.519767761230469
17.825368881225586
11.2747220993042
16.160797119140625
6.9908246994018555
21.998077392578125
12.312414169311523
19.930448532104492
9.724752426147461
19.712398529052734
13.435263633728027
17.51475715637207
14.05443286895752
15.84636402130127
8.558321952819824
20.72933006286621
14.527044296264648
20.946866989135742
5.200212001800537
17.88683319091797
7.560575485229492
17.885391235351562
10.970198631286621
15.600707054138184
12.931514739990234
19.448144912719727
11.448324203491211
21.26826286315918
9.16129016876

7.355264186859131
16.453968048095703
11.03429889678955
15.810833930969238
17.174034118652344
17.684070587158203
15.22984790802002
16.29916763305664
13.652913093566895
19.842023849487305
13.348427772521973
11.79475212097168
15.963812828063965
13.37975025177002
12.939261436462402
10.766737937927246
16.14882469177246
11.107285499572754
17.913881301879883
9.267935752868652
20.21605110168457
12.053550720214844
15.576103210449219
17.271610260009766
11.322371482849121
15.48583984375
16.65350914001465
10.369365692138672
20.964080810546875
10.30339527130127
18.75173568725586
11.657306671142578
17.9188232421875
14.501222610473633
22.208694458007812
10.819920539855957
16.435672760009766
5.38090705871582
21.554899215698242
9.919986724853516
14.686580657958984
10.553056716918945
19.15534019470215
7.151330947875977
16.873109817504883
7.842532157897949
19.64354133605957
9.854643821716309
19.912521362304688
9.900542259216309
19.019630432128906
8.996930122375488
16.73277473449707
12.795511245727539
18.

21.4864559173584
9.873369216918945
15.985414505004883
10.832572937011719
14.417308807373047
11.747172355651855
18.069110870361328
11.629948616027832
16.51857566833496
14.48609447479248
12.642477989196777
14.366971969604492
13.562095642089844
14.16628360748291
16.97157096862793
16.996517181396484
17.310060501098633
12.024194717407227
11.967395782470703
15.798147201538086
18.31348991394043
14.648662567138672
14.807442665100098
13.493988037109375
18.29271697998047
11.409502983093262
17.180618286132812
8.673980712890625
19.13686752319336
12.730985641479492
17.06989288330078
8.9564208984375
21.863718032836914
3.1209421157836914
11.780136108398438
11.292614936828613
16.19009780883789
13.286222457885742
18.268308639526367
17.11054039001465
15.146900177001953
16.278614044189453
11.052759170532227
12.606277465820312
15.232075691223145
9.397686958312988
19.278823852539062
5.861858367919922
19.61968421936035
12.920731544494629
16.343643188476562
9.57571792602539
20.631977081298828
5.9810585975646

17.899658203125
9.40440845489502
16.132043838500977
14.217649459838867
15.436402320861816
11.600654602050781
18.54608154296875
7.8750152587890625
14.968427658081055
7.997065544128418
21.18132972717285
5.444719314575195
22.811189651489258
6.514308929443359
19.797277450561523
10.707226753234863
18.776212692260742
6.825650691986084
17.37550926208496
10.036449432373047
19.831710815429688
9.678683280944824
21.918607711791992
4.145171165466309
21.226341247558594
8.904726028442383
13.486846923828125
8.17460823059082
17.716440200805664
10.229897499084473
17.806766510009766
9.756468772888184
12.15327262878418
15.508913040161133
11.567541122436523
15.937952041625977
7.6941351890563965
18.569217681884766
12.885784149169922
14.5629301071167
12.16781997680664
15.252346992492676
13.168610572814941
13.389787673950195
16.773061752319336
16.478078842163086
11.220833778381348
18.01103973388672
15.709996223449707
17.206867218017578
6.613095283508301
18.804401397705078
9.485597610473633
20.489315032958984

13.505184173583984
15.145330429077148
9.098734855651855
17.683300018310547
12.681142807006836
11.942285537719727
14.705928802490234
10.849325180053711
19.025470733642578
10.722939491271973
17.456363677978516
11.984549522399902
14.898319244384766
14.367640495300293
17.854183197021484
12.197649002075195
16.48641586303711
14.228132247924805
12.516047477722168
12.674972534179688
19.170169830322266
8.742992401123047
14.82038688659668
14.738608360290527
12.942717552185059
9.143960952758789
17.540578842163086
10.609476089477539
15.765775680541992
8.547093391418457
17.757047653198242
6.588712215423584
18.252317428588867
9.342040061950684
15.828694343566895
10.44375228881836
16.50904655456543
9.962846755981445
18.770376205444336
12.869542121887207
15.292524337768555
13.237778663635254
15.306806564331055
10.922475814819336
16.263456344604492
10.751160621643066
14.918782234191895
9.754952430725098
18.71693229675293
9.917462348937988
17.353321075439453
8.720942497253418
14.719512939453125
12.96359

7.120171070098877
17.922954559326172
7.772352695465088
17.325576782226562
7.728084087371826
14.894978523254395
9.614907264709473
16.50470542907715
12.476367950439453
16.43975257873535
8.475576400756836
17.168397903442383
11.96102523803711
15.984654426574707
11.811199188232422
22.26247215270996
9.206243515014648
18.564367294311523
11.770718574523926
16.49119758605957
11.216054916381836
13.363691329956055
10.240222930908203
16.332504272460938
8.229671478271484
17.964218139648438
7.397693157196045
22.33977508544922
11.269124984741211
16.43527603149414
13.969915390014648
14.321573257446289
15.244935035705566
10.440062522888184
17.593822479248047
13.027743339538574
15.360696792602539
16.183591842651367
11.098997116088867
16.296588897705078
7.839303970336914
14.147294998168945
10.85300350189209
15.999774932861328
11.979421615600586
14.234807968139648
11.552680969238281
13.89344596862793
9.375153541564941
13.580381393432617
13.402746200561523
16.817657470703125
16.587398529052734
16.916006088

13.193522453308105
17.68488311767578
7.516947269439697
19.28411865234375
6.631808757781982
19.387094497680664
9.759186744689941
15.026904106140137
13.884410858154297
13.325935363769531
15.02768325805664
13.810629844665527
14.244483947753906
15.238351821899414
14.337815284729004
15.552027702331543
9.689080238342285
15.132210731506348
7.186839580535889
16.29486656188965
7.551083564758301
18.758317947387695
5.915826797485352
16.45699691772461
7.08294677734375
17.938940048217773
9.002881050109863
16.208173751831055
7.30322790145874
16.356338500976562
6.649960517883301
18.398569107055664
8.877555847167969
18.045658111572266
8.376441955566406
17.434986114501953
10.640861511230469
11.490188598632812
15.125722885131836
17.199289321899414
5.1392388343811035
16.266389846801758
10.205132484436035
15.045831680297852
9.2868070602417
15.580253601074219
10.670269966125488
16.52410888671875
9.183772087097168
18.92286491394043
9.27044677734375
16.206323623657227
12.856078147888184
12.86277961730957
11.

In [None]:
import matplotlib.pyplot as plt
import random
from tqdm.notebook import tqdm, trange

agent_id = 0
epoch = 10
        
# Use the nn package to define our loss function
loss_fn=torch.nn.MSELoss()

# Use the optim package to define an Optimizer

learning_rate =1e-3
#learning_rate =0.01
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
#optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for i in trange(epoch):
    
    for i_batch, sample_batch in enumerate(tqdm(val_loader)):
        inp, out, scene_ids, track_ids, agent_ids = sample_batch
        """TODO:
          Deep learning model
          training routine
        """
        #x = torch.flatten(inp, start_dim=2)
        x = inp.permute(0,2,1,3).float()
        x = torch.flatten(x, start_dim=2)
        x = x.float()
        y = out.float()

        if cuda_status:
            #model = model.cuda()
            #x = inp.cuda()
            #y = out.cuda()
            x.to(device)
            y.to(device)
            x = x.cuda()
            y = y.cuda()

        y_pred = None
        # Forward pass: predict y by passing x to the model.    
        y_pred = model(x)
        #y_pred = torch.reshape(y_pred, torch.Size([batch_sz, 60, 30, 4]))

        # Compute the loss.
        loss = loss_fn(y_pred, y)
        print(torch.sqrt(loss))

        # Before backward pass, zero outgradients to clear buffers  
        optimizer.zero_grad()

        # Backward pass: compute gradient w.r.t modelparameters
        loss.backward()

        # makes an gradient descent step to update its parameters
        optimizer.step()

        if i_batch == 3216:
            #show_sample_batch(sample_batch, agent_id)
            #show_sample_batch([inp, y_pred.cpu().detach(), scene_ids, track_ids, agent_ids], agent_id)
            break


In [None]:
torch.save(model, './models/12InitialWorkspace.pt')

In [2]:
model = torch.load('./models/12InitialWorkspace.pt')
model.eval()
model.to(device)
if cuda_status:
    model = model.cuda()

In [7]:
import pandas as pd

# Submission output
writeCSV = True
val_path = "./new_val_in/new_val_in"

if writeCSV:
    
    dataset = ArgoverseDataset(data_path=val_path)
    test_loader = DataLoader(dataset,batch_size=batch_sz, shuffle = False, collate_fn=test_collate, num_workers=0)
    
    data = []
    
    with torch.no_grad():
        for i_batch, sample_batch in enumerate(tqdm(test_loader)):
            inp, scene_ids, track_ids, agent_ids = sample_batch
            inp = torch.flatten(inp, start_dim=2)

            if cuda_status:
                model = model.cuda()
                x = inp.cuda()
            else:
                x = inp

            y_pred = None

            # Forward pass: predict y by passing x to the model.    
            y_pred = model(x.float())
            y_pred = torch.reshape(y_pred, torch.Size([batch_sz, 60, 30, 4]))
            
            for i in range(batch_sz):
                row = []
                row.append(scene_ids[i].item())
                curr = y_pred[i]
                
                agent_id = agent_ids[i]
                
                for j in range(30):
                    vehicle_index = 0
                    found = False
                    while not found:
                        if track_ids[i][vehicle_index][j][0] == agent_id:
                            found = True
                        else:
                            vehicle_index += 1

                    row.append(str(curr[vehicle_index][j][0].item()))
                    row.append(str(curr[vehicle_index][j][1].item()))
                    
                data.append(row)

    df = pd.DataFrame(data, columns = ['ID','v1','v2','v3','v4','v5','v6','v7','v8','v9','v10','v11','v12','v13','v14','v15','v16','v17','v18','v19','v20','v21','v22','v23','v24','v25','v26','v27','v28','v29','v30','v31','v32','v33','v34','v35','v36','v37','v38','v39','v40','v41','v42','v43','v44','v45','v46','v47','v48','v49','v50','v51','v52','v53','v54','v55','v56','v57','v58','v59','v60'])
    print(df)
    df.to_csv('submission.csv', index=False)
                
                
                

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))


         ID                  v1                  v2                  v3  \
0     10002   1830.747802734375      535.7841796875   1744.435302734375   
1     10015   729.9448852539062   1250.381591796875     738.62451171875   
2     10019   566.5105590820312  1277.9290771484375   573.7442626953125   
3     10028  1810.7493896484375   522.7767333984375  1721.6607666015625   
4      1003   2149.365478515625   730.1478271484375   2123.237548828125   
...     ...                 ...                 ...                 ...   
3195   9897   260.5370788574219    827.177001953125  258.47503662109375   
3196     99   581.5123901367188   1175.595458984375   588.9920654296875   
3197   9905  1820.1151123046875   566.4986572265625  1767.1998291015625   
3198   9910    565.965576171875  1324.4290771484375   571.4410400390625   
3199   9918   580.6732177734375   1183.749755859375   586.4964599609375   

                      v4                 v5                 v6  \
0     441.82391357421875   1782.