In [708]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import json
import numpy as np
import jax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from timecast.learners import AR
from timecast.learners._ar import _ar_predict, _ar_batch_window
from timecast.utils.numpy import ecdf
from timecast.utils.losses import MeanSquareError
import torch
from numba import njit

from ealstm.main import get_basin_list, load_attributes, Model, GLOBAL_SETTINGS, evaluate

import tqdm.notebook as tqdm

%matplotlib notebook

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [676]:
def load_cfg(cfg_path):
    cfg = json.load(open(cfg_path, "r"))
    cfg["camels_root"] = Path(cfg["camels_root"])
    cfg["run_dir"] = Path(cfg["run_dir"])
    cfg.update(GLOBAL_SETTINGS)
    return cfg

In [716]:
from torch.utils.data import TensorDataset

from timecast.learners import BaseLearner

from ealstm.main import DEVICE
from ealstm.main import evaluate_basin
from ealstm.main import Model
from ealstm.papercode.datautils import reshape_data
        
class FloodLSTM(BaseLearner):
    def __init__(self, cfg_path, input_dim=5, output_dim=1):
        self._input_dim = input_dim,
        self._output_dim = output_dim
        
        self.cfg = load_cfg(cfg_path)
        self.model = Model(input_size_dyn=(5 if (self.cfg["no_static"] or not self.cfg["concat_static"]) else 32),
                           input_size_stat=(0 if self.cfg["no_static"] else 27),
                           hidden_size=self.cfg["hidden_size"],
                           dropout=self.cfg["dropout"],
                           concat_static=self.cfg["concat_static"],
                           no_static=self.cfg["no_static"]).to(DEVICE)
        
        weight_file = os.path.join(self.cfg["run_dir"], "model_epoch30.pt")
        self.model.load_state_dict(torch.load(weight_file, map_location=DEVICE))
        
    def predict(self, X):
        """Assumes we get one basin's data at a time
        """
        y = np.ones((X.shape[0], 1))
        X, y = reshape_data(X, y, self.cfg["seq_length"])
        
        X = torch.from_numpy(X.astype(np.float32))
        y = torch.from_numpy(y.astype(np.float32))
        
        loader = DataLoader(TensorDataset(X, y), batch_size=1024, shuffle=False)
        preds, obs = evaluate_basin(self.model, loader)
        return preds
        
    def update(self, X, y, **kwargs):
        pass

In [710]:
import json
from pathlib import Path
from torch.utils.data import DataLoader

from ealstm.main import GLOBAL_SETTINGS
from ealstm.main import get_basin_list
from ealstm.main import load_attributes
from ealstm.papercode.datasets import CamelsTXT

class FloodData():
    def __init__(self, cfg_path):
        self.cfg =load_cfg(cfg_path)
        self.basins = get_basin_list()
        self.db_path = os.path.join(self.cfg["run_dir"], "attributes.db")
        self.attributes = load_attributes(db_path=self.db_path,
                                          basins=self.basins,
                                          drop_lat_lon=True)
        
    def generator(self, is_train=False, with_attributes=True):
        for basin in self.basins:
            ds_test = CamelsTXT(camels_root=self.cfg["camels_root"],
                                basin=basin,
                                dates=[GLOBAL_SETTINGS["val_start"], GLOBAL_SETTINGS["val_end"]],
                                is_train=is_train,
                                seq_length=self.cfg["seq_length"],
                                with_attributes=True,
                                attribute_means=self.attributes.mean(),
                                attribute_stds=self.attributes.std(),
                                concat_static=self.cfg["concat_static"],
                                db_path=self.db_path,
                                reshape=False,
                                torchify=False
                               )
            X = np.concatenate((ds_test.x, np.tile(np.array(ds_test.attributes), (ds_test.x.shape[0], 1))), axis=1)
            yield X, ds_test.y, basin

In [733]:
ea = pickle.load(open("../ealstm/runs/run_2503_0429_seed283956/lstm_seed283956.p", "rb"))

In [718]:
cfg_path = "/home/dsuo/src/toy_flood/ealstm/runs/run_2503_0429_seed283956/cfg.json"
run_cfg = load_cfg(cfg_path)

In [719]:
data = FloodData(cfg_path)
learner = FloodLSTM(cfg_path)

In [721]:
results = {}
date_range = pd.date_range(start=GLOBAL_SETTINGS["val_start"], end=GLOBAL_SETTINGS["val_end"])
for X, y, basin in tqdm.tqdm(data.generator(), total=len(data.basins)):
    pred = learner.predict(X)
    true = y[run_cfg["seq_length"] - 1:]
    df = pd.DataFrame(data={"qobs": true.ravel(), "qsim": pred.ravel()}, index=date_range)
    results[basin] = df

HBox(children=(FloatProgress(value=0.0, max=531.0), HTML(value='')))




In [736]:
for basin in data.basins:
    print(basin)
    if basin not in results:
        print("Basin {} not in results".format(basin))
    if basin not in ea:
        print("Basin {} not in ea".format(basin))
        continue
    np.testing.assert_array_almost_equal(results[basin], ea[basin], decimal=4)

01022500
01031500
01047000
01052500
01054200
01055000
01057000
01073000
01078000
01123000
01134500
01137500
01139000
01139800
01142500
01144000
01162500
01169000
01170100
01181000
01187300
01195100
04296000
01333000
01350000
01350080
01350140
01365000
01411300
01413500
01414500
01415000
01423000
01434025
01435000
01439500
01440000
01440400
01451800
01466500
01484100
01487000
01491000
01510000
01516500
01518862
01532000
01539000
01542810
01543000
01543500
01544500
01545600
01547700
01548500
01549500
01550000
01552000
01552500
01557500
01567500
01568000
01580000
01583500
01586610
01591400
01594950
01596500
01605500
01606500
01632000
01632900
01634500
01638480
01639500
01644000
01664000
01666500
01667500
01669000
01669520
02011400
02013000
02014000
02015700
02016000
02017500
02018000
02027000
02027500
02028500
02038850
02046000
02051500
02053200
02053800
02055100
02056900
02059500
02064000
02065500
02069700
02070000
02074500
02077200
02081500
02082950
02092500
02096846
02102908
02108000
0