In [6]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import json
import numpy as onp
import jax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from timecast.learners import AR
from timecast.learners._ar import _ar_predict, _ar_batch_window
from timecast.utils.numpy import ecdf
from timecast.utils.losses import MeanSquareError
import torch
import matplotlib

plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
def load_cfg(cfg_path):
    cfg = json.load(open(cfg_path, "r"))
    cfg["camels_root"] = Path(cfg["camels_root"])
    cfg["run_dir"] = Path(cfg["run_dir"])
    cfg.update(GLOBAL_SETTINGS)
    return cfg

In [3]:
import json
from pathlib import Path

from ealstm.main import GLOBAL_SETTINGS
from ealstm.main import get_basin_list
from ealstm.main import load_attributes
from ealstm.papercode.datasets import CamelsTXT

class FloodData():
    def __init__(self, cfg_path):
        self.cfg =load_cfg(cfg_path)
        self.basins = get_basin_list()
        self.db_path = os.path.join(self.cfg["run_dir"], "attributes.db")
        self.attributes = load_attributes(db_path=self.db_path,
                                          basins=self.basins,
                                          drop_lat_lon=True)
        
    def generator(self, is_train=False, with_attributes=True):
        for basin in self.basins:
            ds_test = CamelsTXT(camels_root=self.cfg["camels_root"],
                                basin=basin,
                                dates=[GLOBAL_SETTINGS["val_start"], GLOBAL_SETTINGS["val_end"]],
                                is_train=is_train,
                                seq_length=self.cfg["seq_length"],
                                with_attributes=True,
                                attribute_means=self.attributes.mean(),
                                attribute_stds=self.attributes.std(),
                                concat_static=self.cfg["concat_static"],
                                db_path=self.db_path,
                                reshape=False,
                                torchify=False
                               )
            X = np.concatenate((ds_test.x, np.tile(np.array(ds_test.attributes), (ds_test.x.shape[0], 1))), axis=1)
            yield X, ds_test.y, basin

In [4]:
from timecast.optim import SGD
from timecast.learners import BaseLearner

REG = 0.0
TRAINING_STEPS = 1e6
BATCH_SIZE = 1
SEQUENCE_LENGTH = 270
HIDDEN_DIM = 256
DP_RATE = 0.0
LR_AR = 1e-5
AR_INPUT_DIM=32
AR_OUTPUT_DIM=1

In [5]:
import jax.numpy as np
from timecast.utils.losses.core import Loss

class BatchedMeanSquareError(Loss):
    def __init__(self):
        pass

    def compute(self, y_pred: np.ndarray, y_true: np.ndarray):
        return np.mean(np.mean((y_pred - y_true) ** 2, axis=tuple(range(1, y_true.ndim))))

In [306]:
import jax.numpy as np
import jax

def batch_window(X, window_size, offset=0):
    num_windows = X.shape[0] - window_size + 1
    return np.swapaxes(np.stack([np.roll(X, shift=-(i + offset), axis=0) for i in range(window_size)]), 0, 1)[:num_windows]

class ARStateless(BaseLearner):
    def __init__(self, input_dim: int, output_dim: int, window_size: int, optimizer=None, loss=None):
        self._input_dim = input_dim
        self._output_dim = output_dim
        self._window_size = window_size
        self._optimizer = optimizer or SGD()
        self._loss = loss or BatchedMeanSquareError()
        
        W = np.zeros((window_size, input_dim, output_dim))
        b = np.zeros((output_dim, 1))
        self._params = {"W": W, "b": b}
        
        def _predict(params, x):
            return np.tensordot(params["W"], x, ([0, 1], [0, 1])) + params["b"]
        
        self._predict_jit = jax.jit(lambda params, X: _predict(params, X))
        self._grad = jax.jit(jax.grad(lambda params, X, y: self._loss.compute(_predict(params, X), y)))
        self._value_and_grad = jax.jit(jax.value_and_grad(lambda params, X, y: self._loss.compute(_predict(params, X), y)))
        
        def _predict_and_update(params, xy):
            value = self._predict_jit(params, xy[0])
            gradients = self._grad(params, xy[0], xy[1])
            params = self._optimizer.update(params, gradients)
            return params, value
        
        self._predict_and_update_jit = jax.jit(lambda params, xy: _predict_and_update(params, xy))

    def predict(self, X):
        return jax.vmap(self._predict_jit, in_axes=(None, 0))(self._params, X).reshape(-1, 1)

    def update(self, X, y):
        gradients = jax.vmap(self._grad, in_axes=({"W": None, "b": None}, 0, 0), out_axes=0)(self._params, X, y)
        gradients["W"] = gradients["W"].mean(axis=0)
        gradients["b"] = gradients["b"].mean(axis=0)
        self._params = self._optimizer.update(self._params, gradients)
        
    def predict_and_update(self, X, y):
        # TODO: think about batching
        self._params, value = jax.lax.scan(self._predict_and_update_jit, self._params, (X, y))
        return value.reshape(-1, 1)

In [307]:
cfg_path = "/home/dsuo/src/toy_flood/ealstm/runs/run_2503_0429_seed283956/cfg.json"
data = FloodData(cfg_path)
lstm_pred = pickle.load(open("../ealstm/runs/run_2503_0429_seed283956/lstm_seed283956.p", "rb"))

In [308]:
# %%timeit
# %%prun -q -D out.pstats
results = {}
mses = []
nses = []
for X, _, basin in tqdm.tqdm(data.generator(), total=len(data.basins)):
    sgd = SGD(learning_rate=LR_AR, online=False)
    ar = ARStateless(input_dim=AR_INPUT_DIM, output_dim=AR_OUTPUT_DIM, window_size=data.cfg["seq_length"], optimizer=sgd)
    X = batch_window(X, data.cfg["seq_length"])
    Y = np.array(lstm_pred[basin].qobs).reshape(-1, 1)
    
    Y_lstm = np.array(lstm_pred[basin].qsim).reshape(-1, 1)
    Y_target = Y - Y_lstm
    
    Y_ar = ar.predict_and_update(X, Y_target)
    
    Y_hat = Y_lstm + Y_ar
    
    mse = ((Y - Y_hat) ** 2).mean()
    nse = 1 - ((Y - Y_hat) ** 2).sum() / ((Y - Y.mean()) ** 2).sum()
    results[basin] = {
        "mse": mse,
        "nse": nse,
        "count": X.shape[0],
        "avg_mse": np.mean(np.array(mses)),
        "avg_nse": np.mean(np.array(nses))
    }
    mses.append(mse)
    nses.append(nse)
    print(Y.sum(), Y_lstm.sum(), Y_ar.sum(), Y_hat.sum())
    print(basin, mse, nse, np.mean(np.array(mses)), np.mean(np.array(nses)))


HBox(children=(FloatProgress(value=0.0, max=531.0), HTML(value='')))

(270, 32)
(270, 32)
7344.714 7240.897 97.372696 7338.269
01022500 0.59074104 0.8800156 0.59074104 0.8800156
(270, 32)
(270, 32)
7141.21 7755.338 -610.79047 7144.548
01031500 1.0027689 0.8932544 0.79675496 0.886635
(270, 32)
(270, 32)


KeyboardInterrupt: 