In [1]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import json
import numpy as onp
import jax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from timecast.learners import AR
from timecast.learners._ar import _ar_predict, _ar_batch_window
from timecast.utils.numpy import ecdf
from timecast.utils.losses import MeanSquareError
import torch
import matplotlib

plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm



In [2]:
def load_cfg(cfg_path):
    cfg = json.load(open(cfg_path, "r"))
    cfg["camels_root"] = Path(cfg["camels_root"])
    cfg["run_dir"] = Path(cfg["run_dir"])
    cfg.update(GLOBAL_SETTINGS)
    return cfg

In [3]:
import json
from pathlib import Path

from ealstm.main import GLOBAL_SETTINGS
from ealstm.main import get_basin_list
from ealstm.main import load_attributes
from ealstm.papercode.datasets import CamelsTXT

class FloodData():
    def __init__(self, cfg_path):
        self.cfg =load_cfg(cfg_path)
        self.basins = get_basin_list()
        self.db_path = os.path.join(self.cfg["run_dir"], "attributes.db")
        self.attributes = load_attributes(db_path=self.db_path,
                                          basins=self.basins,
                                          drop_lat_lon=True)
        
    def generator(self, is_train=False, with_attributes=True):
        for basin in self.basins:
            ds_test = CamelsTXT(camels_root=self.cfg["camels_root"],
                                basin=basin,
                                dates=[GLOBAL_SETTINGS["val_start"], GLOBAL_SETTINGS["val_end"]],
                                is_train=is_train,
                                seq_length=self.cfg["seq_length"],
                                with_attributes=True,
                                attribute_means=self.attributes.mean(),
                                attribute_stds=self.attributes.std(),
                                concat_static=self.cfg["concat_static"],
                                db_path=self.db_path,
                                reshape=False,
                                torchify=False
                               )
            X = np.concatenate((ds_test.x, np.tile(np.array(ds_test.attributes), (ds_test.x.shape[0], 1))), axis=1)
            yield X, ds_test.y, basin

In [4]:
import jax.numpy as np
from timecast.utils.losses.core import Loss

class BatchedMeanSquareError(Loss):
    def __init__(self):
        pass

    def compute(self, y_pred: np.ndarray, y_true: np.ndarray):
        return np.mean(np.mean((y_pred - y_true) ** 2, axis=tuple(range(1, y_true.ndim))))

In [5]:
import jax.numpy as np
import jax

from timecast.learners import BaseLearner

def batch_window(X, window_size, offset=0):
    num_windows = X.shape[0] - window_size + 1
    return np.swapaxes(np.stack([np.roll(X, shift=-(i + offset), axis=0) for i in range(window_size)]), 0, 1)[:num_windows]

class ARStateless(BaseLearner):
    def __init__(self, input_dim: int, output_dim: int, window_size: int, optimizer=None, loss=None):
        self._input_dim = input_dim
        self._output_dim = output_dim
        self._window_size = window_size
        self._optimizer = optimizer or SGD()
        self._loss = loss or BatchedMeanSquareError()
        
        W = np.zeros((window_size * input_dim + 1, output_dim))
        self._params = {"W": W}
        
        def _predict(params, x):
            squeeze = x.ndim == 2
            if squeeze:
                x = x[np.newaxis, :]
            shape = x.shape
            
            X = np.hstack((np.ones((shape[0], 1)), x.reshape(shape[0], -1))).reshape(shape[0], -1, 1)
            X = np.tensordot(X, params["W"], axes=(1, 0))
            return X.reshape(shape[0], -1)
        
        self._predict_jit = jax.jit(lambda params, X: _predict(params, X))
        self._grad = jax.jit(jax.grad(lambda params, X, y: self._loss.compute(_predict(params, X), y)))
        self._value_and_grad = jax.jit(jax.value_and_grad(lambda params, X, y: self._loss.compute(_predict(params, X), y)))
        
        def _predict_and_update(params, xy):
            value = self._predict_jit(params, xy[0])
            if value.shape[0] > 1:
                gradients = jax.vmap(self._grad, in_axes=({"W": None}, 0, 0), out_axes=0)(params, xy[0], xy[1])
                gradients = {key:val.mean(axis=0) for key, val in gradients.items()}
            else:
                gradients = self._grad(params, xy[0], xy[1])
            params = self._optimizer.update(params, gradients)
            return params, value
        
        self._predict_and_update_jit = jax.jit(lambda params, xy: _predict_and_update(params, xy))

    def predict(self, X):
        return jax.vmap(self._predict_jit, in_axes=(None, 0))(self._params, X).reshape(-1, 1)

    def update(self, X, y):
        gradients = jax.vmap(self._grad, in_axes=({"W": None}, 0, 0), out_axes=0)(self._params, X, y)
        gradients["W"] = gradients["W"].mean(axis=0)
        gradients["b"] = gradients["b"].mean(axis=0)
        self._params = self._optimizer.update(self._params, gradients)
        
    def predict_and_update(self, X, y, batch_size=1):
        last_batch_size = X.shape[0] % batch_size
        trim_size = X.shape[0] - last_batch_size
        trimmed_X = X[:trim_size].reshape((-1, batch_size) + X.shape[1:])
        trimmed_y = y[:trim_size].reshape((-1, batch_size) + y.shape[1:])
        self._params, value = jax.lax.scan(self._predict_and_update_jit, self._params, (trimmed_X, trimmed_y))
        value = value.reshape((value.shape[0] * value.shape[1],) + value.shape[2:])
        
        if last_batch_size > 0:
            last_X = X[trim_size:][np.newaxis, :]
            last_y = y[trim_size:][np.newaxis, :]
            self._params, trim_value = jax.lax.scan(self._predict_and_update_jit, self._params, (last_X, last_y))
            value = np.vstack((value, trim_value.squeeze(axis=0)))
        return value.reshape(-1, 1)

In [8]:
from ealstm.gaip import FloodLSTM
from ealstm.gaip import FloodData
from ealstm.gaip.utils import MSE, NSE

from timecast.optim import SGD
from timecast.learners import AR

cfg_path = "../data/models/runs/run_2006_0032_seed444/cfg.json"
flood_data = FloodData(cfg_path)
ea_data = pickle.load(open('../data/models/runs/run_2006_0032_seed444/lstm_seed444.p', "rb"))

LR_AR = 1e-5
AR_INPUT_DIM=32
AR_OUTPUT_DIM=1

In [9]:
# %%timeit
# %%prun -q -D out.pstats
results = {}
mses = []
nses = []
for X, _, basin in tqdm.tqdm(flood_data.generator(), total=len(flood_data.basins)):
    sgd = SGD(learning_rate=LR_AR, online=False)
    ar = ARStateless(input_dim=AR_INPUT_DIM, output_dim=AR_OUTPUT_DIM, window_size=flood_data.cfg["seq_length"], optimizer=sgd)
    X = batch_window(X, flood_data.cfg["seq_length"])
    Y = np.array(ea_data[basin].qobs).reshape(-1, 1)
    
    Y_lstm = np.array(ea_data[basin].qsim).reshape(-1, 1)
    Y_target = Y - Y_lstm
    
    Y_ar = ar.predict_and_update(X, Y_target, batch_size=1)
    
    Y_hat = Y_lstm + Y_ar
    
    mse = ((Y - Y_hat) ** 2).mean()
    nse = 1 - ((Y - Y_hat) ** 2).sum() / ((Y - Y.mean()) ** 2).sum()
    results[basin] = {
        "mse": mse,
        "nse": nse,
        "count": X.shape[0],
        "avg_mse": np.mean(np.array(mses)),
        "avg_nse": np.mean(np.array(nses))
    }
    mses.append(mse)
    nses.append(nse)
    print(basin, mse, nse, np.mean(np.array(mses)), np.mean(np.array(nses)))


HBox(children=(FloatProgress(value=0.0, max=531.0), HTML(value='')))

01022500 0.7921798 0.8391017 0.7921798 0.8391017
01031500 1.034634 0.88986236 0.9134069 0.86448205
01047000 1.2811565 0.87995404 1.0359901 0.8696394
01052500 1.9988087 0.847376 1.2766948 0.8640735
01054200 7.819855 0.73467207 2.5853267 0.83819324
01055000 4.73436 0.73252463 2.9434988 0.82058173
01057000 0.9145993 0.87551165 2.653656 0.82842886
01073000 0.9153605 0.86096424 2.4363692 0.8324958
01078000 0.67313606 0.88788176 2.2404544 0.8386498
01123000 0.8993281 0.8173188 2.1063418 0.8365167
01134500 1.715929 0.79124093 2.0708497 0.8324007
01137500 2.376047 0.766341 2.0962827 0.8268957
01139000 0.80212957 0.7747908 1.9967326 0.82288766
01139800 0.84583026 0.76805705 1.9145253 0.81897116
01142500 1.1238302 0.77238315 1.8618124 0.8158653
01144000 0.729486 0.85160327 1.791042 0.8180989
01162500 0.7964086 0.84338397 1.732534 0.8195863
01169000 2.2953029 0.7448952 1.7637991 0.8154367
01170100 1.8372197 0.7727082 1.7676632 0.81318784
01181000 1.9368116 0.78878325 1.7761208 0.8119677
01187300 

02479155 6.0705748 0.6767037 1.9745607 0.7358027
02479300 1.6803375 0.8036434 1.9728094 0.73620653
02479560 1.3658031 0.84836113 1.9692175 0.7368701
02481000 6.3722224 0.74335706 1.9951175 0.7369083
02481510 2.4130614 0.80391026 1.9975618 0.73730016
04015330 2.252899 0.66174245 1.9990463 0.7368609
04024430 0.4891839 0.7830616 1.9903187 0.73712784
04027000 0.27036124 0.887639 1.9804341 0.7379929
04040500 0.33802435 0.8444825 1.971049 0.73860145
04043050 0.812655 0.8166368 1.9644669 0.7390448
04045500 0.02412607 0.97218746 1.9535044 0.740362
04057510 0.038344175 0.9494198 1.9427452 0.7415365
04057800 0.47224835 0.7852626 1.9345303 0.7417807
04059500 0.11727201 0.8931821 1.9244344 0.74262184
04063700 0.05912218 0.9114649 1.9141285 0.7435547
04074950 0.03584394 0.86431634 1.9038085 0.7442183
04105700 0.024202451 0.854716 1.8935375 0.744822


KeyboardInterrupt: 