In [1]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import json
import numpy as onp
import jax
import jax.numpy as jnp
import flax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import torch
import matplotlib
import timecast as tc

from mpl_toolkits import mplot3d


plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm



# AR: MSE=2.7122574

In [4]:
basins = pickle.load(open("../data/flood/meta.pkl", "rb"))["basins"]

In [13]:
@tc.experiment("basin", [(basin,) for basin in basins])
def runner(basin, lr=1e-5):
    import pickle
    from ealstm.gaip.utils import MSE

    import jax.numpy as jnp

    from timecast.learners import Sequential, Parallel, Take, AR
    from timecast import smap
    from timecast.objectives import residual
    from timecast.optim import GradientDescent

    import tqdm.notebook as tqdm
    
    X = pickle.load(open("../data/test/{}.p".format(basin), "rb"))
    Y = pickle.load(open("../data/ealstm/{}.p".format(basin), "rb"))
    
    history_len = 270
    
    lstm = Take.partial(index=0)
    pcr = AR.partial(history_len=history_len, history=X[:history_len - 1])
    pcr = Sequential.partial(learners=[Index.partial(index=1), pcr])
    model, state = Parallel.new(shape=(1, 32), learners=[lstm, pcr])
    
    optim_def = GradientDescent(learning_rate=lr)
    optimizer = optim_def.create(model)

    # NOTE: difference in indexing convention, so need to pad one row
    X_t = X[history_len - 1:]
    Y_lstm = jnp.array(Y.qsim)
    Y = jnp.array(Y.qobs).reshape(-1, 1)

    Y_hat, optimizer, state = smap((Y_lstm, X_t), Y, optimizer, state=state, objective=residual)

    return {"basin": basin, "mse": MSE(Y, Y_hat)}

NameError: name 'tc' is not defined

In [4]:
results = runner.run(processes=10, tqdm=tqdm)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [5]:
print("Average MSE: {}".format(jnp.average(jnp.array([result["mse"] for result in results]))))

Average MSE: 2.7122573852539062


# PCR

## Train

In [2]:
from timecast.learners._ar import _ar_gram
from timecast.learners._pcr import _compute_pca_projection

In [3]:
from ealstm.gaip.flood_data import FloodData

cfg_path = "../data/models/runs/run_2006_0032_seed444/cfg.json"
train_data = FloodData(cfg_path, is_train=True)
# basins = pickle.load(open("../data/basins.p", "rb"))
# def generator():
#     for basin in basins:
#         yield(pickle.load(open("../data/train/{}.p".format(basin), "rb")))
        
XTX, XTY = _ar_gram(train_data.generator(), input_dim=32, output_dim=1, history_len=270)
# pcr, state = PCR.fit(, input_dim=32, history_len=270, k=100)

In [4]:
projections = {}
for k in tqdm.tqdm([10, 50, 100, 500, 1000, 5000]):
    projections[k] = _compute_pca_projection(XTX.matrix(normalize=True), k)

HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))




## Test

In [6]:
basins = pickle.load(open("../data/basins.p", "rb"))

In [15]:
@tc.experiment("k,projection", projections.items())
@tc.experiment("basin", basins)
def runner(basin, k, projection, lr=1e-5):
    import pickle
    from ealstm.gaip.utils import MSE

    import jax.numpy as jnp

    from timecast.learners import Sequential, Parallel, Index, PCR
    from timecast import smap
    from timecast.objectives import residual
    from timecast.optim import GradientDescent

    import tqdm.notebook as tqdm
    
    X = pickle.load(open("../data/test/{}.p".format(basin), "rb"))
    Y = pickle.load(open("../data/ealstm/{}.p".format(basin), "rb"))
    
    history_len = 270
    
    lstm = Index.partial(index=0)
    pcr = PCR.partial(projection=projection, history_len=history_len, history=X[:history_len - 1])
    pcr = Sequential.partial(learners=[Index.partial(index=1), pcr])
    model, state = Parallel.new(shape=(1, 32), learners=[lstm, pcr])
    
    optim_def = GradientDescent(learning_rate=lr)
    optimizer = optim_def.create(model)

    # NOTE: difference in indexing convention, so need to pad one row
    X_t = X[history_len - 1:]
    Y_lstm = jnp.array(Y.qsim)
    Y = jnp.array(Y.qobs).reshape(-1, 1)

    Y_hat, optimizer, state = smap((Y_lstm, X_t), Y, optimizer, state=state, objective=residual)

    return {"basin": basin, "k": k, "mse": MSE(Y, Y_hat)}

In [16]:
results = runner.run(processes=10, tqdm=tqdm)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [19]:
for k in tqdm.tqdm([10, 50, 100, 500, 1000, 5000]):
    print("Average MSE (k={}): {}".format(k, jnp.average(jnp.array([result["mse"] for result in results if result["k"] == k]))))

HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))

Average MSE (k=10): 2.7545559406280518
Average MSE (k=50): 2.739213466644287
Average MSE (k=100): 2.7256996631622314
Average MSE (k=500): 2.712125062942505
Average MSE (k=1000): 2.7122278213500977
Average MSE (k=5000): 2.712256908416748



In [9]:
@tc.experiment("k,projection", [(500, projections[500])])
@tc.experiment("basin", basins)
@tc.experiment("lr", [10 ** -8, 10 ** -7.5, 10 ** -7, 10 ** -6.5, 10 ** -6, 10 ** -5.5, 10 ** -5, 10 ** -4.5, 10 ** -4, 10 ** -3.5])
def runner(basin, k, projection, lr=1e-5):
    import pickle
    from ealstm.gaip.utils import MSE

    import jax.numpy as jnp

    from timecast.learners import Sequential, Parallel, Index, PCR
    from timecast import smap
    from timecast.objectives import residual
    from timecast.optim import GradientDescent

    import tqdm.notebook as tqdm
    
    X = pickle.load(open("../data/test/{}.p".format(basin), "rb"))
    Y = pickle.load(open("../data/ealstm/{}.p".format(basin), "rb"))
    
    history_len = 270
    
    lstm = Index.partial(index=0)
    pcr = PCR.partial(projection=projection, history_len=history_len, history=X[:history_len - 1])
    pcr = Sequential.partial(learners=[Index.partial(index=1), pcr])
    model, state = Parallel.new(shape=(1, 32), learners=[lstm, pcr])
    
    optim_def = GradientDescent(learning_rate=lr)
    optimizer = optim_def.create(model)

    # NOTE: difference in indexing convention, so need to pad one row
    X_t = X[history_len - 1:]
    Y_lstm = jnp.array(Y.qsim)
    Y = jnp.array(Y.qobs).reshape(-1, 1)

    Y_hat, optimizer, state = smap((Y_lstm, X_t), Y, optimizer, state=state, objective=residual)

    return {"basin": basin, "k": k, "lr": lr, "mse": MSE(Y, Y_hat)}

In [10]:
results = runner.run(processes=15, tqdm=tqdm)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [13]:
for lr in tqdm.tqdm([10 ** -8, 10 ** -7.5, 10 ** -7, 10 ** -6.5, 10 ** -6, 10 ** -5.5, 10 ** -5, 10 ** -4.5, 10 ** -4, 10 ** -3.5]):
    print("Average MSE (lr={0:.10f}): {1:.2f}".format(lr, jnp.average(jnp.array([result["mse"] for result in results if result["lr"] == lr]))))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Average MSE (lr=0.0000000100): 3.13
Average MSE (lr=0.0000000316): 3.11
Average MSE (lr=0.0000001000): 3.09
Average MSE (lr=0.0000003162): 3.05
Average MSE (lr=0.0000010000): 2.96
Average MSE (lr=0.0000031623): 2.84
Average MSE (lr=0.0000100000): 2.71
Average MSE (lr=0.0000316228): nan
Average MSE (lr=0.0001000000): nan
Average MSE (lr=0.0003162278): nan



In [9]:
@tc.experiment("k,projection", [(10, projections[10]), (500, projections[500])])
@tc.experiment("basin", basins)
@tc.experiment("lr", jnp.linspace(-5, -4.5, 6))
def runner(basin, k, projection, lr):
    import pickle
    from ealstm.gaip.utils import MSE

    import jax.numpy as jnp

    from timecast.learners import Sequential, Parallel, Index, PCR
    from timecast import smap
    from timecast.objectives import residual
    from timecast.optim import GradientDescent

    import tqdm.notebook as tqdm
    
    X = pickle.load(open("../data/test/{}.p".format(basin), "rb"))
    Y = pickle.load(open("../data/ealstm/{}.p".format(basin), "rb"))
    
    history_len = 270
    
    lstm = Index.partial(index=0)
    pcr = PCR.partial(projection=projection, history_len=history_len, history=X[:history_len - 1])
    pcr = Sequential.partial(learners=[Index.partial(index=1), pcr])
    model, state = Parallel.new(shape=(1, 32), learners=[lstm, pcr])
    
    optim_def = GradientDescent(learning_rate=(10 ** lr))
    optimizer = optim_def.create(model)

    # NOTE: difference in indexing convention, so need to pad one row
    X_t = X[history_len - 1:]
    Y_lstm = jnp.array(Y.qsim)
    Y = jnp.array(Y.qobs).reshape(-1, 1)

    Y_hat, optimizer, state = smap((Y_lstm, X_t), Y, optimizer, state=state, objective=residual)

    return {"basin": basin, "k": k, "lr": lr, "mse": MSE(Y, Y_hat)}

In [10]:
results = runner.run(processes=15, tqdm=tqdm)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [7]:
import pickle

import jax.numpy as jnp

basin = basins[0]
lr = 1e-5
from timecast.learners import Sequential, Parallel, Index, AR
from timecast import smap
from timecast.objectives import residual
from timecast.optim import GradientDescent

import tqdm.notebook as tqdm

X = pickle.load(open("../data/flood/test/{}.pkl".format(basin), "rb"))
Y = pickle.load(open("../data/flood/base/{}.pkl".format(basin), "rb"))

history_len = 270

lstm = Index.partial(index=0)
pcr = AR.partial(history_len=history_len, history=X[:history_len - 1])
pcr = Sequential.partial(learners=[Index.partial(index=1), pcr])
model, state = Parallel.new(shape=(1, 32), learners=[lstm, pcr])

optim_def = GradientDescent(learning_rate=lr)
optimizer = optim_def.create(model)

# NOTE: difference in indexing convention, so need to pad one row
X_t = X[history_len - 1:]
Y_lstm = jnp.array(Y.qsim)
Y = jnp.array(Y.qobs).reshape(-1, 1)

Y_hat, optimizer, state = smap((Y_lstm, X_t), Y, optimizer, state=state, objective=residual)


In [12]:
model.params["Sequential_1"]["AR_1"]["linear"]["kernel"].shape

(1, 8640, 1)