In [182]:
import os
import pickle

import jax
import jax.numpy as jnp

import tqdm

In [183]:
# https://ttic.uchicago.edu/~tewari/lectures/lecture4.pdf

# from tigerforecast.batch.camels_dataloader import CamelsTXT
# for basin in tqdm.tqdm(basins):
#     usgs_val = CamelsTXT(basin=basin, concat_static=True)
#     for data, targets in usgs_val.sequential_batches(batch_size=5000):
#         pickle.dump(data, open("../data/flood/test/{}.pkl".format(basin), "wb"))
#         pickle.dump(targets, open("../data/flood/qobs/{}.pkl".format(basin), "wb"))

In [184]:
basins = pickle.load(open("../data/flood/meta.pkl", "rb"))["basins"]

In [185]:
from tigerforecast.utils.download_tools import get_tigerforecast_dir
basin_to_yhats_LSTM = pickle.load(open(os.path.join(get_tigerforecast_dir(), "flood_prediction", "basin_to_yhats_LSTM"), "rb"))

In [206]:
# TODO:
# - Apply to hierarchy

class SGD:
    def __init__(self,
                 loss_fn=lambda pred, true: jnp.square(pred - true).mean(),
                 learning_rate=0.0001,
                 project_threshold={}):
        self.loss_fn = loss_fn
        self.learning_rate = learning_rate
        self.project_threshold = project_threshold
        
    def update(self, module, params, x, y):
        grad = jax.jit(jax.grad(lambda params, x, y: self.loss_fn(module(params, x), y)))(params, x, y)
        new_params = {k:w - self.learning_rate * grad[k] for (k, w) in params.items()}
        
        for k, param in new_params.items():
            norm = jnp.linalg.norm(new_params[k])
            new_params[k] = jax.lax.cond(norm > self.project_threshold[k],
                                          new_params[k],
                                          lambda x : (self.project_threshold[k]/norm) * x,
                                          new_params[k],
                                          lambda x : x)
        return new_params

In [199]:
class MW:
    def __init__(self, eta=0.008):
        self.eta = eta
        self.grad = jax.jit(jax.grad(lambda W, preds, y: jnp.square(jnp.dot(W, preds) - y).sum()))
        
    def update(self, params, x, y):
        grad = self.grad(params, x, y)
        new_params = params * jnp.exp(-1 * self.eta * grad)
        return new_params / new_params.sum()

In [193]:
# TODO
# - Optimizers should apply to all children unless children have specified version
# - Tree flatten is very crude (only applies to params)
# - Pass class directly to jax
# - How to handle buffers vs parameters
# - Users can do bad things with naming

@jax.tree_util.register_pytree_node_class
class Module:
    # So users don't need to call super
    def __new__(cls, *args, **kwargs):
        obj = object.__new__(cls)
        object.__setattr__(obj, "modules", {})
        object.__setattr__(obj, "params", {})        
        return obj
    
    def __setattr__(self, name, value):
        if name in self.__dict__:
            del self.__dict__[name]
            
        if isinstance(value, Module):
            self.add_module(value, name)
        elif isinstance(value, jnp.ndarray):
            self.add_param(value, name)
        else:
            object.__setattr__(self, name, value)
        
    def add_module(self, module, name=None):
        counter = 0
        while name is None or name in self.__dict__["modules"]:
            name = "{}_{}".format(type(module).__name__, counter)
            counter += 1
        self.__dict__["modules"][name] = module
        
        self.add_param(module.params, name)
        
    def add_param(self, param, name):
        counter = 0
        while name is None or name in self.__dict__["params"]:
            name = "{}_{}".format(name, counter)
            counter += 1
        self.__dict__["params"][name] = param
        
    def tree_flatten(self):
        return jax.tree_util.tree_flatten(self.params)
    
    @classmethod
    def tree_unflatten(cls, treedef, leaves):
        obj = object.__new__(cls)
        obj.params = jax.tree_util.tree_unflatten(treedef, leaves)
        return obj

In [194]:
class AR(Module):
    def __init__(self, input_dim=1, output_dim=1, history_len=32):
        self.kernel = jnp.zeros((history_len, input_dim, output_dim))
        self.bias = jnp.zeros((output_dim, 1))
        
    def __call__(self, params, x):
        return jnp.tensordot(params["kernel"], x, ([0,1],[0,1])) + params["bias"]

In [195]:
class GradientBoosting(Module):
    def __init__(self, N, input_dim=32, output_dim=1, history_len=270):
        for i in range(N):
            self.add_module(AR(input_dim=input_dim, output_dim=output_dim, history_len=history_len))
            
        self.W = jnp.ones(N) / N
        
    def __call__(self, params, x):
        pred, preds = 0, []
        for i, (name, module) in enumerate(self.modules.items()):
            pred_i = module(params[name], x).squeeze()
            preds.append(pred_i)
            pred += params["W"][i] * pred_i

        return preds

In [208]:
for basin in tqdm.tqdm(basins):
    bias_threshold = 1e-4
    eta = 0.008

    optimizers = [SGD(
        learning_rate=lr,
        project_threshold={"kernel": kernel_threshold, "bias": bias_threshold})
                  for kernel_threshold, lr in [
            (0.03, 2e-5),
            (0.05, 2e-5),
            (0.07, 2e-5),
            (0.09, 2e-5),
        ]]
    
    N = len(optimizers)

    method_boosting = GradientBoosting(N)

    yhats_LSTM = jnp.array(basin_to_yhats_LSTM[basin])
    X = pickle.load(open("../data/flood/test/{}.pkl".format(basin), "rb"))
    Y = pickle.load(open("../data/flood/qobs/{}.pkl".format(basin), "rb"))
    
    def loop(params, xy):
        x, y = xy
        preds = jnp.asarray(method_boosting(params, x))
        
        pred = 0
        for i, (name, module) in enumerate(method_boosting.modules.items()):
            params[name] = optimizers[i].update(module, params[name], x, y - pred)
            pred += params["W"][i] * preds[i]
        
        print(pred, params[name]["kernel"].sum(), params["W"], preds, y, jnp.dot(params["W"], preds) - y)
        params["W"] = MW().update(params["W"], preds, y)
        print(params["W"])
        
        return params, pred
        
    y_true = Y - yhats_LSTM
    method_boosting.params, y_pred_ar = jax.lax.scan(loop, method_boosting.params, (X, y_true))
#     for x, y in zip(X, y_true):
#         method_boosting.params, pred = loop(method_boosting.params, (x, y))
    
    y_pred_ar = jnp.asarray(y_pred_ar)
    yhats = yhats_LSTM + y_pred_ar.squeeze()
    loss = ((Y - yhats) ** 2).mean()
    
    print(basin, loss)
#     break


  0%|          | 0/531 [00:00<?, ?it/s]

Traced<ShapedArray(float32[]):JaxprTrace(level=0/0)> Traced<ShapedArray(float32[]):JaxprTrace(level=0/0)> Traced<ShapedArray(float32[4]):JaxprTrace(level=0/0)> Traced<ShapedArray(float32[4]):JaxprTrace(level=0/0)> Traced<ShapedArray(float32[]):JaxprTrace(level=0/0)> Traced<ShapedArray(float32[]):JaxprTrace(level=0/0)>
Traced<ShapedArray(float32[4]):JaxprTrace(level=0/0)>
01022500 

  0%|          | 1/531 [00:02<18:49,  2.13s/it]

0.484344
Traced<ShapedArray(float32[]):JaxprTrace(level=0/0)> Traced<ShapedArray(float32[]):JaxprTrace(level=0/0)> Traced<ShapedArray(float32[4]):JaxprTrace(level=0/0)> Traced<ShapedArray(float32[4]):JaxprTrace(level=0/0)> Traced<ShapedArray(float32[]):JaxprTrace(level=0/0)> Traced<ShapedArray(float32[]):JaxprTrace(level=0/0)>
Traced<ShapedArray(float32[4]):JaxprTrace(level=0/0)>


  0%|          | 1/531 [00:04<35:38,  4.04s/it]


KeyboardInterrupt: 

In [None]:
y_pred_ar

In [None]:
yhats_LSTM.sum()

In [None]:
y_true.sum()

In [None]:
print(y_true)