In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, os.path 
import numpy 
import pickle
from glob import glob
from tqdm.notebook import tqdm, trange

"""Change to the data folder"""
new_path = "./new_train/new_train"

cuda_status = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# number of sequences in each dataset
# train:205942  val:3200 test: 36272 
# sequences sampled at 10HZ rate

### Create a dataset class 

In [2]:
class ArgoverseDataset(Dataset):
    """Dataset class for Argoverse"""
    def __init__(self, data_path: str, transform=None):
        super(ArgoverseDataset, self).__init__()
        self.data_path = data_path
        self.transform = transform

        self.pkl_list = glob(os.path.join(self.data_path, '*'))
        self.pkl_list.sort()
        
    def __len__(self):
        return len(self.pkl_list)

    def __getitem__(self, idx):

        pkl_path = self.pkl_list[idx]
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)
            
        if self.transform:
            data = self.transform(data)

        return data


# intialize a dataset
val_dataset  = ArgoverseDataset(data_path=new_path)

### Create a loader to enable batch processing

In [3]:
batch_sz = 64

def my_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    out = [numpy.dstack([scene['p_out'], scene['v_out']]) for scene in batch]
    scene_ids = [scene['scene_idx'] for scene in batch]
    track_ids = [scene['track_id'] for scene in batch]
    agent_ids = [scene['agent_id'] for scene in batch]
    inp = torch.LongTensor(inp)
    out = torch.LongTensor(out)
    scene_ids = torch.LongTensor(scene_ids)
    return [inp, out, scene_ids, track_ids, agent_ids]

def test_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    scene_ids = [scene['scene_idx'] for scene in batch]
    track_ids = [scene['track_id'] for scene in batch]
    agent_ids = [scene['agent_id'] for scene in batch]
    inp = torch.LongTensor(inp)
    scene_ids = torch.LongTensor(scene_ids)
    return [inp, scene_ids, track_ids, agent_ids]

val_loader = DataLoader(val_dataset,batch_size=batch_sz, shuffle = False, collate_fn=my_collate, num_workers=0)

In [4]:
model = torch.nn.Sequential(
    torch.nn.Linear(76, 32),
    torch.nn.ReLU(),
    torch.nn.Linear(32, 32),
    torch.nn.ReLU(),
    torch.nn.Linear(32, 32),
    torch.nn.ReLU(),
    torch.nn.Linear(32, 4)
)
model.to(device)
if cuda_status:
    model = model.cuda()

### Visualize the batch of sequences

In [5]:
import matplotlib.pyplot as plt
import random
from tqdm.notebook import tqdm

agent_id = 0
epoch = 3

def show_sample_batch(sample_batch, agent_id):
    """visualize the trajectory for a batch of samples with a randon agent"""
    inp, out, scene_ids, track_ids, agent_ids = sample_batch
    batch_sz = inp.size(0)
    agent_sz = inp.size(1)
    
    fig, axs = plt.subplots(1,batch_sz, figsize=(15, 3), facecolor='w', edgecolor='k')
    fig.subplots_adjust(hspace = .5, wspace=.001)
    axs = axs.ravel()   
    for i in range(batch_sz):
        axs[i].xaxis.set_ticks([])
        axs[i].yaxis.set_ticks([])
        
        # first two feature dimensions are (x,y) positions
        axs[i].scatter(inp[i, agent_id,:,0], inp[i, agent_id,:,1])
        axs[i].scatter(out[i, agent_id,:,0], out[i, agent_id,:,1])
        
# Use the nn package to define our loss function
loss_fn=torch.nn.MSELoss()

# Use the optim package to define an Optimizer

learning_rate =1e-3
#optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=0.001)
#optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate, lr_decay=0.01)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
iterator = tqdm(val_loader)

for i in trange(epoch):
    
    for i_batch, sample_batch in enumerate(iterator):
        inp, out, scene_ids, track_ids, agent_ids = sample_batch
        """TODO:
          Deep learning model
          training routine
        """
        if i_batch >= 3216:
            #show_sample_batch(sample_batch, agent_id)
            #show_sample_batch([inp, y_pred.cpu().detach(), scene_ids, track_ids, agent_ids], agent_id)
            continue

        x = inp.float()
        y = out.float()

        if cuda_status:
            #model = model.cuda()
            #x = inp.cuda()
            #y = out.cuda()
            x.to(device)
            y.to(device)
            x = x.cuda()
            y = y.cuda()

        y_pred = torch.zeros(30,64,60,4, device=device).cuda()

        # Forward pass: predict y by passing x to the model.   
        for j in range(30):
            #if j < 19:
            #    diff = 19 - j
            #    next_x = torch.cat((x[:,:,j:19,:], y[:,:,0:19 - diff,:]), 2)
            #    next_x = torch.flatten(next_x, start_dim=2)
                # 64 x 60 x 19 x 4
            #    curr_y = model(next_x)
            #else:
            #    next_x = y[:,:,j - 19:j,:]
            #    next_x = torch.flatten(next_x, start_dim=2)
                # 64 x 60 x 19 x 4
            #    curr_y = model(next_x)
            # 64 x 60 x 4
            #y_pred[j] = curr_y
            if j < 19:
                diff = 19 - j

                next_x = torch.cat((x[:,:,j:19,:], y_pred.permute(1,2,0,3)[:,:,0:19 - diff,:]), 2)
                next_x = torch.flatten(next_x, start_dim=2)
                # 64 x 60 x 19 x 4
                curr_y = model(next_x)
            else:
                next_x = y_pred.permute(1,2,0,3)[:,:,j - 19:j,:]
                next_x = torch.flatten(next_x, start_dim=2)
                # 64 x 60 x 19 x 4
                curr_y = model(next_x)
                # 64 x 60 x 4
            y_pred[j] = curr_y
            
        #y_pred = torch.reshape(y_pred, torch.Size([batch_sz, 60, 30, 2]))
        y_pred = y_pred.permute(1,2,0,3)

        # Compute the loss.
        loss = loss_fn(y_pred, y)

        # Before backward pass, zero outgradients to clear buffers  
        optimizer.zero_grad()

        # Backward pass: compute gradient w.r.t modelparameters
        loss.backward()

        # makes an gradient descent step to update its parameters
        optimizer.step()
        
        print(torch.sqrt(loss).item(), end='\r')


HBox(children=(FloatProgress(value=0.0, max=3218.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

3.3875164985656747
1.7350105047225952


In [6]:
torch.save(model, './models/3epochlinearincrementalnoteachingadamextralayer.pt')

In [None]:
model = torch.load('./models/6epochmodel.pt')
model.eval()
model.to(device)
if cuda_status:
    model = model.cuda()

In [7]:
import pandas as pd

# Submission output
writeCSV = True
val_path = "./new_val_in/new_val_in"

if writeCSV:
    
    dataset = ArgoverseDataset(data_path=val_path)
    test_loader = DataLoader(dataset,batch_size=batch_sz, shuffle = False, collate_fn=test_collate, num_workers=0)
    
    data = []
    
    with torch.no_grad():
        for i_batch, sample_batch in enumerate(tqdm(test_loader)):
            inp, scene_ids, track_ids, agent_ids = sample_batch

            if cuda_status:
                model = model.cuda()
                x = inp.float().to(device).cuda()
            else:
                x = inp.float()

            y_pred = torch.zeros(30,64,60,4).to(device).cuda()
            
            # Forward pass: predict y by passing x to the model.   
            for j in range(30):
                if j < 19:
                    diff = 19 - j

                    next_x = torch.cat((x[:,:,j:19,:], y_pred.permute(1,2,0,3)[:,:,0:19 - diff,:]), 2)
                    next_x = torch.flatten(next_x, start_dim=2)
                    # 64 x 60 x 19 x 4
                    curr_y = model(next_x)
                else:
                    next_x = y_pred.permute(1,2,0,3)[:,:,j - 19:j,:]
                    next_x = torch.flatten(next_x, start_dim=2)
                    # 64 x 60 x 19 x 4
                    curr_y = model(next_x)
                # 64 x 60 x 4
                y_pred[j] = curr_y
            #y_pred = torch.reshape(y_pred, torch.Size([batch_sz, 60, 30, 2]))
            y_pred = y_pred.permute(1,2,0,3)
            
            for i in range(batch_sz):
                row = []
                row.append(scene_ids[i].item())
                curr = y_pred[i]
                
                agent_id = agent_ids[i]
                
                for j in range(30):
                    vehicle_index = 0
                    found = False
                    while not found:
                        if track_ids[i][vehicle_index][j][0] == agent_id:
                            found = True
                        else:
                            vehicle_index += 1

                    row.append(str(curr[vehicle_index][j][0].item()))
                    row.append(str(curr[vehicle_index][j][1].item()))
                    
                data.append(row)

    df = pd.DataFrame(data, columns = ['ID','v1','v2','v3','v4','v5','v6','v7','v8','v9','v10','v11','v12','v13','v14','v15','v16','v17','v18','v19','v20','v21','v22','v23','v24','v25','v26','v27','v28','v29','v30','v31','v32','v33','v34','v35','v36','v37','v38','v39','v40','v41','v42','v43','v44','v45','v46','v47','v48','v49','v50','v51','v52','v53','v54','v55','v56','v57','v58','v59','v60'])
    print(df)
    df.to_csv('submission.csv', index=False)
                
                
                

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))


         ID                  v1                  v2                  v3  \
0     10002  1644.1842041015625  458.61651611328125     1643.9541015625   
1     10015   712.4559936523438    1234.95654296875         716.8359375   
2     10019   572.0779418945312   1245.367431640625     572.10595703125   
3     10028   1615.775146484375   445.6374206542969  1614.7332763671875   
4      1003      2111.544921875   691.9639892578125    2108.16259765625   
...     ...                 ...                 ...                 ...   
3195   9897  253.69850158691406    807.767822265625  254.44679260253906   
3196     99   582.8790283203125        1157.9140625     582.59130859375   
3197   9905   1713.667724609375   515.9618530273438   1711.799072265625   
3198   9910   574.7551879882812  1289.3184814453125       574.634765625   
3199   9918      580.0478515625   1167.566650390625     580.77197265625   

                      v4                  v5                  v6  \
0      459.5924987792969  1644