In [2]:
import os
import pickle

import jax
import jax.numpy as jnp
import numpy as np

import tqdm

In [3]:
# https://ttic.uchicago.edu/~tewari/lectures/lecture4.pdf

# from tigerforecast.batch.camels_dataloader import CamelsTXT
# for basin in tqdm.tqdm(basins):
#     usgs_val = CamelsTXT(basin=basin, concat_static=True)
#     for data, targets in usgs_val.sequential_batches(batch_size=5000):
#         pickle.dump(data, open("../data/flood/test/{}.pkl".format(basin), "wb"))
#         pickle.dump(targets, open("../data/flood/qobs/{}.pkl".format(basin), "wb"))

In [4]:
basins = pickle.load(open("../data/flood/meta.pkl", "rb"))["basins"]

In [5]:
from tigerforecast.utils.download_tools import get_tigerforecast_dir
basin_to_yhats_LSTM = pickle.load(open(os.path.join(get_tigerforecast_dir(), "flood_prediction", "basin_to_yhats_LSTM"), "rb"))



In [6]:
# TODO:
# - Apply to hierarchy

class SGD:
    def __init__(self,
                 loss_fn=lambda pred, true: jnp.square(pred - true).mean(),
                 learning_rate=0.0001,
                 project_threshold={}):
        self.loss_fn = loss_fn
        self.learning_rate = learning_rate
        self.project_threshold = project_threshold
        
    def update(self, module, params, x, y):
        grad = jax.jit(jax.grad(lambda module, x, y: self.loss_fn(module(x), y)))(module, x, y)
        new_params = {k:w - self.learning_rate * grad.params[k] - 2 for (k, w) in params.items()}
        
        for k, param in new_params.items():
            norm = jnp.linalg.norm(new_params[k])
            new_params[k] = jax.lax.cond(norm > self.project_threshold[k],
                                          new_params[k],
                                          lambda x : (self.project_threshold[k]/norm) * x,
                                          new_params[k],
                                          lambda x : x)
        return new_params

In [64]:
class MW:
    def __init__(self, eta=0.008):
        self.eta = eta
        self.grad = jax.jit(jax.grad(lambda W, preds, y: jnp.square(jnp.dot(W, preds) - y).sum()))
        
    def update(self, params, x, y):
        grad = self.grad(params, x, y)
        new_params = params * jnp.exp(-1 * self.eta * grad)
        return new_params / new_params.sum()

In [71]:
# TODO
# - Optimizers should apply to all children unless children have specified version
# - hierarchical parameters?
# - Tree flatten is very crude (only applies to params)
# - Pass class directly to jax
# - How to handle buffers vs parameters
# - Users can do bad things with naming
# - Should we keep copies in __dict__ and params/modules?
import inspect

def pytree(cls):
    def tree_flatten(module):
        leaves, aux = jax.tree_util.tree_flatten(module.params)
        aux = {
            "treedef": aux,
            "arguments": module.arguments,
            "attrs": module.attrs
        }
        return leaves, aux
    
    def tree_unflatten(aux, leaves):
        module = cls(*aux["arguments"].args, **aux["arguments"].kwargs)
        module.params = jax.tree_util.tree_unflatten(aux["treedef"], leaves)
        
        for attr in aux["attrs"]:
            if attr in module.params:
                module.__dict__[attr] = module.params[attr]
            
        return module
    
    jax.tree_util.register_pytree_node(cls, tree_flatten, tree_unflatten)
    
    return cls

class Module:
    def __new__(cls, *args, **kwargs):
        obj = object.__new__(cls)
        obj.__setattr__("attrs", set())
        obj.__setattr__("modules", {})
        obj.__setattr__("params", {})
        obj.__setattr__("arguments", inspect.signature(obj.__init__).bind(*args))
        obj.arguments.apply_defaults()

        return obj
    
    def __setattr__(self, name, value):
        self.__dict__[name] = value
        self.attrs.add(name)
        
        if isinstance(value, Module):
            self.add_module(value, name)
        elif isinstance(value, jnp.ndarray):
            self.add_param(value, name)
        
    def add_module(self, module, name=None):
        counter = 0
        while name is None or name in self.__dict__["modules"]:
            name = "{}_{}".format(type(module).__name__, counter)
            counter += 1
            
        self.__dict__["modules"][name] = module
        
    def add_param(self, param, name):
        counter = 0
        while name is None or name in self.__dict__["params"]:
            name = "{}_{}".format(name, counter)
            counter += 1 
        self.__dict__["params"][name] = param

In [72]:
@pytree
class AR(Module):
    def __init__(self, input_dim=32, output_dim=1, history_len=270):
        self.kernel = jnp.zeros((history_len, input_dim, output_dim))
        self.bias = jnp.zeros((output_dim, 1))
        
    def __call__(self, x):
#         return jnp.tensordot(self.kernel, x, ([0,1],[0,1])) + self.bias
        return jnp.tensordot(self.params["kernel"], x, ([0,1],[0,1])) + self.params["bias"]

In [73]:
@pytree
class GradientBoosting(Module):
    def __init__(self, N, input_dim=32, output_dim=1, history_len=270):
        for i in range(N):
            self.add_module(AR(input_dim=input_dim, output_dim=output_dim, history_len=history_len))
            
        self.W = jnp.ones(N) / N
        
    def __call__(self, x):
        pred, preds = 0, []
        for i, module in enumerate(self.modules.values()):
            pred_i = module(x).squeeze()
            preds.append(pred_i)
            pred += self.params["W"][i] * pred_i

        return preds

In [74]:
for basin in tqdm.tqdm(basins):
    bias_threshold = 1e-4
    eta = 0.008

    optimizers = [SGD(
        learning_rate=lr,
        project_threshold={"kernel": kernel_threshold, "bias": bias_threshold})
                  for kernel_threshold, lr in [
            (0.03, 2e-5),
            (0.05, 2e-5),
            (0.07, 2e-5),
            (0.09, 2e-5),
        ]]
    
    N = len(optimizers)

    module = GradientBoosting(N)

    yhats_LSTM = jnp.array(basin_to_yhats_LSTM[basin])
    X = pickle.load(open("../data/flood/test/{}.pkl".format(basin), "rb"))
    Y = pickle.load(open("../data/flood/qobs/{}.pkl".format(basin), "rb"))
    
    def loop(module, xy):
        x, y = xy
#         print(module.modules["AR_0"].params["kernel"].sum())
        
        preds = jnp.asarray(module(x))
        pred = 0
        
        for i, name in enumerate(module.modules):
            module.modules[name].params = optimizers[i].update(module.modules[name], module.modules[name].params, x, y - pred)
            pred += module.params["W"][i] * preds[i]
            
        print(module.params["W"], preds, y, jnp.dot(module.params["W"], preds) - y)
        module.params["W"] = MW().update(module.params["W"], preds, y)
        
        return module, pred
        
    y_true = Y - yhats_LSTM
#     module, y_pred_ar = jax.lax.scan(loop, module, (X, y_true))
    for x, y in zip(X, y_true):
        module, pred = loop(module, (x, y))
    
    y_pred_ar = jnp.asarray(y_pred_ar)
    yhats = yhats_LSTM + y_pred_ar.squeeze()
    loss = ((Y - yhats) ** 2).mean()
    
    print(basin, loss)
#     break


  0%|          | 0/531 [00:00<?, ?it/s]

[0.25 0.25 0.25 0.25] [0. 0. 0. 0.] 0.118166566 -0.118166566
[0.25 0.25 0.25 0.25] [0.13430952 0.22391267 0.3135231  0.40314493] -0.04906863 0.3177912
[0.2501709  0.25005695 0.24994303 0.24982914] [0.13174988 0.2196353  0.30750763 0.39540604] -0.06850159 0.33202624
[0.25034603 0.25011522 0.24988459 0.24965413] [0.12980895 0.21640135 0.30297926 0.38958448] -0.06486517 0.3244588
[0.25051472 0.25017127 0.24982826 0.24948569] [0.12886204 0.21482384 0.30076975 0.38674304] -0.010355771 0.268008
[0.2506531  0.25021723 0.24978209 0.24934763] [0.12838292 0.21402366 0.29965568 0.38530204] 0.05162567 0.20502907
[0.2507586  0.25025222 0.24974684 0.24924241] [0.1271712  0.21200448 0.2968281  0.38166642] 0.1418643 0.1123389
[0.25081584 0.25027117 0.24972768 0.2491853 ] [0.1259012  0.20989317 0.29386944 0.37787443] 0.110736966 0.14091936


  0%|          | 0/531 [00:05<?, ?it/s]


KeyboardInterrupt: 

In [None]:
y_pred_ar

In [None]:
yhats_LSTM.sum()

In [None]:
y_true.sum()

In [None]:
print(y_true)