In [407]:
import os
import pickle

import jax
import jax.numpy as np

import tqdm

In [408]:
# https://ttic.uchicago.edu/~tewari/lectures/lecture4.pdf

# from tigerforecast.batch.camels_dataloader import CamelsTXT
# for basin in tqdm.tqdm(basins):
#     usgs_val = CamelsTXT(basin=basin, concat_static=True)
#     for data, targets in usgs_val.sequential_batches(batch_size=5000):
#         pickle.dump(data, open("../data/flood/test/{}.pkl".format(basin), "wb"))
#         pickle.dump(targets, open("../data/flood/qobs/{}.pkl".format(basin), "wb"))

In [409]:
basins = pickle.load(open("../data/flood/meta.pkl", "rb"))["basins"]

In [410]:
from tigerforecast.utils.download_tools import get_tigerforecast_dir
basin_to_yhats_LSTM = pickle.load(open(os.path.join(get_tigerforecast_dir(), "flood_prediction", "basin_to_yhats_LSTM"), "rb"))

In [473]:
class SGD:
    def __init__(self, pred=None, learning_rate=0.0001, project_threshold={}):
        self.lr = learning_rate
        self.project_threshold = project_threshold
        self.grad = jax.jit(jax.grad(lambda params, x, y: np.square(pred(params, x) - y).sum()))

    def update(self, params, x, y):
        grad = self.grad(params, x, y)
        new_params = {k:w - self.lr * grad[k] for (k, w) in params.items()}
        
        for k, param in new_params.items():
            norm = np.linalg.norm(new_params[k])
            new_params[k] = jax.lax.cond(norm > self.project_threshold[k],
                                          new_params[k],
                                          lambda x : (self.project_threshold[k]/norm) * x,
                                          new_params[k],
                                          lambda x : x)
        return new_params

In [474]:
class AR:
    def __init__(self, input_dim=1, output_dim=1, history_len=32):
        self.params = {'W_lnm': np.zeros((history_len, input_dim, output_dim)), 'b': np.zeros((output_dim, 1))}
        self.predict = jax.jit(lambda params, x: np.tensordot(params['W_lnm'], x, ([0,1],[0,1])) + params['b'])

In [475]:
class Gradient_boosting:
    def __init__(self, 
                 methods,
                 optimizers,
                 eta=1.0):
        
        self.N = len(methods)
        self.methods = methods
        self.optimizers = optimizers
        
        loss_W_grad = jax.jit(jax.grad(lambda W, preds, y: np.square(np.dot(W, preds) - y).sum()))

        def _predict_and_update(carry, xy):
            x, y = xy
            W, all_params = carry
            
            pred, yhats = 0, []
            for i in range(self.N):
#                 if i == 0:
#                     print(all_params[i])
                pred_i = self.methods[i].predict(all_params[i], x).squeeze()
                yhats.append(pred_i)
                all_params[i] = self.optimizers[i].update(all_params[i], x, y - pred)
                pred += W[i] * pred_i
                
            loss_W_grads = loss_W_grad(W, np.asarray(yhats), y)
            nums = W * np.exp(-1 * eta * loss_W_grads)
            W = nums / nums.sum()

            return (W, all_params), pred
        self._predict_and_update = _predict_and_update

    def predict_and_update(self, X, Y):
        # NOTE: we don't update method params after this scan
        _, preds = jax.lax.scan(self._predict_and_update, 
                                (np.ones(self.N) / self.N,
                                 [method.params for method in self.methods]), 
                                (X, Y))
        return preds

In [476]:
for basin in tqdm.tqdm(basins):
    SEQUENCE_LENGTH = 270
    INPUT_DIM = 32

    b_threshold = 1e-4
    eta = 0.008

    W_lr_best_pairs = [
            (0.03, 2e-5),
            (0.05, 2e-5),
            (0.07, 2e-5),
            (0.09, 2e-5),
        ]

    methods, optimizers = [], []

    for W_threshold, lr in W_lr_best_pairs:
        project_threshold = {"W_lnm": W_threshold, "b": b_threshold}
        
        method_ar = AR(
            input_dim=INPUT_DIM,
            output_dim=1,
            history_len=SEQUENCE_LENGTH
        )
        optim_ar = SGD(pred = method_ar.predict, learning_rate=lr, project_threshold=project_threshold)
        optimizers.append(optim_ar)
        methods.append(method_ar)

    method_boosting = Gradient_boosting(
        methods,
        optimizers,
        eta=eta,
    )

    yhats_LSTM = np.array(basin_to_yhats_LSTM[basin])
    X = pickle.load(open("../data/flood/test/{}.pkl".format(basin), "rb"))
    Y = pickle.load(open("../data/flood/qobs/{}.pkl".format(basin), "rb"))
    
    y_true = Y - yhats_LSTM
#     y_pred_ar = method_boosting.predict_and_update(X, y_true)

    W = np.ones(4) / 4
    params = [method.params for method in method_boosting.methods]
    y_pred_ar = []
    for x, y in zip(X, y_true):
        print(W)
        (W, params), y_hat = method_boosting._predict_and_update((W, params), (x, y))
        y_pred_ar.append(y_hat)

    y_pred_ar = np.asarray(y_pred_ar)
    yhats = yhats_LSTM + y_pred_ar.squeeze()
    loss = ((Y - yhats) ** 2).mean()
    
    print(basin, loss)
    break


  0%|          | 0/531 [00:00<?, ?it/s]

[0.25 0.25 0.25 0.25]
[0.25 0.25 0.25 0.25]
[0.25 0.25 0.25 0.25]
[0.24999952 0.24999984 0.25000018 0.25000048]
[0.24999914 0.2499997  0.25000033 0.2500009 ]
[0.24999915 0.2499997  0.2500003  0.25000086]
[0.24999905 0.24999966 0.25000036 0.250001  ]
[0.24999827 0.24999936 0.2500006  0.25000182]
[0.24999785 0.2499992  0.25000072 0.2500022 ]
[0.24999797 0.24999923 0.2500007  0.2500021 ]
[0.24999775 0.24999915 0.25000077 0.25000232]
[0.24999465 0.24999817 0.25000185 0.2500054 ]
[0.24999695 0.24999884 0.25000104 0.25000316]
[0.24999739 0.24999896 0.25000086 0.25000274]
[0.24999851 0.24999928 0.25000048 0.25000173]
[0.2500007  0.24999988 0.2499997  0.24999976]


  0%|          | 0/531 [00:06<?, ?it/s]


KeyboardInterrupt: 

In [458]:
y_pred_ar

DeviceArray([0.        , 0.02236109, 0.006649  , ..., 0.5759724 ,
             0.5599743 , 0.53088856], dtype=float32)

In [456]:
yhats_LSTM.sum()

DeviceArray(7102.1475, dtype=float32)

In [457]:
y_true.sum()

DeviceArray(242.56683, dtype=float32)

In [429]:
print(y_true)

[ 0.11816657 -0.04906863 -0.06850159 ...  0.6123748   0.54948944
  0.45159483]
