In [1]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import json
import numpy as onp
import jax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from timecast.learners import AR
from timecast.learners._ar import _ar_predict, _ar_batch_window
from timecast.utils.numpy import ecdf
from timecast.utils.losses import MeanSquareError
import torch
import matplotlib

plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm



In [2]:
def load_cfg(cfg_path):
    cfg = json.load(open(cfg_path, "r"))
    cfg["camels_root"] = Path(cfg["camels_root"])
    cfg["run_dir"] = Path(cfg["run_dir"])
    cfg.update(GLOBAL_SETTINGS)
    return cfg

In [3]:
import json
from pathlib import Path

from ealstm.main import GLOBAL_SETTINGS
from ealstm.main import get_basin_list
from ealstm.main import load_attributes
from ealstm.papercode.datasets import CamelsTXT

class FloodData():
    def __init__(self, cfg_path):
        self.cfg =load_cfg(cfg_path)
        self.basins = get_basin_list()
        self.db_path = os.path.join(self.cfg["run_dir"], "attributes.db")
        self.attributes = load_attributes(db_path=self.db_path,
                                          basins=self.basins,
                                          drop_lat_lon=True)
        
    def generator(self, is_train=False, with_attributes=True):
        for basin in self.basins:
            ds_test = CamelsTXT(camels_root=self.cfg["camels_root"],
                                basin=basin,
                                dates=[GLOBAL_SETTINGS["val_start"], GLOBAL_SETTINGS["val_end"]],
                                is_train=is_train,
                                seq_length=self.cfg["seq_length"],
                                with_attributes=True,
                                attribute_means=self.attributes.mean(),
                                attribute_stds=self.attributes.std(),
                                concat_static=self.cfg["concat_static"],
                                db_path=self.db_path,
                                reshape=False,
                                torchify=False
                               )
            X = np.concatenate((ds_test.x, np.tile(np.array(ds_test.attributes), (ds_test.x.shape[0], 1))), axis=1)
            yield X, ds_test.y, basin

In [4]:
from timecast.optim import SGD
from timecast.learners import BaseLearner

LR_AR = 1e-5
AR_INPUT_DIM=32
AR_OUTPUT_DIM=1

In [5]:
import jax.numpy as np
from timecast.utils.losses.core import Loss

class BatchedMeanSquareError(Loss):
    def __init__(self):
        pass

    def compute(self, y_pred: np.ndarray, y_true: np.ndarray):
        return np.mean(np.mean((y_pred - y_true) ** 2, axis=tuple(range(1, y_true.ndim))))

In [6]:
cfg_path = "/home/dsuo/src/toy_flood/ealstm/runs/run_2503_0429_seed283956/cfg.json"
data = FloodData(cfg_path)
lstm_pred = pickle.load(open("../ealstm/runs/run_2503_0429_seed283956/lstm_seed283956.p", "rb"))

In [19]:
import jax.numpy as np
import jax

def batch_window(X, window_size, offset=0):
    num_windows = X.shape[0] - window_size + 1
    return np.swapaxes(np.stack([np.roll(X, shift=-(i + offset), axis=0) for i in range(window_size)]), 0, 1)[:num_windows]

class ARStateless(BaseLearner):
    def __init__(self, input_dim: int, output_dim: int, window_size: int, optimizer=None, loss=None):
        self._input_dim = input_dim
        self._output_dim = output_dim
        self._window_size = window_size
        self._optimizer = optimizer or SGD()
        self._loss = loss or BatchedMeanSquareError()
        
        W = np.zeros((window_size, input_dim, output_dim))
        b = np.zeros((output_dim, 1))
        self._params = {"W": W, "b": b}
        
        def _predict(params, x):
            print(params["W"].shape, x.shape)
            return np.tensordot(params["W"], x, ([0, 1], [0, 1])) + params["b"]
        
        self._predict_jit = jax.jit(lambda params, X: _predict(params, X))
        self._grad = jax.jit(jax.grad(lambda params, X, y: self._loss.compute(_predict(params, X), y)))
        self._value_and_grad = jax.jit(jax.value_and_grad(lambda params, X, y: self._loss.compute(_predict(params, X), y)))
        
        def _predict_and_update(params, xy):
            value = self._predict_jit(params, xy[0])
            gradients = self._grad(params, xy[0], xy[1])
            params = self._optimizer.update(params, gradients)
            return params, value
        
        self._predict_and_update_jit = jax.jit(lambda params, xy: _predict_and_update(params, xy))

    def predict(self, X):
        return jax.vmap(self._predict_jit, in_axes=(None, 0))(self._params, X).reshape(-1, 1)

    def update(self, X, y):
        gradients = jax.vmap(self._grad, in_axes=({"W": None, "b": None}, 0, 0), out_axes=0)(self._params, X, y)
        gradients["W"] = gradients["W"].mean(axis=0)
        gradients["b"] = gradients["b"].mean(axis=0)
        self._params = self._optimizer.update(self._params, gradients)
        
    def predict_and_update(self, X, y):
        # TODO: think about batching
        self._params, value = jax.lax.scan(self._predict_and_update_jit, self._params, (X, y))
        return value.reshape(-1, 1)

In [20]:
# %%timeit
# %%prun -q -D out.pstats
results = {}
mses = []
nses = []
for X, _, basin in tqdm.tqdm(data.generator(), total=len(data.basins)):
    sgd = SGD(learning_rate=LR_AR, online=False)
    ar = ARStateless(input_dim=AR_INPUT_DIM, output_dim=AR_OUTPUT_DIM, window_size=data.cfg["seq_length"], optimizer=sgd)
    X = batch_window(X, data.cfg["seq_length"])
    Y = np.array(lstm_pred[basin].qobs).reshape(-1, 1)
    
    Y_lstm = np.array(lstm_pred[basin].qsim).reshape(-1, 1)
    Y_target = Y - Y_lstm
    
    Y_ar = ar.predict_and_update(X, Y_target)
    
    Y_hat = Y_lstm + Y_ar
    
    mse = ((Y - Y_hat) ** 2).mean()
    nse = 1 - ((Y - Y_hat) ** 2).sum() / ((Y - Y.mean()) ** 2).sum()
    results[basin] = {
        "mse": mse,
        "nse": nse,
        "count": X.shape[0],
        "avg_mse": np.mean(np.array(mses)),
        "avg_nse": np.mean(np.array(nses))
    }
    mses.append(mse)
    nses.append(nse)
    print(Y.sum(), Y_lstm.sum(), Y_ar.sum(), Y_hat.sum())
    print(basin, mse, nse, np.mean(np.array(mses)), np.mean(np.array(nses)))


HBox(children=(FloatProgress(value=0.0, max=531.0), HTML(value='')))

(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
7344.714 7240.897 97.372696 7338.269
01022500 0.59074104 0.8800156 0.59074104 0.8800156
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
7141.21 7755.338 -610.79047 7144.5474
01031500 1.002769 0.8932544 0.796755 0.886635
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
7770.477 7483.0864 287.93332 7771.0195
01047000 1.4249041 0.8664847 1.006138 0.8799183
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
8346.02 8322.72 22.376762 8345.097
01052500 1.7676423 0.8650273 1.1965141 0.87619555
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
9908.52 10008.48 -117.526215 9890.953
01054200 7.3498154 0.7506206 2.4271743 0.85108054
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
8030.0264 7952.967 65.19052 8018.157
01055000 4.5261536 0.7442876 2.7770042 0.8332817
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
6163.7393 6422.6855 -259.26996 6163.4155
01057000 0.9297008 0.8734561 2.5131037 0.83902085
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
5748.851 5532.0293 215.

5444.896 5458.4395 -27.189135 5431.2505
01568000 1.7282537 0.74545836 1.7975678 0.7781133
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
4992.863 4331.603 663.65076 4995.2534
01580000 0.8801472 0.6436259 1.7830055 0.77597857
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
4360.011 3596.8096 760.23865 4357.048
01583500 0.7091803 0.6047565 1.766227 0.7733033
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
4660.7764 4187.607 478.14926 4665.7554
01586610 1.1733024 0.6358969 1.7571052 0.77118933
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
4061.6155 3968.764 95.43463 4064.1982
01591400 1.4696093 0.62457514 1.7527492 0.7689679
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
6689.5376 7112.8335 -421.19565 6691.6387
01594950 2.060491 0.75200677 1.7573425 0.7687148
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
5637.326 5487.597 151.68364 5639.2793
01596500 2.1027353 0.69710195 1.7624217 0.76766163
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
3794.5266 4154.092 -354.17004 3799.9219
01605500 1.1939986

6612.3774 5396.2754 1223.209 6619.485
02149000 1.200001 0.64963233 1.86163 0.7280423
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
5433.0596 5699.3013 -262.80432 5436.498
02152100 1.3299211 0.6501297 1.857342 0.727414
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
11947.078 11188.832 749.9151 11938.747
02177000 0.8619463 0.86133087 1.849379 0.72848535
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
12337.602 11751.957 582.3672 12334.327
02178400 1.4640473 0.82753104 1.8463206 0.7292714
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
3308.6133 3179.244 131.4032 3310.6467
02193340 1.8416048 0.73942924 1.8462836 0.7293515
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
2882.0447 2418.2014 465.42456 2883.626
02196000 1.1130134 0.7577707 1.8405547 0.7295735
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
3929.4368 2828.885 1104.6884 3933.5735
02198100 3.9394422 0.5943083 1.8568254 0.72852486
(270, 32, 1) (270, 32)
(270, 32, 1) (270, 32)
3564.7717 3452.8706 111.92069 3564.7913
02202600 0.9459475 0.80

KeyboardInterrupt: 