In [414]:
import os
import pickle

import jax
import jax.numpy as jnp
import numpy as np

import tqdm

In [415]:
# https://ttic.uchicago.edu/~tewari/lectures/lecture4.pdf

# from tigerforecast.batch.camels_dataloader import CamelsTXT
# for basin in tqdm.tqdm(basins):
#     usgs_val = CamelsTXT(basin=basin, concat_static=True)
#     for data, targets in usgs_val.sequential_batches(batch_size=5000):
#         pickle.dump(data, open("../data/flood/test/{}.pkl".format(basin), "wb"))
#         pickle.dump(targets, open("../data/flood/qobs/{}.pkl".format(basin), "wb"))

In [416]:
basins = pickle.load(open("../data/flood/meta.pkl", "rb"))["basins"]

In [417]:
from tigerforecast.utils.download_tools import get_tigerforecast_dir
basin_to_yhats_LSTM = pickle.load(open(os.path.join(get_tigerforecast_dir(), "flood_prediction", "basin_to_yhats_LSTM"), "rb"))

In [426]:
os.path.join(get_tigerforecast_dir(), "flood_prediction", "basin_to_yhats_LSTM")

'/home/dsuo/src/TigerForecast/tigerforecast/flood_prediction/basin_to_yhats_LSTM'

In [418]:
# TODO
# - Optimizers should apply to all children unless children have specified version
# - hierarchical parameters?
# - Tree flatten is very crude (only applies to params)
# - How to identify params (right now just ndarray)
# - Pass class directly to jax
# - How to handle buffers vs parameters
# - Users can do bad things with naming
import inspect

def tree_flatten(module):
    leaves, aux = jax.tree_util.tree_flatten(module.get_param_tree())
    aux = {
        "treedef": aux,
        "arguments": module.arguments,
        "attrs": module.attrs,
        "class": module.__class__,
    }
    return leaves, aux

def tree_unflatten(aux, leaves):
    module = aux["class"](*aux["arguments"].args, **aux["arguments"].kwargs)
    module.set_param_tree(jax.tree_util.tree_unflatten(aux["treedef"], leaves))
    for attr in aux["attrs"]:
        if attr in module.__dict__["params"]:
            module.__dict__[attr] = module.__dict__["params"][attr]
    return module

In [419]:
class Module:
    def __new__(cls, *args, **kwargs):
        obj = object.__new__(cls)
        obj.__setattr__("attrs", set())
        obj.__setattr__("modules", {})
        obj.__setattr__("params", {})
        obj.__setattr__("arguments", inspect.signature(obj.__init__).bind(*args))
        obj.arguments.apply_defaults()

        return obj
    
    @classmethod
    def __init_subclass__(cls, *args, **kwargs):
        super().__init_subclass__(*args, **kwargs)
        jax.tree_util.register_pytree_node(cls, tree_flatten, tree_unflatten)
    
    def __setattr__(self, name, value):
        self.__dict__[name] = value
        self.attrs.add(name)

        if isinstance(value, Module):
            self.__dict__["modules"][name] = value
        elif isinstance(value, jnp.ndarray):
            self.__dict__["params"][name] = value
    
    def get_param_tree(self):
        params = self.params
        for name, module in self.modules.items():
            params[name] = module.get_param_tree()    
        return params
    
    def set_param_tree(self, tree):
        for param in self.params:
            self.params[param] = tree[param]
            self.__dict__[param] = tree[param]
        for name, module in self.modules.items():
            module.set_param_tree(tree[name])
            
    def add_module(self, module, name=None):
        counter = 0
        while name is None or name in self.__dict__["modules"]:
            name = "{}_{}".format(type(module).__name__, counter)
            counter += 1      
        self.__dict__["modules"][name] = module
        
    def add_param(self, param, name):
        counter = 0
        while name is None or name in self.__dict__["params"]:
            name = "{}_{}".format(name, counter)
            counter += 1 
        self.__dict__["params"][name] = param

In [420]:
class SGD:
    def __init__(self,
                 loss_fn=lambda pred, true: jnp.square(pred - true).mean(),
                 learning_rate=0.0001,
                 project_threshold={}):
        self.loss_fn = loss_fn
        self.learning_rate = learning_rate
        self.project_threshold = project_threshold
        
    def update(self, module, params, x, y):
        grad = jax.jit(jax.grad(lambda module, x, y: self.loss_fn(module(x), y)))(module, x, y)
        new_params = {k:w - self.learning_rate * grad.params[k] for (k, w) in params.items()}
        
        for k, param in new_params.items():
            norm = jnp.linalg.norm(new_params[k])
            new_params[k] = jax.lax.cond(norm > self.project_threshold[k],
                                          new_params[k],
                                          lambda x : (self.project_threshold[k]/norm) * x,
                                          new_params[k],
                                          lambda x : x)
        return new_params

In [421]:
class MW:
    def __init__(self, eta=0.008):
        self.eta = eta
        self.grad = jax.jit(jax.grad(lambda W, preds, y: jnp.square(jnp.dot(W, preds) - y).sum()))
        
    def update(self, params, x, y):
        grad = self.grad(params, x, y)
        new_params = params * jnp.exp(-1 * self.eta * grad)
        return new_params / new_params.sum()

In [422]:
class AR(Module):
    def __init__(self, input_dim=32, output_dim=1, history_len=270):
        self.kernel = jnp.zeros((history_len, input_dim, output_dim))
        self.bias = jnp.zeros((output_dim, 1))
        
    def __call__(self, x):
        return jnp.tensordot(self.kernel, x, ([0,1],[0,1])) + self.bias

In [423]:
class GradientBoosting(Module):
    def __init__(self, N, input_dim=32, output_dim=1, history_len=270):
        for i in range(N):
            self.add_module(AR(input_dim=input_dim, output_dim=output_dim, history_len=history_len))
            
        self.W = jnp.ones(N) / N
        
    def __call__(self, x):
        pred, preds = 0, []
        for i, (name, submodule) in enumerate(self.modules.items()):
            pred_i = submodule(x).squeeze()
            preds.append(pred_i)
            pred += self.W[i] * pred_i

        return preds

In [425]:
for basin in tqdm.tqdm(basins):
    bias_threshold = 1e-4
    eta = 0.008

    SGDs = [SGD(
        learning_rate=lr,
        project_threshold={"kernel": kernel_threshold, "bias": bias_threshold})
                  for kernel_threshold, lr in [
            (0.03, 2e-5),
            (0.05, 2e-5),
            (0.07, 2e-5),
            (0.09, 2e-5),
        ]]
    
    N = len(SGDs)

    module = GradientBoosting(N)

    Y_LSTM = jnp.array(basin_to_yhats_LSTM[basin])
    X = pickle.load(open("../data/flood/test/{}.pkl".format(basin), "rb"))
    Y = pickle.load(open("../data/flood/qobs/{}.pkl".format(basin), "rb"))
    
    def loop(module, xy):
        x, y = xy

        preds = jnp.asarray(module(x))
        pred = 0
        
        for i, (name, submodule) in enumerate(module.modules.items()):
            submodule.params = SGDs[i].update(submodule, submodule.params, x, y - pred)
            pred += module.W[i] * preds[i]
        
        module.W = MW().update(module.W, preds, y)
        
        return module, pred
        
    Y_RESID = Y - Y_LSTM
    module, Y_BOOST = jax.lax.scan(loop, module, (X, Y_RESID))
    
#     for x, y in zip(X, Y_RESID):
#         module, y_hat = loop(module, (x, y))
    
    Y_BOOST = jnp.asarray(Y_BOOST).squeeze()
    loss = ((Y - (Y_LSTM + Y_BOOST)) ** 2).mean()
    
    print(basin, loss)

  0%|          | 0/531 [00:00<?, ?it/s]

01022500 

  0%|          | 1/531 [00:02<19:10,  2.17s/it]

0.484344
01031500 

  0%|          | 1/531 [00:04<38:40,  4.38s/it]


KeyboardInterrupt: 