In [1]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import scipy.stats as stats
from collections import defaultdict
from itertools import product
from sklearn.metrics import mean_absolute_error as mae
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import optim
from sklearn.preprocessing import StandardScaler

In [13]:
class MyModel(nn.Module):
    def __init__(self, input_feature, hidden_size, output_feature, num_layers=1):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(input_feature, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, num_layers=num_layers, dropout=0.2)
        ''' gru input is (N,L,H_in=H_hidden), output is (N,L,H_hidden), hidden is (num_layers, h_hidden)'''
        self.linear_out = nn.Linear(hidden_size, output_feature)
        self.hidden_size = hidden_size
        self.num_layers = num_layers
    
    def forward(self, input, hidden):
        ''' X is in the shape of (N,L,input_feature) '''
        output = F.relu(self.linear(input))
        output, hidden = self.gru(output, hidden)
        output = self.linear_out(F.relu(output))
        return output
    
    def initHidden(self, batch_size):
        return torch.zeros((self.num_layers, batch_size, self.hidden_size))

In [2]:
look_back = 72
batch_size = 512
linear_node = 32

In [3]:
dat = pd.read_csv('train.csv', index_col='row_id')

In [84]:
uniques[:5]

array(['00EB', '00NB', '00SB', '01EB', '01NB'], dtype=object)

In [4]:
def preprocess(dat):
    time_mapper = {}
    ii = 0
    for h in range(24):
        for mm in ['00','20','40']:
            hh = '{0:02d}'.format(h)
            time_mapper[hh+':'+mm] = ii
            ii += 1

    dat['unique'] = dat['x'].astype(str) + dat['y'].astype(str) + dat['direction']
    uniques = dat['unique'].unique()
    dat['day'] = pd.to_datetime(dat['time']).dt.weekday
    dat['time_stamp'] = dat['time'].apply(lambda x:time_mapper[x.split()[1][:5]])

    tmp = dat.groupby(['unique','day','time_stamp']).agg({'congestion':np.median})
    median_mapper = tmp.to_dict()['congestion']
    dat['median'] = dat.apply(lambda x: \
                              median_mapper[x['unique'],x['day'],x['time_stamp']], axis=1)
    dat['congestion-median'] = dat['congestion'] - dat['median']
    
    all_time = pd.DataFrame(pd.date_range('1991-04-01 00:00:00', '1991-09-30 11:40:00', freq='20Min'), columns=['time'])
    all_time['time'] = all_time['time'].astype(str)
    
    return uniques, median_mapper, time_mapper, all_time

In [5]:
uniques, median_mapper, time_mapper, all_time = preprocess(dat)

In [None]:
test_periods = [
    ['1991-09-16 12:00:00', '1991-09-16 24:00:00'],
    ['1991-09-23 12:00:00', '1991-09-23 24:00:00']]


In [6]:
def getseries(unique):
    df = dat.loc[dat['unique']==unique, ['time', 'congestion-median']]
    df = pd.merge(all_time, df, left_on='time', right_on='time', how='outer')
    df = df.set_index('time')
    df['congestion-median'] = df['congestion-median'].fillna(0)
    ss = StandardScaler()
    df['congestion-median-normalized'] = ss.fit_transform(df['congestion-median'].values.reshape(-1,1)).reshape(-1)
    return df, ss

In [8]:
def preprocess_test(dat):
    dat['unique'] = dat['x'].astype(str) + dat['y'].astype(str) + dat['direction']
    dat['day'] = pd.to_datetime(dat['time']).dt.weekday
    dat['time_stamp'] = dat['time'].apply(lambda x:time_mapper[x.split()[1][:5]])

    dat['median'] = dat.apply(lambda x: \
                              median_mapper[x['unique'],x['day'],x['time_stamp']], axis=1)


In [10]:
test = pd.read_csv('test.csv')
preprocess_test(test)
test

Unnamed: 0,row_id,time,x,y,direction,unique,day,time_stamp,median
0,848835,1991-09-30 12:00:00,0,0,EB,00EB,0,36,47.0
1,848836,1991-09-30 12:00:00,0,0,NB,00NB,0,36,35.0
2,848837,1991-09-30 12:00:00,0,0,SB,00SB,0,36,56.5
3,848838,1991-09-30 12:00:00,0,1,EB,01EB,0,36,22.0
4,848839,1991-09-30 12:00:00,0,1,NB,01NB,0,36,72.0
...,...,...,...,...,...,...,...,...,...
2335,851170,1991-09-30 23:40:00,2,3,NB,23NB,0,71,68.0
2336,851171,1991-09-30 23:40:00,2,3,NE,23NE,0,71,25.0
2337,851172,1991-09-30 23:40:00,2,3,SB,23SB,0,71,71.0
2338,851173,1991-09-30 23:40:00,2,3,SW,23SW,0,71,11.0


In [15]:
model = MyModel(1, linear_node, 1, num_layers=3)

In [85]:
test_uniques = {}

with torch.no_grad():
    for unique in uniques[0:3]:
                                
        print(f"doing {unique}")
        df, ss = getseries(unique)
        pfile = torch.load('model_'+unique+'.pickle')
        model.load_state_dict(pfile['model'])
        model.eval()
        
        X = torch.tensor(df['congestion-median-normalized'].values[-look_back:], 
                         dtype=torch.float32)
        X = X.reshape(1,-1,1)
        ys = np.zeros(36)
        for idx in range(36):
            h0 = model.initHidden(1)
            y = model.forward(X,h0)[0,0,-1].item()
            ys[idx] = y
            X = X.reshape(-1).numpy()
            X = np.hstack([X[1:],y])
            X = torch.tensor(X, dtype=torch.float32).reshape(1,-1,1)
        
        ss.inverse_transform(ys.reshape(1,-1)).reshape(-1)
        test_unique = test[test['unique']==unique].copy()
        test_unique['to_add'] = ys
        test_unique['congestion'] = test_unique['median'] + test_unique['to_add']
        
        test_uniques[unique] = test_unique

doing 00EB
doing 00NB
doing 00SB


In [87]:
test_uniques['00NB']

Unnamed: 0,row_id,time,x,y,direction,unique,day,time_stamp,median,to_add,congestion
0,848835,1991-09-30 12:00:00,0,0,EB,00EB,0,36,47.0,0.547127,47.547127
65,848900,1991-09-30 12:20:00,0,0,EB,00EB,0,37,42.5,0.038401,42.538401
130,848965,1991-09-30 12:40:00,0,0,EB,00EB,0,38,46.0,0.093753,46.093753
195,849030,1991-09-30 13:00:00,0,0,EB,00EB,0,39,50.0,-0.452418,49.547582
260,849095,1991-09-30 13:20:00,0,0,EB,00EB,0,40,43.0,-0.376944,42.623056
325,849160,1991-09-30 13:40:00,0,0,EB,00EB,0,41,56.0,0.293775,56.293775
390,849225,1991-09-30 14:00:00,0,0,EB,00EB,0,42,51.0,-0.013637,50.986363
455,849290,1991-09-30 14:20:00,0,0,EB,00EB,0,43,50.0,0.236289,50.236289
520,849355,1991-09-30 14:40:00,0,0,EB,00EB,0,44,53.0,-0.236929,52.763071
585,849420,1991-09-30 15:00:00,0,0,EB,00EB,0,45,55.0,-0.394138,54.605862
