# Preprocessing for Argoverse Data

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, os.path 
import numpy as np
import pickle
from glob import glob
import pandas as pd
import matplotlib.pyplot as plt

Sampled at 10 Hz rate
- Train: 205,942 sequences
- Val: 3200 sequences
- Test: 36272 sequences

In [4]:
class ArgoverseDataset(Dataset):
    '''Dataset class for Argoverse'''
    def __init__(self, data_path: str, transform=None):
        super(ArgoverseDataset, self).__init__()
        self.data_path = data_path
        self.transform = transform
        self.pkl_list = glob(os.path.join(self.data_path, '*'))
        self.pkl_list.sort()

    def __len__(self):
        return len(self.pkl_list)
    
    def __getitem__(self, idx):
        pkl_path = self.pkl_list[idx]
        
        with open(pkl_path, "rb") as file:
            data = pickle.load(file)
        
        if self.transform:
            data = self.transform(data)
        
        return data

In [5]:
# initialize datasets
train_data = ArgoverseDataset("data/new_train")
val_data = ArgoverseDataset("data/new_val_in")

In [6]:
print("TRAIN DATA SEQUENCES:", train_data.__len__())
print("VAL DATA SEQUENCES:", val_data.__len__())

TRAIN DATA SEQUENCES: 205942
VAL DATA SEQUENCES: 3200


## Data Processing

---

Use $(x, y)$ positions and $(x, y)$ velocities for every agent

**Normalized Positions**

- Every trajectory begins @ the same place

**Every trajectory begins at $(0, 0)$ for position & Velocity remains the same**

- This solves our problem of "direction"

In [None]:
for key, value in train_data.__getitem__(0).items():
    print(key, type(value))

In [None]:
# Display keys and values for a pickled object

for key, value in train_data.__getitem__(0).items():
    if key == "p_in" or key == "v_in" or key == "p_out" or key == "v_out":
        print(key, value.shape)

In [None]:
for key, value in train_data.__getitem__(0).items():
    if key == "p_in":
        for i in np.nditer(value):
            print(i)

In [None]:
from tqdm import tqdm

# Create Distribution

In [None]:
p_in_x = np.empty([0])
for d in tqdm(range(3200)):
    for key, value in train_data.__getitem__(d).items():
        if key == "p_in":
            for i in np.nditer(value):
                if value[i].all() != 0:
                    p_in_x = np.append(p_in_x, value[i][j][0])

fig = plt.hist(p_in_x, bins=10)

In [None]:
p_in_y = np.empty([0])
for d in tqdm(range(3200)):
    for key, value in train_data.__getitem__(d).items():
        if key == "p_in":
            for i in range(60):
                if value[i].all() != 0:
                    for j in range(19):
                        p_in_y = np.append(p_in_y, value[i][j][1])
                    
fig = plt.hist(p_in_y, bins=10)

In [None]:
v_in_x = np.empty([0])
for d in tqdm(range(3200)):
    for key, value in train_data.__getitem__(d).items():
        if key == "v_in":
            for i in range(60):
                if value[i].all() != 0:
                    for j in range(19):
                        v_in_x = np.append(v_in_x, value[i][j][0])

In [None]:
fig = plt.hist(v_in_x, bins=50)

In [None]:
v_in_y = np.empty([0])
for d in tqdm(range(3200)):
    for key, value in train_data.__getitem__(d).items():
        if key == "v_in":
            for i in range(60):
                if value[i].all() != 0:
                    for j in range(19):
                        v_in_y = np.append(v_in_y, value[i][j][1])

In [None]:
fig = plt.hist(v_in_y, bins=50)

In [None]:
p_out_x = np.empty([0])
for d in tqdm(range(3200)):
    for key, value in train_data.__getitem__(d).items():
        if key == "p_out":
            for i in range(60):
                if value[i].all() != 0:
                    for j in range(30):
                        p_out_x = np.append(p_out_x, value[i][j][0])

fig = plt.hist(p_out_x, bins=10)

In [None]:
p_out_y = np.empty([0])
for d in tqdm(range(3200)):
    for key, value in train_data.__getitem__(d).items():
        if key == "p_out":
            for i in range(60):
                if value[i].all() != 0:
                    for j in range(30):
                        p_out_y = np.append(p_out_y, value[i][j][1])

fig = plt.hist(p_out_y, bins=10)

In [None]:
v_out_x = np.empty([0])
for d in tqdm(range(3200)):
    for key, value in train_data.__getitem__(d).items():
        if key == "v_out":
            for i in range(60):
                if value[i].all() != 0:
                    for j in range(30):
                        v_out_x = np.append(v_out_x, value[i][j][0])

In [None]:
fig = plt.hist(v_out_x, bins=50)

In [None]:
v_out_y = np.empty([0])
for d in tqdm(range(3200)):
    for key, value in train_data.__getitem__(d).items():
        if key == "v_out":
            for i in range(60):
                if value[i].all() != 0:
                    for j in range(30):
                        v_out_y = np.append(v_out_y, value[i][j][1])

fig = plt.hist(v_out_y, bins=50)

In [None]:
fig = plt.hist(v_out_y, bins=50)

In [None]:
euclidean = np.empty([0])
for d in tqdm(range(3200)):
    for key, value in train_data.__getitem__(d).items():
        if key == "v_out":
            for i in range(60):
                if value[i].all() != 0:
                    for j in range(30):
                        v_out_y = np.append(v_out_y, value[i][j][1])

fig = plt.hist(v_out_y, bins=50)

# Compare Submissions via DataFrames

In [6]:
import pandas as pd

In [30]:
good_sub = pd.read_csv("submission1.csv")
curr_sub = pd.read_csv("submission12.csv")
full_sub = pd.read_csv("submission13.csv") # best submission on full dataset
trainval_sub = pd.read_csv("submission15.csv")
fulltrain = pd.read_csv("submission16.csv")
fullzero = pd.read_csv("submission17.csv")
fivelr = pd.read_csv("submission18.csv")
thirty = pd.read_csv("submission19.csv")
two = pd.read_csv("submission21.csv")

In [36]:
display(good_sub)

Unnamed: 0,ID,v1,v2,v3,v4,v5,v6,v7,v8,v9,...,v51,v52,v53,v54,v55,v56,v57,v58,v59,v60
0,10002,1711.430176,334.143524,1704.404053,328.905273,1708.612061,333.807648,1708.448853,336.607849,1710.478027,...,1738.787598,364.024048,1739.626587,364.808350,1740.437378,365.563416,1741.222656,366.291870,1741.984985,366.996277
1,10015,720.927551,1223.142700,717.426025,1223.441772,718.368286,1223.305176,719.239563,1223.624512,719.530396,...,719.890808,1223.278687,719.839478,1223.287476,719.795593,1223.295166,719.758789,1223.302124,719.728455,1223.308472
2,10019,569.576477,1238.076782,568.307922,1237.721680,568.218445,1237.782959,569.077576,1237.640991,569.584839,...,575.810181,1236.821045,575.851257,1236.820679,575.876160,1236.821655,575.888489,1236.823853,575.891357,1236.826904
3,10028,1689.902222,316.115631,1682.990479,306.500610,1686.288330,309.022888,1685.131958,312.121948,1686.547241,...,1708.481079,335.765656,1708.986572,336.290009,1709.457520,336.780396,1709.896606,337.239563,1710.306763,337.669861
4,1003,2121.974121,674.118225,2119.932129,669.976257,2118.793945,669.619995,2117.451660,668.338989,2115.815918,...,2087.592285,644.536560,2086.603027,643.720459,2085.669434,642.949097,2084.791748,642.223572,2083.970459,641.544861
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3195,9897,253.582031,804.451721,240.163620,800.650696,243.820496,805.212158,246.211700,805.805908,249.004776,...,265.395691,800.110291,265.635620,800.127563,265.804230,800.150452,265.908630,800.177429,265.957489,800.206787
3196,99,582.598694,1145.300293,579.368530,1145.138062,579.084717,1144.992676,579.837158,1144.249390,580.079346,...,581.317993,1128.047363,581.237549,1127.683472,581.162109,1127.378174,581.093384,1127.127075,581.032837,1126.924194
3197,9905,1756.488525,440.919891,1752.069580,433.045227,1752.697876,434.331085,1750.657104,436.683838,1750.500000,...,1751.608398,449.111023,1751.571411,449.327789,1751.535522,449.524078,1751.500854,449.701599,1751.467651,449.862244
3198,9910,570.892944,1283.924805,568.463867,1283.513428,568.201416,1283.361572,568.721436,1283.258789,568.746704,...,567.780762,1278.472412,567.709839,1278.362671,567.648132,1278.269775,567.595703,1278.192505,567.552368,1278.129395


In [37]:
display(curr_sub)

Unnamed: 0,ID,v1,v2,v3,v4,v5,v6,v7,v8,v9,...,v51,v52,v53,v54,v55,v56,v57,v58,v59,v60
0,10002,1712.565552,334.922913,1713.395508,335.808716,1714.357422,336.773590,1715.135498,337.561859,1715.944702,...,1733.844482,355.382538,1734.659790,356.119507,1735.452271,356.847473,1736.219482,357.554108,1736.932861,358.231323
1,10015,725.663208,1230.206665,725.930237,1229.985718,726.067261,1229.967285,726.041504,1229.988281,725.986328,...,726.435364,1227.164429,726.449463,1227.014648,726.470093,1226.859131,726.494202,1226.729736,726.457703,1226.633423
2,10019,573.697388,1244.223145,573.734802,1244.008789,573.896301,1243.776855,574.095520,1243.606689,574.267517,...,576.894958,1237.101440,577.021362,1236.753052,577.128540,1236.422363,577.236938,1236.156738,577.375122,1236.010376
3,10028,1690.002808,314.024597,1690.518555,314.460968,1691.150513,315.138184,1691.725708,315.743408,1692.313843,...,1706.048462,328.649994,1706.660400,329.178162,1707.268555,329.706116,1707.844971,330.206116,1708.400879,330.701569
4,1003,2123.318604,676.197876,2122.089844,675.037842,2120.738770,673.815002,2119.332520,672.685669,2118.104004,...,2094.202148,649.903381,2093.056641,648.864929,2091.949951,647.848511,2090.919678,646.899963,2089.990234,646.036560
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3195,9897,255.502472,805.716064,255.578796,805.245483,255.759033,804.718750,255.957520,804.297363,256.196167,...,260.576813,794.076965,260.746582,793.446106,260.918121,792.839172,261.117889,792.286072,261.333344,791.787231
3196,99,587.280334,1153.836670,587.281616,1153.175781,587.289490,1152.431885,587.337402,1151.776367,587.395020,...,587.914795,1137.415771,587.944641,1136.799805,587.986145,1136.220093,588.030823,1135.694336,588.082275,1135.269775
3197,9905,1755.765625,443.192963,1755.619263,443.376343,1755.487549,443.660767,1755.394043,443.931793,1755.316284,...,1752.897827,449.090424,1752.858032,449.210541,1752.838135,449.292877,1752.815674,449.317261,1752.797852,449.273224
3198,9910,574.902405,1287.998657,574.882202,1287.739380,574.783203,1287.446899,574.697693,1287.157471,574.635193,...,572.281067,1281.311768,572.082275,1280.940796,571.936279,1280.497681,571.881531,1279.978271,571.967834,1279.440918


In [40]:
display(full_sub)

Unnamed: 0,ID,v1,v2,v3,v4,v5,v6,v7,v8,v9,...,v51,v52,v53,v54,v55,v56,v57,v58,v59,v60
0,10002,1714.468994,336.564941,1715.505981,337.429321,1716.427002,338.324829,1717.402710,339.203125,1718.373657,...,1736.905518,357.823059,1737.725586,358.602692,1738.512207,359.362305,1739.293091,360.129944,1740.043213,360.874756
1,10015,724.592773,1230.037720,724.697205,1230.196167,724.804688,1230.227173,724.755188,1230.180908,724.701111,...,724.185730,1226.987793,724.096252,1226.911865,724.013367,1226.845947,723.944214,1226.791016,723.890991,1226.755371
2,10019,573.297119,1244.542358,573.433411,1244.372437,573.485779,1244.216431,573.611389,1244.145142,573.781738,...,576.296631,1238.967285,576.333313,1238.650391,576.352539,1238.342651,576.348511,1238.055908,576.331238,1237.812256
3,10028,1690.699707,315.181885,1691.547974,315.746338,1692.299805,316.427795,1693.000854,317.050537,1693.678467,...,1707.307007,330.236328,1707.969116,330.840393,1708.616577,331.445099,1709.213989,332.012421,1709.803589,332.592316
4,1003,2122.012939,676.655701,2120.637451,675.567444,2119.242188,674.565552,2117.931885,673.472839,2116.766602,...,2090.379395,649.389160,2089.114014,648.326111,2087.878418,647.276672,2086.710205,646.327942,2085.578857,645.400574
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3195,9897,255.481873,805.780212,255.782181,805.515198,255.979965,805.419556,256.263733,805.283447,256.477509,...,259.894897,793.734741,259.969452,792.922119,260.032593,792.079590,260.100433,791.231934,260.176239,790.406555
3196,99,586.758240,1154.168945,586.871643,1153.793579,586.912598,1153.399292,586.955505,1152.880981,587.042297,...,587.227173,1138.900024,587.242798,1138.268433,587.252747,1137.635254,587.262878,1137.037109,587.279053,1136.480835
3197,9905,1755.018311,443.903717,1754.779175,444.196045,1754.629639,444.581116,1754.523804,444.914307,1754.450562,...,1750.623169,450.667450,1750.534180,450.953827,1750.447876,451.210541,1750.379639,451.446198,1750.348999,451.658783
3198,9910,574.003601,1288.866455,573.976685,1288.832031,573.876282,1288.704956,573.729980,1288.529053,573.583374,...,568.223511,1282.977905,567.843811,1282.740356,567.446350,1282.522583,567.046814,1282.331787,566.700989,1282.104492


In [17]:
diff_sub = (good_sub - full_sub).abs()
diff_sub_means = diff_sub.mean(axis=1)
avg_diff = np.average(diff_sub_means.to_numpy())
print(avg_diff)


#max_sub = pd.DataFrame.max(diff_sub, axis=0)

4.083687288174864


In [10]:
diff_sub = (full_sub - full_sub).abs()
diff_sub_means = diff_sub.mean(axis=1)
avg_diff = np.average(diff_sub_means.to_numpy())
avg_diff

0.0

In [5]:
diff_sub = (full_sub - trainval_sub).abs()
diff_sub_means = diff_sub.mean(axis=1)
avg_diff = np.average(diff_sub_means.to_numpy())
avg_diff

0.9739181799966781

In [8]:
diff_sub = (full_sub - fulltrain).abs()
diff_sub_means = diff_sub.mean(axis=1)
avg_diff = np.average(diff_sub_means.to_numpy())
avg_diff

0.489501657720472

In [10]:
diff_sub = (full_sub - fullzero).abs()
diff_sub_means = diff_sub.mean(axis=1)
avg_diff = np.average(diff_sub_means.to_numpy())
avg_diff

0.8857038233710116

In [20]:
diff_sub = (full_sub - fivelr).abs()
diff_sub_means = diff_sub.mean(axis=1)
avg_diff = np.average(diff_sub_means.to_numpy())
print(avg_diff)

diff_sub = (good_sub - fivelr).abs()
diff_sub_means = diff_sub.mean(axis=1)
avg_diff = np.average(diff_sub_means.to_numpy())
print(avg_diff)

0.8673327775079696
4.41289933720573


In [24]:
diff_sub = (full_sub - thirty).abs()
diff_sub_means = diff_sub.mean(axis=1)
avg_diff = np.average(diff_sub_means.to_numpy())
print(avg_diff)

0.525007336100594


In [31]:
diff_sub = (thirty - two).abs()
diff_sub_means = diff_sub.mean(axis=1)
avg_diff = np.average(diff_sub_means.to_numpy())
print(avg_diff)

0.5296079926412615


In [11]:
display(full_sub)

Unnamed: 0,ID,v1,v2,v3,v4,v5,v6,v7,v8,v9,...,v51,v52,v53,v54,v55,v56,v57,v58,v59,v60
0,10002,1714.468994,336.564941,1715.505981,337.429321,1716.427002,338.324829,1717.402710,339.203125,1718.373657,...,1736.905518,357.823059,1737.725586,358.602692,1738.512207,359.362305,1739.293091,360.129944,1740.043213,360.874756
1,10015,724.592773,1230.037720,724.697205,1230.196167,724.804688,1230.227173,724.755188,1230.180908,724.701111,...,724.185730,1226.987793,724.096252,1226.911865,724.013367,1226.845947,723.944214,1226.791016,723.890991,1226.755371
2,10019,573.297119,1244.542358,573.433411,1244.372437,573.485779,1244.216431,573.611389,1244.145142,573.781738,...,576.296631,1238.967285,576.333313,1238.650391,576.352539,1238.342651,576.348511,1238.055908,576.331238,1237.812256
3,10028,1690.699707,315.181885,1691.547974,315.746338,1692.299805,316.427795,1693.000854,317.050537,1693.678467,...,1707.307007,330.236328,1707.969116,330.840393,1708.616577,331.445099,1709.213989,332.012421,1709.803589,332.592316
4,1003,2122.012939,676.655701,2120.637451,675.567444,2119.242188,674.565552,2117.931885,673.472839,2116.766602,...,2090.379395,649.389160,2089.114014,648.326111,2087.878418,647.276672,2086.710205,646.327942,2085.578857,645.400574
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3195,9897,255.481873,805.780212,255.782181,805.515198,255.979965,805.419556,256.263733,805.283447,256.477509,...,259.894897,793.734741,259.969452,792.922119,260.032593,792.079590,260.100433,791.231934,260.176239,790.406555
3196,99,586.758240,1154.168945,586.871643,1153.793579,586.912598,1153.399292,586.955505,1152.880981,587.042297,...,587.227173,1138.900024,587.242798,1138.268433,587.252747,1137.635254,587.262878,1137.037109,587.279053,1136.480835
3197,9905,1755.018311,443.903717,1754.779175,444.196045,1754.629639,444.581116,1754.523804,444.914307,1754.450562,...,1750.623169,450.667450,1750.534180,450.953827,1750.447876,451.210541,1750.379639,451.446198,1750.348999,451.658783
3198,9910,574.003601,1288.866455,573.976685,1288.832031,573.876282,1288.704956,573.729980,1288.529053,573.583374,...,568.223511,1282.977905,567.843811,1282.740356,567.446350,1282.522583,567.046814,1282.331787,566.700989,1282.104492


In [28]:
display(two)

Unnamed: 0,ID,v1,v2,v3,v4,v5,v6,v7,v8,v9,...,v51,v52,v53,v54,v55,v56,v57,v58,v59,v60
0,10002,1713.726196,336.522186,1714.818481,337.500580,1715.912964,338.451508,1716.861450,339.313171,1717.841553,...,1736.507812,358.410431,1737.380737,359.240173,1738.250732,360.075348,1739.123413,360.883057,1739.998291,361.704681
1,10015,724.291443,1229.796021,724.446777,1229.695068,724.635986,1229.645874,724.698181,1229.175903,724.918884,...,725.064026,1226.474976,725.068298,1226.250732,725.071106,1226.018311,725.092773,1225.817749,725.131287,1225.676514
2,10019,572.902954,1244.488892,572.977356,1244.401733,573.171570,1244.382080,573.352417,1244.174316,573.524048,...,576.718750,1237.493042,576.814331,1237.068359,576.913940,1236.635742,577.021240,1236.208618,577.142090,1235.786011
3,10028,1690.635498,314.911072,1691.258179,315.600342,1692.050415,316.324829,1692.751343,316.954742,1693.458130,...,1707.103394,330.296906,1707.759277,330.897308,1708.434326,331.514160,1709.131348,332.161713,1709.822388,332.801758
4,1003,2122.406738,676.573669,2121.112061,675.434204,2119.583008,674.525146,2118.267822,673.476868,2116.994385,...,2091.547119,649.957581,2090.371582,648.935303,2089.214600,647.923218,2088.141113,646.999207,2087.094482,646.126892
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3195,9897,254.974533,805.996460,255.168900,805.799683,255.367676,805.439758,255.538208,804.994324,255.713257,...,259.975708,794.140198,260.195343,793.654785,260.420807,793.150085,260.662659,792.643127,260.899811,792.122131
3196,99,586.982361,1154.164307,587.078308,1153.675049,587.033569,1153.007690,587.023010,1152.335815,587.054016,...,587.629089,1138.314697,587.661194,1137.593872,587.687134,1136.874268,587.723999,1136.143921,587.762817,1135.424683
3197,9905,1755.456543,443.589569,1755.220337,443.843994,1755.054443,444.193939,1754.946777,444.490662,1754.844116,...,1751.186768,450.144745,1751.026733,450.380157,1750.895264,450.630035,1750.804688,450.882812,1750.759155,451.151093
3198,9910,574.270874,1288.614380,574.305969,1288.441406,574.211182,1288.198730,574.077332,1287.947144,573.930664,...,570.781311,1282.422485,570.543640,1282.048706,570.329895,1281.637939,570.142334,1281.217163,569.964172,1280.842773


In [12]:
display(fullzero)

Unnamed: 0,ID,v1,v2,v3,v4,v5,v6,v7,v8,v9,...,v51,v52,v53,v54,v55,v56,v57,v58,v59,v60
0,10002,1713.376709,336.149719,1713.877808,336.814941,1714.785645,337.637543,1715.741943,338.457642,1716.724121,...,1734.944946,356.272644,1735.812012,357.062775,1736.660889,357.841370,1737.532227,358.623016,1738.371704,359.350525
1,10015,725.199951,1231.909302,725.380005,1231.402710,725.590515,1231.086182,725.659363,1230.838989,725.691406,...,725.005310,1226.673096,724.990295,1226.501831,724.966248,1226.320801,724.929504,1226.122803,724.887817,1225.911133
2,10019,573.091736,1244.743286,573.196350,1244.789551,573.345886,1244.506226,573.483582,1244.309326,573.481750,...,576.761108,1237.849487,576.908508,1237.489868,577.045105,1237.116333,577.165344,1236.742676,577.274414,1236.383667
3,10028,1689.653198,314.011749,1690.191284,314.516418,1690.704834,314.904266,1691.332031,315.509857,1692.049683,...,1706.667603,330.034973,1707.354126,330.653076,1708.049438,331.277557,1708.742188,331.902954,1709.431763,332.510742
4,1003,2123.457031,677.625977,2122.260742,676.586792,2120.917480,675.614502,2119.812988,674.610718,2118.498535,...,2090.332275,650.036377,2089.127197,649.040588,2087.897705,648.031372,2086.658936,647.027222,2085.519043,646.106323
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3195,9897,255.300385,806.155823,255.287292,805.925781,255.516495,805.430786,255.645233,805.005859,255.802551,...,258.521149,792.233154,258.548981,791.470947,258.553802,790.697937,258.534180,789.926392,258.509308,789.193726
3196,99,587.035400,1155.067261,586.975952,1154.371582,586.957397,1153.871216,586.958069,1153.358276,586.955872,...,587.269287,1139.553345,587.281494,1138.896362,587.297729,1138.239258,587.305115,1137.585449,587.308838,1136.974609
3197,9905,1755.890991,442.888123,1755.634033,443.451111,1755.459961,443.808014,1755.326050,444.042969,1755.283813,...,1750.535522,450.393616,1750.321411,450.608490,1750.095093,450.817719,1749.882202,451.022736,1749.720337,451.235352
3198,9910,574.476013,1289.000854,574.520874,1288.774658,574.500732,1288.467773,574.453674,1288.206909,574.406006,...,570.844971,1282.423340,570.614258,1282.082153,570.375671,1281.719727,570.132874,1281.347046,569.904907,1280.986328


In [14]:
import numpy as np

In [22]:
np.amax(max_sub.to_numpy())

44.18365478515625