In [178]:
import os
import pickle

import jax
import jax.numpy as np

import tqdm

from tigerforecast.utils.optimizers.losses import batched_mse, mse

In [2]:
# https://ttic.uchicago.edu/~tewari/lectures/lecture4.pdf

In [3]:
basins = pickle.load(open("../data/flood/meta.pkl", "rb"))["basins"]

In [4]:
from tigerforecast.utils.download_tools import get_tigerforecast_dir
basin_to_yhats_LSTM = pickle.load(open(os.path.join(get_tigerforecast_dir(), "flood_prediction", "basin_to_yhats_LSTM"), "rb"))

In [5]:
# from tigerforecast.batch.camels_dataloader import CamelsTXT
# for basin in tqdm.tqdm(basins):
#     usgs_val = CamelsTXT(basin=basin, concat_static=True)
#     for data, targets in usgs_val.sequential_batches(batch_size=5000):
#         pickle.dump(data, open("../data/flood/test/{}.pkl".format(basin), "wb"))
#         pickle.dump(targets, open("../data/flood/qobs/{}.pkl".format(basin), "wb"))

In [256]:
def proxy_loss(y_pred, yprev_ytrue):
    y_prev, y_true = yprev_ytrue
    return batched_mse(y_prev + y_pred, y_true)

class SGD:
    def __init__(self, pred=None, learning_rate=0.0001, project_threshold={}):
        self.lr = learning_rate
        self.project_threshold = project_threshold
        self._grad = jax.jit(jax.grad(lambda params, x, y: proxy_loss(pred(params, x), y)))

    def update_scan(self, params, x, y):
        grad = self._grad(params, x, y)
        new_params = {k:w - self.lr * grad[k] for (k, w) in params.items()}
        
        for k, param in new_params.items():
            new_params[k] = jax.lax.cond(np.linalg.norm(new_params[k]) > self.project_threshold[k],
                                          None,
                                          lambda x : (self.project_threshold[k]/np.linalg.norm(new_params[k])) * new_params[k],
                                          new_params[k],
                                          lambda x : x)
        return new_params

In [257]:
class AR:
    def __init__(self, input_dim=1, output_dim=1, history_len=32):
        # initialize parameters
        self.params = {'W_lnm': np.zeros((history_len, input_dim, output_dim)), 'b': np.zeros((output_dim, 1))}
        self.predict = jax.jit(lambda params, x: np.tensordot(params['W_lnm'], x, ([0,1],[0,1])) + params['b'])

In [266]:
class Gradient_boosting:
    def __init__(self, 
                 methods,
                 optimizers,
                 X_shape=(),
                 Y_shape=(),
                 eta=1.0):
        """
        Description: Initializes autoregressive method parameters
        Args:
            methods: list of instances of methods
            optimizers: list of optimizers
            X_shape: shape of input (exclude batch dim)
            Y_shape: shape of output (exclude batch dim)
            loss(function): loss function for boosting method
            T: time horiizon
        """
        self.Y_shape = Y_shape
        self.N = len(methods)
        self.eta = eta
        self.methods = methods
        self.optimizers = optimizers
        self.all_params = [None for i in range(self.N)]  # params of each method
        self.Z = np.array([1.0/self.N for i in range(self.N)])
        self.W = [1.0/self.N for i in range(self.N)]

        for i in range(self.N):
            self.all_params[i] = self.methods[i].params

        def loss_W(W, preds, y):
            pred = 0
            for j in range(self.N):
                pred += W[j] * preds[j]
            return batched_mse(pred, y)

        def _predict_and_update(carry, xy):
            x, y = xy[0], xy[1]
            Z, W, all_params = carry
            
            ys = [np.zeros(self.Y_shape)]
            method_yhats = []
            for i in range(self.N):
                pred_i = np.reshape(self.methods[i].predict(all_params[i], x), self.Y_shape)
                method_yhats.append(pred_i)
                y_i = ys[-1] + W[i] * pred_i
                y_i = np.reshape(y_i, self.Y_shape)
                ys.append(y_i)
            pred = ys[-1]

            nums, den = None, None
            
            num = lambda j : W[j] * np.exp(-1 * self.eta * jax.grad(loss_W)(W, method_yhats, y)[j])
            nums = [num(j) for j in range(self.N)]
            den = np.sum(np.array(nums))
            W = [nums[i]/den for i in range(self.N)]

            for i in range(self.N):
                # update method's self.x with proper shape (often needed for update step)
                all_params[i] = self.optimizers[i].update_scan(all_params[i], x, (ys[i], y))
            return (Z, W, all_params), pred
        self._predict_and_update = _predict_and_update

    def predict_and_update(self, X, Y):
        _, preds = jax.lax.scan(self._predict_and_update, 
                                (self.Z, self.W, self.all_params), 
                                (X, Y)
                               )
        return np.reshape(preds, Y.shape)

In [267]:
for basin in tqdm.tqdm(basins):
    SEQUENCE_LENGTH = 270
    INPUT_DIM = 32

    b_threshold = 1e-4
    eta = 0.008

    W_lr_best_pairs = [
            (0.03, 2e-5),
            (0.05, 2e-5),
            (0.07, 2e-5),
            (0.09, 2e-5),
        ]

    methods, optimizers = [], []

    for W_threshold, lr in W_lr_best_pairs:
        project_threshold = {"W_lnm": W_threshold, "b": b_threshold}
        
        method_ar = AR(
            input_dim=INPUT_DIM,
            output_dim=1,
            history_len=SEQUENCE_LENGTH
        )
        optim_ar = SGD(pred = method_ar.predict, learning_rate=lr, project_threshold=project_threshold)
        optimizers.append(optim_ar)
        methods.append(method_ar)

    method_boosting = Gradient_boosting(
        methods,
        optimizers,
        X_shape=(270, 32),
        Y_shape=(),
        eta=eta,
    )

    yhats_LSTM = np.array(basin_to_yhats_LSTM[basin])
    X = pickle.load(open("../data/flood/test/{}.pkl".format(basin), "rb"))
    Y = pickle.load(open("../data/flood/qobs/{}.pkl".format(basin), "rb"))
    
    y_true = Y - yhats_LSTM
    y_pred_ar = method_boosting.predict_and_update(X, y_true)
    yhats = yhats_LSTM + y_pred_ar.squeeze()

    loss = ((Y - yhats) ** 2).mean()
    
    print(basin, loss)


  0%|          | 0/531 [00:00<?, ?it/s]

01022500 

  0%|          | 1/531 [00:02<20:30,  2.32s/it]

0.48434398
01031500 

  0%|          | 2/531 [00:04<20:30,  2.33s/it]

0.93345135


  0%|          | 2/531 [00:06<28:13,  3.20s/it]


KeyboardInterrupt: 

In [63]:
class A:
    def __init__(self):
        @jax.jit
        def a():
            print(locals())
            return 5
        self._a = a

In [64]:
a = A()

In [65]:
a._a()

{}


5