In [1]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import json
import numpy as onp
import jax
import jax.numpy as jnp
import flax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import torch
import matplotlib
import timecast as tc

from mpl_toolkits import mplot3d


plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm

# Before parallelization

In [21]:
from ealstm.gaip.flood_data import FloodData

from timecast.learners._ar import _ar_gram
from timecast.learners._pcr import _compute_pca_projection

# Prepare training data
cfg_path = "../data/models/runs/run_2006_0032_seed444/cfg.json"
train_data = FloodData(cfg_path, is_train=True)

# Compute gram matrix for training data
XTX, XTY = _ar_gram(train_data.generator(), input_dim=32, output_dim=1, history_len=270)

# Compute PCA projection with k=10
k = 10
projection = _compute_pca_projection(XTX.matrix(normalize=True), k=k)



In [None]:
import pickle
from ealstm.gaip.flood_data import FloodData
from ealstm.gaip.utils import MSE, NSE

import jax.numpy as jnp

from timecast.learners import Sequential, Parallel, Index, PCR
from timecast import smap
from timecast.objectives import residual
from timecast.optim import GradientDescent, RMSProp

import tqdm.notebook as tqdm

cfg_path = "../data/models/runs/run_2006_0032_seed444/cfg.json"
ea_data = pickle.load(open("../data/models/runs/run_2006_0032_seed444/lstm_seed444.p", "rb"))
flood_data = FloodData(cfg_path)

lr = 1e-5

results = {}
mses = []
nses = []

for X, _, basin in tqdm.tqdm(flood_data.generator(), total=531):
    lstm = Index.partial(index=0)
    pcr = PCR.partial(projection=projection, history_len=270, history=X[:flood_data.cfg["seq_length"] - 1])
    pcr = Sequential.partial(learners=[Index.partial(index=1), pcr])
    model, state = Parallel.new(shape=(1, 32), learners=[lstm, pcr])

    optim_def = GradientDescent(learning_rate=lr)
    optimizer = optim_def.create(model)

    # NOTE: difference in indexing convention, so need to pad one row
    X_t = X[flood_data.cfg["seq_length"]-1:]
    Y_lstm = jnp.array(ea_data[basin].qsim)
    Y = jnp.array(ea_data[basin].qobs).reshape(-1, 1)

    Y_hat, optimizer, state = smap((Y_lstm, X_t), Y, optimizer, state=state, objective=residual)
    model = optimizer.target

    mse = MSE(Y, Y_hat)
    nse = NSE(Y, Y_hat)
    mses.append(mse)
    nses.append(nse)

    results[basin] = {
        "mse": mse,
        "nse": nse,
        "count": X.shape[0],
        "avg_mse": jnp.mean(jnp.array(mses)),
        "avg_nse": jnp.mean(jnp.array(nses))
    }
    break
print({"basin": basin, "k": k, "mse": jnp.mean(jnp.array(mses))})

# After parallelization

In [22]:
import pickle

basins = pickle.load(open("../data/basins.p", "rb"))

In [52]:
@tc.experiment("k,projection", [(k, projection)])
@tc.experiment("basin", [(basin,) for basin in basins])
def runner(basin, k, projection, lr=1e-5):
    import pickle
    from ealstm.gaip.utils import MSE

    import jax.numpy as jnp

    from timecast.learners import Sequential, Parallel, Index, PCR
    from timecast import smap
    from timecast.objectives import residual
    from timecast.optim import GradientDescent

    import tqdm.notebook as tqdm
    
    X = pickle.load(open("../data/camels/{}.p".format(basin), "rb"))
    Y = pickle.load(open("../data/ealstm/{}.p".format(basin), "rb"))
    
    history_len = 270
    
    lstm = Index.partial(index=0)
    pcr = PCR.partial(projection=projection, history_len=history_len, history=X[:history_len - 1])
    pcr = Sequential.partial(learners=[Index.partial(index=1), pcr])
    model, state = Parallel.new(shape=(1, 32), learners=[lstm, pcr])
    
    optim_def = GradientDescent(learning_rate=lr)
    optimizer = optim_def.create(model)

    # NOTE: difference in indexing convention, so need to pad one row
    X_t = X[history_len - 1:]
    Y_lstm = jnp.array(Y.qsim)
    Y = jnp.array(Y.qobs).reshape(-1, 1)

    Y_hat, optimizer, state = smap((Y_lstm, X_t), Y, optimizer, state=state, objective=residual)

    return {"basin": basin, "k": k, "mse": MSE(Y, Y_hat)}

In [53]:
runner.run(processes=10, tqdm=tqdm)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




[{'basin': '01022500', 'k': 10, 'mse': array(0.8243627, dtype=float32)},
 {'basin': '01031500', 'k': 10, 'mse': array(1.0610197, dtype=float32)},
 {'basin': '01047000', 'k': 10, 'mse': array(1.308005, dtype=float32)},
 {'basin': '01052500', 'k': 10, 'mse': array(2.0327582, dtype=float32)},
 {'basin': '01054200', 'k': 10, 'mse': array(7.8569345, dtype=float32)},
 {'basin': '01055000', 'k': 10, 'mse': array(4.757536, dtype=float32)},
 {'basin': '01057000', 'k': 10, 'mse': array(0.93694293, dtype=float32)},
 {'basin': '01073000', 'k': 10, 'mse': array(0.9335359, dtype=float32)},
 {'basin': '01078000', 'k': 10, 'mse': array(0.69459224, dtype=float32)},
 {'basin': '01123000', 'k': 10, 'mse': array(0.90753907, dtype=float32)},
 {'basin': '01134500', 'k': 10, 'mse': array(1.7401067, dtype=float32)},
 {'basin': '01137500', 'k': 10, 'mse': array(2.4031172, dtype=float32)},
 {'basin': '01139000', 'k': 10, 'mse': array(0.8189094, dtype=float32)},
 {'basin': '01139800', 'k': 10, 'mse': array(0.854