In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, os.path 
import numpy 
import pickle
from glob import glob

"""Change to the data folder"""
new_path = "./new_train/new_train"
new_test = "./new_val_in/new_val_in"



# number of sequences in each dataset
# train:205942  val:3200 test: 36272 
# sequences sampled at 10HZ rate

In [2]:
class ArgoverseDataset(Dataset):
    """Dataset class for Argoverse"""
    def __init__(self, data_path: str, transform=None):
        super(ArgoverseDataset, self).__init__()
        self.data_path = data_path
        self.transform = transform

        self.pkl_list = glob(os.path.join(self.data_path, '*'))
        self.pkl_list.sort()
        
    def __len__(self):
        return len(self.pkl_list)

    def __getitem__(self, idx):

        pkl_path = self.pkl_list[idx]
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)
            
        if self.transform:
            data = self.transform(data)

        return data


# intialize a dataset
val_dataset  = ArgoverseDataset(data_path=new_path)
val_testset  = ArgoverseDataset(data_path=new_test)

In [3]:
batch_sz = 128
batch_sz_test = 1
def my_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    out = [numpy.dstack([scene['p_out'], scene['v_out']]) for scene in batch]
    
    inp = torch.FloatTensor(inp)
    out = torch.FloatTensor(out)
    return [inp, out]

val_loader = DataLoader(val_dataset,batch_size=batch_sz, shuffle = True, collate_fn=my_collate, num_workers=1)

In [4]:
def my_collate1(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    scene_idx = [numpy.dstack([scene['scene_idx']]) for scene in batch]
    inp = torch.FloatTensor(inp)
    return scene_idx, inp

test_loader = DataLoader(val_testset, batch_size=batch_sz_test, shuffle=False, collate_fn=my_collate1, num_workers = 1)

In [5]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class FCModule(nn.Module):
    
    def __init__(self):
        super(FCModule, self).__init__()
        
        self.linear1 = nn.Sequential(
                                    nn.Linear(240*19, 8000), 
                                     nn.ReLU(),
                                     nn.Linear(8000, 10000), 
                                     nn.ReLU(),
                                     nn.Linear(10000, 8000), 
                                     nn.ReLU(),            
                                     nn.Linear(8000, 240 * 30)
                                    )

    def forward(self, x):
        x = self.linear1(x)
        
        return x

In [6]:
from tqdm import tqdm_notebook as tqdm
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

device = "cuda:0"
model = FCModule().to(device)
optimizer = optim.Adam(model.parameters(), lr = 1e-3)

loss_ema = -1
loss_ema2 = -1

for epoch in range(5):
    for i_batch, sample_batch in enumerate (val_loader):
        inp,out = sample_batch
        inp = inp.cuda()
        out = out.cuda()
        mixed = torch.cat([inp,out],2).transpose(1,2).reshape(-1,49,240)
        y_pred = model(inp.reshape(len(inp),-1)).reshape(-1,60,30,4)
        loss = (torch.mean((y_pred-out)**2))**0.5
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if loss_ema < 0:
            loss_ema = loss
        loss_ema = loss_ema*0.99 + loss*0.01
        
        if i_batch%10 == 0:
            print('loss full', epoch, i_batch, loss_ema.item(), loss.item())
       
    # Save the model after every epoch.
    torch.save(model.state_dict(),"usingSequential")

loss full 0 0 561.4251098632812 561.4251098632812
loss full 0 10 548.3597412109375 350.57769775390625
loss full 0 20 523.02880859375 241.37503051757812
loss full 0 30 493.0355224609375 188.54757690429688
loss full 0 40 462.55255126953125 144.61856079101562
loss full 0 50 432.34503173828125 140.5107879638672
loss full 0 60 403.5177307128906 120.67237854003906
loss full 0 70 376.5809020996094 118.95751953125
loss full 0 80 350.5466613769531 99.45633697509766
loss full 0 90 326.8785400390625 105.22321319580078
loss full 0 100 304.3461608886719 79.65991973876953
loss full 0 110 283.2421875 84.23167419433594
loss full 0 120 264.0826110839844 76.66941833496094
loss full 0 130 246.1168212890625 70.75442504882812
loss full 0 140 228.97976684570312 65.83231353759766
loss full 0 150 213.18077087402344 64.01181030273438
loss full 0 160 198.9647979736328 64.25509643554688
loss full 0 170 185.636474609375 60.630271911621094
loss full 0 180 174.03329467773438 79.82838439941406
loss full 0 190 162.65

In [7]:
import csv
import pandas as pd
import numpy as np
import gc

gc.collect()

torch.cuda.empty_cache()

device = "cuda:0"
model = FCModule().to(device)
model.load_state_dict(torch.load("usingSequential"))

filename = "output_test.csv"

def test(model, device, test_loader):
    model.eval()
    store = [[]]
    strArr = []
    strArr.append('ID')
    for i in range(60):
        strArr.append('v' + str(i + 1))
    

    with torch.no_grad():
        for scene_idx, data in test_loader:
            data = data.to(device)
            y_pred = model(data.reshape(len(data),-1)).reshape(-1,60,30,4)
            y_pred = y_pred[-1,0,:,0:2]
            y_pred = y_pred.cpu()
            y_pred_np = y_pred.numpy()
            y_pred_np = y_pred_np.flatten()
            y_pred_np = np.insert(y_pred_np, 0, scene_idx[0][0][0][0])
            store.append(y_pred_np)

    with open(filename, 'w',  newline='') as csvfile:
        csvwriter = csv.writer(csvfile)
        csvwriter.writerow(strArr) 
        csvwriter.writerows(store)

test(model, device, test_loader)

In [8]:
import pandas as pd

file = pd.read_csv('output_test.csv')

file["ID"] = file["ID"].astype('int32')

file.to_csv(path_or_buf="./output_test.csv", header=True, index=False)