In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from src.ESRNN import ESRNN
plt.style.use('ggplot')

In [2]:
# Plot
def plot_prediction(y, y_hat):
    n_y = len(y)
    n_yhat = len(y_hat)
    ds_y = np.array(range(n_y))
    ds_yhat = np.array(range(n_y, n_y+n_yhat))

    plt.plot(ds_y, y, label = 'y')
    plt.plot(ds_yhat, y_hat, label='y_hat')
    plt.legend(loc='upper left')
    plt.show()

In [3]:
# Original stock data
data = pd.read_csv('data/train.csv')
data['Date'] = data['Year'].astype(str)+'-'+data['Date'].astype(str)
data['Date'] = pd.to_datetime(data['Date'])
data = data[['Company','Year','Date','Close']]
data.head()

Unnamed: 0,Company,Year,Date,Close
0,abbv,2013,2013-01-04,28.81
1,abbv,2013,2013-01-07,28.869
2,abbv,2013,2013-01-08,28.242
3,abbv,2013,2013-01-09,28.399
4,abbv,2013,2013-01-10,28.481


In [4]:
# Clean data (model assumes this name columns)
data['unique_id'] = data['Company']+"_"+data['Year'].astype(str)
data = data.rename(columns={'Date':'ds', 'Close':'y'})
data['x'] = data['Year'].astype(str)
data.head()

Unnamed: 0,Company,Year,ds,y,unique_id,x
0,abbv,2013,2013-01-04,28.81,abbv_2013,2013
1,abbv,2013,2013-01-07,28.869,abbv_2013,2013
2,abbv,2013,2013-01-08,28.242,abbv_2013,2013
3,abbv,2013,2013-01-09,28.399,abbv_2013,2013
4,abbv,2013,2013-01-10,28.481,abbv_2013,2013


In [5]:
X_train = data[['unique_id','ds','x']]
y_train = data[['unique_id','ds','y']]
unique_ids = X_train['unique_id'].unique()[0:15]
X_train = X_train[X_train['unique_id'].isin(unique_ids)]
#X_train = X_train[X_train['unique_id']=='baf_2013']

In [6]:
esrnn = ESRNN(max_epochs=1, batch_size=1, learning_rate=1e-3, seasonality=30, input_size=30, output_size=30,
             level_variability_penalty=0)
esrnn.fit(X_train, y_train, random_seed=1)


0.073444135
[0]
levels.data.numpy() [[17.474148 17.496422 17.268116 17.241194 17.26199  17.350634 17.461117
  17.666306 18.043919 18.468882 18.917007 18.798508 19.195114 19.236118
  19.296148 19.091532 19.14944  19.050491 18.934982 19.04843  19.123354
  19.12559  19.10718  18.870306 18.726131 18.543716 18.337042 18.210314
  18.57965  19.041887 19.411966 19.680153 19.90435  19.830921 19.442032
  19.187565 18.971178 18.927822 19.152473 19.36417  19.209757 19.360315
  19.21468  19.194632 19.288618 19.226877 19.141668 19.271273 19.626307
  19.595015 19.656652 19.921383 19.844013 20.023367 20.180128 20.535263
  20.65854  20.88578  20.948263 20.819803 20.751911 21.07415  21.187126
  21.393951 21.792744 22.230833 22.37352  22.316557 21.647629 21.71042
  21.761856 21.8781   21.885523 22.534279 23.04438  23.091766 23.014595
  23.487232 23.41185  23.688854 23.61792  23.053913 23.162197 23.16766
  23.252728 22.787289 22.920713 23.170126 23.011375 23.237171 23.552536
  23.649384 24.145592 24.0343

0.03873451
[4]
levels.data.numpy() [[36.521034 36.841564 37.13738  37.25661  37.454533 37.47791  36.64444
  36.38074  36.54282  36.55606  36.48782  36.358986 36.225777 36.10564
  35.920216 36.14469  36.188667 35.73825  35.72639  35.972225 35.983864
  35.98826  35.909126 35.872074 35.824486 35.79217  35.89739  35.7827
  35.77262  35.88321  36.230015 36.22422  36.334793 36.4395   36.28862
  36.48455  36.816498 36.85472  36.628986 36.942585 37.235    37.39302
  37.53136  37.632282 37.852356 37.966362 38.5616   38.901943 38.851944
  39.090576 39.00174  38.892647 38.91351  38.74487  38.74659  38.790947
  38.76182  38.983665 38.99599  38.892918 38.681385 38.500584 38.362312
  38.384205 38.381725 38.381348 38.73039  38.582497 38.43023  38.34839
  38.25441  38.332924 38.12362  37.97603  38.031506 37.990635 38.219536
  38.655228 38.77659  39.070587 39.271217 39.55477  39.70788  39.841362
  40.03534  40.0051   39.68577  39.548096 39.44781  39.153767 39.290913
  39.571625 39.710995 39.222282 39.1

0.13478038
[8]
levels.data.numpy() [[5.470664  5.5200844 5.558563  5.586984  5.620669  5.648525  5.654172
  5.6767287 5.711635  5.7425203 5.7500653 5.7393227 5.6781826 5.675789
  5.7026725 5.6950397 5.6921577 5.6780825 5.6495123 5.649108  5.623358
  5.616204  5.62615   5.6299057 5.6417055 5.6491055 5.644122  5.6393337
  5.6611223 5.6016927 5.4631243 5.40476   5.4168735 5.4148893 5.3980355
  5.35695   5.408787  5.4600325 5.4221234 5.3236637 5.21718   5.196289
  5.2328687 5.256934  5.251905  5.283315  5.317546  5.3074007 5.2751665
  5.223576  5.2020073 5.1898317 5.1815324 5.1847205 5.1815863 5.2276554
  5.2379    5.234022  5.2267904 5.2567725 5.2793555 5.292239  5.2907276
  5.3087125 5.348896  5.4490194 5.443672  5.432435  5.473057  5.525215
  5.531785  5.48362   5.490815  5.486169  5.5081162 5.4809704 5.4847665
  5.5121956 5.5085573 5.5056143 5.484082  5.4626956 5.453928  5.457415
  5.45847   5.4478254 5.4405866 5.4355145 5.42843   5.4376535 5.4538803
  5.4240623 5.400078  5.396956  5.3

0.103774846
[12]
levels.data.numpy() [[6.787078  6.764803  6.8054733 6.801196  6.844887  6.8579836 6.8753867
  6.9080076 6.927497  6.942783  6.9357176 6.922102  6.898839  6.8957195
  6.8771753 6.8999996 6.982992  7.0128155 7.020299  7.0389814 7.046412
  7.080176  7.074424  7.063946  7.0781116 7.0947866 7.1097655 7.1218386
  7.156223  7.13145   7.143995  7.0896316 7.065448  7.067265  7.007369
  6.9713297 7.0046115 7.0075355 7.0042377 7.0028663 7.044286  7.0497875
  7.0553303 7.037852  7.0429    7.0522213 7.062625  7.0905204 7.1320343
  7.055275  6.836459  6.648138  6.63282   6.660945  6.710992  6.715755
  6.6711597 6.6123304 6.629792  6.688258  6.733765  6.793209  6.78081
  6.7936115 6.814193  6.8263416 6.8354917 6.8457823 6.8602047 6.85676
  6.862931  6.88371   6.90025   6.9134855 6.9395976 6.9251003 6.924285
  6.920699  6.9161882 6.933455  6.998513  7.005847  6.9941015 6.9649706
  6.9674807 6.986824  7.012393  7.02942   7.0301266 7.029102  7.036892
  7.0463095 7.0646005 7.070182  7.03

In [None]:
y_test = X_train.loc[X_train['unique_id'].isin(['abbv_2013'])]
y_hat = esrnn.predict(y_test)

In [None]:
plot_prediction(y_test['y'], y_hat['y_hat'])

In [None]:
X_train

In [None]:
esrnn.sort_key