In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle

import jax
import jax.numpy as jnp
import timecast as tc
import tqdm

In [3]:
class SGD:
    def __init__(self,
               loss_fn=lambda pred, true: jnp.square(pred - true).mean(),
               learning_rate=0.0001,
               project_threshold={}):
        self.loss_fn = loss_fn
        self.learning_rate = learning_rate
        self.project_threshold = project_threshold

    def update(self, module, x, y):
        grad = jax.jit(jax.grad(lambda module, x, y: self.loss_fn(module(x), y)))(module, x, y)
        new_params = {k:w - self.learning_rate * grad.params[k] for (k, w) in params.items()}

        for k, param in new_params.items():
            norm = jnp.linalg.norm(new_params[k])
            new_params[k] = jax.lax.cond(norm > self.project_threshold[k],
                                          new_params[k],
                                          lambda x : (self.project_threshold[k]/norm) * x,
                                          new_params[k],
                                          lambda x : x)
        return new_params

class MultiplicativeWeights:
    def __init__(self, eta=0.008):
        self.eta = eta
        self.grad = jax.jit(jax.grad(lambda W, preds, y: jnp.square(jnp.dot(W, preds) - y).sum()))

    def update(self, module, params, x, y):
        grad = self.grad(params, x, y)
        new_params = params * jnp.exp(-1 * self.eta * grad)
        return new_params / new_params.sum()

class AR(tc.Module):
    def __init__(self, input_dim=32, output_dim=1, history_len=270):
        self.kernel = jnp.zeros((history_len, input_dim, output_dim))
        self.bias = jnp.zeros((output_dim, 1))
    def __call__(self, x):
        print("*** ENTER AR PRED ***")
        print("self.kernel = " + str(self.kernel))
        print("self.bias = " + str(self.bias))
        print("x = " + str(x))
        return jnp.tensordot(self.kernel, x, ([0,1],[0,1])) + self.bias

class GradientBoosting(tc.Module):
    def __init__(self, N, input_dim=32, output_dim=1, history_len=270):
        for i in range(N):
            self.add_module(AR(input_dim=input_dim, output_dim=output_dim, history_len=history_len))
        self.W = jnp.ones(N) / N
    def __call__(self, x):
        pred, preds = 0, []
        print("----- ENTER GRADIENT_BOOSTING PRED -----")
        print("x = " + str(x))
        for i, (name, submodule) in enumerate(self.modules.items()):
            print("submodule.params = " + str(submodule.params))
            print("submodule.kernel = " + str(submodule.kernel))
            pred_i = submodule(x).squeeze()
            print("pred_i = " + str(pred_i))
            preds.append(pred_i)
            pred += self.W[i] * pred_i

        return preds

In [38]:
def ND_loss(Yhats, Y):
    return float((jnp.abs(Y - Yhats)).sum() / (jnp.abs(Y)).sum())

def predict_nips(SGDs, MW):
    N = len(SGDs)
    model = GradientBoosting(N, 1, 1, 7)
    Yhats = 1.001*jnp.ones((20,1))
    print("Yhats.shape = " + str(Yhats.shape))
    X = jnp.ones((20,7,1))
    print("X.shape = " + str(X.shape))
    Y = jnp.ones((20,1))
    print("Y.shape = " + str(Y.shape))
    def loop(model, xy):
        x, y = xy
        preds = jnp.asarray(model(x))
        pred = 0
        print("================= ENTERED LOOP ===================")
        print("x = " + str(x))
        print("y = " + str(y))
        print("preds = " + str(preds))
        print("model.W = " + str(model.W))
        for i, (name, module) in enumerate(model.modules.items()):
            print("module.params BEFORE = " + str(module.params))
            print("module.kernel BEFORE = " + str(module.kernel))
            module.update_params(SGDs[i].update(module, module.params, x, y - pred))
            print("module.params AFTER = " + str(module.params))
            print("module.kernel AFTER = " + str(module.kernel))
            pred += model.W[i] * preds[i]
            print("pred = " + str(pred))

        model.W = MW.update(model, model.W, preds, y)
        return model, pred

    Y_RESID = Y - Yhats
    Y_RESID = jnp.expand_dims(Y_RESID, -1)
    Z = 2
    print("Y_RESID.shape = " + str(Y_RESID.shape))
    print("Y_RESID[:Z] = " + str(Y_RESID[:Z]))
    print("X[:Z] = " + str(X[:Z]))

    Y_BOOST = []
    for x, y in zip(X[:Z], Y_RESID[:Z]):
        model, y_boost = loop(model, (x, y))
        Y_BOOST.append(y_boost)

    Y_BOOST = jnp.expand_dims(jnp.array(Y_BOOST), -1)
    print("Y_BOOST.shape = " + str(Y_BOOST.shape))

    loss = ND_loss(Y[:Z], (Yhats[:Z] + Y_BOOST))
    print("loss = " + str(loss))
    print("Y_BOOST[:Z] = " + str(Y_BOOST[:Z]))

    return loss

def run_train_loop_nips(dataset_name, learning_rate):
    LR = learning_rate

    bias_threshold = 1e-4
    eta = 0.008

    MW = MultiplicativeWeights(eta=eta)
    SGDs = [SGD(
                learning_rate=lr,
                project_threshold={
                    "kernel": kernel_threshold,
                    "bias": bias_threshold
                })
                for kernel_threshold, lr in [
                    (0.03, LR),
                ]]
    loss = predict_nips(SGDs, MW)
    print("loss = " + str(loss))

In [39]:
run_train_loop_nips("electricity_nbeats_last7days", 2e-5)

Yhats.shape = (20, 1)
X.shape = (20, 7, 1)
Y.shape = (20, 1)
Y_RESID.shape = (20, 1, 1)
Y_RESID[:Z] = [[[-0.00100005]]

 [[-0.00100005]]]
X[:Z] = [[[1.]
  [1.]
  [1.]
  [1.]
  [1.]
  [1.]
  [1.]]

 [[1.]
  [1.]
  [1.]
  [1.]
  [1.]
  [1.]
  [1.]]]
----- ENTER GRADIENT_BOOSTING PRED -----
x = [[1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]]
submodule.params = {'kernel': DeviceArray([[[0.]],

             [[0.]],

             [[0.]],

             [[0.]],

             [[0.]],

             [[0.]],

             [[0.]]], dtype=float32), 'bias': DeviceArray([[0.]], dtype=float32)}
submodule.kernel = [[[0.]]

 [[0.]]

 [[0.]]

 [[0.]]

 [[0.]]

 [[0.]]

 [[0.]]]
*** ENTER AR PRED ***
self.kernel = [[[0.]]

 [[0.]]

 [[0.]]

 [[0.]]

 [[0.]]

 [[0.]]

 [[0.]]]
self.bias = [[0.]]
x = [[1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]]
pred_i = 0.0
x = [[1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]]
y = [[-0.00100005]]
preds = [0.]
model.W = [1.]
module.params BEFORE = {'kernel': DeviceArray([[[0.]],

         