In [20]:
import os
import pickle

import jax
import jax.numpy as np
import tqdm

from tigerforecast.utils.optimizers.losses import batched_mse, mse

In [2]:
# https://ttic.uchicago.edu/~tewari/lectures/lecture4.pdf

In [3]:
basins = pickle.load(open("../data/flood/meta.pkl", "rb"))["basins"]

In [16]:
from tigerforecast.utils.download_tools import get_tigerforecast_dir
basin_to_yhats_LSTM = pickle.load(open(os.path.join(get_tigerforecast_dir(), "flood_prediction", "basin_to_yhats_LSTM"), "rb"))

In [5]:
# from tigerforecast.batch.camels_dataloader import CamelsTXT
# for basin in tqdm.tqdm(basins):
#     usgs_val = CamelsTXT(basin=basin, concat_static=True)
#     for data, targets in usgs_val.sequential_batches(batch_size=5000):
#         pickle.dump(data, open("../data/flood/test/{}.pkl".format(basin), "wb"))
#         pickle.dump(targets, open("../data/flood/qobs/{}.pkl".format(basin), "wb"))

In [25]:
'''
SGD optimizer
'''
from tigerforecast.utils.optimizers.core import Optimizer

class SGD(Optimizer):
    """
    Description: Stochastic Gradient Descent optimizer.
    Args:
        pred (function): a prediction function implemented with jax.numpy 
        loss (function): specifies loss function to be used; defaults to MSE
        learning_rate (float): learning rate
    Returns:
        None
    """
    def __init__(self, pred=None, loss=mse, learning_rate=0.0001, include_x_loss=False, hyperparameters={}, clip_grad=False, clip_threshold={}):
        self.initialized = False
        self.lr = learning_rate
        self.include_x_loss = include_x_loss
        self.hyperparameters = hyperparameters
        self.pred = pred
        self.loss = loss
        if self._is_valid_pred(pred, raise_error=False) and self._is_valid_loss(loss, raise_error=False):
            self.set_predict(pred, loss=loss)
        self.t = 0
        self.grad_norm_sum = {}
        self.clip_grad = clip_grad
        self.clip_threshold = clip_threshold

    def update_scan(self, params, metadata, x, y, loss=None):
        assert self.initialized
        assert type(params) == dict
        grad = self.gradient(params, x, y, loss=loss)
        if self.clip_grad:
            for k,w in grad.items():
                grad[k] = jax.lax.cond(np.linalg.norm(grad[k]) > self.clip_threshold[k],
                                       None,
                                       lambda x : (self.clip_threshold[k]/np.linalg.norm(grad[k])) *  grad[k],
                                       grad[k],
                                       lambda x : x)
               
        new_params = {k:w - self.lr * grad[k] for (k, w) in params.items()}
        return new_params, metadata

    def initialize_metadata(self, params):
        self.metadata = {'t': 0, 'grad_norm_sum': {k: 0.0 for k,v in params.items()}}

In [26]:
"""
Stateless AR method
"""
import jax.experimental.stax as stax
import tigerforecast
from tigerforecast.utils.random import generate_key
from tigerforecast.methods import Method


class ARStateless_scan(Method):
    """
    Description: Produces outputs from a randomly initialized seq2seq LSTM neural network.
                 Supposed to be used in batch seq2seq mode. Not online mode. 
    """

    compatibles = set(['TimeSeries'])

    def __init__(self, n=1, m=1, l=32, optimizer=None, activation=None, project_threshold={}, scan_mode=True):
        """
        Description: Randomly initialize the Stateless AR.
        Args:
            m (int): Observation/output dimension.
            n (int): Input action dimension.
            l (int): Length of memory for update step purposes.
            optimizer (instance of Optimizer Class): optimizer choice
            acitvation: activation function to compose with predict
        """
        self.T = 0
        self.initialized = True
        self.n, self.m, self.l = n, m, l

        # initialize parameters
        glorot_init = stax.glorot() # returns a function that initializes weights
        # W_lnm = glorot_init(generate_key(), (l, m, n)) # maps l inputs to output
        W_lnm = np.zeros((l,n,m))
        b = np.zeros((m, 1)) # bias 
        self.params = {'W_lnm': W_lnm, 'b': b}
        self.metadata = {'x': np.zeros((self.l, self.n))}
        self.activation = activation
        self.project_threshold = project_threshold
        self.params_norm_sum = {'W_lnm': 0.0, 'b': 0.0}
        self.t = 0

        """ private helper methods"""

        @jax.jit
        def _predict(params, x):
            ### TODO - einsum is not needed here
            # y = np.einsum('ijk,ij->k', params['W_lnm'], x) + params['b']
            # print("--------------- ARSTATELESS_SCAN PREDICT ----------------")
            # print("x.shape = " + str(x.shape))
            y = np.tensordot(params['W_lnm'], x, ([0,1],[0,1])) + params['b']
            if self.activation:
                y = self.activation(y)
            return y

        self.transform = lambda x: float(x) if (self.m == 1) else x
        # self._predict = jax.vmap(_predict, in_axes=(None, 0))
        self._predict = _predict
        self._predict_vmap = jax.vmap(_predict, in_axes=(None, 0))
        if optimizer==None:
            optimizer_instance = OGD(loss=batched_mse)
            self._store_optimizer(optimizer_instance, self._predict)
        else:
            if scan_mode == False:
                self._store_optimizer(optimizer, self._predict_vmap)
            else:
                self._store_optimizer(optimizer, self._predict)
                

        self.optimizer.initialize_metadata(self.params) # intialize state variables in optimizer to proper jax type

        @jax.jit
        def _predict_and_update(carry, xy):
            params, metadata_opt, cnt = carry
            pred = self._predict(params, xy[0])
            metadata_opt['t'] = cnt
            next_params, next_metadata_opt = self.optimizer.update_scan(params, metadata_opt, xy[0], xy[1])
            if not (self.project_threshold is None):
                for k, param in next_params.items():
                    next_params[k] = jax.lax.cond(np.linalg.norm(next_params[k]) > self.project_threshold[k],
                                                  None,
                                                  lambda x : (self.project_threshold[k]/np.linalg.norm(next_params[k])) * next_params[k],
                                                  next_params[k],
                                                  lambda x : x)
                                                  
            return (next_params, next_metadata_opt, cnt+1), pred

        self._predict_and_update = _predict_and_update

    def predict_scan(self, x, params, metadata):
        assert self.initialized
        # print("x.shape = " + str(x.shape))
        # print("x = " + str(x))
        assert(x.shape[0] == self.l)
        assert(x.shape[1] == self.n)
        return self._predict(params, x)

    def update_scan(self, params, metadata, metadata_opt, y):
        assert self.initialized
        ####TODO:Catch the error better
        # assert self.x
        next_params, next_metadata_opt = self.optimizer.update_scan(params, metadata_opt, metadata['x'], y)
        if not (self.project_threshold is None):
            for k, param in next_params.items():
                next_params[k] = jax.lax.cond(np.linalg.norm(next_params[k]) > self.project_threshold[k],
                                              None,
                                              lambda x : (self.project_threshold[k]/np.linalg.norm(next_params[k])) * next_params[k],
                                              next_params[k],
                                              lambda x : x)
        return next_params, metadata, next_metadata_opt

In [27]:
import jax
import jax.numpy as np
import tigerforecast
from tigerforecast.utils.random import generate_key
from tigerforecast.methods import Method
from tigerforecast.utils.optimizers import *
from tigerforecast.utils.optimizers.losses import *

class Gradient_boosting(Method):
    compatibles = set(['TimeSeries'])

    def __init__(self, method_id, X_shape=(), Y_shape=(), loss=mse, eta=1.0, proxy_loss='original', W_update_rule='uniform', T=100):
        """
        Description: Initializes autoregressive method parameters
        Args:
            method_id: list of instances of methods
            X_shape: shape of input (exclude batch dim)
            Y_shape: shape of output (exclude batch dim)
            loss(function): loss function for boosting method
            T: time horiizon
        """
        self.X_shape = X_shape
        self.Y_shape = Y_shape
        self.N = len(method_id)
        self.eta = eta
        self.W_update_rule = W_update_rule
        self.methods = []
        self.all_params = [None for i in range(self.N)]  # params of each method
        self.all_metadata = [None for i in range(self.N)] # metadata of each method
        self.all_metadata_opt = [None for i in range(self.N)] # metadata of optimizer of each method
        self.ys = []
        self.Z = np.array([1.0/self.N for i in range(self.N)])
        self.W = [1.0/self.N for i in range(self.N)]
        self.preds = [0.0 for i in range(self.N)]

        # proxy loss options
        def original_loss(y_pred, yprev_ytrue):
            y_prev, y_true = yprev_ytrue
            y_true = np.reshape(y_true, self.Y_shape)
            y_pred = np.reshape(y_pred, self.Y_shape)
            return loss(y_prev + y_pred, y_true)

        self._proxy_loss = original_loss

        for i in range(self.N):
            new_method = method_id[i]
            new_method.optimizer.set_loss(self._proxy_loss)
            self.methods.append(new_method)
            self.all_params[i] = self.methods[i].params
            self.all_metadata[i] = self.methods[i].metadata
            self.all_metadata_opt[i] = self.methods[i].optimizer.metadata

        def loss_W(W, preds, y):
            pred = 0
            for j in range(self.N):
                pred += W[j] * preds[j]
            return loss(pred, y)
        self.loss_W = jax.jit(loss_W)

        def _predict_scan(x, W, all_params, all_metadata):
            ys = [np.zeros(self.Y_shape)]
            method_preds = []
            for i in range(self.N):
                pred_i = np.reshape(self.methods[i].predict_scan(x, all_params[i], all_metadata[i]), self.Y_shape)
                method_preds.append(pred_i)
                y_i = ys[-1] + W[i] * pred_i
                y_i = np.reshape(y_i, self.Y_shape)
                ys.append(y_i)
            return ys, method_preds

        self._predict_scan = jax.jit(_predict_scan)

        def _predict_and_update(carry, xy):
            x, y = xy[0], xy[1]
            Z, W, all_params, all_metadata, all_metadata_opt, cnt = carry
            ys, method_yhats = self._predict_scan(x, W, all_params, all_metadata)
            pred = ys[-1]

            nums, den = None, None
            
            num = lambda j : W[j] * np.exp(-1 * self.eta * jax.grad(self.loss_W)(W, method_yhats, y)[j])
            nums = [num(j) for j in range(self.N)]
            den = np.sum(np.array(nums))
            W = [nums[i]/den for i in range(self.N)]

            for i in range(self.N):
                # update method's self.x with proper shape (often needed for update step)
                all_metadata[i]['x'] = np.reshape(x, all_metadata[i]['x'].shape)
                all_metadata_opt[i]['t'] = cnt
                all_params[i], all_metadata[i], all_metadata_opt[i] = self.methods[i].update_scan(all_params[i], 
                                                                                                  all_metadata[i], 
                                                                                                  all_metadata_opt[i], 
                                                                                                  (ys[i],y))
            return (Z, W, all_params, all_metadata, all_metadata_opt, cnt+1), pred
        self._predict_and_update = _predict_and_update

    def predict_and_update(self, X, Y):
        assert(X.shape[1:] == self.X_shape)
        assert(Y.shape[1:] == self.Y_shape)
        carry, preds = jax.lax.scan(self._predict_and_update, 
                                (self.Z, self.W, self.all_params, self.all_metadata, self.all_metadata_opt, 0), 
                                (X, Y)
                               )
        preds = np.reshape(preds, Y.shape)
        W = carry[1]
        # print("W = " + str(W))
        return preds, np.array(W)

In [28]:
for basin in tqdm.tqdm(basins):
    SEQUENCE_LENGTH = 270
    INPUT_DIM = 32

    b_threshold = 1e-4
    eta = 0.008

    W_lr_best_pairs = [
            (0.03, 2e-5),
            (0.05, 2e-5),
            (0.07, 2e-5),
            (0.09, 2e-5),
        ]

    method_ids = []

    for W_threshold, lr in W_lr_best_pairs:
        project_threshold = {"W_lnm": W_threshold, "b": b_threshold}
        optim_ar = SGD(loss=batched_mse, learning_rate=lr, clip_grad=False)
        method_ar = ARStateless_scan(
            n=INPUT_DIM,
            m=1,
            l=SEQUENCE_LENGTH,
            optimizer=optim_ar,
            project_threshold=project_threshold,
            scan_mode=True,
        )
        method_ids.append(method_ar)

    method_boosting = Gradient_boosting(
        method_ids,
        X_shape=(270, 32),
        Y_shape=(),
        loss=batched_mse,
        eta=eta,
        proxy_loss="original",
        W_update_rule="GECO",
    )

    yhats_LSTM = np.array(basin_to_yhats_LSTM[basin])
    X = pickle.load(open("../data/flood/test/{}.pkl".format(basin), "rb"))
    Y = pickle.load(open("../data/flood/qobs/{}.pkl".format(basin), "rb"))
    
    y_true = Y - yhats_LSTM
    y_pred_ar, W = method_boosting.predict_and_update(X, y_true)
    yhats = yhats_LSTM + y_pred_ar.squeeze()

#     W_entropy = float(-1 * np.sum(W * np.log2(W)))

    loss = ((Y - yhats) ** 2).mean()
#     ys_mean = Y.mean()

#     nse = 1 - ((Y - yhats) ** 2).sum() / ((Y - ys_mean) ** 2).sum()
#     loss, nse = float(loss), float(nse)
    
    print(basin, loss)


  0%|          | 0/531 [00:00<?, ?it/s]

01022500 

  0%|          | 1/531 [00:02<20:06,  2.28s/it]

0.48434398
01031500 

  0%|          | 1/531 [00:04<40:23,  4.57s/it]


KeyboardInterrupt: 