In [29]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import json
import numpy as np
import jax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from timecast.learners import AR
from timecast.learners._ar import _ar_predict, _ar_batch_window
from timecast.utils.numpy import ecdf
from timecast.utils.losses import MeanSquareError
import torch

from ealstm.main import get_basin_list, load_attributes, Model, GLOBAL_SETTINGS, evaluate

import tqdm.notebook as tqdm

%matplotlib notebook

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Wrap EALSTM code

In [10]:
def load_cfg(cfg_path):
    cfg = json.load(open(cfg_path, "r"))
    cfg["camels_root"] = Path(cfg["camels_root"])
    cfg["run_dir"] = Path(cfg["run_dir"])
    cfg.update(GLOBAL_SETTINGS)
    return cfg

In [4]:
from torch.utils.data import TensorDataset

from timecast.learners import BaseLearner

from ealstm.main import DEVICE
from ealstm.main import evaluate_basin
from ealstm.main import Model
from ealstm.papercode.datautils import reshape_data
        
class FloodLSTM(BaseLearner):
    def __init__(self, cfg_path, input_dim=5, output_dim=1):
        self._input_dim = input_dim,
        self._output_dim = output_dim
        
        self.cfg = load_cfg(cfg_path)
        self.model = Model(input_size_dyn=(5 if (self.cfg["no_static"] or not self.cfg["concat_static"]) else 32),
                           input_size_stat=(0 if self.cfg["no_static"] else 27),
                           hidden_size=self.cfg["hidden_size"],
                           dropout=self.cfg["dropout"],
                           concat_static=self.cfg["concat_static"],
                           no_static=self.cfg["no_static"]).to(DEVICE)
        
        weight_file = os.path.join(self.cfg["run_dir"], "model_epoch30.pt")
        self.model.load_state_dict(torch.load(weight_file, map_location=DEVICE))
        
    def predict(self, X):
        """Assumes we get one basin's data at a time
        """
        y = np.ones((X.shape[0], 1))
        X, y = reshape_data(X, y, self.cfg["seq_length"])
        
        X = torch.from_numpy(X.astype(np.float32))
        y = torch.from_numpy(y.astype(np.float32))
        
        loader = DataLoader(TensorDataset(X, y), batch_size=1024, shuffle=False)
        preds, obs = evaluate_basin(self.model, loader)
        return preds
        
    def update(self, X, y, **kwargs):
        pass

In [5]:
import json
from pathlib import Path
from torch.utils.data import DataLoader

from ealstm.main import GLOBAL_SETTINGS
from ealstm.main import get_basin_list
from ealstm.main import load_attributes
from ealstm.papercode.datasets import CamelsTXT

class FloodData():
    def __init__(self, cfg_path):
        self.cfg =load_cfg(cfg_path)
        self.basins = get_basin_list()
        self.db_path = os.path.join(self.cfg["run_dir"], "attributes.db")
        self.attributes = load_attributes(db_path=self.db_path,
                                          basins=self.basins,
                                          drop_lat_lon=True)
        
    def generator(self, is_train=False, with_attributes=True):
        for basin in self.basins:
            ds_test = CamelsTXT(camels_root=self.cfg["camels_root"],
                                basin=basin,
                                dates=[GLOBAL_SETTINGS["val_start"], GLOBAL_SETTINGS["val_end"]],
                                is_train=is_train,
                                seq_length=self.cfg["seq_length"],
                                with_attributes=True,
                                attribute_means=self.attributes.mean(),
                                attribute_stds=self.attributes.std(),
                                concat_static=self.cfg["concat_static"],
                                db_path=self.db_path,
                                reshape=False,
                                torchify=False
                               )
            X = np.concatenate((ds_test.x, np.tile(np.array(ds_test.attributes), (ds_test.x.shape[0], 1))), axis=1)
            yield X, ds_test.y, basin

In [8]:
ea = pickle.load(open("../ealstm/runs/run_2503_0429_seed283956/lstm_seed283956.p", "rb"))

In [746]:
cfg_path = "/home/dsuo/src/toy_flood/ealstm/runs/run_2503_0429_seed283956/cfg.json"
run_cfg = load_cfg(cfg_path)

In [747]:
data = FloodData(cfg_path)
flood_lstm = FloodLSTM(cfg_path)

In [721]:
results = {}
date_range = pd.date_range(start=GLOBAL_SETTINGS["val_start"], end=GLOBAL_SETTINGS["val_end"])
for X, y, basin in tqdm.tqdm(data.generator(), total=len(data.basins)):
    pred = flood_lstm.predict(X)
    true = y[run_cfg["seq_length"] - 1:]
    df = pd.DataFrame(data={"qobs": true.ravel(), "qsim": pred.ravel()}, index=date_range)
    results[basin] = df

HBox(children=(FloatProgress(value=0.0, max=531.0), HTML(value='')))




In [737]:
for basin in tqdm.tqdm(data.basins):
    np.testing.assert_array_almost_equal(results[basin], ea[basin], decimal=4)

HBox(children=(FloatProgress(value=0.0, max=531.0), HTML(value='')))




# Replicating `flood_prediction`

In [271]:
from timecast.optim import SGD
from timecast.learners import BaseLearner

REG = 0.0
TRAINING_STEPS = 1e6
BATCH_SIZE = 1
SEQUENCE_LENGTH = 270
HIDDEN_DIM = 256
DP_RATE = 0.0
LR_AR = 1e-5
AR_INPUT_DIM=32
AR_OUTPUT_DIM=1

In [80]:
# import jax.numpy as np
from timecast.utils.losses.core import Loss

class BatchedMeanSquareError(Loss):
    def __init__(self):
        pass

    def compute(self, y_pred: np.ndarray, y_true: np.ndarray):
        return np.mean(np.mean((y_pred - y_true) ** 2, axis=tuple(range(1, y_true.ndim))))

In [260]:
import jax.numpy as np
import jax

def batch_window(X, window_size, offset=0):
    num_windows = X.shape[0] - window_size + 1
    return np.swapaxes(np.stack([np.roll(X, shift=-(i + offset), axis=0) for i in range(window_size)]), 0, 1)[:num_windows]

class ARStateless(BaseLearner):
    def __init__(self, input_dim: int, output_dim: int, window_size: int, optimizer=None, loss=None):
        self._input_dim = input_dim
        self._output_dim = output_dim
        self._window_size = window_size
        self._optimizer = optimizer or SGD()
        self._loss = loss or BatchedMeanSquareError()
        
        W = np.zeros((window_size, input_dim, output_dim))
        b = np.zeros((output_dim, 1))
        self._params = {"W": W, "b": b}
        
        def _predict(params, x):
            print(params["W"].shape, params["b"].shape, x.shape)
            return np.tensordot(params["W"], x, ([0, 1], [0, 1])) + params["b"]
        
        self._predict_jit = jax.jit(lambda params, X: _predict(params, X))
        
        self._grad = jax.jit(jax.grad(lambda params, X, y: self._loss.compute(self._predict_jit(params, X), y)))

    def predict(self, X):
        return jax.vmap(self._predict_jit, in_axes=(None, 0))(self._params, X).reshape(-1, 1)

    def update(self, X, y):
        gradients = jax.vmap(self._grad, in_axes=({"W": None, "b": None}, 0, 0), out_axes=0)(self._params, X, y)
        gradients["W"] = gradients["W"].mean(axis=0)
        gradients["b"] = gradients["b"].mean(axis=0)
        self._params = self._optimizer.update(self._params, gradients)

In [251]:
data = FloodData(cfg_path)
lstm_pred = pickle.load(open("../ealstm/runs/run_2503_0429_seed283956/lstm_seed283956.p", "rb"))

In [263]:
X, y, basin = next(data.generator())

In [265]:
ar._predict_jit(ar._params, batch_window(X, data.cfg["seq_length"])[0])

(270, 32, 1) (1, 1) (270, 32)


DeviceArray([[0.]], dtype=float32)

In [266]:
lstm_pred[basin].shape

(3652, 2)

In [None]:
results = {}
mses = []
nses = []
for X, y, basin in tqdm.tqdm(data.generator(), total=len(data.basins)):
    # We don't need a new SGD each basin; there is no state
    sgd = SGD(learning_rate=LR_AR, online=False)
    
    # Q: why are we starting with new AR each basin?
    ar = ARStateless(input_dim=AR_INPUT_DIM, output_dim=AR_OUTPUT_DIM, window_size=data.cfg["seq_length"], optimizer=sgd)
    
    # Batch data
    X = batch_window(X, data.cfg["seq_length"])
    Y = np.array(lstm_pred[basin].qobs).reshape(-1, 1)
    
    Y_hat = np.array(lstm_pred[basin].qsim).reshape(-1, 1)
    Y_target = Y - Y_hat

    for i in range(0, X.shape[0], BATCH_SIZE):
        x = X[i : i + BATCH_SIZE, :, :]
        y = Y_target[i : i + BATCH_SIZE, :]

        # TODO: we essentially run predict twice
        y_ar = ar.predict(x)
        ar.update(x, y)
        
        Y_hat = jax.ops.index_add(Y_hat, jax.ops.index[i : i + BATCH_SIZE, :], y_ar)
    
    mse = ((Y - Y_hat) ** 2).mean()
    nse = 1 - ((Y - Y_hat) ** 2).sum() / ((Y - Y.mean()) ** 2).sum()
    results[basin] = {
        "mse": mse,
        "nse": nse,
        "count": X.shape[0]
    }
    mses.append(mse)
    nses.append(nse)
    print(Y.sum(), Y_hat.sum())
    print(basin, mse, nse, np.mean(np.array(mses)), np.mean(np.array(nses)))

HBox(children=(FloatProgress(value=0.0, max=531.0), HTML(value='')))

(270, 32, 1) (1, 1) (270, 32)
(270, 32, 1) (1, 1) (270, 32)
7344.714 7338.269
01022500 0.59074104 0.8800156 0.59074104 0.8800156
(270, 32, 1) (1, 1) (270, 32)
(270, 32, 1) (1, 1) (270, 32)
7141.21 7144.5474
01031500 1.0027689 0.8932544 0.79675496 0.886635
(270, 32, 1) (1, 1) (270, 32)
(270, 32, 1) (1, 1) (270, 32)
7770.477 7771.0195
01047000 1.424904 0.86648476 1.006138 0.8799183
(270, 32, 1) (1, 1) (270, 32)
(270, 32, 1) (1, 1) (270, 32)
8346.02 8345.097
01052500 1.7676423 0.8650273 1.1965141 0.87619555
(270, 32, 1) (1, 1) (270, 32)
(270, 32, 1) (1, 1) (270, 32)
9908.52 9890.953
01054200 7.349815 0.7506206 2.4271743 0.85108054
(270, 32, 1) (1, 1) (270, 32)
(270, 32, 1) (1, 1) (270, 32)
8030.0264 8018.157
01055000 4.5261536 0.7442876 2.7770042 0.8332817
(270, 32, 1) (1, 1) (270, 32)
(270, 32, 1) (1, 1) (270, 32)
6163.7393 6163.4155
01057000 0.9297006 0.8734561 2.5131037 0.83902085
(270, 32, 1) (1, 1) (270, 32)
(270, 32, 1) (1, 1) (270, 32)
5748.851 5747.0854
01073000 0.8854102 0.865513

# Questions
- Why `batched_mse` for SGD loss? Is it because of windows?
- Are we sure truncating end for batches the right thing to do?
- New AR for each basin?