In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pickle

import jax
import jax.numpy as jnp
import timecast as tc
import tqdm

# Download all data

In [None]:
!mkdir -p data

In [None]:
!gsutil -m cp -r gs://skgaip/data/flood data

In [None]:
basins = pickle.load(open("data/flood/meta.pkl", "rb"))["basins"]
basin_to_yhats_LSTM = pickle.load(open("data/flood/tigerforecast/lstm.pkl", "rb"))

# Define optimizers

In [None]:
class SGD:
    def __init__(self,
                 loss_fn=lambda pred, true: jnp.square(pred - true).mean(),
                 learning_rate=0.0001,
                 project_threshold={}):
        self.loss_fn = loss_fn
        self.learning_rate = learning_rate
        self.project_threshold = project_threshold
        
    def update(self, module, params, x, y):
        grad = jax.jit(jax.grad(lambda module, x, y: self.loss_fn(module(x), y)))(module, x, y)
        new_params = {k:w - self.learning_rate * grad.params[k] for (k, w) in params.items()}
        
        for k, param in new_params.items():
            norm = jnp.linalg.norm(new_params[k])
            new_params[k] = jax.lax.cond(norm > self.project_threshold[k],
                                          new_params[k],
                                          lambda x : (self.project_threshold[k]/norm) * x,
                                          new_params[k],
                                          lambda x : x)
        return new_params

In [None]:
class MultiplicativeWeights:
    def __init__(self, eta=0.008):
        self.eta = eta
        self.grad = jax.jit(jax.grad(lambda W, preds, y: jnp.square(jnp.dot(W, preds) - y).sum()))
        
    def update(self, module, params, x, y):
        grad = self.grad(params, x, y)
        new_params = params * jnp.exp(-1 * self.eta * grad)
        return new_params / new_params.sum()

# Define modules

In [None]:
class AR(tc.Module):
    def __init__(self, input_dim=32, output_dim=1, history_len=270):
        self.kernel = jnp.zeros((history_len, input_dim, output_dim))
        self.bias = jnp.zeros((output_dim, 1))
        
    def __call__(self, x):
        return jnp.tensordot(self.kernel, x, ([0,1],[0,1])) + self.bias

In [None]:
class GradientBoosting(tc.Module):
    def __init__(self, N, input_dim=32, output_dim=1, history_len=270):
        for i in range(N):
            self.add_module(AR(input_dim=input_dim, output_dim=output_dim, history_len=history_len))
            
        self.W = jnp.ones(N) / N
        
    def __call__(self, x):
        pred, preds = 0, []
        for i, (name, submodule) in enumerate(self.modules.items()):
            pred_i = submodule(x).squeeze()
            preds.append(pred_i)
            pred += self.W[i] * pred_i

        return preds

# Initialize optimizers

In [None]:
bias_threshold = 1e-4
eta = 0.008

SGDs = [SGD(
    learning_rate=learning_rate,
    project_threshold={
        "kernel": kernel_threshold,
        "bias": bias_threshold
    })
    for kernel_threshold, learning_rate in [
        (0.03, 2e-5),
        (0.05, 2e-5),
        (0.07, 2e-5),
        (0.09, 2e-5),
]]

MW = MultiplicativeWeights(eta=eta)

# Predict basin

In [None]:
def predict(basin):
    N = len(SGDs)
    
    model = GradientBoosting(N)
    
    Y_LSTM = jnp.array(basin_to_yhats_LSTM[basin])
    X = pickle.load(open("data/flood/test/{}.pkl".format(basin), "rb"))
    Y = pickle.load(open("data/flood/qobs/{}.pkl".format(basin), "rb"))
    
    def loop(model, xy):
        x, y = xy

        preds = jnp.asarray(model(x))
        pred = 0
        
        for i, (name, module) in enumerate(model.modules.items()):
            module.params = SGDs[i].update(module, module.params, x, y - pred)
            pred += model.W[i] * preds[i]
        
        model.W = MW.update(model, model.W, preds, y)
        
        return model, pred
    
    Y_RESID = Y - Y_LSTM
    module, Y_BOOST = jax.lax.scan(loop, model, (X, Y_RESID))
    
    # for x, y in zip(X, Y_RESID):
        # module, y_hat = loop(module, (x, y))
    
    Y_BOOST = jnp.asarray(Y_BOOST).squeeze()
    loss = ((Y - (Y_LSTM + Y_BOOST)) ** 2).mean()
    
    return loss

# Run!

In [None]:
for basin in tqdm.tqdm(basins):
    loss = predict(basin)
    print(basin, loss)