In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, os.path 
import numpy 
import pickle
from glob import glob
from tqdm.notebook import tqdm, trange

"""Change to the data folder"""
new_path = "./new_train/new_train"

cuda_status = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# number of sequences in each dataset
# train:205942  val:3200 test: 36272 
# sequences sampled at 10HZ rate

### Create a dataset class 

In [2]:
class ArgoverseDataset(Dataset):
    """Dataset class for Argoverse"""
    def __init__(self, data_path: str, transform=None):
        super(ArgoverseDataset, self).__init__()
        self.data_path = data_path
        self.transform = transform

        self.pkl_list = glob(os.path.join(self.data_path, '*'))
        self.pkl_list.sort()
        
    def __len__(self):
        return len(self.pkl_list)

    def __getitem__(self, idx):

        pkl_path = self.pkl_list[idx]
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)
        
        if self.transform:
            data = self.transform(data)

        return data


# intialize a dataset
val_dataset  = ArgoverseDataset(data_path=new_path)

### Create a loader to enable batch processing

In [None]:
batch_sz = 512

def my_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    out = [numpy.dstack([scene['p_out'], scene['v_out']]) for scene in batch]
    scene_ids = [scene['scene_idx'] for scene in batch]
    track_ids = [scene['track_id'] for scene in batch]
    agent_ids = [scene['agent_id'] for scene in batch]
    car_mask = [scene['car_mask'] for scene in batch]
    inp = torch.LongTensor(inp)
    out = torch.LongTensor(out)
    scene_ids = torch.LongTensor(scene_ids)
    car_mask = torch.LongTensor(car_mask)
    return [inp, out, scene_ids, track_ids, agent_ids, car_mask]

def test_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    scene_ids = [scene['scene_idx'] for scene in batch]
    track_ids = [scene['track_id'] for scene in batch]
    agent_ids = [scene['agent_id'] for scene in batch]
    inp = torch.LongTensor(inp)
    scene_ids = torch.LongTensor(scene_ids)
    return [inp, scene_ids, track_ids, agent_ids]

temp_loader = DataLoader(val_dataset,batch_size=205942, shuffle = False, collate_fn=my_collate, num_workers=0, drop_last=True)
data = next(iter(temp_loader))
inp, out, scene_ids, track_ids, agent_ids, car_mask = data
print(inp.mean())
print(inp.std())
print(out.mean())
print(out.std())

val_loader = DataLoader(val_dataset,batch_size=batch_sz, shuffle = False, collate_fn=my_collate, num_workers=0, drop_last=True)

In [11]:
class RNNEncoderDecoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = torch.nn.RNN(240, hidden_size=512, batch_first=True, nonlinearity='relu')
        self.decoder = torch.nn.RNN(240, hidden_size=512, batch_first=True, nonlinearity='relu')
        
        self.align1 = torch.nn.Linear(10240, 19)
        #self.attn = Attention(512,512)
        
        self.linear = torch.nn.Linear(512, 240)

    def forward(self, x, y, teach = False, teaching_ratio = 0.5):
        # batch_szx60x19x4
        x = x.permute(0,2,1,3)
        x = torch.flatten(x, start_dim=2)
        # batch_szx19x240
        output, hidden = self.encoder(x)
        
        outputs = torch.zeros(30,batch_sz,60,4).to(device).cuda()
        dec_out, dec_hidden = self.decoder(torch.full((batch_sz,1,240), -1).to(device).cuda(), hidden)
        # dec_out: batch_szx1x512
        dec_out = dec_out.permute(1,0,2).squeeze(0)
        # batch_szx512
        dec_out = self.linear(dec_out)
        # batch_sz x 240
        outputs[0] = torch.reshape(dec_out, torch.Size([batch_sz, 60, 4]))
        
        if teach:
            next_in = torch.flatten(y[:,:,0,:].squeeze(2), start_dim=1).unsqueeze(1)
            # batch_szx240
        else:
            next_in = dec_out.unsqueeze(1)
        
        # output: batch_szx19x512
        # h_n: 1xbatch_szx512
        prevState = hidden.permute(1,0,2)
        inputStates = output
        
        for i in range(1,30):
            alignment = self.align1(torch.flatten(torch.cat((inputStates, prevState), 1), start_dim=1))
            #batch_szx19
            
            attention = torch.nn.functional.softmax(alignment, dim=1)
            attention = attention.unsqueeze(1)
            #batch_szx1x19
            
            new_hidden = torch.bmm(attention, inputStates)
            new_hidden = new_hidden.permute(1,0,2)
            #1xbatch_szx512
             
            dec_out, dec_hidden = self.decoder(next_in, new_hidden)
            # dec_out: batch_szx1x512
            dec_out = dec_out.permute(1,0,2).squeeze(0)
            # batch_szx512
            dec_out = self.linear(dec_out)
            # batch_sz x 240
            
            teaching = random.random() < teaching_ratio
            
            if teach and teaching:
                next_in = torch.flatten(y[:,:,i-1,:].squeeze(2), start_dim=1).unsqueeze(1)
            else:
                next_in = dec_out.unsqueeze(1)
                
            outputs[i] = torch.reshape(dec_out, torch.Size([batch_sz, 60, 4]))
            
            prevState = dec_hidden.permute(1,0,2)
        
        return outputs.permute(1,2,0,3)

model = RNNEncoderDecoder()
model.to(device)
if cuda_status:
    model = model.cuda()

### Visualize the batch of sequences

In [12]:
import matplotlib.pyplot as plt
import random
from tqdm.notebook import tqdm

agent_id = 0
epoch = 3
        
# Use the nn package to define our loss function
loss_fn=torch.nn.MSELoss()

# Use the optim package to define an Optimizer

learning_rate =1e-3
#learning_rate =0.01
#optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
iterator = tqdm(val_loader)

for i in trange(epoch):
    
    for i_batch, sample_batch in enumerate(iterator):
        inp, out, scene_ids, track_ids, agent_ids, car_mask = sample_batch
        """TODO:
          Deep learning model
          training routine
        """
        
        x = inp.float()
        y = out.float()

        if cuda_status:
            x.to(device)
            y.to(device)
            x = x.cuda()
            y = y.cuda()

        y_pred = None
        # Forward pass: predict y by passing x to the model.    
        y_pred = model(x, y, False)
        #y_pred = torch.reshape(y_pred, torch.Size([batch_sz, 60, 30, 4]))

        # Compute the loss.
        loss = loss_fn(y_pred, y)

        # Before backward pass, zero outgradients to clear buffers  
        optimizer.zero_grad()

        # Backward pass: compute gradient w.r.t modelparameters
        loss.backward()

        # makes an gradient descent step to update its parameters
        optimizer.step()
        
        print(torch.sqrt(loss).item(), end='\r')


HBox(children=(FloatProgress(value=0.0, max=402.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

{'city': 'PIT', 'lane': array([[3278.8716, 1968.7596,    0.    ],
       [3282.6606, 1972.2533,    0.    ],
       [3286.4429, 1975.7545,    0.    ],
       [3290.2249, 1979.2559,    0.    ],
       [3294.007 , 1982.7572,    0.    ],
       [3297.789 , 1986.2584,    0.    ],
       [3301.5713, 1989.7598,    0.    ],
       [3305.3533, 1993.2611,    0.    ],
       [3309.1355, 1996.7623,    0.    ],
       [3269.4802, 1967.0625,    0.    ],
       [3267.2764, 1965.0217,    0.    ],
       [3265.0664, 1962.9877,    0.    ],
       [3262.8562, 1960.9536,    0.    ],
       [3260.646 , 1958.9197,    0.    ],
       [3258.436 , 1956.8856,    0.    ],
       [3256.2258, 1954.8517,    0.    ],
       [3254.0132, 1952.8207,    0.    ],
       [3251.7898, 1950.8013,    0.    ],
       [3235.7622, 1935.9915,    0.    ],
       [3231.5237, 1932.1249,    0.    ],
       [3227.2852, 1928.2582,    0.    ],
       [3223.0923, 1924.343 ,    0.    ],
       [3218.9045, 1920.4214,    0.    ],
       [32

{'city': 'PIT', 'lane': array([[3855.859 , 2232.4053,    0.    ],
       [3859.7212, 2233.9846,    0.    ],
       [3863.5168, 2235.7087,    0.    ],
       [3867.3005, 2237.4583,    0.    ],
       [3871.1543, 2239.0593,    0.    ],
       [3875.049 , 2240.549 ,    0.    ],
       [3878.9993, 2241.8894,    0.    ],
       [3883.0676, 2242.8047,    0.    ],
       [3887.1987, 2243.3994,    0.    ],
       [3929.784 , 2252.613 ,    0.    ],
       [3927.6108, 2252.3289,    0.    ],
       [3925.4377, 2252.0447,    0.    ],
       [3923.2646, 2251.7605,    0.    ],
       [3921.0916, 2251.4763,    0.    ],
       [3918.9185, 2251.1921,    0.    ],
       [3916.7456, 2250.908 ,    0.    ],
       [3914.5725, 2250.6238,    0.    ],
       [3912.3994, 2250.3396,    0.    ],
       [3869.8562, 2227.841 ,    0.    ],
       [3868.2893, 2231.8813,    0.    ],
       [3866.7227, 2235.9214,    0.    ],
       [3865.1636, 2239.9646,    0.    ],
       [3863.648 , 2244.0225,    0.    ],
       [38

{'city': 'MIA', 'lane': array([[ 779.42664, 2408.2197 ,    0.     ],
       [ 778.8532 , 2408.2004 ,    0.     ],
       [ 778.27985, 2408.181  ,    0.     ],
       [ 777.7064 , 2408.1616 ,    0.     ],
       [ 777.13306, 2408.142  ,    0.     ],
       [ 776.5597 , 2408.1228 ,    0.     ],
       [ 775.98627, 2408.1033 ,    0.     ],
       [ 775.4129 , 2408.0837 ,    0.     ],
       [ 774.8395 , 2408.0645 ,    0.     ],
       [ 767.7139 , 2404.6855 ,    0.     ],
       [ 768.617  , 2404.7122 ,    0.     ],
       [ 769.52   , 2404.7385 ,    0.     ],
       [ 770.42303, 2404.765  ,    0.     ],
       [ 771.32605, 2404.7913 ,    0.     ],
       [ 772.22906, 2404.8176 ,    0.     ],
       [ 773.1321 , 2404.844  ,    0.     ],
       [ 774.03516, 2404.8704 ,    0.     ],
       [ 774.9382 , 2404.8967 ,    0.     ],
       [ 773.93054, 2408.0374 ,    0.     ],
       [ 773.02155, 2408.0103 ,    0.     ],
       [ 772.1126 , 2407.9832 ,    0.     ],
       [ 771.2036 , 2407.956  ,

{'city': 'PIT', 'lane': array([[2321.7842 ,  840.37213,    0.     ],
       [2323.066  ,  841.8145 ,    0.     ],
       [2323.9546 ,  843.44226,    0.     ],
       [2324.2334 ,  845.2709 ,    0.     ],
       [2323.8342 ,  847.1187 ,    0.     ],
       [2322.8538 ,  848.7902 ,    0.     ],
       [2321.6895 ,  850.3564 ,    0.     ],
       [2320.4426 ,  851.8506 ,    0.     ],
       [2319.0408 ,  853.20483,    0.     ],
       [2309.0488 ,  863.90686,    0.     ],
       [2308.0974 ,  864.9077 ,    0.     ],
       [2307.1462 ,  865.90857,    0.     ],
       [2306.197  ,  866.91156,    0.     ],
       [2305.2595 ,  867.92554,    0.     ],
       [2304.3162 ,  868.9339 ,    0.     ],
       [2303.3677 ,  869.9375 ,    0.     ],
       [2302.4192 ,  870.9411 ,    0.     ],
       [2301.4705 ,  871.94464,    0.     ],
       [2330.518  ,  854.5796 ,    0.     ],
       [2329.2285 ,  853.4998 ,    0.     ],
       [2327.9265 ,  852.43634,    0.     ],
       [2326.5186 ,  851.5277 ,

{'city': 'PIT', 'lane': array([[3166.9822, 1673.7092,    0.    ],
       [3165.898 , 1672.7544,    0.    ],
       [3164.814 , 1671.7996,    0.    ],
       [3163.7297, 1670.8448,    0.    ],
       [3162.6455, 1669.89  ,    0.    ],
       [3161.5613, 1668.9353,    0.    ],
       [3160.4773, 1667.9805,    0.    ],
       [3159.393 , 1667.0258,    0.    ],
       [3158.3088, 1666.0709,    0.    ],
       [3213.692 , 1641.4071,    0.    ],
       [3213.4585, 1641.6428,    0.    ],
       [3213.225 , 1641.8784,    0.    ],
       [3212.9868, 1642.1095,    0.    ],
       [3212.7485, 1642.3406,    0.    ],
       [3212.5105, 1642.5715,    0.    ],
       [3212.2722, 1642.8026,    0.    ],
       [3212.0342, 1643.0337,    0.    ],
       [3211.796 , 1643.2646,    0.    ],
       [3209.9124, 1645.1428,    0.    ],
       [3208.0288, 1647.021 ,    0.    ],
       [3206.1453, 1648.8992,    0.    ],
       [3204.2617, 1650.7773,    0.    ],
       [3202.3782, 1652.6554,    0.    ],
       [32

{'city': 'MIA', 'lane': array([[ 547.81964, 3005.2327 ,    0.     ],
       [ 546.1771 , 3004.6748 ,    0.     ],
       [ 544.6534 , 3003.852  ,    0.     ],
       ...,
       [ 550.35284, 3005.6516 ,    0.     ],
       [ 549.93054, 3005.654  ,    0.     ],
       [ 549.50824, 3005.6565 ,    0.     ]], dtype=float32), 'lane_norm': array([[-1.6886553 , -0.4239167 ,  0.        ],
       [-1.6424944 , -0.5578527 ,  0.        ],
       [-1.5237297 , -0.8227594 ,  0.        ],
       ...,
       [-0.42227837,  0.00237911,  0.        ],
       [-0.42227837,  0.00237911,  0.        ],
       [-0.42227837,  0.00237911,  0.        ]], dtype=float32), 'scene_idx': 100054, 'agent_id': '00000000-0000-0000-0000-000000070752', 'car_mask': array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
  

{'city': 'PIT', 'lane': array([[2188.4702 ,  753.46875,    0.     ],
       [2190.7456 ,  750.9887 ,    0.     ],
       [2193.208  ,  748.7103 ,    0.     ],
       [2196.18   ,  747.1589 ,    0.     ],
       [2199.4504 ,  746.4228 ,    0.     ],
       [2202.8052 ,  746.20966,    0.     ],
       [2206.1667 ,  746.3303 ,    0.     ],
       [2209.4521 ,  747.00336,    0.     ],
       [2212.3904 ,  748.6006 ,    0.     ],
       [2206.5957 ,  751.31665,    0.     ],
       [2204.5198 ,  749.6973 ,    0.     ],
       [2202.0674 ,  748.82556,    0.     ],
       [2199.6746 ,  749.6208 ,    0.     ],
       [2197.7922 ,  751.4288 ,    0.     ],
       [2196.0881 ,  753.44165,    0.     ],
       [2194.434  ,  755.4968 ,    0.     ],
       [2192.7803 ,  757.55194,    0.     ],
       [2191.1265 ,  759.6071 ,    0.     ],
       [2237.4492 ,  777.0036 ,    0.     ],
       [2233.8428 ,  774.0072 ,    0.     ],
       [2230.2366 ,  771.01086,    0.     ],
       [2226.6304 ,  768.01447,

{'city': 'PIT', 'lane': array([[2836.3933, 1365.9226,    0.    ],
       [2834.9219, 1364.5897,    0.    ],
       [2833.4507, 1363.2568,    0.    ],
       ...,
       [2795.1362, 1374.9858,    0.    ],
       [2795.443 , 1374.6564,    0.    ],
       [2795.75  , 1374.3269,    0.    ]], dtype=float32), 'lane_norm': array([[-1.4712793, -1.3328346,  0.       ],
       [-1.4712793, -1.3328346,  0.       ],
       [-1.4712793, -1.3328346,  0.       ],
       ...,
       [ 0.3069587, -0.3294321,  0.       ],
       [ 0.3069587, -0.3294321,  0.       ],
       [ 0.3069587, -0.3294321,  0.       ]], dtype=float32), 'scene_idx': 100078, 'agent_id': '00000000-0000-0000-0000-000000105523', 'car_mask': array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [0.],
       [0.],
       [0.],
       [0.],

{'city': 'PIT', 'lane': array([[2660.4836, 1387.2975,    0.    ],
       [2658.8218, 1389.0569,    0.    ],
       [2657.1602, 1390.8163,    0.    ],
       [2655.4985, 1392.5758,    0.    ],
       [2653.837 , 1394.3352,    0.    ],
       [2652.175 , 1396.0946,    0.    ],
       [2650.5134, 1397.8541,    0.    ],
       [2648.8518, 1399.6135,    0.    ],
       [2647.1902, 1401.3729,    0.    ],
       [2663.0964, 1402.0398,    0.    ],
       [2664.5662, 1403.39  ,    0.    ],
       [2666.036 , 1404.7402,    0.    ],
       [2667.5056, 1406.0905,    0.    ],
       [2668.9756, 1407.4407,    0.    ],
       [2670.4453, 1408.7909,    0.    ],
       [2671.915 , 1410.1411,    0.    ],
       [2673.385 , 1411.4913,    0.    ],
       [2674.8547, 1412.8416,    0.    ],
       [2646.5361, 1386.7582,    0.    ],
       [2648.4224, 1388.4996,    0.    ],
       [2650.3086, 1390.241 ,    0.    ],
       [2652.195 , 1391.9824,    0.    ],
       [2654.0813, 1393.7239,    0.    ],
       [26

{'city': 'PIT', 'lane': array([[2416.831  ,  928.1459 ,    0.     ],
       [2414.8352 ,  926.2918 ,    0.     ],
       [2412.8394 ,  924.4377 ,    0.     ],
       [2410.8423 ,  922.585  ,    0.     ],
       [2408.843  ,  920.7346 ,    0.     ],
       [2406.838  ,  918.89056,    0.     ],
       [2404.8352 ,  917.04395,    0.     ],
       [2402.8423 ,  915.18665,    0.     ],
       [2400.8496 ,  913.3294 ,    0.     ],
       [2485.3618 ,  991.58344,    0.     ],
       [2477.044  ,  983.8863 ,    0.     ],
       [2468.7212 ,  976.1946 ,    0.     ],
       [2460.3918 ,  968.5099 ,    0.     ],
       [2452.0757 ,  960.8109 ,    0.     ],
       [2443.7559 ,  953.11584,    0.     ],
       [2435.4333 ,  945.4239 ,    0.     ],
       [2427.0974 ,  937.7463 ,    0.     ],
       [2418.827  ,  930.     ,    0.     ],
       [2399.6138 ,  912.2013 ,    0.     ],
       [2398.3618 ,  911.09106,    0.     ],
       [2397.1096 ,  909.9808 ,    0.     ],
       [2395.8577 ,  908.8705 ,

{'city': 'MIA', 'lane': array([[ 559.5046 , 4004.9236 ,    0.     ],
       [ 550.3973 , 4004.625  ,    0.     ],
       [ 541.29004, 4004.3262 ,    0.     ],
       ...,
       [ 378.2971 , 3993.984  ,    0.     ],
       [ 376.6509 , 3993.9358 ,    0.     ],
       [ 375.00473, 3993.888  ,    0.     ]], dtype=float32), 'lane_norm': array([[-9.107284  , -0.29876003,  0.        ],
       [-9.107284  , -0.29876003,  0.        ],
       [-9.107284  , -0.29876003,  0.        ],
       ...,
       [-1.6461741 , -0.04798343,  0.        ],
       [-1.6461741 , -0.04798343,  0.        ],
       [-1.6461741 , -0.04798343,  0.        ]], dtype=float32), 'scene_idx': 100112, 'agent_id': '00000000-0000-0000-0000-000000044738', 'car_mask': array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
  

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



{'city': 'MIA', 'lane': array([[ 211.30756, 2343.6929 ,    0.     ],
       [ 211.1441 , 2343.6846 ,    0.     ],
       [ 210.98067, 2343.6765 ,    0.     ],
       ...,
       [ 154.0243 , 2325.284  ,    0.     ],
       [ 154.014  , 2322.5234 ,    0.     ],
       [ 154.00371, 2319.7627 ,    0.     ]], dtype=float32), 'lane_norm': array([[-0.1634445 , -0.00819138,  0.        ],
       [-0.1634445 , -0.00819138,  0.        ],
       [-0.1634445 , -0.00819138,  0.        ],
       ...,
       [-0.01029973, -2.7605793 ,  0.        ],
       [-0.01029973, -2.7605793 ,  0.        ],
       [-0.01029973, -2.7605793 ,  0.        ]], dtype=float32), 'scene_idx': 100198, 'agent_id': '00000000-0000-0000-0000-000000017149', 'car_mask': array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
  

{'city': 'MIA', 'lane': array([[ 406.76813, 1659.3225 ,    0.     ],
       [ 408.54257, 1659.4855 ,    0.     ],
       [ 410.1956 , 1660.1122 ,    0.     ],
       ...,
       [ 401.6829 , 1655.7319 ,    0.     ],
       [ 403.33102, 1655.7976 ,    0.     ],
       [ 404.97946, 1655.8195 ,    0.     ]], dtype=float32), 'lane_norm': array([[1.783487  , 0.09639139, 0.        ],
       [1.7744435 , 0.16296893, 0.        ],
       [1.653029  , 0.6266    , 0.        ],
       ...,
       [1.6481392 , 0.06563015, 0.        ],
       [1.6481392 , 0.06563015, 0.        ],
       [1.6484305 , 0.02194648, 0.        ]], dtype=float32), 'scene_idx': 100213, 'agent_id': '00000000-0000-0000-0000-000000062949', 'car_mask': array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       

{'city': 'MIA', 'lane': array([[ 584.4091 , 2105.4539 ,    0.     ],
       [ 586.358  , 2105.5125 ,    0.     ],
       [ 588.3069 , 2105.571  ,    0.     ],
       [ 590.25574, 2105.6296 ,    0.     ],
       [ 592.2046 , 2105.6882 ,    0.     ],
       [ 594.15344, 2105.7466 ,    0.     ],
       [ 596.1023 , 2105.8052 ,    0.     ],
       [ 598.05115, 2105.8638 ,    0.     ],
       [ 600.     , 2105.9224 ,    0.     ],
       [ 567.7272 , 2120.3804 ,    0.     ],
       [ 567.85034, 2117.8613 ,    0.     ],
       [ 568.287  , 2115.3816 ,    0.     ],
       [ 568.92523, 2112.9521 ,    0.     ],
       [ 570.0415 , 2110.7024 ,    0.     ],
       [ 571.61426, 2108.7324 ,    0.     ],
       [ 573.57416, 2107.153  ,    0.     ],
       [ 575.8246 , 2106.0276 ,    0.     ],
       [ 578.2311 , 2105.2734 ,    0.     ],
       [ 581.5442 , 2116.3857 ,    0.     ],
       [ 581.0305 , 2116.376  ,    0.     ],
       [ 580.5167 , 2116.367  ,    0.     ],
       [ 580.003  , 2116.358  ,

{'city': 'PIT', 'lane': array([[2551.5457, 1220.915 ,    0.    ],
       [2552.3306, 1221.6533,    0.    ],
       [2553.1523, 1222.3373,    0.    ],
       [2553.91  , 1223.1003,    0.    ],
       [2554.5176, 1223.9644,    0.    ],
       [2554.7034, 1225.0186,    0.    ],
       [2554.6208, 1226.0828,    0.    ],
       [2554.2073, 1227.0647,    0.    ],
       [2553.548 , 1227.9144,    0.    ],
       [2558.9697, 1222.3204,    0.    ],
       [2558.292 , 1223.0197,    0.    ],
       [2557.6143, 1223.7189,    0.    ],
       [2556.9365, 1224.4182,    0.    ],
       [2556.2588, 1225.1174,    0.    ],
       [2555.581 , 1225.8167,    0.    ],
       [2554.9033, 1226.516 ,    0.    ],
       [2554.2258, 1227.2152,    0.    ],
       [2553.548 , 1227.9144,    0.    ],
       [2622.2808, 1155.2864,    0.    ],
       [2621.7412, 1155.8472,    0.    ],
       [2621.2017, 1156.4081,    0.    ],
       [2620.662 , 1156.969 ,    0.    ],
       [2620.1226, 1157.5298,    0.    ],
       [26

{'city': 'MIA', 'lane': array([[ 110.606766, 3265.9146  ,    0.      ],
       [ 109.86044 , 3266.2178  ,    0.      ],
       [ 109.1218  , 3266.5544  ,    0.      ],
       ...,
       [ 126.76297 , 3327.2346  ,    0.      ],
       [ 126.76604 , 3328.6174  ,    0.      ],
       [ 126.769104, 3330.      ,    0.      ]], dtype=float32), 'lane_norm': array([[-0.81068623,  0.00518932,  0.        ],
       [-0.7463248 ,  0.30334026,  0.        ],
       [-0.73863584,  0.33658883,  0.        ],
       ...,
       [ 0.0030659 ,  1.3826602 ,  0.        ],
       [ 0.0030659 ,  1.3826602 ,  0.        ],
       [ 0.0030659 ,  1.3826602 ,  0.        ]], dtype=float32), 'scene_idx': 100246, 'agent_id': '00000000-0000-0000-0000-000000046296', 'car_mask': array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.

{'city': 'MIA', 'lane': array([[ 106.74009, 2331.1863 ,    0.     ],
       [ 105.05172, 2331.1187 ,    0.     ],
       [ 103.36335, 2331.0508 ,    0.     ],
       ...,
       [ 154.0243 , 2325.284  ,    0.     ],
       [ 154.014  , 2322.5234 ,    0.     ],
       [ 154.00371, 2319.7627 ,    0.     ]], dtype=float32), 'lane_norm': array([[-1.6883711 , -0.0678477 ,  0.        ],
       [-1.6883711 , -0.0678477 ,  0.        ],
       [-1.6883711 , -0.0678477 ,  0.        ],
       ...,
       [-0.01029973, -2.7605793 ,  0.        ],
       [-0.01029973, -2.7605793 ,  0.        ],
       [-0.01029973, -2.7605793 ,  0.        ]], dtype=float32), 'scene_idx': 100257, 'agent_id': '00000000-0000-0000-0000-000000028592', 'car_mask': array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
  

{'city': 'MIA', 'lane': array([[ 554.06866, 2350.8757 ,    0.     ],
       [ 555.2835 , 2350.7217 ,    0.     ],
       [ 556.4623 , 2350.4028 ,    0.     ],
       ...,
       [ 553.5297 , 2397.1045 ,    0.     ],
       [ 553.8585 , 2397.1177 ,    0.     ],
       [ 554.1873 , 2397.1216 ,    0.     ]], dtype=float32), 'lane_norm': array([[ 1.2221984 , -0.12197297,  0.        ],
       [ 1.2148497 , -0.1538906 ,  0.        ],
       [ 1.1788079 , -0.31897703,  0.        ],
       ...,
       [ 0.328852  ,  0.01710322,  0.        ],
       [ 0.3288244 ,  0.01318312,  0.        ],
       [ 0.32875937,  0.00393784,  0.        ]], dtype=float32), 'scene_idx': 100271, 'agent_id': '00000000-0000-0000-0000-000000096067', 'car_mask': array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
  

{'city': 'MIA', 'lane': array([[ 573.872  , 1888.8588 ,    0.     ],
       [ 573.9139 , 1887.7175 ,    0.     ],
       [ 573.95575, 1886.5762 ,    0.     ],
       [ 573.9976 , 1885.4349 ,    0.     ],
       [ 574.0395 , 1884.2937 ,    0.     ],
       [ 574.0814 , 1883.1525 ,    0.     ],
       [ 574.1233 , 1882.0112 ,    0.     ],
       [ 574.16516, 1880.87   ,    0.     ],
       [ 574.20703, 1879.7286 ,    0.     ],
       [ 562.03516, 1910.2266 ,    0.     ],
       [ 563.7824 , 1910.2301 ,    0.     ],
       [ 565.51733, 1910.0596 ,    0.     ],
       [ 567.079  , 1909.314  ,    0.     ],
       [ 568.31726, 1908.0927 ,    0.     ],
       [ 569.24866, 1906.6204 ,    0.     ],
       [ 569.8999 , 1905.004  ,    0.     ],
       [ 570.2895 , 1903.3058 ,    0.     ],
       [ 570.42786, 1901.5675 ,    0.     ],
       [ 577.0103 , 1891.283  ,    0.     ],
       [ 576.96655, 1892.5659 ,    0.     ],
       [ 576.9228 , 1893.8488 ,    0.     ],
       [ 576.879  , 1895.1317 ,





KeyboardInterrupt: 

In [10]:
torch.save(model, './models/3epoch-RNN-Encoder-Decoder-Attention-512-batch-no-teach.pt')

  "type " + obj.__name__ + ". It won't be checked "


In [None]:
model = torch.load('./models/3epoch-RNN-Encoder-Decoder-Attention.pt')
model.eval()
model.to(device)
if cuda_status:
    model = model.cuda()

In [None]:
import pandas as pd

# Submission output
writeCSV = True
val_path = "./new_val_in/new_val_in"

if writeCSV:
    
    dataset = ArgoverseDataset(data_path=val_path)
    test_loader = DataLoader(dataset,batch_size=64, shuffle = False, collate_fn=test_collate, num_workers=0)
    
    data = []
    
    with torch.no_grad():
        for i_batch, sample_batch in enumerate(tqdm(test_loader)):
            inp, scene_ids, track_ids, agent_ids = sample_batch

            if cuda_status:
                model = model.cuda()
                x = inp.float().cuda()
            else:
                x = inp.float()

            y_pred = None

            # Forward pass: predict y by passing x to the model.    
            y_pred = model(x, None, False)
            y_pred = torch.reshape(y_pred, torch.Size([batch_sz, 60, 30, 4]))
            
            for i in range(batch_sz):
                row = []
                row.append(scene_ids[i].item())
                curr = y_pred[i]
                
                agent_id = agent_ids[i]
                
                for j in range(30):
                    vehicle_index = 0
                    found = False
                    while not found:
                        if track_ids[i][vehicle_index][j][0] == agent_id:
                            found = True
                        else:
                            vehicle_index += 1

                    row.append(str(curr[vehicle_index][j][0].item()))
                    row.append(str(curr[vehicle_index][j][1].item()))
                    
                data.append(row)

    df = pd.DataFrame(data, columns = ['ID','v1','v2','v3','v4','v5','v6','v7','v8','v9','v10','v11','v12','v13','v14','v15','v16','v17','v18','v19','v20','v21','v22','v23','v24','v25','v26','v27','v28','v29','v30','v31','v32','v33','v34','v35','v36','v37','v38','v39','v40','v41','v42','v43','v44','v45','v46','v47','v48','v49','v50','v51','v52','v53','v54','v55','v56','v57','v58','v59','v60'])
    print(df)
    df.to_csv('submission.csv', index=False)
                
                
                