# Load data

In [1]:
import pickle

basins = pickle.load(open("../data/flood/meta.pkl", "rb"))["basins"]
lstm = pickle.load(open("../data/flood/replicated.pkl", "rb"))

In [238]:
import flax
import jax
import jax.numpy as jnp

from timecast.learners import Linear

class BoostedFlood(flax.nn.Module):
    def apply(self, x, W):
        
        preds = [Linear(x,
                        input_axes=(0, 1),
                        output_shape=(1,),
                        kernel_init=flax.nn.initializers.zeros,
                        bias_init=flax.nn.initializers.zeros)
                 for i in range(len(W))]
        
        return jnp.asarray(preds)

In [239]:
from timecast.optim import ProjectedSGD, MultiplicativeWeights

W_thresholds = [0.03, 0.05, 0.07, 0.09]
b_thresholds = [1e-4, 1e-4, 1e-4, 1e-4]
learning_rate = 2e-5
eta = 0.008

pairs = []
for i in range(N):
    kernel_traversal = flax.optim.ModelParamTraversal((lambda x: lambda path, _: str(x) in path and 'kernel' in path)(i))
    kernel_optimizer = ProjectedSGD(learning_rate=learning_rate, projection_threshold=W_thresholds[i])
    pairs.append((kernel_traversal, kernel_optimizer))
    
    bias_traversal = flax.optim.ModelParamTraversal((lambda x: lambda path, _: str(x) in path and 'bias' in path)(i))
    bias_optimizer = ProjectedSGD(learning_rate=learning_rate, projection_threshold=b_thresholds[i])
    pairs.append((bias_traversal, bias_optimizer))
    
# W_traversal = flax.optim.ModelParamTraversal(lambda path, _: "W" in path)
# W_optimizer = MultiplicativeWeights(eta=eta)
# pairs.append((W_traversal, W_optimizer))

In [277]:
import tqdm.notebook as tqdm
from functools import partial

SEQUENCE_LENGTH = 270
INPUT_DIM = 32
OUTPUT_DIM = 1

N = 4

model_def = BoostedFlood.partial()
_, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(SEQUENCE_LENGTH, INPUT_DIM), (N,)])    
model = flax.nn.Model(model_def, params)

optim_def = flax.optim.MultiOptimizer(*pairs)
optimizer = optim_def.create(model)

basin = basins[0]
Y_LSTM = jnp.array(lstm[basin])
X = pickle.load(open("../data/flood/test/{}.pkl".format(basin), "rb"))
Y = pickle.load(open("../data/flood/qobs/{}.pkl".format(basin), "rb"))

def loss_fun(model, x, W, y):
    y_hats = model(x, W)
    target, y_hat, loss = y, y_hats[0] * W[0], 0
    for i in range(len(y_hats) - 1):
        target -= y_hats[i]
        loss += jnp.square(target - y_hat).mean()
        y_hat += y_hats[i + 1] * W[i + 1]
    return loss, y_hats

Y_BOOST = []
Y_RESID = Y - Y_LSTM

def tscan(carry, xy):
#     print(optimizer.target.params["Linear_0"])
    x, y = xy
    optimizer, W = carry
    print(W)
    (loss, y_hats), grad = jax.value_and_grad(loss_fun, has_aux=True)(optimizer.target, x, W, y)
    y_hat = jnp.dot(W, y_hats)
    loss_W_grads = jax.grad(lambda W, y_hats, y: jnp.square(jnp.dot(W, y_hats) - y).sum())
    nums = W * jnp.exp(-1 * eta * loss_W_grads(W, y_hats, y))
    W = nums / nums.sum()
    return (optimizer.apply_gradient(grad), W), y_hat

# (optimizer, W), Y_BOOST = jax.lax.scan(tscan, (optimizer, jnp.ones(N) / N), (X, Y_RESID))

W = jnp.ones(N) / N
for x, y in zip(X, Y_RESID):
    (optimizer, W), y_hat = tscan((optimizer, W), (x, y))

Traced<ShapedArray(float32[4]):JaxprTrace(level=1/0)>


In [278]:
Y_BOOST

DeviceArray([[ 0.        ],
             [ 0.04192694],
             [-0.04051012],
             ...,
             [ 0.37542355],
             [ 0.09966948],
             [ 0.22868551]], dtype=float32)

In [279]:
jnp.square(Y - (Y_BOOST + Y_LSTM)).mean()

DeviceArray(1.1566014, dtype=float32)

In [249]:
Y_LSTM.sum()

DeviceArray(7102.1475, dtype=float32)

In [250]:
Y_RESID.sum()

DeviceArray(242.56683, dtype=float32)