In [63]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import json
import numpy as onp
import jax
import jax.numpy as jnp
import flax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import torch
import matplotlib
import timecast as tc

from mpl_toolkits import mplot3d


plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# AR: MSE=2.7122574

In [46]:
import pickle
from ealstm.gaip.flood_data import FloodData
from ealstm.gaip.utils import MSE, NSE
from flax import nn

import jax
import jax.numpy as jnp

from timecast.learners import Sequential, Parallel, AR, Index, PCR
from timecast import smap
from timecast.objectives import residual
from timecast.optim import GradientDescent, RMSProp

cfg_path = "../data/models/runs/run_2006_0032_seed444/cfg.json"
ea_data = pickle.load(open("../data/models/runs/run_2006_0032_seed444/lstm_seed444.p", "rb"))
flood_data = FloodData(cfg_path)

lr = 1e-5

results = {}
mses = []
nses = []

for X, y, basin in tqdm.tqdm(flood_data.generator(), total=531):
    lstm = Index.partial(index=0)
    ar = AR.partial(history_len=270, history=X[:flood_data.cfg["seq_length"] - 1])
    ar = Sequential.partial(learners=[Index.partial(index=1), ar])
    model, state = Parallel.new(shape=(1, 32), learners=[lstm, ar])

    optim_def = GradientDescent(learning_rate=lr)
    optimizer = optim_def.create(model)

    # NOTE: difference in indexing convention, so need to pad one row
    X_t = X[flood_data.cfg["seq_length"]-1:]
    Y_lstm = jnp.array(ea_data[basin].qsim)
    Y = jnp.array(ea_data[basin].qobs).reshape(-1, 1)

    Y_hat, optimizer, state = smap((Y_lstm, X_t), Y, optimizer, state=state, objective=residual)
    model = optimizer.target

    mse = MSE(Y, Y_hat)
    nse = NSE(Y, Y_hat)
    mses.append(mse)
    nses.append(nse)

    results[basin] = {
        "mse": mse,
        "nse": nse,
        "count": X.shape[0],
        "avg_mse": jnp.mean(jnp.array(mses)),
        "avg_nse": jnp.mean(jnp.array(nses))
    }
    
print("MSE: {}".format(jnp.mean(jnp.array(mses))))

HBox(children=(FloatProgress(value=0.0, max=531.0), HTML(value='')))


MSE: 2.7122573852539062


# PCR

## Train

In [30]:
from timecast.learners._ar import _ar_gram
from timecast.learners._pcr import _compute_pca_projection

In [31]:
XTX, XTY = _ar_gram(flood_data.generator(is_train=True), input_dim=32, output_dim=1, history_len=270)
# pcr, state = PCR.fit(, input_dim=32, history_len=270, k=100)

In [33]:
projections = {}
for k in tqdm.tqdm([10, 50, 100, 500, 1000, 5000]):
    projections[k] = _compute_pca_projection(XTX.matrix(normalize=True), k)

HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))




In [65]:
@tc.experiment("k, projection", projections.items())
def runner(k, projection):
    import pickle
    from ealstm.gaip.flood_data import FloodData
    from ealstm.gaip.utils import MSE, NSE
    from flax import nn

    import jax
    import jax.numpy as jnp

    from timecast.learners import Sequential, Parallel, AR, Index, PCR
    from timecast import smap
    from timecast.objectives import residual
    from timecast.optim import GradientDescent, RMSProp
    
    import tqdm.notebook as tqdm

    cfg_path = "../data/models/runs/run_2006_0032_seed444/cfg.json"
    ea_data = pickle.load(open("../data/models/runs/run_2006_0032_seed444/lstm_seed444.p", "rb"))
    flood_data = FloodData(cfg_path)

    lr = 1e-5

    results = {}
    mses = []
    nses = []

    for X, y, basin in tqdm.tqdm(flood_data.generator(), total=531):
        lstm = Index.partial(index=0)
        ar = PCR.partial(projection=projection, history_len=270, history=X[:flood_data.cfg["seq_length"] - 1])
        ar = Sequential.partial(learners=[Index.partial(index=1), ar])
        model, state = Parallel.new(shape=(1, 32), learners=[lstm, ar])

        optim_def = GradientDescent(learning_rate=lr)
        optimizer = optim_def.create(model)

        # NOTE: difference in indexing convention, so need to pad one row
        X_t = X[flood_data.cfg["seq_length"]-1:]
        Y_lstm = jnp.array(ea_data[basin].qsim)
        Y = jnp.array(ea_data[basin].qobs).reshape(-1, 1)

        Y_hat, optimizer, state = smap((Y_lstm, X_t), Y, optimizer, state=state, objective=residual)
        model = optimizer.target

        mse = MSE(Y, Y_hat)
        nse = NSE(Y, Y_hat)
        mses.append(mse)
        nses.append(nse)

        results[basin] = {
            "mse": mse,
            "nse": nse,
            "count": X.shape[0],
            "avg_mse": jnp.mean(jnp.array(mses)),
            "avg_nse": jnp.mean(jnp.array(nses))
        }
    return {"k": k, "mse": jnp.mean(jnp.array(mses))}

In [None]:
results = runner.run(processes=6, tqdm=tqdm)

HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))