In [103]:
import numpy as np, pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

In [104]:
# These are all of the files you are given
df_tr = pd.read_csv("train.csv")

In [105]:
# Over every single 
def polyline_to_trip_duration(polyline):
  return max(polyline.count("[") - 2, 0) * 15

df_tr["LEN"] = df_tr["POLYLINE"].apply(polyline_to_trip_duration)

In [147]:
from datetime import datetime
def parse_time(x):
  # We are using python's builtin datetime library
  # https://docs.python.org/3/library/datetime.html#datetime.date.fromtimestamp

  # Each x is essentially a 1 row, 1 column pandas Series
  dt = datetime.fromtimestamp(x["TIMESTAMP"])
  return dt.year, dt.month, dt.day, dt.hour, dt.weekday()

# Because we are assigning multiple values at a time, we need to "expand" our computed (year, month, day, hour, weekday) tuples on 
# the column axis, or axis 1
# https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.apply.html
df_tr[["YR", "MON", "DAY", "HR", "WK"]] = df_tr[["TIMESTAMP"]].apply(parse_time, axis=1, result_type="expand")

In [9]:
# cab_data = pd.read_csv("metaData_taxistandsID_name_GPSlocation.csv")
# cab_dict = {}
# for i in range(1, 64):
#     cab_dict[i] = (cab_data[cab_data["ID"] == i].values[0][2], cab_data[cab_data["ID"] == i].values[0][3])
# lat_mean = np.mean([cab_dict[i][0] for i in range(1, 64)])
# lat_std = np.std([cab_dict[i][0] for i in range(1, 64)])
# long_mean = np.mean([cab_dict[i][1] for i in range(1, 64)])
# long_std = np.std([cab_dict[i][1] for i in range(1, 64)])
# for i in range(1, 64):
#     old_lat = cab_dict[i][0]
#     old_long = cab_dict[i][1]
#     new_lat = (old_lat - lat_mean) / lat_std
#     new_long = (old_long - long_mean) / long_std
#     cab_dict[i] = (new_lat + 3, new_long + 3) # push up the z-normalized values so we can use 0 as a placeholder for "null"

In [10]:
# cab_data = pd.read_csv("metaData_taxistandsID_name_GPSlocation.csv")
# cab_dict = {}
# # for i in range(1, 64):
#     cab_dict[i] = (cab_data[cab_data["ID"] == i].values[0][2], cab_data[cab_data["ID"] == i].values[0][3])

In [12]:
# def get_lat_long(x):
#     i = x["ORIGIN_STAND"]
# #     if (i != i): return 0, 0 # placeholder for null values
#     i = int(i)
#     return cab_dict[i][0], cab_dict[i][1]

# df_tr[["LAT", "LONG"]] = df_tr[["ORIGIN_STAND"]].apply(get_lat_long, axis=1, result_type="expand")

In [13]:
# entry = df_tr.iloc[103]
torch.tensor(entry["LEN"]).to(device)

NameError: name 'device' is not defined

In [None]:
# print(len(df_tr))

In [108]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [109]:
class TaxiDataset(Dataset):
    def __init__(self, dataframe, transform=None, target_transform=None):
        self.dataframe = dataframe
        self.transform = None
        self.target_transform = None
    def __len__(self):
        return len(self.dataframe)
    def __getitem__(self, idx):
        entry = self.dataframe.iloc[idx]
        time = torch.tensor([entry["LEN"]]).to(torch.float32).to(device)
        if (entry["ORIGIN_STAND"] != entry["ORIGIN_STAND"]): # if ORIGIN_STAND is NaN
            # idea to do one-hot encoding comes from https://towardsdatascience.com/deep-neural-networks-for-regression-problems-81321897ca33
            origin_stand = [0 for _ in range(63)]
        else:
            origin_stand = F.one_hot(torch.tensor(int(entry["ORIGIN_STAND"]) - 1), num_classes=63).tolist()
        feature_tuple = (entry["YR"], entry["MON"], entry["WK"], entry["DAY"], entry["HR"], *origin_stand)
        feature_tensor = torch.tensor(feature_tuple).to(torch.float32).to(device)
        return feature_tensor, time

In [110]:
# credit to https://stackoverflow.com/questions/54730276/how-to-randomly-split-a-dataframe-into-several-smaller-dataframes
shuffled = df_tr.sample(frac=1)
result = np.array_split(shuffled, 5)

In [111]:
train_df = pd.concat(result[:-1])
outlier_threshold = 3
mean, std = train_df["LEN"].mean(), train_df["LEN"].std()
train_df = train_df[train_df["LEN"] < mean + outlier_threshold * std]
test_df = result[-1]
train_set = TaxiDataset(train_df)
test_set = TaxiDataset(test_df)

In [112]:
from torch.utils.data import DataLoader

In [244]:
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=True)

Features are going to be:
- Year
- Month
- Day
- Hr
- Week
- One-hot encoding of taxi stand

In [256]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layer1 = nn.Linear(68, 75)
        self.layer2 = nn.Linear(75, 75)
        self.layer3 = nn.Linear(75, 75)
        self.layer4 = nn.Linear(75, 75)
#         self.layer5 = nn.Linear(256, 256)
#         self.layer6 = nn.Linear(256, 256)
#         self.layer7 = nn.Linear(256, 256)
        self.layer8 = nn.Linear(75, 1)
        self.dropout = nn.Dropout(p=0.3)
        self.batchnorm1 = nn.BatchNorm1d(75)
        self.batchnorm2 = nn.BatchNorm1d(75)
        self.batchnorm3 = nn.BatchNorm1d(75) 
        self.batchnorm4 = nn.BatchNorm1d(75) 
#         self.batchnorm5 = nn.BatchNorm1d(256)
#         self.batchnorm6 = nn.BatchNorm1d(256)
#         self.batchnorm7 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.batchnorm1(self.relu(self.layer1(x)))
        x = self.batchnorm2(self.relu(self.dropout((self.layer2(x)))))
        x = self.batchnorm3(self.relu(self.layer3(x)))
        x = self.batchnorm4(self.relu(self.dropout(self.layer4(x))))
#         x = self.batchnorm5(self.relu(self.layer5(x)))
#         x = self.batchnorm6(self.relu(self.dropout(self.layer6(x))))
#         x = self.batchnorm7(self.relu(self.layer7(x)))
        x = self.relu(self.layer8(x))
        return x

In [236]:
# i tried many designs, but next time i'm gonna use https://towardsdatascience.com/deep-neural-networks-for-regression-problems-81321897ca33
# for inspiration, just make each layer 256 (and use lots of them)

In [257]:
model = Net().to(device)

In [258]:
# code taken from https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html
def train_loop(dataloader, model, loss_fn, optimizer):
    model.train()
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 1000 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    model.eval()

def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss = 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()

    test_loss /= num_batches
    print(f"Avg loss: {test_loss:>8f} \n")

In [None]:
# code taken from https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html
learning_rate = 1e-5
batch_size = 64

loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9) # TODO: try tweaking momentum
epochs = 5

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_loader, model, loss_fn, optimizer)
    test_loop(test_loader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 582456.250000  [   64/1354391]
loss: 118335.023438  [64064/1354391]
loss: 154138.312500  [128064/1354391]
loss: 214907.640625  [192064/1354391]
loss: 126040.765625  [256064/1354391]
loss: 209021.609375  [320064/1354391]
loss: 236529.218750  [384064/1354391]
loss: 119143.304688  [448064/1354391]
loss: 138834.359375  [512064/1354391]
loss: 107427.914062  [576064/1354391]
loss: 202971.468750  [640064/1354391]
loss: 146651.437500  [704064/1354391]
loss: 165495.203125  [768064/1354391]
loss: 140445.468750  [832064/1354391]
loss: 135398.125000  [896064/1354391]
loss: 110275.265625  [960064/1354391]
loss: 201624.375000  [1024064/1354391]
loss: 250858.171875  [1088064/1354391]
loss: 206823.656250  [1152064/1354391]
loss: 212402.937500  [1216064/1354391]
loss: 94231.609375  [1280064/1354391]
loss: 175632.406250  [1344064/1354391]
Avg loss: 445541.851726 

Epoch 2
-------------------------------
loss: 126525.742188  [   64/1354391]
loss: 143270.75000

In [223]:
torch.save(model.state_dict(), 'model_weights_3.pth')

In [180]:
model.state_dict()

OrderedDict([('layer1.weight',
              tensor([[-6.0009e+00, -1.0272e-01, -1.4309e-01,  ..., -2.9638e-02,
                       -1.0018e-01, -8.4046e-02],
                      [-8.2913e-02,  8.4076e-02, -6.4861e-03,  ..., -1.1879e-01,
                        9.0569e-02,  6.0004e-02],
                      [ 8.4488e+00, -9.5632e-02,  1.7476e-01,  ...,  4.0618e-02,
                       -6.6144e-03, -7.0138e-02],
                      ...,
                      [-6.5091e-01, -3.4877e-02, -1.2291e-01,  ...,  8.3111e-02,
                        1.0635e-01, -1.1597e-01],
                      [-6.8323e-01, -1.1419e-01,  8.9300e-02,  ..., -1.0473e-01,
                        2.8171e-02,  3.3060e-02],
                      [-5.9991e-01,  8.9902e-02, -7.7266e-02,  ..., -1.0984e-01,
                        1.1263e-03,  1.0224e-01]], device='cuda:0')),
             ('layer1.bias',
              tensor([-5.1394e-03, -8.7806e-02, -3.2915e-02,  2.3216e-02, -5.7221e-02,
                    

In [224]:
for x, y in test_loader:
    print(model(x))

tensor([[709.7649],
        [701.7513],
        [714.7644],
        [708.8948],
        [647.3586],
        [642.0899],
        [659.2987],
        [649.5165],
        [720.4346],
        [644.0352],
        [698.8878],
        [715.2222],
        [676.2192],
        [713.7262],
        [719.2209],
        [720.9607],
        [723.8368],
        [703.9745],
        [702.2076],
        [722.4414],
        [722.4331],
        [712.5272],
        [706.4847],
        [725.3876],
        [704.5151],
        [644.6284],
        [655.2440],
        [714.2769],
        [644.9001],
        [646.3492],
        [714.6024],
        [707.3200],
        [702.7755],
        [715.0449],
        [647.2408],
        [713.2944],
        [712.8351],
        [718.6614],
        [715.4042],
        [715.4598],
        [671.9119],
        [713.2944],
        [730.3537],
        [704.0322],
        [644.0861],
        [648.7922],
        [707.9645],
        [710.7996],
        [728.5804],
        [652.7434],


tensor([[719.0234],
        [651.8455],
        [647.7273],
        [722.7735],
        [680.4702],
        [729.8893],
        [712.9273],
        [710.9506],
        [709.1403],
        [665.8143],
        [650.9576],
        [703.3046],
        [724.8683],
        [710.0390],
        [687.2653],
        [648.8015],
        [650.5892],
        [717.2235],
        [718.8282],
        [718.0242],
        [720.3637],
        [668.3231],
        [703.1558],
        [646.2755],
        [679.8284],
        [696.1437],
        [717.3530],
        [726.4556],
        [661.5309],
        [661.6179],
        [644.2765],
        [719.3159],
        [705.9010],
        [645.7632],
        [709.8542],
        [648.4883],
        [682.7766],
        [704.3840],
        [717.9229],
        [720.9132],
        [711.7587],
        [658.6305],
        [723.1679],
        [714.6504],
        [653.2051],
        [718.2143],
        [652.3988],
        [647.8443],
        [718.6461],
        [715.2949],


tensor([[673.9623],
        [676.2834],
        [715.6962],
        [647.3835],
        [661.0580],
        [709.1508],
        [705.9334],
        [642.9057],
        [696.0711],
        [702.6757],
        [654.1226],
        [683.9793],
        [706.9826],
        [645.4001],
        [724.1534],
        [700.6540],
        [644.4546],
        [712.2341],
        [711.7920],
        [720.0056],
        [718.2488],
        [659.9018],
        [662.0830],
        [652.7932],
        [707.9558],
        [651.3689],
        [706.8025],
        [704.7007],
        [724.9948],
        [701.6818],
        [707.9283],
        [679.4816],
        [724.4171],
        [712.4689],
        [713.2979],
        [721.0132],
        [729.8853],
        [706.6652],
        [669.1924],
        [710.7276],
        [711.8571],
        [706.2205],
        [718.4659],
        [644.2852],
        [701.3256],
        [704.8810],
        [706.3167],
        [704.1176],
        [719.6100],
        [728.3258],


tensor([[683.0710],
        [717.4042],
        [715.1134],
        [645.2672],
        [701.9400],
        [700.9382],
        [708.6130],
        [688.9644],
        [702.4941],
        [677.3546],
        [712.0088],
        [644.8430],
        [647.2936],
        [643.2911],
        [646.5510],
        [719.3448],
        [707.5072],
        [729.7245],
        [687.3479],
        [702.9677],
        [648.1700],
        [646.3289],
        [722.1058],
        [656.0929],
        [663.4479],
        [645.3192],
        [714.0710],
        [720.7222],
        [641.5226],
        [697.9026],
        [718.6234],
        [648.8745],
        [704.9589],
        [717.2803],
        [653.2051],
        [650.1227],
        [723.5601],
        [666.2139],
        [663.8320],
        [711.4570],
        [725.7637],
        [644.6064],
        [721.3156],
        [727.3908],
        [715.1497],
        [720.0713],
        [725.0358],
        [721.1857],
        [666.0775],
        [712.5632],


tensor([[652.4556],
        [644.0687],
        [644.7250],
        [708.5101],
        [705.5233],
        [725.1271],
        [709.3048],
        [707.5704],
        [718.6833],
        [643.7630],
        [704.3319],
        [723.9095],
        [710.2613],
        [652.3998],
        [645.7454],
        [645.7632],
        [644.6847],
        [713.7493],
        [648.3651],
        [712.5023],
        [647.9863],
        [645.9587],
        [663.1756],
        [652.2532],
        [648.2300],
        [648.4325],
        [647.0894],
        [724.7474],
        [648.4085],
        [718.1801],
        [648.3187],
        [722.3976],
        [647.7414],
        [661.2903],
        [649.8483],
        [704.8068],
        [647.6060],
        [711.3210],
        [648.0063],
        [716.2585],
        [695.5588],
        [709.8646],
        [725.1746],
        [684.8671],
        [726.8647],
        [666.2614],
        [653.9222],
        [704.6019],
        [725.5719],
        [702.5627],


tensor([[706.1772],
        [667.9232],
        [645.5237],
        [720.7152],
        [704.8138],
        [656.2307],
        [705.0109],
        [721.6010],
        [690.2886],
        [656.5363],
        [645.7531],
        [706.8077],
        [716.4158],
        [718.4407],
        [644.7516],
        [711.3620],
        [713.0937],
        [713.2678],
        [678.4930],
        [715.4594],
        [649.5339],
        [647.2559],
        [709.9215],
        [708.1555],
        [722.4495],
        [650.4390],
        [704.3638],
        [664.3889],
        [700.5917],
        [642.0231],
        [707.9661],
        [641.7922],
        [659.0319],
        [717.0035],
        [653.2051],
        [710.1408],
        [645.4694],
        [704.6852],
        [683.7446],
        [718.4814],
        [710.3254],
        [673.2482],
        [703.5153],
        [709.1198],
        [710.2465],
        [717.2166],
        [711.0344],
        [706.4532],
        [648.7148],
        [705.4575],


tensor([[648.7954],
        [645.9044],
        [711.9837],
        [704.1163],
        [706.5038],
        [680.3016],
        [717.9446],
        [706.0434],
        [654.1643],
        [708.9347],
        [704.9658],
        [720.9621],
        [645.4294],
        [655.5463],
        [645.9667],
        [715.9429],
        [657.2977],
        [648.8985],
        [697.4977],
        [657.0228],
        [713.1533],
        [712.2637],
        [686.9894],
        [672.7267],
        [653.2249],
        [643.3557],
        [717.3815],
        [650.0607],
        [715.1770],
        [644.5526],
        [670.6744],
        [644.5767],
        [653.6609],
        [721.0680],
        [709.4108],
        [705.0616],
        [723.5944],
        [647.4017],
        [712.4081],
        [669.4845],
        [714.1922],
        [676.2368],
        [703.6453],
        [649.9618],
        [698.4299],
        [718.8057],
        [665.2957],
        [652.2608],
        [710.3459],
        [704.7819],


tensor([[645.2636],
        [718.9797],
        [729.8253],
        [710.9010],
        [716.7164],
        [703.0028],
        [703.9176],
        [704.7528],
        [723.7070],
        [714.9468],
        [650.5728],
        [714.9945],
        [706.2205],
        [698.5190],
        [695.6166],
        [648.1009],
        [717.0858],
        [722.4331],
        [700.5917],
        [651.0654],
        [646.4887],
        [649.8494],
        [694.9972],
        [718.1785],
        [683.1605],
        [706.8123],
        [713.1735],
        [709.2203],
        [703.5554],
        [708.8991],
        [668.0123],
        [720.9245],
        [677.1267],
        [664.0900],
        [711.4344],
        [699.6240],
        [709.0833],
        [656.9469],
        [720.8599],
        [713.5526],
        [642.6005],
        [647.2397],
        [645.3737],
        [658.7516],
        [724.2550],
        [704.5851],
        [645.3839],
        [728.6764],
        [707.3474],
        [705.6903],


tensor([[709.7938],
        [720.7070],
        [717.6817],
        [648.0112],
        [678.6965],
        [719.5606],
        [667.0795],
        [712.4072],
        [718.2642],
        [648.1088],
        [705.7969],
        [647.8328],
        [721.0130],
        [642.8270],
        [712.8926],
        [644.1362],
        [723.8458],
        [705.1624],
        [645.2837],
        [683.0700],
        [669.2070],
        [644.3145],
        [725.8939],
        [714.7207],
        [714.9968],
        [702.4273],
        [723.8727],
        [721.2911],
        [704.7637],
        [657.3281],
        [678.1988],
        [647.9086],
        [714.2633],
        [709.1622],
        [727.6982],
        [719.0053],
        [703.8021],
        [708.5528],
        [714.1917],
        [643.7856],
        [661.2903],
        [650.1121],
        [660.0056],
        [710.5139],
        [722.0408],
        [673.8549],
        [710.7435],
        [643.2343],
        [648.8942],
        [650.3870],


tensor([[648.2312],
        [659.9413],
        [705.7661],
        [644.2455],
        [649.3237],
        [659.2686],
        [644.8375],
        [644.5786],
        [693.1516],
        [708.8840],
        [677.5745],
        [646.3875],
        [643.8329],
        [717.6338],
        [711.0074],
        [708.8950],
        [722.0587],
        [648.0665],
        [669.9898],
        [679.2647],
        [644.6105],
        [713.2390],
        [714.4578],
        [712.6741],
        [707.8724],
        [688.1486],
        [725.0733],
        [703.2388],
        [704.5195],
        [714.5632],
        [715.4204],
        [720.6199],
        [654.8384],
        [655.9053],
        [707.4112],
        [714.6530],
        [640.8380],
        [722.3483],
        [711.0026],
        [643.8802],
        [644.4031],
        [644.7871],
        [722.9721],
        [645.1510],
        [715.4014],
        [682.7206],
        [646.7667],
        [717.8799],
        [711.5583],
        [723.0093],


tensor([[690.3129],
        [673.7137],
        [713.4441],
        [654.1005],
        [703.9933],
        [714.0271],
        [714.1711],
        [650.7403],
        [717.9932],
        [715.7847],
        [649.8337],
        [710.2686],
        [711.2218],
        [654.1667],
        [701.5366],
        [647.8676],
        [711.7062],
        [703.7253],
        [706.7252],
        [724.1648],
        [667.4592],
        [644.4630],
        [713.0853],
        [711.4410],
        [708.3212],
        [706.6500],
        [715.6586],
        [642.2759],
        [721.3564],
        [662.1260],
        [650.0884],
        [692.9330],
        [666.1967],
        [647.0221],
        [680.6764],
        [720.2557],
        [725.4635],
        [699.5182],
        [710.6034],
        [645.0201],
        [717.8771],
        [729.6547],
        [677.2529],
        [698.5719],
        [714.9703],
        [721.8732],
        [648.8394],
        [646.1778],
        [668.4183],
        [666.9546],


tensor([[700.7949],
        [696.2048],
        [650.0172],
        [706.5159],
        [686.3626],
        [649.4229],
        [650.1687],
        [645.3586],
        [646.1373],
        [695.2551],
        [681.5609],
        [669.9334],
        [702.1096],
        [709.3909],
        [721.9242],
        [711.2540],
        [649.7887],
        [716.0922],
        [721.0923],
        [650.7102],
        [643.9423],
        [717.4741],
        [676.7483],
        [707.4365],
        [647.7049],
        [722.9721],
        [708.6212],
        [664.2363],
        [643.5314],
        [707.6769],
        [653.9869],
        [713.6638],
        [662.1436],
        [695.3411],
        [725.1027],
        [680.0447],
        [712.8373],
        [696.1110],
        [704.5640],
        [695.3908],
        [675.8901],
        [707.3708],
        [720.5657],
        [710.4327],
        [661.3803],
        [704.1067],
        [710.3063],
        [706.6500],
        [712.7034],
        [718.6057],


tensor([[714.6504],
        [721.9446],
        [706.4573],
        [654.4932],
        [672.9108],
        [648.0992],
        [715.1885],
        [704.4340],
        [647.9086],
        [641.2928],
        [694.5987],
        [649.6024],
        [719.6420],
        [654.5058],
        [719.0901],
        [649.2039],
        [661.0170],
        [709.3519],
        [730.7819],
        [671.3393],
        [729.0718],
        [648.7245],
        [676.0576],
        [681.3879],
        [650.2757],
        [677.5681],
        [719.8604],
        [665.9076],
        [715.0396],
        [715.1651],
        [721.0891],
        [673.0573],
        [712.8293],
        [711.3101],
        [647.9289],
        [726.6646],
        [701.9036],
        [682.0677],
        [676.4850],
        [657.7324],
        [686.8989],
        [724.9095],
        [650.4628],
        [711.7186],
        [721.2588],
        [714.2783],
        [650.8035],
        [646.3284],
        [664.9297],
        [700.1724],


tensor([[716.2579],
        [716.4125],
        [723.1230],
        [718.6613],
        [723.0284],
        [707.3331],
        [719.0575],
        [713.5577],
        [720.0532],
        [706.1689],
        [706.1718],
        [716.7884],
        [645.3856],
        [655.3543],
        [647.4313],
        [720.9958],
        [718.5521],
        [709.0001],
        [647.3676],
        [718.1809],
        [676.6125],
        [648.6197],
        [654.6028],
        [695.2520],
        [706.7308],
        [649.0503],
        [703.2861],
        [698.0696],
        [649.0496],
        [663.8213],
        [727.7578],
        [725.2098],
        [707.8801],
        [700.1506],
        [667.9589],
        [643.8548],
        [652.9521],
        [713.0367],
        [657.4885],
        [643.3122],
        [713.5433],
        [714.0242],
        [723.9854],
        [656.5363],
        [715.1224],
        [663.8719],
        [682.4649],
        [715.7756],
        [644.8669],
        [724.9948],


tensor([[655.1491],
        [699.9475],
        [722.3150],
        [704.8997],
        [710.7766],
        [653.2838],
        [702.0731],
        [713.9403],
        [702.9369],
        [714.4194],
        [660.2291],
        [657.8618],
        [716.0818],
        [695.2704],
        [723.4906],
        [667.0681],
        [712.8541],
        [701.8083],
        [713.0265],
        [647.5725],
        [691.4191],
        [665.3032],
        [709.4399],
        [717.2707],
        [645.7687],
        [708.6066],
        [720.9995],
        [657.1852],
        [713.2119],
        [668.7717],
        [711.7616],
        [713.2864],
        [723.9703],
        [651.5547],
        [724.1937],
        [707.9999],
        [713.1399],
        [654.2711],
        [714.9866],
        [706.9008],
        [723.4271],
        [699.2558],
        [700.9382],
        [710.4327],
        [725.7144],
        [644.6337],
        [649.6078],
        [711.9931],
        [649.7534],
        [703.3336],


tensor([[712.9273],
        [655.3042],
        [718.0786],
        [700.1604],
        [719.3182],
        [703.3480],
        [679.7468],
        [710.6228],
        [671.2315],
        [704.3151],
        [720.5771],
        [682.1979],
        [721.2371],
        [696.6534],
        [647.0841],
        [675.8594],
        [645.6723],
        [647.3309],
        [704.0494],
        [723.2401],
        [713.3998],
        [656.1151],
        [721.9924],
        [725.2343],
        [648.9866],
        [652.9787],
        [676.3740],
        [645.3194],
        [693.4725],
        [706.2552],
        [648.2535],
        [706.3464],
        [706.8859],
        [717.2707],
        [649.9719],
        [654.6038],
        [657.5220],
        [659.5727],
        [645.9439],
        [710.0876],
        [717.7042],
        [653.7927],
        [650.4049],
        [696.8604],
        [712.8391],
        [712.2665],
        [647.1434],
        [640.7975],
        [693.5296],
        [724.1817],


tensor([[718.7761],
        [642.8611],
        [719.3857],
        [721.6279],
        [653.2252],
        [716.3000],
        [726.3972],
        [674.2041],
        [716.5408],
        [649.4818],
        [723.0495],
        [723.2318],
        [713.2986],
        [711.4712],
        [650.7556],
        [702.8224],
        [651.0228],
        [701.2746],
        [709.7388],
        [648.2037],
        [715.5077],
        [704.5036],
        [715.9830],
        [680.1169],
        [715.1913],
        [716.1959],
        [700.7505],
        [705.2340],
        [712.9578],
        [657.3109],
        [716.2995],
        [717.9973],
        [715.8864],
        [722.0178],
        [703.0331],
        [647.7491],
        [649.1802],
        [707.0527],
        [645.2690],
        [709.2239],
        [723.4149],
        [705.2421],
        [649.5972],
        [719.9976],
        [706.2183],
        [728.8574],
        [647.5016],
        [648.8394],
        [715.2407],
        [649.1201],


KeyboardInterrupt: 

In [225]:
def get_input_tensor(entry):    
    if (entry["ORIGIN_STAND"] != entry["ORIGIN_STAND"]): # if ORIGIN_STAND is NaN
        # idea to do one-hot encoding comes from https://towardsdatascience.com/deep-neural-networks-for-regression-problems-81321897ca33
        origin_stand = [0 for _ in range(63)]
    else:
        origin_stand = F.one_hot(torch.tensor(int(entry["ORIGIN_STAND"]) - 1), num_classes=63).tolist()
    feature_tuple = (entry["YR"], entry["MON"], entry["WK"], entry["DAY"], entry["HR"], *origin_stand)
    feature_tensor = torch.tensor(feature_tuple).to(torch.float32).to(device)
    return feature_tensor

I ran into a problem where the model would just output the exact same value every time. I combatted this by adding more layers + batch normalization + reduced learning rate + less epochs.

https://datascience.stackexchange.com/questions/58220/how-to-deal-with-a-constant-value-as-an-output-from-neural-network

https://stackoverflow.com/questions/4493554/neural-network-always-produces-same-similar-outputs-for-any-input

https://stackoverflow.com/questions/39217567/keras-neural-network-outputs-same-result-for-every-input

In [226]:
dt = datetime.fromtimestamp(1408039037)
dt.year, dt.month, dt.day, dt.hour, dt.weekday()

(2014, 8, 14, 17, 3)

In [228]:
df_public_test = pd.read_csv("test_public.csv")
df_public_test[["YR", "MON", "DAY", "HR", "WK"]] = df_public_test[["TIMESTAMP"]].apply(parse_time, axis=1, result_type="expand")
pred_dict = {}
for row in df_public_test.iloc:
    pred_dict[row["TRIP_ID"]] = model(torch.unsqueeze(get_input_tensor(row), dim=0)).item()

In [229]:
def get_prediction(x):
    return pred_dict[x]

In [230]:
# Sample submission file that is given on kaggle
df_sample = pd.read_csv("sampleSubmission.csv")
df_sample["TRAVEL_TIME"] = df_sample["TRIP_ID"].apply(get_prediction)
df_sample.to_csv("my_pred.csv", index=None)