## Install the package dependencies before running this notebook

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, os.path 
import numpy 
import pickle
from glob import glob
import matplotlib.pyplot as plt

"""
    number of trajectories in each city
    # austin --  train: 43041 test: 6325 
    # miami -- train: 55029 test:7971
    # pittsburgh -- train: 43544 test: 6361
    # dearborn -- train: 24465 test: 3671
    # washington-dc -- train: 25744 test: 3829
    # palo-alto -- train:  11993 test:1686

    trajectories sampled at 10HZ rate, input 5 seconds, output 6 seconds
    
"""

'\n    number of trajectories in each city\n    # austin --  train: 43041 test: 6325 \n    # miami -- train: 55029 test:7971\n    # pittsburgh -- train: 43544 test: 6361\n    # dearborn -- train: 24465 test: 3671\n    # washington-dc -- train: 25744 test: 3829\n    # palo-alto -- train:  11993 test:1686\n\n    trajectories sampled at 10HZ rate, input 5 seconds, output 6 seconds\n    \n'

## Create a Torch.Dataset class for the training dataset

In [2]:
from glob import glob
import pickle
import numpy as np

ROOT_PATH = "./"

cities = ["austin", "miami", "pittsburgh", "dearborn", "washington-dc", "palo-alto"]
splits = ["train", "test"]

def get_city_trajectories(city="palo-alto", split="train", normalized=False):

    
    outputs = None
    
    if split=="train":
        f_in = ROOT_PATH + split + "/" + city + "_inputs"
        inputs = pickle.load(open(f_in, "rb"))
        n = len(inputs)
        inputs = np.asarray(inputs)[:int(n * 0.8)]
        
        f_out = ROOT_PATH + split + "/" + city + "_outputs"
        outputs = pickle.load(open(f_out, "rb"))
        outputs = np.asarray(outputs)[:int(n * 0.8)]
        
    elif split == 'val':
        f_in = ROOT_PATH + 'train' + "/" + city + "_inputs"
        inputs = pickle.load(open(f_in, "rb"))
        n = len(inputs)
        inputs = np.asarray(inputs)[int(n * 0.8):]
        
        f_out = ROOT_PATH + 'train' + "/" + city + "_outputs"
        outputs = pickle.load(open(f_out, "rb"))
        outputs = np.asarray(outputs)[int(n * 0.8):]
    
    else:
        f_in = ROOT_PATH + split + "/" + city + "_inputs"
        inputs = pickle.load(open(f_in, "rb"))
        n = len(inputs)
        inputs = np.asarray(inputs)

    return inputs, outputs

class ArgoverseDataset(Dataset):
    """Dataset class for Argoverse"""
    def __init__(self, city: str, split:str, transform=None):
        super(ArgoverseDataset, self).__init__()
        self.transform = transform

        self.inputs, self.outputs = get_city_trajectories(city=city, split=split, normalized=False)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):

        data = (self.inputs[idx], self.outputs[idx])
            
        if self.transform:
            data = self.transform(data)

        return data

## Create a DataLoader class for training

In [3]:
torch.cuda.current_device()

0

In [4]:
!nvidia-smi

Fri May 20 18:09:17 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 384.81                 Driver Version: 460.67                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce RTX 208...  Off  | 00000000:B1:00.0 Off |                  N/A |
|  0%   31C    P8    10W / 250W |   1436MiB / 11019MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
Internal

In [5]:
from torch import nn, optim

class Pred(nn.Module):

    def __init__(self):
        super().__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(100, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        
        
        self.decoder = nn.Sequential(
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 120)
        )
        
    def forward(self, x):
        x = x.reshape(-1, 100).float()
        x = self.encoder(x)
        x = self.decoder(x)
        x = x.reshape(-1, 60, 2)
        return x

In [6]:
def train(pred, opt, train_dataset, train_loader, val_dataset, val_loader):
    device = torch.device('cuda:0')
    pred = pred.to(device)
    train_losses = []
    val_losses = []
    
    early_stop_counter = 0
    early_stop_criteria = 25
    dummy = 1
    for epoch in range(120):

        total_loss = 0
        for i_batch, sample_batch in enumerate(train_loader):
            inp, out = sample_batch
            out = out.to(device)
            inp = inp.to(device)
            preds = pred(inp)
            loss = ((preds - out) ** 2).sum()

            opt.zero_grad()
            loss.backward()
            opt.step()

            total_loss += loss.item()

        val_loss = 0
        for i_batch, sample_batch in enumerate(val_loader):
            inp, out = sample_batch
            out = out.to(device)
            inp = inp.to(device)
            preds = pred(inp)
            #print(preds)
            loss = ((preds - out) ** 2).sum()
            val_loss += loss.item()

        train_loss = np.log(total_loss / len(train_dataset))
        val_loss = np.log(val_loss / len(val_dataset))

        if dummy == 1:
            dummy -= 1
        else: 
            last_valid = val_losses[-1]
            if last_valid < val_loss:
                early_stop_counter += 1
            else:
                pickle.dump(pred, open('models/ta_model_baseline_' + city + '_large', 'wb'))

            if early_stop_counter == early_stop_criteria:
                break
            
        print('epoch {} train_loss: {} val_loss: {}'.format(epoch, train_loss, val_loss))
        train_losses.append(train_loss)
        val_losses.append(val_loss)


    plt.title("Loss")
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.plot(train_losses, color ="red", label = "train_loss")
    plt.plot(val_losses, color ="blue", label = "val_loss")
    plt.legend()
    plt.show()

In [7]:
%%time
import pickle
#train city models
for city in cities:
    print('city: ' + city)
    batch_sz = 128  # batch size
    train_dataset  = ArgoverseDataset(city = city, split = 'train')
    train_loader = DataLoader(train_dataset,batch_size=batch_sz)
    val_dataset = ArgoverseDataset(city = city, split = 'val')
    val_loader = DataLoader(val_dataset,batch_size=batch_sz)
    
    pred = Pred()
    opt = optim.Adam(pred.parameters(), lr=1e-4)
    train(pred, opt, train_dataset, train_loader, val_dataset, val_loader)
#     pickle.dump(pred, open('models/ta_model_baseline_' + city + '_large', 'wb'))

city: austin


IndexError: list index out of range

In [8]:
import pandas as pd
cols = np.array(['v' + str(i) for i in range(120)])
all_preds = []
for city in cities:
    load_pred = pickle.load(open('models/ta_model_baseline_' + city + '_large', 'rb'))
    test_dataset = get_city_trajectories(city = city, split = 'test')
    device = torch.device('cuda:0')
    load_pred = load_pred.to(device)
    preds = load_pred(torch.from_numpy(test_dataset[0]).to(device))
    preds_reshaped = preds.reshape(preds.size()[0], 120)
    preds_numpy = preds_reshaped.cpu().detach().numpy()
    ids = np.array([str(i) + '_' + city for i in range(len(preds_numpy))])
    predictions = pd.DataFrame(preds_numpy, columns=cols)
    predictions.insert(0, 'ID', ids)
    all_preds.append(predictions)
    
all_predictions = pd.concat(all_preds, ignore_index = True)

In [9]:
all_predictions

Unnamed: 0,ID,v0,v1,v2,v3,v4,v5,v6,v7,v8,...,v110,v111,v112,v113,v114,v115,v116,v117,v118,v119
0,0_austin,-28.801285,-560.009399,-27.892868,-561.753113,-31.578941,-558.108459,-31.028038,-562.863953,-27.999935,...,-32.094719,-560.802612,-31.520973,-560.729370,-30.207336,-561.039795,-30.379959,-562.401001,-33.874695,-558.281067
1,1_austin,-351.791077,-16.518909,-350.631744,-11.963470,-349.960754,-16.469337,-350.095886,-11.953239,-348.949127,...,-349.922546,-18.126869,-346.669312,-17.889891,-345.552368,-15.173858,-344.196350,-16.205135,-345.413177,-15.435160
2,2_austin,52.436302,-248.489532,52.140957,-250.146164,52.351112,-249.089142,51.662369,-249.291809,51.956688,...,52.267677,-249.162445,52.596104,-249.651611,52.855629,-249.619232,52.745193,-249.201508,52.498756,-249.133041
3,3_austin,-109.968552,1789.038696,-109.238091,1790.392700,-106.411926,1787.224121,-105.454437,1787.998535,-104.229691,...,-94.058929,1785.341187,-93.369354,1781.733032,-97.497498,1784.602173,-94.718445,1788.680664,-94.420776,1785.015503
4,4_austin,1218.504883,-654.344177,1218.668457,-650.626709,1222.852783,-655.712952,1216.936890,-651.783203,1222.586182,...,1237.391113,-656.272949,1232.532959,-658.153809,1235.907471,-659.775391,1233.119873,-651.960144,1236.646729,-656.925293
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29838,1681_palo-alto,-1376.467407,-464.125031,-1384.580322,-458.017181,-1384.734375,-462.706421,-1379.104614,-462.642120,-1379.713989,...,-1373.551147,-460.315674,-1370.754395,-462.993774,-1377.638794,-465.140717,-1377.243774,-468.786407,-1376.171143,-456.997375
29839,1682_palo-alto,128.995255,-35.810993,128.545120,-35.341633,128.960480,-35.607285,129.079666,-37.788258,130.129791,...,128.674194,-37.168591,130.155670,-35.534267,129.042145,-35.315475,130.105423,-36.422218,129.495880,-35.221157
29840,1683_palo-alto,-1447.390747,2154.197266,-1450.735596,2152.534668,-1446.267334,2149.142822,-1445.245117,2154.157959,-1448.742432,...,-1445.182739,2152.987549,-1444.835815,2154.348633,-1443.962036,2155.707520,-1442.448730,2156.274658,-1442.447266,2153.673828
29841,1684_palo-alto,1054.554443,1372.162842,1057.992676,1371.576294,1053.897827,1374.461670,1056.356689,1369.133179,1052.059448,...,1050.925171,1368.172607,1052.122925,1377.685547,1056.445801,1371.140869,1056.937622,1370.141113,1051.179932,1367.896973


In [10]:
all_predictions.to_csv('out.csv', index=False)

In [11]:
pd.read_csv('out.csv')

Unnamed: 0,ID,v0,v1,v2,v3,v4,v5,v6,v7,v8,...,v110,v111,v112,v113,v114,v115,v116,v117,v118,v119
0,0_austin,-28.801285,-560.009400,-27.892868,-561.753100,-31.578941,-558.108460,-31.028038,-562.863950,-27.999935,...,-32.094720,-560.80260,-31.520973,-560.729400,-30.207336,-561.039800,-30.379960,-562.401000,-33.874695,-558.281070
1,1_austin,-351.791080,-16.518910,-350.631740,-11.963470,-349.960750,-16.469337,-350.095900,-11.953239,-348.949130,...,-349.922550,-18.12687,-346.669300,-17.889890,-345.552370,-15.173858,-344.196350,-16.205135,-345.413180,-15.435160
2,2_austin,52.436302,-248.489530,52.140957,-250.146160,52.351112,-249.089140,51.662370,-249.291810,51.956688,...,52.267677,-249.16245,52.596104,-249.651610,52.855630,-249.619230,52.745193,-249.201500,52.498756,-249.133040
3,3_austin,-109.968550,1789.038700,-109.238090,1790.392700,-106.411930,1787.224100,-105.454440,1787.998500,-104.229690,...,-94.058930,1785.34120,-93.369354,1781.733000,-97.497500,1784.602200,-94.718445,1788.680700,-94.420780,1785.015500
4,4_austin,1218.504900,-654.344200,1218.668500,-650.626700,1222.852800,-655.712950,1216.936900,-651.783200,1222.586200,...,1237.391100,-656.27295,1232.533000,-658.153800,1235.907500,-659.775400,1233.119900,-651.960140,1236.646700,-656.925300
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29838,1681_palo-alto,-1376.467400,-464.125030,-1384.580300,-458.017180,-1384.734400,-462.706420,-1379.104600,-462.642120,-1379.714000,...,-1373.551100,-460.31567,-1370.754400,-462.993770,-1377.638800,-465.140720,-1377.243800,-468.786400,-1376.171100,-456.997380
29839,1682_palo-alto,128.995250,-35.810993,128.545120,-35.341633,128.960480,-35.607285,129.079670,-37.788258,130.129790,...,128.674200,-37.16859,130.155670,-35.534267,129.042140,-35.315475,130.105420,-36.422220,129.495880,-35.221157
29840,1683_palo-alto,-1447.390700,2154.197300,-1450.735600,2152.534700,-1446.267300,2149.142800,-1445.245100,2154.158000,-1448.742400,...,-1445.182700,2152.98750,-1444.835800,2154.348600,-1443.962000,2155.707500,-1442.448700,2156.274700,-1442.447300,2153.673800
29841,1684_palo-alto,1054.554400,1372.162800,1057.992700,1371.576300,1053.897800,1374.461700,1056.356700,1369.133200,1052.059400,...,1050.925200,1368.17260,1052.122900,1377.685500,1056.445800,1371.140900,1056.937600,1370.141100,1051.179900,1367.897000
