In [1]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import json
import numpy as onp
import jax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from timecast.learners import AR
from timecast.learners._ar import _ar_predict, _ar_batch_window
from timecast.utils.numpy import ecdf
from timecast.utils.losses import MeanSquareError
import torch
import matplotlib

plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm



In [2]:
def load_cfg(cfg_path):
    cfg = json.load(open(cfg_path, "r"))
    cfg["camels_root"] = Path(cfg["camels_root"])
    cfg["run_dir"] = Path(cfg["run_dir"])
    cfg.update(GLOBAL_SETTINGS)
    return cfg

In [3]:
import json
from pathlib import Path

from ealstm.main import GLOBAL_SETTINGS
from ealstm.main import get_basin_list
from ealstm.main import load_attributes
from ealstm.papercode.datasets import CamelsTXT

class FloodData():
    def __init__(self, cfg_path):
        self.cfg =load_cfg(cfg_path)
        self.basins = get_basin_list()
        self.db_path = os.path.join(self.cfg["run_dir"], "attributes.db")
        self.attributes = load_attributes(db_path=self.db_path,
                                          basins=self.basins,
                                          drop_lat_lon=True)
        
    def generator(self, is_train=False, with_attributes=True):
        for basin in self.basins:
            ds_test = CamelsTXT(camels_root=self.cfg["camels_root"],
                                basin=basin,
                                dates=[GLOBAL_SETTINGS["val_start"], GLOBAL_SETTINGS["val_end"]],
                                is_train=is_train,
                                seq_length=self.cfg["seq_length"],
                                with_attributes=True,
                                attribute_means=self.attributes.mean(),
                                attribute_stds=self.attributes.std(),
                                concat_static=self.cfg["concat_static"],
                                db_path=self.db_path,
                                reshape=False,
                                torchify=False
                               )
            X = np.concatenate((ds_test.x, np.tile(np.array(ds_test.attributes), (ds_test.x.shape[0], 1))), axis=1)
            yield X, ds_test.y, basin

In [4]:
import jax.numpy as np
from timecast.utils.losses.core import Loss

class BatchedMeanSquareError(Loss):
    def __init__(self):
        pass

    def compute(self, y_pred: np.ndarray, y_true: np.ndarray):
        return np.mean(np.mean((y_pred - y_true) ** 2, axis=tuple(range(1, y_true.ndim))))

In [5]:
import jax.numpy as np
import jax

from timecast.learners import BaseLearner

def batch_window(X, window_size, offset=0):
    num_windows = X.shape[0] - window_size + 1
    return np.swapaxes(np.stack([np.roll(X, shift=-(i + offset), axis=0) for i in range(window_size)]), 0, 1)[:num_windows]

class ARStateless(BaseLearner):
    def __init__(self, input_dim: int, output_dim: int, window_size: int, optimizer=None, loss=None):
        self._input_dim = input_dim
        self._output_dim = output_dim
        self._window_size = window_size
        self._optimizer = optimizer or SGD()
        self._loss = loss or BatchedMeanSquareError()
        
        W = np.zeros((window_size * input_dim + 1, output_dim))
        self._params = {"W": W}
        
        def _predict(params, x):
            squeeze = x.ndim == 2
            if squeeze:
                x = x[np.newaxis, :]
            shape = x.shape
            
            X = np.hstack((np.ones((shape[0], 1)), x.reshape(shape[0], -1))).reshape(shape[0], -1, 1)
            X = np.tensordot(X, params["W"], axes=(1, 0))
            return X.reshape(shape[0], -1)
        
        self._predict_jit = jax.jit(lambda params, X: _predict(params, X))
        self._grad = jax.jit(jax.grad(lambda params, X, y: self._loss.compute(_predict(params, X), y)))
        self._value_and_grad = jax.jit(jax.value_and_grad(lambda params, X, y: self._loss.compute(_predict(params, X), y)))
        
        def _predict_and_update(params, xy):
            value = self._predict_jit(params, xy[0])
            if value.shape[0] > 1:
                gradients = jax.vmap(self._grad, in_axes=({"W": None}, 0, 0), out_axes=0)(params, xy[0], xy[1])
                gradients = {key:val.mean(axis=0) for key, val in gradients.items()}
            else:
                gradients = self._grad(params, xy[0], xy[1])
            params = self._optimizer.update(params, gradients)
            return params, value
        
        self._predict_and_update_jit = jax.jit(lambda params, xy: _predict_and_update(params, xy))

    def predict(self, X):
        return jax.vmap(self._predict_jit, in_axes=(None, 0))(self._params, X).reshape(-1, 1)

    def update(self, X, y):
        gradients = jax.vmap(self._grad, in_axes=({"W": None}, 0, 0), out_axes=0)(self._params, X, y)
        gradients["W"] = gradients["W"].mean(axis=0)
        gradients["b"] = gradients["b"].mean(axis=0)
        self._params = self._optimizer.update(self._params, gradients)
        
    def predict_and_update(self, X, y, batch_size=1):
        last_batch_size = X.shape[0] % batch_size
        trim_size = X.shape[0] - last_batch_size
        trimmed_X = X[:trim_size].reshape((-1, batch_size) + X.shape[1:])
        trimmed_y = y[:trim_size].reshape((-1, batch_size) + y.shape[1:])
        self._params, value = jax.lax.scan(self._predict_and_update_jit, self._params, (trimmed_X, trimmed_y))
        value = value.reshape((value.shape[0] * value.shape[1],) + value.shape[2:])
        
        if last_batch_size > 0:
            last_X = X[trim_size:][np.newaxis, :]
            last_y = y[trim_size:][np.newaxis, :]
            self._params, trim_value = jax.lax.scan(self._predict_and_update_jit, self._params, (last_X, last_y))
            value = np.vstack((value, trim_value.squeeze(axis=0)))
        return value.reshape(-1, 1)

In [6]:
from ealstm.gaip import FloodLSTM
from ealstm.gaip import FloodData
from ealstm.gaip.utils import MSE, NSE

from timecast.optim import SGD
from timecast.learners import AR

cfg_path = "../data/models/runs/run_2006_0032_seed444/cfg.json"
flood_data = FloodData(cfg_path)
ea_data = pickle.load(open('../data/models/runs/run_2006_0032_seed444/lstm_seed444.p', "rb"))

LR_AR = 1e-5
AR_INPUT_DIM=32
AR_OUTPUT_DIM=1

In [7]:
# %%timeit
# %%prun -q -D out.pstats
results = {}
mses = []
nses = []
for X, _, basin in tqdm.tqdm(flood_data.generator(), total=len(flood_data.basins)):
    sgd = SGD(learning_rate=LR_AR, online=False)
    ar = ARStateless(input_dim=AR_INPUT_DIM, output_dim=AR_OUTPUT_DIM, window_size=flood_data.cfg["seq_length"], optimizer=sgd)
    X = batch_window(X, flood_data.cfg["seq_length"])
    Y = np.array(ea_data[basin].qobs).reshape(-1, 1)
    
    Y_lstm = np.array(ea_data[basin].qsim).reshape(-1, 1)
    Y_target = Y - Y_lstm
    
    Y_ar = ar.predict_and_update(X, Y_target, batch_size=2)
    
    Y_hat = Y_lstm + Y_ar
    
    mse = ((Y - Y_hat) ** 2).mean()
    nse = 1 - ((Y - Y_hat) ** 2).sum() / ((Y - Y.mean()) ** 2).sum()
    results[basin] = {
        "mse": mse,
        "nse": nse,
        "count": X.shape[0],
        "avg_mse": np.mean(np.array(mses)),
        "avg_nse": np.mean(np.array(nses))
    }
    mses.append(mse)
    nses.append(nse)
    print(basin, mse, nse, np.mean(np.array(mses)), np.mean(np.array(nses)))


HBox(children=(FloatProgress(value=0.0, max=531.0), HTML(value='')))

01022500 0.91817844 0.8135103 0.91817844 0.8135103
01031500 1.1610899 0.876401 1.0396342 0.8449557
01047000 1.4554787 0.8636198 1.178249 0.85117704
01052500 2.25135 0.82809263 1.4465243 0.84540594
01054200 7.9875073 0.72898364 2.754721 0.82212144
01055000 4.912552 0.7224574 3.1143596 0.80551076
01057000 1.0984551 0.8504866 2.8263733 0.8119359
01073000 0.98797137 0.84993523 2.596573 0.8166858
01078000 0.7965708 0.8673223 2.396573 0.8223121
01123000 0.9518823 0.8066434 2.2521038 0.82074517
01134500 1.8486514 0.775094 2.2154264 0.8165951
01137500 2.5455236 0.7496748 2.2429345 0.8110184
01139000 0.88057077 0.7527673 2.1381373 0.80653757
01139800 0.91736186 0.7484417 2.050939 0.80238783
01142500 1.2183942 0.7532305 1.9954361 0.79911065
01144000 0.8705375 0.8229096 1.9251299 0.8005981
01162500 0.94441104 0.8142789 1.8674405 0.8014028
01169000 2.4088166 0.73227906 1.8975168 0.7975626
01170100 1.9442805 0.75946313 1.899978 0.7955574
01181000 2.0418077 0.777333 1.9070696 0.79464614
01187300 2.8

02479155 6.105443 0.67484665 2.0514762 0.7233357
02479300 1.858087 0.78287244 2.0503252 0.72369003
02479560 1.5419898 0.82879996 2.0473173 0.724312
02481000 6.539118 0.7366353 2.0737395 0.7243845
02481510 2.4975028 0.79704845 2.0762177 0.72480947
04015330 2.2948647 0.6554415 2.0774887 0.7244062
04024430 0.5902369 0.73824763 2.068892 0.72448623
04027000 0.3252176 0.8648409 2.0588708 0.72529286
04040500 0.4508688 0.7925653 2.0496824 0.72567725
04043050 0.9730037 0.7804566 2.0435648 0.72598845
04045500 0.046247125 0.9466863 2.0322804 0.7272354
04057510 0.072422564 0.9044666 2.02127 0.728231
04057800 0.57518107 0.73845774 2.0131915 0.7282881
04059500 0.15532552 0.85852087 2.00287 0.7290117
04063700 0.0810927 0.87856424 1.9922525 0.72983783
04074950 0.049473174 0.81272423 1.9815778 0.73029333
04105700 0.028265666 0.8303251 1.9709039 0.73084
04115265 0.13058229 0.53210604 1.9609022 0.72976
04122200 0.034918524 0.8245848 1.9504915 0.7302725
04122500 0.031076608 0.8242738 1.9401721 0.73077786


06910800 1.7515635 0.82504857 2.0017238 0.7134119
06911900 1.9241946 0.835909 2.0014906 0.71377975
06917000 3.40823 0.7082039 2.0057025 0.713763
06918460 1.2706416 0.7937612 2.003508 0.7140019
06919500 1.3326459 0.83678913 2.0015118 0.71436733
06921070 1.9669566 0.82096803 2.0014095 0.7146837
06921200 5.256485 0.6755648 2.0110397 0.71456796
07057500 0.99419546 0.66049814 2.0080402 0.7144084
07060710 1.6049944 0.6526916 2.0068548 0.7142269
07066000 1.8733165 0.6177937 2.006463 0.713944
07083000 0.22945657 0.9499468 2.0012672 0.7146341
07142300 0.017893119 0.23794585 1.9954846 0.7132444
07145700 1.4306543 0.7353616 1.9938427 0.7133087
07167500 9.765283 0.44794083 2.0163684 0.71253955
07180500 1.9782257 0.7088028 2.016258 0.71252877
07184000 8.034253 0.73896015 2.0336013 0.71260494
07195800 2.1866884 0.6551273 2.0340414 0.7124397
07196900 7.994441 0.58185387 2.0511198 0.7120656
07197000 2.0345454 0.75932705 2.0510724 0.7122006
07208500 0.06421145 0.63336694 2.0454118 0.711976
07261000 3.3

14305500 16.188831 0.8610291 2.7339344 0.716541
14306340 9.743691 0.85034746 2.7480104 0.71680975
14306500 3.8080924 0.9055227 2.7501345 0.7171879
14308990 1.9732285 0.7916229 2.7485807 0.71733665
14309500 3.6830068 0.8659608 2.750446 0.71763337
14316700 3.5967443 0.8699613 2.7521322 0.7179368
14325000 8.742933 0.8556009 2.764042 0.7182105
14362250 0.19256367 0.83333224 2.7589397 0.7184389
14400000 16.619587 0.90941125 2.7863867 0.7188171
10259000 0.42413095 0.35552603 2.781718 0.7180992
11124500 1.5339868 0.40205443 2.779257 0.7174758
11141280 0.7283734 0.5482255 2.77522 0.71714264
11143000 2.740168 0.89653647 2.7751513 0.7174951
11148900 5.620408 0.78302246 2.78073 0.71762353
11151300 0.13326111 0.6262143 2.7755492 0.71744466
11176400 1.0380373 0.57416385 2.7721555 0.7171648
11230500 0.45347446 0.94527876 2.7676358 0.7176094
11237500 2.3723853 0.89149606 2.7668667 0.7179477
11264500 1.0587503 0.9054266 2.7635503 0.7183118
11266500 1.2611852 0.8899219 2.7606387 0.7186444
11284400 2.40