In [1]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import json
import numpy as np
import jax
import jax.numpy as jnp
import flax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib
import timecast as tc

from mpl_toolkits import mplot3d

plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm



In [151]:
import pickle

import timecast as tc
import pandas as pd
import matplotlib.pyplot as plt

import flax
import jax.numpy as jnp
import numpy as np

from timecast.learners import AR
from timecast.utils.ar import historify, compute_gram

In [200]:
data = jnp.array(pd.read_csv("../data/wind/original/MS_winds.dat", names=list(range(57))))
pickle.dump(data, open("../data/wind/original/MS_winds.pkl", "wb"))

In [153]:
# Columns = 57 stations
# Rows = wind speed readings (m/s)
data

DeviceArray([[5.0963, 2.0564, 3.0399, ..., 3.0399, 3.5763, 2.5481],
             [5.0963, 1.5199, 2.5481, ..., 2.5481, 3.5763, 2.5481],
             [5.588 , 1.5199, 2.0564, ..., 2.5481, 3.5763, 1.5199],
             ...,
             [4.6045, 4.0681, 5.0963, ..., 4.6045, 0.    , 3.0399],
             [7.1526, 6.1244, 4.6045, ..., 4.0681, 0.    , 4.0681],
             [7.1526, 3.5763, 3.0399, ..., 4.0681, 0.    , 4.6045]],            dtype=float32)

In [154]:
# Normalization
# NOTE: This is a bug; they claim this normalizes from 0 to 1, but it doesn't
# NOTE: Their variable also refer to min and max as mean and std, respectively so...
data_min = data.min()
data_max = data.max()
data = (data - data_min) / data_max

In [155]:
models = pickle.load(open("../data/wind/original/models.pkl", "rb"))["models"]

In [182]:
# Mostly from https://github.com/amirstar/Deep-Forecast/blob/4dcdf66f8ae3070ab706b30a6e3cf888f36e0536/multiLSTM.py#L210
def predict(X, models):
    X = X.reshape(X.shape[0], history_len, num_stations)
    results = np.zeros_like(np.zeros((X.shape[0], num_stations)))
    
    for ind in range(len(X)):
        modelInd = ind % 6
        if modelInd == 0:
            testInputRaw = X[ind]
            testInputShape = testInputRaw.shape
            testInput = np.reshape(testInputRaw, [1, testInputShape[0], testInputShape[1]])
        else:
            testInputRaw = np.vstack((testInputRaw, results[ind-1]))
            testInput = np.delete(testInputRaw, 0, axis=0)
            testInputShape = testInput.shape
            testInput = np.reshape(testInput, [1, testInputShape[0], testInputShape[1]])
    
        pred = models[modelInd].predict(testInput)
        results[ind] = pred

    return jnp.array(results)

In [183]:
num_train = 6000
num_test = 361

history_len = 12
num_stations = 57

In [184]:
# 12..5999
train_true = data[history_len:num_train]
# 0..11, 1..12, ..., 5987..5998
train_data = historify(data, history_len=history_len, num_histories=train_true.shape[0])

# 6012..8386
test_true = data[num_train + history_len:]
# 6000..6011, ..., 8374..8385
test_data = historify(data, history_len=history_len, num_histories=test_true.shape[0], offset=num_train)

In [185]:
# 6012..8386
test_pred = predict(test_data, models)

In [186]:
data.shape

(8387, 57)

In [187]:
# Metric: mean absolute error
jnp.absolute((test_true - test_pred) * data_max + data_min).mean(axis=0).mean()

DeviceArray(1.3113904, dtype=float32)

In [188]:
# 12..5999
train_pred = predict(train_data, models)

In [189]:
train_pred.shape

(5988, 57)

In [190]:
# 1..5998
train_pred = jnp.vstack((jnp.zeros((history_len - 1, num_stations)), train_pred))

In [191]:
print(test_pred.shape, test_data.shape)

(2375, 57) (2375, 684)


In [192]:
print(train_pred.shape, train_data.shape)

(5999, 57) (5988, 684)


In [193]:
ars, states = [None] * num_stations, [None] * num_stations
for station in tqdm.tqdm(range(num_stations)):
    ars[station], states[station] = AR.fit(
        data=[(data[:num_train - 1], train_pred[:, station], None)],
        input_dim=num_stations,
        output_dim=1,
        history=data[num_train : num_train + history_len],
        history_len=history_len
    )

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))




In [194]:
for station in tqdm.tqdm(range(num_stations)):
    pickle.dump(ars[station].params, open("../data/wind/ar/{}.pkl".format(station), "wb"))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))




In [195]:
for station in tqdm.tqdm(range(num_stations)):
    pickle.dump(test_pred[:, station], open("../data/wind/base/{}.pkl".format(station), "wb"))

HBox(children=(FloatProgress(value=0.0, max=57.0), HTML(value='')))




In [260]:
@tc.experiment("station", range(num_stations))
@tc.experiment("lr", jnp.linspace(-7, -4, 13))
@tc.experiment("history_len", [4, 8, 12, 16])
def runner(station, history_len, lr=-5):
    import jax.numpy as jnp
    import pickle
    
    from timecast.learners import Sequential, Parallel, BlackBox, AR
    from timecast import tmap
    from timecast.objectives import residual
    from timecast.optim import GradientDescent
    
    num_train = 6000

#     history_len = 4
    num_stations = 57

    data = jnp.asarray(pickle.load(open("../data/wind/original/MS_winds.pkl", "rb")))
    data_min = data.min()
    data_max = data.max()
    data = (data - data_min) / data_max
    
    Y_lstm = jnp.asarray(pickle.load(open("../data/wind/base/{}.pkl".format(station), "rb")))
#     params = pickle.load(open("../data/wind/ar/{}.pkl".format(station), "rb"))
    
    lstm = BlackBox.partial(arr=Y_lstm)
    ar = AR.partial(
        output_dim=1,
        history=data[num_train : num_train + history_len - 1],
        history_len=history_len
    )
    model, state = Parallel.new(shape=(1, num_stations), learners=[lstm, ar])
#     model.params["AR"] = params

    optim_def = GradientDescent(learning_rate=(10 ** lr))
    optimizer = optim_def.create(model)

    X = data[num_train + history_len - 1:-1]
    Y = data[num_train + history_len:, station]

    Y_hat, optimizer, state = tmap(X, Y, optimizer, state=state, objective=residual)
    
    return {
        "station": station,
        "lr": lr,
        "history_len": history_len,
        "mae": jnp.absolute((Y - Y_hat) * data_max + data_min).mean()
    }

In [261]:
results = runner.run(processes=15, tqdm=tqdm)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [263]:
df = {}
for lr in jnp.linspace(-7, -4, 13):
    df[lr] = {}
    for history_len in [4, 8, 12, 16]:
        df[lr][history_len] = np.mean([result["mae"] for result in results if (result["lr"] == lr and result["history_len"] == history_len)])

In [264]:
df = pd.DataFrame.from_records(df)

In [265]:
df

Unnamed: 0,-7.00,-6.75,-6.50,-6.25,-6.00,-5.75,-5.50,-5.25,-5.00,-4.75,-4.50,-4.25,-4.00
4,1.980253,1.98024,1.980217,1.98018,1.980125,1.980063,1.980054,1.980314,1.981452,1.984768,1.99209,2.00389,2.017094
8,1.981006,1.980998,1.980986,1.980974,1.980983,1.981082,1.981488,1.98275,1.986006,1.992746,2.00334,2.015344,2.026235
12,1.981896,1.981895,1.981899,1.981921,1.982003,1.982271,1.983055,1.985095,1.989593,1.997494,2.008014,2.019119,2.031024
16,1.982891,1.982889,1.982891,1.982917,1.983024,1.983379,1.984402,1.986951,1.992145,2.00041,2.010522,2.021352,2.034401
