In [407]:
import os
import pickle

import jax
import jax.numpy as np

import tqdm

In [408]:
# https://ttic.uchicago.edu/~tewari/lectures/lecture4.pdf

# from tigerforecast.batch.camels_dataloader import CamelsTXT
# for basin in tqdm.tqdm(basins):
#     usgs_val = CamelsTXT(basin=basin, concat_static=True)
#     for data, targets in usgs_val.sequential_batches(batch_size=5000):
#         pickle.dump(data, open("../data/flood/test/{}.pkl".format(basin), "wb"))
#         pickle.dump(targets, open("../data/flood/qobs/{}.pkl".format(basin), "wb"))

In [409]:
basins = pickle.load(open("../data/flood/meta.pkl", "rb"))["basins"]

In [410]:
from tigerforecast.utils.download_tools import get_tigerforecast_dir
basin_to_yhats_LSTM = pickle.load(open(os.path.join(get_tigerforecast_dir(), "flood_prediction", "basin_to_yhats_LSTM"), "rb"))

In [411]:
class SGD:
    def __init__(self, pred=None, learning_rate=0.0001, project_threshold={}):
        self.lr = learning_rate
        self.project_threshold = project_threshold
        self.grad = jax.jit(jax.grad(lambda params, x, prev, true: np.square(pred(params, x) + prev - true).sum()))

    def update(self, params, x, prev, true):
        grad = self.grad(params, x, prev, true)
        new_params = {k:w - self.lr * grad[k] for (k, w) in params.items()}
        
        for k, param in new_params.items():
            new_params[k] = jax.lax.cond(np.linalg.norm(new_params[k]) > self.project_threshold[k],
                                          None,
                                          lambda x : (self.project_threshold[k]/np.linalg.norm(new_params[k])) * new_params[k],
                                          new_params[k],
                                          lambda x : x)
        return new_params

In [412]:
class AR:
    def __init__(self, input_dim=1, output_dim=1, history_len=32):
        self.params = {'W_lnm': np.zeros((history_len, input_dim, output_dim)), 'b': np.zeros((output_dim, 1))}
        self.predict = jax.jit(lambda params, x: np.tensordot(params['W_lnm'], x, ([0,1],[0,1])) + params['b'])

In [413]:
class Gradient_boosting:
    def __init__(self, 
                 methods,
                 optimizers,
                 eta=1.0):
        
        self.N = len(methods)
        self.methods = methods
        self.optimizers = optimizers
        
        loss_W_grad = jax.jit(jax.grad(lambda W, preds, y: np.square(np.dot(W, preds).sum() - y).sum()))

        def _predict_and_update(carry, xy):
            x, y = xy
            W, all_params = carry
            
            pred, yhats = 0, []
            for i in range(self.N):
                pred_i = self.methods[i].predict(all_params[i], x).squeeze()
                yhats.append(pred_i)
                all_params[i] = self.optimizers[i].update(all_params[i], x, pred, y)
                pred += W[i] * pred_i
                
            loss_W_grads = loss_W_grad(W, np.asarray(yhats), y)
            nums = W * np.exp(-1 * eta * loss_W_grads)
            W = nums / nums.sum()

            return (W, all_params), pred
        self._predict_and_update = _predict_and_update

    def predict_and_update(self, X, Y):
        # NOTE: we don't update method params after this scan
        _, preds = jax.lax.scan(self._predict_and_update, 
                                (np.ones(self.N) / self.N,
                                 [method.params for method in self.methods]), 
                                (X, Y))
        return preds

In [414]:
for basin in tqdm.tqdm(basins):
    SEQUENCE_LENGTH = 270
    INPUT_DIM = 32

    b_threshold = 1e-4
    eta = 0.008

    W_lr_best_pairs = [
            (0.03, 2e-5),
            (0.05, 2e-5),
            (0.07, 2e-5),
            (0.09, 2e-5),
        ]

    methods, optimizers = [], []

    for W_threshold, lr in W_lr_best_pairs:
        project_threshold = {"W_lnm": W_threshold, "b": b_threshold}
        
        method_ar = AR(
            input_dim=INPUT_DIM,
            output_dim=1,
            history_len=SEQUENCE_LENGTH
        )
        optim_ar = SGD(pred = method_ar.predict, learning_rate=lr, project_threshold=project_threshold)
        optimizers.append(optim_ar)
        methods.append(method_ar)

    method_boosting = Gradient_boosting(
        methods,
        optimizers,
        eta=eta,
    )

    yhats_LSTM = np.array(basin_to_yhats_LSTM[basin])
    X = pickle.load(open("../data/flood/test/{}.pkl".format(basin), "rb"))
    Y = pickle.load(open("../data/flood/qobs/{}.pkl".format(basin), "rb"))
    
    y_true = Y - yhats_LSTM
    y_pred_ar = method_boosting.predict_and_update(X, y_true)
    yhats = yhats_LSTM + y_pred_ar.squeeze()

    loss = ((Y - yhats) ** 2).mean()
    
    print(basin, loss)


  0%|          | 0/531 [00:00<?, ?it/s]

01022500 

  0%|          | 1/531 [00:02<20:07,  2.28s/it]

0.48434398
01031500 

  0%|          | 2/531 [00:04<20:12,  2.29s/it]

0.93345135
01047000 

  1%|          | 3/531 [00:06<20:00,  2.27s/it]

1.2708772
01052500 

  1%|          | 4/531 [00:09<19:52,  2.26s/it]

1.5983098
01054200 

  1%|          | 5/531 [00:11<19:48,  2.26s/it]

8.211933
01055000 

  1%|          | 6/531 [00:13<19:41,  2.25s/it]

5.189535
01057000 

  1%|▏         | 7/531 [00:15<19:39,  2.25s/it]

0.9874269
01073000 

  2%|▏         | 8/531 [00:18<19:45,  2.27s/it]

0.8832621
01078000 

  2%|▏         | 9/531 [00:20<19:47,  2.27s/it]

0.66164553


  2%|▏         | 9/531 [00:22<21:18,  2.45s/it]


KeyboardInterrupt: 