In [1]:
import os
import pickle

import jax
import jax.numpy as np

import tqdm

In [2]:
# https://ttic.uchicago.edu/~tewari/lectures/lecture4.pdf

# from tigerforecast.batch.camels_dataloader import CamelsTXT
# for basin in tqdm.tqdm(basins):
#     usgs_val = CamelsTXT(basin=basin, concat_static=True)
#     for data, targets in usgs_val.sequential_batches(batch_size=5000):
#         pickle.dump(data, open("../data/flood/test/{}.pkl".format(basin), "wb"))
#         pickle.dump(targets, open("../data/flood/qobs/{}.pkl".format(basin), "wb"))

In [3]:
basins = pickle.load(open("../data/flood/meta.pkl", "rb"))["basins"]

In [4]:
from tigerforecast.utils.download_tools import get_tigerforecast_dir
basin_to_yhats_LSTM = pickle.load(open(os.path.join(get_tigerforecast_dir(), "flood_prediction", "basin_to_yhats_LSTM"), "rb"))



In [14]:
class SGD:
    def __init__(self,
                 pred=None,
                 loss_fn=lambda pred, true: np.square(pred - true).mean(),
                 learning_rate=0.0001,
                 project_threshold={}):
        self.loss_fn = loss_fn
        self.learning_rate = learning_rate
        self.project_threshold = project_threshold
        self.grad = jax.jit(jax.grad(lambda params, x, y: self.loss_fn(pred(params, x), y)))

    def update(self, params, x, y):
        grad = self.grad(params, x, y)
        new_params = {k:w - self.learning_rate * grad[k] for (k, w) in params.items()}
        
        for k, param in new_params.items():
            norm = np.linalg.norm(new_params[k])
            new_params[k] = jax.lax.cond(norm > self.project_threshold[k],
                                          new_params[k],
                                          lambda x : (self.project_threshold[k]/norm) * x,
                                          new_params[k],
                                          lambda x : x)
        return new_params

In [73]:
class MW:
    def __init__(self, eta=0.008):
        self.eta = eta
        self.grad = jax.jit(jax.grad(lambda W, preds, y: np.square(np.dot(W, preds) - y).sum()))
        
    def update(self, params, x, y):
        grad = self.grad(params, x, y)
        new_params = params * np.exp(-1 * self.eta * grad)
        return new_params / new_params.sum()
        

In [15]:
class AR:
    def __init__(self, input_dim=1, output_dim=1, history_len=32):
        self.params = {'W_lnm': np.zeros((history_len, input_dim, output_dim)), 'b': np.zeros((output_dim, 1))}
        
    def __call__(self, params, x):
        return np.tensordot(params['W_lnm'], x, ([0,1],[0,1])) + params['b']

In [85]:
from functools import partial

class Gradient_boosting:
    def __init__(self, methods):
        self.N = len(methods)
        self.methods = methods
        self.params = {"W": np.ones(self.N) / self.N, "children": [method.params for method in self.methods]}
        self.loss_W_grad = jax.jit(jax.grad(lambda W, preds, y: np.square(np.dot(W, preds) - y).sum()))
        
    def __call__(self, params, x):
        pred, preds = 0, []
        for i in range(self.N):
            pred_i = self.methods[i](params["children"][i], x).squeeze()
            preds.append(pred_i)
            pred += params["W"][i] * pred_i

        return preds

In [87]:
for basin in tqdm.tqdm(basins):
    SEQUENCE_LENGTH = 270
    INPUT_DIM = 32

    b_threshold = 1e-4
    eta = 0.008

    W_lr_best_pairs = [
            (0.03, 2e-5),
            (0.05, 2e-5),
            (0.07, 2e-5),
            (0.09, 2e-5),
        ]

    methods, optimizers = [], []

    for W_threshold, lr in W_lr_best_pairs:
        project_threshold = {"W_lnm": W_threshold, "b": b_threshold}
        
        method_ar = AR(
            input_dim=INPUT_DIM,
            output_dim=1,
            history_len=SEQUENCE_LENGTH
        )
        optim_ar = SGD(pred = method_ar, learning_rate=lr, project_threshold=project_threshold)
        optimizers.append(optim_ar)
        methods.append(method_ar)

    method_boosting = Gradient_boosting(methods)

    yhats_LSTM = np.array(basin_to_yhats_LSTM[basin])
    X = pickle.load(open("../data/flood/test/{}.pkl".format(basin), "rb"))
    Y = pickle.load(open("../data/flood/qobs/{}.pkl".format(basin), "rb"))
    
    def loop(params, xy):
        x, y = xy
        preds = method_boosting(params, x)
        
        pred = 0
        for i in range(method_boosting.N):
            params["children"][i] = optimizers[i].update(params["children"][i], x, y - pred)
            pred += params["W"][i] * preds[i]
            
        params["W"] = MW().update(params["W"], np.asarray(preds), y)
        
        return params, pred
        
    y_true = Y - yhats_LSTM
    method_boosting.params, y_pred_ar = jax.lax.scan(loop, method_boosting.params, (X, y_true))
    
# TODO
# - Figure out hierarchy
# - Register nodes
# - Pass in directly to jax
# - Figure out how to automatically do hierarchy/figure out nodes
# - Optimizers should apply to all children unless children have specified version

#     W = np.ones(4) / 4
#     params = [method.params for method in method_boosting.methods]
#     y_pred_ar = []
#     for x, y in zip(X, y_true):
#         print(W)
#         (W, params), y_hat = method_boosting((W, params), (x, y))
#         y_pred_ar.append(y_hat)

    y_pred_ar = np.asarray(y_pred_ar)
    yhats = yhats_LSTM + y_pred_ar.squeeze()
    loss = ((Y - yhats) ** 2).mean()
    
    print(basin, loss)
    break


  0%|          | 0/531 [00:00<?, ?it/s]

01022500 

  0%|          | 0/531 [00:02<?, ?it/s]

0.484344





In [None]:
y_pred_ar

In [None]:
yhats_LSTM.sum()

In [None]:
y_true.sum()

In [None]:
print(y_true)