In [1]:
import os
import pickle

import jax.numpy as np
import tqdm

from timecast.utils.ar import historify

from tigerforecast.batch.camels_dataloader import get_basin_list
from tigerforecast.batch.camels_dataloader import CamelsTXT
from tigerforecast.methods.ARStateless_scan import ARStateless_scan
from tigerforecast.methods.Gradient_boosting import Gradient_boosting
from tigerforecast.utils.download_tools import get_tigerforecast_dir
from tigerforecast.utils.optimizers import SGD
from tigerforecast.utils.optimizers.losses import batched_mse, mse



In [None]:
# https://ttic.uchicago.edu/~tewari/lectures/lecture4.pdf

In [2]:
basins = pickle.load(open("../data/flood/meta.pkl", "rb"))["basins"]

In [7]:
basin_to_yhats_LSTM = pickle.load(open(os.path.join(get_tigerforecast_dir(), "flood_prediction", "basin_to_yhats_LSTM"), "rb"))

In [None]:
# for basin in tqdm.tqdm(basins):
#     usgs_val = CamelsTXT(basin=basin, concat_static=True)
#     for data, targets in usgs_val.sequential_batches(batch_size=5000):
#         pickle.dump(data, open("../data/flood/test/{}.pkl".format(basin), "wb"))
#         pickle.dump(targets, open("../data/flood/qobs/{}.pkl".format(basin), "wb"))

In [None]:
# def scan_with_W(yhats_LSTM, method_ar, data, targets, dynamic=False):
#     yhats_boosted, ys = None, None
# #     print(yhats_LSTM.shape)
# #     for data, targets in usgs_val.sequential_batches(batch_size=5000):
# #         print(data.shape, targets.shape)
# #         print(data[0], targets[0])
#     data = np.array(data)
#     ys = np.array(targets)
#     y_true = targets - yhats_LSTM
#     y_pred_ar, W = method_ar.predict_and_update(data, y_true)
#     yhats_boosted = yhats_LSTM + y_pred_ar.squeeze()
#     return yhats_boosted, ys, W

In [9]:
for basin in tqdm.tqdm(basins):
    SEQUENCE_LENGTH = 270
    INPUT_DIM = 32

    b_threshold = 1e-4
    eta = 0.008

    W_lr_best_pairs = [
            (0.03, 2e-5),
            (0.05, 2e-5),
            (0.07, 2e-5),
            (0.09, 2e-5),
        ]

    method_ids = []

    for W_threshold, lr in W_lr_best_pairs:
        project_threshold = {"W_lnm": W_threshold, "b": b_threshold}
        optim_ar = SGD(loss=batched_mse, learning_rate=lr, clip_grad=False)
        method_ar = ARStateless_scan()
        method_ar.initialize(
            n=INPUT_DIM,
            m=1,
            l=SEQUENCE_LENGTH,
            optimizer=optim_ar,
            project_threshold=project_threshold,
            scan_mode=True,
        )
        method_ids.append(method_ar)

    method_boosting = Gradient_boosting()
    method_boosting.initialize(
        method_ids,
        X_shape=(270, 32),
        Y_shape=(),
        loss=batched_mse,
        eta=eta,
        proxy_loss="original",
        W_update_rule="GECO",
    )

    yhats_LSTM = np.array(basin_to_yhats_LSTM[basin])
    X = pickle.load(open("../data/flood/test/{}.pkl".format(basin), "rb"))
    Y = pickle.load(open("../data/flood/qobs/{}.pkl".format(basin), "rb"))
    
    y_true = Y - yhats_LSTM
    y_pred_ar, W = method_boosting.predict_and_update(X, y_true)
    yhats = yhats_LSTM + y_pred_ar.squeeze()

#     W_entropy = float(-1 * np.sum(W * np.log2(W)))

    loss = ((Y - yhats) ** 2).mean()
#     ys_mean = Y.mean()

#     nse = 1 - ((Y - yhats) ** 2).sum() / ((Y - ys_mean) ** 2).sum()
#     loss, nse = float(loss), float(nse)
    
    print(basin, loss)


  0%|          | 0/531 [00:00<?, ?it/s]

01022500 

  0%|          | 1/531 [00:02<20:02,  2.27s/it]

0.48434398
01031500 

  0%|          | 2/531 [00:04<20:21,  2.31s/it]

0.9334514
01047000 

  1%|          | 3/531 [00:06<20:09,  2.29s/it]

1.2708772
01052500 

  1%|          | 4/531 [00:09<20:05,  2.29s/it]

1.5983096
01054200 

  1%|          | 4/531 [00:11<25:17,  2.88s/it]


KeyboardInterrupt: 

In [None]:
for basin in tqdm.tqdm(basins):
    SEQUENCE_LENGTH = 270
    INPUT_DIM = 32

    b_threshold = 1e-4

    W_lr_best_pairs = [
            (0.03, 2e-5),
            (0.05, 2e-5),
            (0.07, 2e-5),
            (0.09, 2e-5),
        ]

    method_ids = []

    for W_threshold, lr in W_lr_best_pairs:
        project_threshold = {"W_lnm": W_threshold, "b": b_threshold}
        optim_ar = SGD(loss=batched_mse, learning_rate=lr, clip_grad=False)
        method_ar = ARStateless_scan()
        method_ar.initialize(
            n=INPUT_DIM,
            m=1,
            l=SEQUENCE_LENGTH,
            optimizer=optim_ar,
            project_threshold=project_threshold,
            scan_mode=True,
        )
        method_ids.append(method_ar)

    method_boosting = Gradient_boosting()
    method_boosting.initialize(
        method_ids,
        X_shape=(270, 32),
        Y_shape=(),
        loss=batched_mse,
        eta=eta,
        proxy_loss="original",
        W_update_rule="GECO",
    )

    usgs_val = CamelsTXT(basin=basin, concat_static=True)
    yhats_LSTM = np.array(basin_to_yhats_LSTM[basin])

    yhats, ys, W = scan_with_W(
        yhats_LSTM, method_boosting, usgs_val, dynamic=False
    )
    W_entropy = float(-1 * np.sum(W * np.log2(W)))

    loss = ((ys - yhats) ** 2).mean()
    ys_mean = ys.mean()

    nse = 1 - ((ys - yhats) ** 2).sum() / ((ys - ys_mean) ** 2).sum()
    loss, nse = float(loss), float(nse)
