In [1]:
import logging
import tensorflow as tf
import math
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import r2_score

import matplotlib.pyplot as plt
import time
import numpy as np
import pickle
import itertools as it
from tensorflow.keras import backend as K, initializers, regularizers, constraints
from tensorflow.keras.layers import Layer
import tensorflow_addons as tfa
from sklearn.metrics import mean_absolute_error
from functools import partial
import itertools
import random
import os
import absl.logging
import json
from tqdm.auto import tqdm
import logging

logging.disable(logging.WARNING)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

absl.logging.set_verbosity(absl.logging.ERROR)
print(tf.__version__)
print(tf.config.list_physical_devices('GPU'))
os.environ['PYTHONHASHSEED']=str(123)
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

2.6.0
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [2]:
def keras_rmse(pred_index, y_true, y_pred):
    """if pred_index is not None:
        y_pred = y_pred[pred_index]"""
    return tf.keras.backend.sqrt(tf.keras.backend.mean(tf.keras.backend.square(y_pred - y_true), axis=-1))


EPSILON = 1e-10


def _error(actual: np.ndarray, predicted: np.ndarray):
    """ Simple error """
    return actual - predicted


def _percentage_error(actual: np.ndarray, predicted: np.ndarray):
    """
    Percentage error

    Note: result is NOT multiplied by 100
    """
    return _error(actual, predicted) / (actual + EPSILON)


def _relative_error(actual: np.ndarray, predicted: np.ndarray, benchmark: np.ndarray = None):
    """ Relative Error """
    if benchmark is None or isinstance(benchmark, int):
        # If no benchmark prediction provided - use naive forecasting
        if not isinstance(benchmark, int):
            seasonality = 1e-10
        else:
            seasonality = benchmark
        return _error(actual[seasonality:], predicted[seasonality:]) /\
               (_error(actual[seasonality:], _naive_forecasting(actual, seasonality)) + EPSILON)

    return _error(actual, predicted) / (_error(actual, benchmark) + EPSILON)


def mse(actual: np.ndarray, predicted: np.ndarray):
    """ Mean Squared Error """
    return np.mean(np.square(_error(actual, predicted)))


def rmse(actual: np.ndarray, predicted: np.ndarray):
    """ Root Mean Squared Error """
    return np.sqrt(mse(actual, predicted))


    return np.mean(_error(actual, predicted))


def mae(actual: np.ndarray, predicted: np.ndarray):
    """ Mean Absolute Error """
    return np.mean(np.abs(_error(actual, predicted)))


def mape(actual: np.ndarray, predicted: np.ndarray):
    """
    Mean Absolute Percentage Error

    Properties:
        + Easy to interpret
        + Scale independent
        - Biased, not symmetric
        - Undefined when actual[t] == 0

    Note: result is NOT multiplied by 100
    """
    return np.mean(np.abs(_percentage_error(actual, predicted)))





METRICS = {
    'mse': mse,
    'rmse': rmse,
    'mae': mae,
    'mape': mape,
}


def evaluate(actual: np.ndarray, predicted: np.ndarray, metrics=('mae', 'rmse')):
    results = {}
    for name in metrics:
        try:
            results[name] = METRICS[name](actual, predicted)
        except Exception as err:
            results[name] = np.nan
            print('Unable to compute metric {0}: {1}'.format(name, err))
    return results


tf.keras.utils.get_custom_objects().update({"rmse": rmse})

In [3]:
def get_model(model_type, window, horizon, d_model, n_features, batch_size, lamb1, lamb2, grad_smooth_alpha, mode, overall):
    if model_type == 'encdecluong':
        inp = tf.keras.layers.Input(shape=(window, n_features))
        embedding = tf.keras.layers.Conv1D(d_model, n_features, activation='relu', input_shape=(window, 3), padding='causal')(inp)

        out = tf.keras.layers.LSTM(d_model, activation='relu', return_sequences=True, return_state=True)(embedding)
        out = tf.keras.layers.Attention()(out)
        out = tf.keras.layers.LSTM(d_model, activation='relu')(out)
        out = tf.keras.layers.Dense(horizon)(out)

        model = tf.keras.Model(inputs=inp, outputs=out)

        loss = 'mse'

    elif model_type == 'lstm':
        inp = tf.keras.layers.Input(shape=(window, n_features))
        embedding = tf.keras.layers.Conv1D(d_model//2, n_features, activation='relu', input_shape=(window, 3), padding='causal')(inp)
        embedding = tf.keras.layers.Conv1D(d_model, n_features, activation='relu', padding='causal')(embedding)
        out = tf.keras.layers.LSTM(d_model, activation='relu', return_sequences=True)(embedding)
        out = tf.keras.layers.LSTM(d_model, activation='relu')(out)
        out = tf.keras.layers.Dense(horizon)(out)

        model = tf.keras.Model(inputs=inp, outputs=out)

        loss = 'mse'

    elif model_type == 'lstm_1layer':
        inp = tf.keras.layers.Input(shape=(window, n_features))
        embedding = tf.keras.layers.Conv1D(d_model//2, n_features, activation='relu', input_shape=(window, 3), padding='causal')(inp)
        embedding = tf.keras.layers.Conv1D(d_model, n_features, activation='relu', padding='causal')(embedding)
        out = tf.keras.layers.LSTM(d_model, activation='relu')(embedding)
        out = tf.keras.layers.Dense(horizon)(out)

        model = tf.keras.Model(inputs=inp, outputs=out)

        loss = 'mse'

    elif model_type == 'hlnet':
        pred_index = -1
        inputs = tf.keras.layers.Input(shape=(window, 3))
        group_factors = [6, 1]
        group_steps = [1, 1]
        dense_layers = []
        outputs = []
        recurrent_units = [d_model]
        recurrent_dropout = 0
        return_sequences = False
        prev_hidden = None

        for gf, ge in zip(group_factors, group_steps):
            return_sequences_tmp = return_sequences if len(recurrent_units) == 1 else True
            inputs_gr = tf.math.reduce_mean(tf.signal.frame(inputs, gf, ge, axis=1), axis=2)
            x = tf.keras.layers.Conv1D(recurrent_units[0], 3, activation='relu', input_shape=(window, 3), padding='causal')(inputs_gr)

            x = tf.keras.layers.LSTM(
                recurrent_units[0],
                return_sequences=return_sequences_tmp,
            )(x)

            for i, u in enumerate(recurrent_units[1:]):
                return_sequences_tmp = (
                    return_sequences if i == len(recurrent_units) - 2 else True
                )
                x = tf.keras.layers.LSTM(
                    u, return_sequences=return_sequences_tmp, dropout=recurrent_dropout
                )(x)

            # Dense layers
            if return_sequences:
                x = tf.keras.layers.Flatten()(x)

            if prev_hidden is not None:
                x = tf.keras.layers.Concatenate(axis=1)([x, prev_hidden])
            else:
                x_hidden = tf.identity(x)
                prev_hidden = tf.stop_gradient(x_hidden)

            for hidden_units in dense_layers:
                x = tf.keras.layers.Dense(hidden_units)(x)
                if dense_dropout > 0:
                    x = tf.keras.layers.Dropout(dense_dropout)(dense_dropout)

            layer_out = tf.keras.layers.Dense(horizon - gf + 1, name=f'level_{gf}_out')(x)
            outputs.append(layer_out)

        model = tf.keras.Model(inputs=inputs, outputs=outputs)

        loss = {
            f"level_{gf}_out": hierarchical_loss(base_criterion='mse', gf=gf, ge=ge, reduction=tf.keras.losses.Reduction.SUM)
            for i, (gf, ge) in enumerate(zip(group_factors, group_steps))
        }

    
    elif model_type == 'multitaskhlnet_taskslstm_firstlevelcnn':
    #pred_index = -1

        class MultitaskHLNet(tf.keras.Model):

            def __init__(self, d_model, n_features, window, horizon, batch_size):
                super(MultitaskHLNet, self).__init__()
                self.d_model = d_model
                self.n_features = n_features
                self.window = window
                self.horizon = horizon
                self.batch_size = batch_size

                self.global_embedding_1 = tf.keras.layers.Conv1D(d_model//2, 3, activation='relu', input_shape=(window, 3), padding='causal')
                self.global_embedding_2 = tf.keras.layers.Conv1D(d_model, 3, activation='relu', padding='causal')
                self.first_level_lstm_layer = tf.keras.layers.LSTM(d_model, return_sequences=True)

                """
                    First Level tasks - Sequence featurization
                """
                self.gap_embedding_1 = tf.keras.layers.Conv1DTranspose(d_model, 3, activation='relu', padding='same')
                self.gap_embedding_2 = tf.keras.layers.Conv1DTranspose(d_model//2, 3, activation='relu', padding='same')
                self.gap_filling_output = tf.keras.layers.Dense(n_features)
                #self.gap_filling_lstm_layer = tf.keras.layers.LSTM(d_model, name='gap_filling_lstm', activation='relu')
                #self.gap_filling_reshape = tf.keras.layers.Reshape((-1, self.window, self.n_features), name='gap_filling_task')
                
                self.noise_embedding_1 = tf.keras.layers.Conv1DTranspose(d_model, 3, activation='relu', padding='same')
                self.noise_embedding_2 = tf.keras.layers.Conv1DTranspose(d_model//2, 3, activation='relu', padding='same')
                self.gaussian_noise = tf.keras.layers.GaussianNoise(0.01)
                #self.noise_reduction_lstm_layer = tf.keras.layers.LSTM(d_model, name='noise_reduction_lstm', activation='relu')
                self.noise_reduction_output = tf.keras.layers.Dense(n_features)
                #self.noise_reduction_reshape = tf.keras.layers.Reshape((-1, window, n_features), name='noise_reduction_task')

                self.swap_combinations = list(itertools.permutations(np.arange(0, 4), 4))
                self.swap_lstm_layer = tf.keras.layers.LSTM(d_model, name='swap_lstm', activation='relu')
                self.swap_output = tf.keras.layers.Dense(len(self.swap_combinations), name='swap_task', activation='softmax')

                #self.quarter_sequence = tf.keras.layers.Dense(4, activation='softmax', name='quarter_sequence')
                #self.month_sequence = tf.keras.layers.Dense(12, activation='softmax', name='month_sequence')
                #self.day_of_week_sequence = tf.keras.layers.Dense(7, activation='softmax', name='day_of_week_sequence')


                """
                    Second Level tasks - Forecasting helpers
                """
                self.second_level_lstm_layer = tf.keras.layers.LSTM(d_model, return_sequences=True)

                self.smooth_lstm_layers = {'8': tf.keras.layers.LSTM(d_model, name='smooth_forecasting_8_lstm', activation='relu'),
                                '6':  tf.keras.layers.LSTM(d_model, name='smooth_forecasting_6_lstm', activation='relu'),
                                '3':  tf.keras.layers.LSTM(d_model, name='smooth_forecasting_3_lstm', activation='relu')}

                self.smooth_layers = {'8': tf.keras.layers.Dense(horizon-8+1, name='smooth_forecasting_8'),
                                '6': tf.keras.layers.Dense(horizon-6+1, name='smooth_forecasting_6'),
                                '3': tf.keras.layers.Dense(horizon-3+1, name='smooth_forecasting_3')}


                self.one_step_forecast_lstm_layers = {'0': tf.keras.layers.LSTM(d_model, name='next_step_forecasting_lstm', activation='relu'),
                                str(horizon//2):  tf.keras.layers.LSTM(d_model, name='mid_step_forecasting_lstm', activation='relu'),
                                str(horizon-1):  tf.keras.layers.LSTM(d_model, name='last_step_forecasting_lstm', activation='relu')}

                self.one_step_forecast = {'0': tf.keras.layers.Dense(1, name='next_step_forecasting'),
                                    str(horizon//2): tf.keras.layers.Dense(1, name='mid_step_forecasting'),
                                    str(horizon-1): tf.keras.layers.Dense(1, name='last_step_forecasting')}

                #self.quarter_forecasting = tf.keras.layers.Dense(4, activation='softmax', name='quarter_forecasting')
                #self.month_forecasting = tf.keras.layers.Dense(12, activation='softmax', name='month_forecasting')
                #self.day_of_week_forecasting = tf.keras.layers.Dense(7, activation='softmax', name='day_of_week_forecasting')

                # self.mean_forecasting = tf.keras.layers.Dense(1, name='mean_forecasting')
                """
                    Last Level tasks
                """
                self.prediction_level_lstm_layer = tf.keras.layers.LSTM(d_model, return_sequences=False)
                self.prediction_output = tf.keras.layers.Dense(horizon, name='prediction')

            def get_mode(self, x, axis=1):
                dt = x.dtype
                # Shift input in case it has negative values
                m = tf.math.reduce_min(x)
                x2 = x - m
                # minlength should not be necessary but may fail without it
                # (reported here https://github.com/tensorflow/probability/issues/962)
                c = tfp.stats.count_integers(x2, axis=axis, dtype=dt,
                                             minlength=tf.math.reduce_max(x2) + 1)
                # Find the values with largest counts
                idx = tf.math.argmax(c, axis=0, output_type=dt)
                # Get the modes by shifting by the subtracted minimum
                modes = idx + m
                # Get the number of counts
                counts = tf.math.reduce_max(c, axis=0)

                return modes

            def gap_filling_task(self, inputs):

                #Gap filling task
                batch_indexes = tf.tile(tf.range(tf.shape(inputs)[0])[:, tf.newaxis, tf.newaxis], (1, self.window, 1))
                head_indexes = tf.tile(tf.range(self.window)[tf.newaxis, :, tf.newaxis], (tf.shape(inputs)[0], 1, 1))
                feat_index = tf.random.uniform((tf.shape(inputs)[0],window,1), minval=0, maxval=3, dtype=tf.int32)

                idx = tf.squeeze(tf.stack(values=[batch_indexes, head_indexes, feat_index], axis=-1))
                idx = tf.transpose(idx, perm=(1,2,0))    

                gap_index =  tf.reshape(tf.transpose(tf.random.shuffle(idx), perm= (2, 0, 1))[:, :1, :], (-1, 3))

                x_gap = tf.identity(inputs)
                x_gap_updated = tf.tensor_scatter_nd_update(x_gap, indices = gap_index, updates = -tf.ones(gap_index.shape[0])*100)
                x_gap_embedding = self.global_embedding_1(x_gap_updated)
                x_gap_embedding = self.global_embedding_2(x_gap_embedding)
                x_gap_embedding = self.gap_embedding_1(x_gap_embedding)
                x_gap_embedding = self.gap_embedding_2(x_gap_embedding)
                #gap_state = self.first_level_lstm_layer(x_gap_embedding)
                gap_filling_output = self.gap_filling_output(x_gap_embedding)
                #gap_filling_output = self.gap_filling_reshape(gap_filling_output_flat)
                #gap_filling_true = tf.reshape(tf.gather_nd(x_gap, gap_index), (tf.shape(x_gap)[0], 1))
                gap_loss = tf.keras.losses.MeanSquaredError()(gap_filling_output, x_gap)

                self.add_metric(gap_loss, name='gap_loss', aggregation='mean')

                return gap_filling_output, gap_loss

            def noise_reduction_task(self, inputs):

                input_copy = tf.identity(inputs)
                x_noise = self.gaussian_noise(input_copy)

                x_noise_embedding = self.global_embedding_1(x_noise)
                x_noise_embedding = self.global_embedding_2(x_noise_embedding)
                x_noise_embedding = self.gap_embedding_1(x_noise_embedding)
                x_noise_embedding = self.gap_embedding_2(x_noise_embedding)
                #noise_state = self.first_level_lstm_layer(x_noise_embedding)
                #noise_state = self.noise_reduction_lstm_layer(noise_state)
                noise_output = self.noise_reduction_output(x_noise_embedding)
                #noise_output = self.noise_reduction_reshape(noise_output_flat)

                noise_loss = tf.keras.losses.MeanSquaredError()(noise_output, input_copy)


                self.add_metric(noise_loss, name='noise_loss', aggregation='mean')
                

                return noise_output, noise_loss

            def smooth_forecast_task(self, second_level_state, labels, smooth=6):

                smooth_state = self.smooth_lstm_layers[str(smooth)](second_level_state)
                smooth_forecasting = self.smooth_layers[str(smooth)](smooth_state)
                labels_smooth = tf.math.reduce_mean(tf.signal.frame(labels, smooth, 1, axis=1), axis=2)

                smooth_loss = tf.keras.losses.MeanSquaredError()(smooth_forecasting, labels_smooth)

                self.add_metric(smooth_loss, name=f'smooth_{smooth}_loss', aggregation='mean')
                
                return smooth_forecasting, smooth_state, smooth_loss

            def one_step_forecast_task(self,second_level_state, labels, step=0):
                step_state = self.one_step_forecast_lstm_layers[str(step)](second_level_state)
                step_forecast = self.one_step_forecast[str(step)](step_state)
                step_loss = tf.keras.losses.MeanSquaredError()(step_forecast, labels[:, step, :])

                self.add_metric(step_loss, name=f'step_{step}_loss', aggregation='mean')
                
                return step_forecast, step_state, step_loss

            def day_of_week_forecasting(self, second_level_state, labels_extra):
                week_forecast = self.day_of_week_forecasting(second_level_state[:, -1, :])
                week_loss = tf.keras.losses.MeanSquaredError()(week_forecast, self.get_mode(labels_extra[:, :, 0], axis=1))

                self.add_metric(week_loss, name=f'week_forecasting_loss', aggregation='mean')

            def day_of_week_sequence(self, first_level_state, inputs_extra):
                # TODO: ONEHOT
                week_forecast = self.day_of_week_sequence(first_level_state[:, -1, :])
                week_loss = tf.keras.losses.MeanSquaredError()(week_forecast, self.get_mode(inputs_extra[:, :, 0], axis=1))

                self.add_metric(week_loss, name=f'week_sequence_loss', aggregation='mean') 

            def quarter_forecasting(self, second_level_state, labels_extra):
                quarter_forecast = self.quarter_forecasting(second_level_state[:, -1, :])
                quarter_loss = tf.keras.losses.MeanSquaredError()(quarter_forecast, self.get_mode(labels_extra[:, :, 2], axis=1))

                self.add_metric(quarter_loss, name=f'quarter_forecasting_loss', aggregation='mean')

            def quarter_sequence(self, first_level_state, inputs_extra):
                quarter_forecast = self.quarter_sequence(first_level_state[:, -1, :])
                quarter_loss = tf.keras.losses.MeanSquaredError()(quarter_forecast, self.get_mode(inputs_extra[:, :, 2], axis=1))

                self.add_metric(quarter_loss, name=f'quarter_sequence_loss', aggregation='mean') 


            def month_forecasting(self, second_level_state, labels_extra):
                month_forecast = self.month_forecasting(second_level_state[:, -1, :])
                month_loss = tf.keras.losses.MeanSquaredError()(month_forecast, self.get_mode(labels_extra[:, :, 1], axis=1))

                self.add_metric(month_loss, name=f'month_forecasting_loss', aggregation='mean')


            def month_sequence(self, first_level_state, inputs_extra):
                month_forecast = self.month_sequence(first_level_state[:, -1, :])
                month_loss = tf.keras.losses.MeanSquaredError()(month_forecast, self.get_mode(inputs_extra[:, :, 1], axis=1))

                self.add_metric(month_loss, name=f'month_sequence_loss', aggregation='mean')                

            def swap_task(self, inputs):

                x_swap = tf.identity(inputs)

                idx = tf.tile([tf.range(0, window)], (tf.shape(x_swap)[0], 1))

                swap_index = tf.random.uniform((tf.shape(x_swap)[0], 1), 0, len(self.swap_combinations), dtype=tf.int64)

                swap_index_onehot = tf.one_hot(swap_index, len(self.swap_combinations))

                swap = tf.gather(tf.convert_to_tensor(self.swap_combinations), swap_index)

                swap_numpy = swap.numpy()
                idx_numpy = idx.numpy()
                idx_numpy = np.concatenate((idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 0].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 1].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 2].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 3].squeeze()]), axis=1)

                idx = tf.convert_to_tensor(idx_numpy)
                x_swap = tf.gather(x_swap, idx, axis=1, batch_dims=1)

                x_swap_embedding = self.embedding(x_swap)
                swap_state = self.first_level_lstm_layer(x_swap_embedding)
                swap_state = self.swap_lstm_layer(swap_state)
                swap_output = self.swap_output(swap_state)

                swap_loss = tf.keras.losses.CategoricalCrossentropy()(tf.squeeze(swap_index_onehot), swap_output)

                self.add_metric(swap_loss, name='swap_loss', aggregation='mean')

                return swap_output, swap_state


            def call(self, inputs):
                labels, labels_extra = inputs[1]
                inputs, inputs_extra = inputs[0]

                x = self.global_embedding_1(inputs)
                x = self.global_embedding_2(x)
                #first_level_state = self.first_level_lstm_layer(x)

                # First level
                gap_filling_outputs, gap_loss = self.gap_filling_task(inputs)
                noise_reduction_outputs, noise_loss = self.noise_reduction_task(inputs)
                #swap_output, swap_state = self.swap_task(inputs)

                #Second level
                second_level_state = self.second_level_lstm_layer(x)
                smooth_forecasting_8, smooth_state_8, smooth_loss_8 = self.smooth_forecast_task(second_level_state, labels, 8)
                smooth_forecasting_6, smooth_state_6, smooth_loss_6 = self.smooth_forecast_task(second_level_state, labels, 6)
                smooth_forecasting_3, smooth_state_3, smooth_loss_3 = self.smooth_forecast_task(second_level_state, labels, 3)

                first_step_forecast, first_step_state, first_step_loss = self.one_step_forecast_task(second_level_state, labels, 0)
                mid_step_forecast, mid_step_state, mid_step_loss = self.one_step_forecast_task(second_level_state, labels, self.horizon//2)
                last_step_forecast, last_step_state, last_step_loss = self.one_step_forecast_task(second_level_state, labels, self.horizon-1)
                
                self.add_loss(tf.reduce_mean([gap_loss, noise_loss, 
                                                             smooth_loss_8, smooth_loss_6,
                                                             smooth_loss_3, first_step_loss,
                                                             mid_step_loss, last_step_loss]))
                # Last level
                prediction_level_state = self.prediction_level_lstm_layer(second_level_state)
                prediction_state = tf.keras.layers.concatenate([prediction_level_state, last_step_state, mid_step_state, 
                                         first_step_state, smooth_state_3, smooth_state_6, smooth_state_8])
                
                prediction_output = self.prediction_output(prediction_state)

                return prediction_output


        model = MultitaskHLNet(d_model, n_features, window, horizon, batch_size)

        loss = 'mse'
        
    elif model_type == 'multitaskhlnet_taskslstm_firstlevelcnn_gradsprojected':
    #pred_index = -1

        class MultitaskHLNet(tf.keras.Model):

            def __init__(self, d_model, n_features, window, horizon, batch_size, grad_smooth_alpha, lamb1, lamb2, mode, overall):
                super(MultitaskHLNet, self).__init__()
                self.d_model = d_model
                self.n_features = n_features
                self.window = window
                self.horizon = horizon
                self.batch_size = batch_size
                self.main_grad_first_level_average = None
                self.main_grad_second_level_average = None
                self.first_level_aux_grad_average = None
                self.grad_smooth_alpha = grad_smooth_alpha
                self.lamb1 = lamb1
                self.lamb2 = lamb2
                self.mode = mode
                self.overall = overall
                
                self.loss_weigts = tf.Variable(tf.ones([8]))
                
                self.global_embedding_1 = tf.keras.layers.Conv1D(d_model//2, 3, activation='relu', input_shape=(window, 3), padding='causal')
                self.global_embedding_2 = tf.keras.layers.Conv1D(d_model, 3, activation='relu', padding='causal')
                
                self.global_first_level_layers = [0,1]
                """
                    First Level tasks - Sequence featurization
                """
                self.decode_cnn_1 = tf.keras.layers.Conv1DTranspose(d_model, 3, activation='relu', padding='same')
                self.decode_cnn_2 = tf.keras.layers.Conv1DTranspose(d_model//2, 3, activation='relu', padding='same')
                self.gap_filling_output = tf.keras.layers.Dense(n_features)

                self.gaussian_noise = tf.keras.layers.GaussianNoise(0.01)
                self.noise_reduction_output = tf.keras.layers.Dense(n_features)

                #self.swap_combinations = list(itertools.permutations(np.arange(0, 4), 4))
                #self.swap_lstm_layer = tf.keras.layers.LSTM(d_model, name='swap_lstm', activation='relu')
                #self.swap_output = tf.keras.layers.Dense(len(self.swap_combinations), name='swap_task', activation='softmax')

                
                """
                    Second Level tasks - Forecasting helpers
                """
                self.second_level_lstm_layer = tf.keras.layers.LSTM(d_model, return_sequences=True)

                self.smooth_lstm_layers = {'8': tf.keras.layers.LSTM(d_model, name='smooth_forecasting_8_lstm', activation='relu'),
                                '6':  tf.keras.layers.LSTM(d_model, name='smooth_forecasting_6_lstm', activation='relu'),
                                '3':  tf.keras.layers.LSTM(d_model, name='smooth_forecasting_3_lstm', activation='relu')}

                self.smooth_layers = {'8': tf.keras.layers.Dense(horizon-8+1, name='smooth_forecasting_8'),
                                '6': tf.keras.layers.Dense(horizon-6+1, name='smooth_forecasting_6'),
                                '3': tf.keras.layers.Dense(horizon-3+1, name='smooth_forecasting_3')}


                self.one_step_forecast_lstm_layers = {'0': tf.keras.layers.LSTM(d_model, name='next_step_forecasting_lstm', activation='relu'),
                                str(horizon//2):  tf.keras.layers.LSTM(d_model, name='mid_step_forecasting_lstm', activation='relu'),
                                str(horizon-1):  tf.keras.layers.LSTM(d_model, name='last_step_forecasting_lstm', activation='relu')}

                self.one_step_forecast = {'0': tf.keras.layers.Dense(1, name='next_step_forecasting'),
                                    str(horizon//2): tf.keras.layers.Dense(1, name='mid_step_forecasting'),
                                    str(horizon-1): tf.keras.layers.Dense(1, name='last_step_forecasting')}
                self.global_second_level_layers = [7]
                """
                    Last Level tasks
                """
                self.prediction_level_lstm_layer = tf.keras.layers.LSTM(d_model, return_sequences=False)
                self.prediction_output = tf.keras.layers.Dense(horizon, name='prediction')

            def get_mode(self, x, axis=1):
                dt = x.dtype
                # Shift input in case it has negative values
                m = tf.math.reduce_min(x)
                x2 = x - m
                # minlength should not be necessary but may fail without it
                # (reported here https://github.com/tensorflow/probability/issues/962)
                c = tfp.stats.count_integers(x2, axis=axis, dtype=dt,
                                             minlength=tf.math.reduce_max(x2) + 1)
                # Find the values with largest counts
                idx = tf.math.argmax(c, axis=0, output_type=dt)
                # Get the modes by shifting by the subtracted minimum
                modes = idx + m
                # Get the number of counts
                counts = tf.math.reduce_max(c, axis=0)

                return modes

            def gap_filling_task(self, inputs):

                #Gap filling task
                batch_indexes = tf.tile(tf.range(tf.shape(inputs)[0])[:, tf.newaxis, tf.newaxis], (1, self.window, 1))
                head_indexes = tf.tile(tf.range(self.window)[tf.newaxis, :, tf.newaxis], (tf.shape(inputs)[0], 1, 1))
                feat_index = tf.random.uniform((tf.shape(inputs)[0],window,1), minval=0, maxval=3, dtype=tf.int32)

                idx = tf.squeeze(tf.stack(values=[batch_indexes, head_indexes, feat_index], axis=-1))
                idx = tf.transpose(idx, perm=(1,2,0))    

                gap_index =  tf.reshape(tf.transpose(tf.random.shuffle(idx), perm= (2, 0, 1))[:, :1, :], (-1, 3))

                x_gap = tf.identity(inputs)
                x_gap_updated = tf.tensor_scatter_nd_update(x_gap, indices = gap_index, updates = -tf.ones(gap_index.shape[0])*100)
                x_gap_embedding = self.global_embedding_1(x_gap_updated)
                x_gap_embedding = self.global_embedding_2(x_gap_embedding)
                x_gap_embedding = self.decode_cnn_1(x_gap_embedding)
                x_gap_embedding = self.decode_cnn_2(x_gap_embedding)
                #gap_state = self.first_level_lstm_layer(x_gap_embedding)
                gap_filling_output = self.gap_filling_output(x_gap_embedding)
                #gap_filling_output = self.gap_filling_reshape(gap_filling_output_flat)
                #gap_filling_true = tf.reshape(tf.gather_nd(x_gap, gap_index), (tf.shape(x_gap)[0], 1))
                gap_loss = tf.keras.losses.MeanSquaredError()(gap_filling_output, x_gap)

                self.add_metric(gap_loss, name='gap_loss', aggregation='mean')
                

                return gap_filling_output, gap_loss

            def noise_reduction_task(self, inputs):

                input_copy = tf.identity(inputs)
                x_noise = self.gaussian_noise(input_copy)

                x_noise_embedding = self.global_embedding_1(x_noise)
                x_noise_embedding = self.global_embedding_2(x_noise_embedding)
                x_noise_embedding = self.decode_cnn_1(x_noise_embedding)
                x_noise_embedding = self.decode_cnn_2(x_noise_embedding)
                #noise_state = self.first_level_lstm_layer(x_noise_embedding)
                #noise_state = self.noise_reduction_lstm_layer(noise_state)
                noise_output = self.noise_reduction_output(x_noise_embedding)
                #noise_output = self.noise_reduction_reshape(noise_output_flat)

                noise_loss = tf.keras.losses.MeanSquaredError()(noise_output, input_copy)

                self.add_metric(noise_loss, name='noise_loss', aggregation='mean')
                

                return noise_output, noise_loss

            def smooth_forecast_task(self, second_level_state, labels, smooth=6):

                smooth_state = self.smooth_lstm_layers[str(smooth)](second_level_state)
                smooth_forecasting = self.smooth_layers[str(smooth)](smooth_state)
                labels_smooth = tf.math.reduce_mean(tf.signal.frame(labels, smooth, 1, axis=1), axis=2)

                smooth_loss = tf.keras.losses.MeanSquaredError()(smooth_forecasting, labels_smooth)

                self.add_metric(smooth_loss, name=f'smooth_{smooth}_loss', aggregation='mean')
                

                return smooth_forecasting, smooth_state, smooth_loss

            def one_step_forecast_task(self,second_level_state, labels, step=0):
                step_state = self.one_step_forecast_lstm_layers[str(step)](second_level_state)
                step_forecast = self.one_step_forecast[str(step)](step_state)
                step_loss = tf.keras.losses.MeanSquaredError()(step_forecast, labels[:, step, :])

                self.add_metric(step_loss, name=f'step_{step}_loss', aggregation='mean')
                
                
                return step_forecast, step_state, step_loss

            def day_of_week_forecasting(self, second_level_state, labels_extra):
                week_forecast = self.day_of_week_forecasting(second_level_state[:, -1, :])
                week_loss = tf.keras.losses.MeanSquaredError()(week_forecast, self.get_mode(labels_extra[:, :, 0], axis=1))

                self.add_metric(week_loss, name=f'week_forecasting_loss', aggregation='mean')

            def day_of_week_sequence(self, first_level_state, inputs_extra):
                # TODO: ONEHOT
                week_forecast = self.day_of_week_sequence(first_level_state[:, -1, :])
                week_loss = tf.keras.losses.MeanSquaredError()(week_forecast, self.get_mode(inputs_extra[:, :, 0], axis=1))

                self.add_metric(week_loss, name=f'week_sequence_loss', aggregation='mean') 

            def quarter_forecasting(self, second_level_state, labels_extra):
                quarter_forecast = self.quarter_forecasting(second_level_state[:, -1, :])
                quarter_loss = tf.keras.losses.MeanSquaredError()(quarter_forecast, self.get_mode(labels_extra[:, :, 2], axis=1))

                self.add_metric(quarter_loss, name=f'quarter_forecasting_loss', aggregation='mean')

            def quarter_sequence(self, first_level_state, inputs_extra):
                quarter_forecast = self.quarter_sequence(first_level_state[:, -1, :])
                quarter_loss = tf.keras.losses.MeanSquaredError()(quarter_forecast, self.get_mode(inputs_extra[:, :, 2], axis=1))

                self.add_metric(quarter_loss, name=f'quarter_sequence_loss', aggregation='mean') 


            def month_forecasting(self, second_level_state, labels_extra):
                month_forecast = self.month_forecasting(second_level_state[:, -1, :])
                month_loss = tf.keras.losses.MeanSquaredError()(month_forecast, self.get_mode(labels_extra[:, :, 1], axis=1))

                self.add_metric(month_loss, name=f'month_forecasting_loss', aggregation='mean')


            def month_sequence(self, first_level_state, inputs_extra):
                month_forecast = self.month_sequence(first_level_state[:, -1, :])
                month_loss = tf.keras.losses.MeanSquaredError()(month_forecast, self.get_mode(inputs_extra[:, :, 1], axis=1))

                self.add_metric(month_loss, name=f'month_sequence_loss', aggregation='mean')                

            def swap_task(self, inputs):

                x_swap = tf.identity(inputs)

                idx = tf.tile([tf.range(0, window)], (tf.shape(x_swap)[0], 1))

                swap_index = tf.random.uniform((tf.shape(x_swap)[0], 1), 0, len(self.swap_combinations), dtype=tf.int64)

                swap_index_onehot = tf.one_hot(swap_index, len(self.swap_combinations))

                swap = tf.gather(tf.convert_to_tensor(self.swap_combinations), swap_index)

                swap_numpy = swap.numpy()
                idx_numpy = idx.numpy()
                idx_numpy = np.concatenate((idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 0].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 1].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 2].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 3].squeeze()]), axis=1)

                idx = tf.convert_to_tensor(idx_numpy)
                x_swap = tf.gather(x_swap, idx, axis=1, batch_dims=1)

                x_swap_embedding = self.embedding(x_swap)
                swap_state = self.first_level_lstm_layer(x_swap_embedding)
                swap_state = self.swap_lstm_layer(swap_state)
                swap_output = self.swap_output(swap_state)

                swap_loss = tf.keras.losses.CategoricalCrossentropy()(tf.squeeze(swap_index_onehot), swap_output)

                self.add_metric(swap_loss, name='swap_loss', aggregation='mean')

                return swap_output, swap_state

            def train_step(self, data):
                inputs, labels = data
                gradients = []
                
                with tf.GradientTape(persistent=True) as tape:
                    prediction, losses = self(inputs, True)
                    loss = self.compiled_loss(labels, prediction, regularization_losses=self.losses)

                self.compiled_metrics.update_state(labels, prediction)
                
                model_layers = np.array(self.layers)

                # First level loss
                first_level_trainable_weights = [w for layer in model_layers[model.global_first_level_layers] for w in layer.trainable_weights]    
                main_grad_first_level = tape.gradient(loss, first_level_trainable_weights)
                self.main_grad_first_level_average = update_smooth_grad(main_grad_first_level, self.main_grad_first_level_average, grad_smooth_alpha)

                first_level_gradients = []
                for l in losses[:2]:

                    aux_grad = tape.gradient(l, first_level_trainable_weights)
                    aux_grad = combined_grads(main_grad_first_level, self.main_grad_first_level_average, aux_grad, mode, overall, lamb1)

                    first_level_gradients.append(aux_grad)

                    gradients.extend(list(zip(aux_grad, first_level_trainable_weights)))

                first_level_aux_grad = combined_grads(first_level_gradients[0], None, first_level_gradients[1], 'Multitask', overall, 1)
                self.first_level_aux_grad_average = update_smooth_grad(first_level_aux_grad, self.first_level_aux_grad_average, grad_smooth_alpha)

                #Second level loss
                first_level_trainable_weights = [w for layer in model_layers[self.global_first_level_layers] for w in layer.trainable_weights]    

                for l in losses[2:]:
                    aux_grad = tape.gradient(l, first_level_trainable_weights)
                    aux_grad = combined_grads(first_level_aux_grad, self.first_level_aux_grad_average, aux_grad, mode, overall, lamb2)
                    gradients.extend(list(zip(aux_grad, first_level_trainable_weights)))

                second_level_trainable_weights = [w for layer in model_layers[self.global_second_level_layers] for w in layer.trainable_weights]    
                main_grad_second_level = tape.gradient(loss, second_level_trainable_weights)
                self.main_grad_second_level_average = update_smooth_grad(main_grad_second_level, self.main_grad_second_level_average, grad_smooth_alpha)

                for l in losses[2:]:
                    aux_grad = tape.gradient(l, second_level_trainable_weights)
                    aux_grad = combined_grads(main_grad_second_level, self.main_grad_second_level_average, aux_grad, mode, overall, lamb2)

                    gradients.extend(list(zip(aux_grad, second_level_trainable_weights)))

                # All losses
                rest_layers = sum([self.global_second_level_layers, self.global_first_level_layers], [])

                mask = np.ones(len(model_layers), bool)
                mask[rest_layers] = False
                all_level_trainable_weights = [w for layer in model_layers[mask] for w in layer.trainable_weights]    
                main_grad_rest_layers = tape.gradient(loss, all_level_trainable_weights)

                for l in losses:
                    grad = tape.gradient(l, all_level_trainable_weights)
                    gradients.extend(list(zip(grad, all_level_trainable_weights)))
                gradients.extend(list(zip(main_grad_rest_layers, all_level_trainable_weights)))
                #Apply all gradients
                self.optimizer.apply_gradients(gradients)

                return {m.name: m.result() for m in self.metrics}
            
            def call(self, inputs, training):
                labels, labels_extra = inputs[1]
                inputs, inputs_extra = inputs[0]

                x = self.global_embedding_1(inputs)
                x = self.global_embedding_2(x)
                #first_level_state = self.first_level_lstm_layer(x)

                # First level
                gap_filling_outputs, gap_loss = self.gap_filling_task(inputs)
                noise_reduction_outputs, noise_loss = self.noise_reduction_task(inputs)
                #swap_output, swap_state = self.swap_task(inputs)

                #Second level
                second_level_state = self.second_level_lstm_layer(x)
                smooth_forecasting_8, smooth_state_8, smooth_loss_8 = self.smooth_forecast_task(second_level_state, labels, 8)
                smooth_forecasting_6, smooth_state_6, smooth_loss_6 = self.smooth_forecast_task(second_level_state, labels, 6)
                smooth_forecasting_3, smooth_state_3, smooth_loss_3 = self.smooth_forecast_task(second_level_state, labels, 3)

                first_step_forecast, first_step_state, first_step_loss = self.one_step_forecast_task(second_level_state, labels, 0)
                mid_step_forecast, mid_step_state, mid_step_loss = self.one_step_forecast_task(second_level_state, labels, self.horizon//2)
                last_step_forecast, last_step_state, last_step_loss = self.one_step_forecast_task(second_level_state, labels, self.horizon-1)
                
                losses = [gap_loss, noise_loss, smooth_loss_8, smooth_loss_6,
                        smooth_loss_3, first_step_loss, mid_step_loss, last_step_loss]
                
                """tasks_loss = 0
                for i, loss in enumerate(losses):
                    losses[i] = (0.5/self.loss_weigts[i]**2)*loss + tf.math.log(1+self.loss_weigts[i]**2)"""
                    
                #self.add_loss(tasks_loss)
                # Last level
                prediction_level_state = self.prediction_level_lstm_layer(second_level_state)
                prediction_state = tf.keras.layers.concatenate([prediction_level_state, last_step_state, mid_step_state, 
                                         first_step_state, smooth_state_3, smooth_state_6, smooth_state_8])
                
                prediction_output = self.prediction_output(prediction_state)
                
                if training:
                    return prediction_output, losses
                else:
                    return prediction_output


        model = MultitaskHLNet(d_model, n_features, window, horizon, batch_size, grad_smooth_alpha, lamb1, lamb2, mode, overall)

        loss = 'mse'
    elif model_type == 'multitaskhlnet_taskslstm_firstlevelcnn_gradsprojected_nosmooth':
    #pred_index = -1

        class MultitaskHLNet(tf.keras.Model):

            def __init__(self, d_model, n_features, window, horizon, batch_size, grad_smooth_alpha, lamb1, lamb2, mode, overall):
                super(MultitaskHLNet, self).__init__()
                self.d_model = d_model
                self.n_features = n_features
                self.window = window
                self.horizon = horizon
                self.batch_size = batch_size
                self.main_grad_first_level_average = None
                self.main_grad_second_level_average = None
                self.first_level_aux_grad_average = None
                self.grad_smooth_alpha = grad_smooth_alpha
                self.lamb1 = lamb1
                self.lamb2 = lamb2
                self.mode = mode
                self.overall = overall
                
                self.loss_weigts = tf.Variable(tf.ones([8]))
                
                self.global_embedding_1 = tf.keras.layers.Conv1D(d_model//2, 3, activation='relu', input_shape=(window, 3), padding='causal')
                self.global_embedding_2 = tf.keras.layers.Conv1D(d_model, 3, activation='relu', padding='causal')
                
                self.global_first_level_layers = [0,1]
                """
                    First Level tasks - Sequence featurization
                """
                self.decode_cnn_1 = tf.keras.layers.Conv1DTranspose(d_model, 3, activation='relu', padding='same')
                self.decode_cnn_2 = tf.keras.layers.Conv1DTranspose(d_model//2, 3, activation='relu', padding='same')
                self.gap_filling_output = tf.keras.layers.Dense(n_features)

                self.gaussian_noise = tf.keras.layers.GaussianNoise(0.01)
                self.noise_reduction_output = tf.keras.layers.Dense(n_features)

                #self.swap_combinations = list(itertools.permutations(np.arange(0, 4), 4))
                #self.swap_lstm_layer = tf.keras.layers.LSTM(d_model, name='swap_lstm', activation='relu')
                #self.swap_output = tf.keras.layers.Dense(len(self.swap_combinations), name='swap_task', activation='softmax')

                
                """
                    Second Level tasks - Forecasting helpers
                """
                self.second_level_lstm_layer = tf.keras.layers.LSTM(d_model, return_sequences=True)

                self.smooth_lstm_layers = {'8': tf.keras.layers.LSTM(d_model, name='smooth_forecasting_8_lstm', activation='relu'),
                                '6':  tf.keras.layers.LSTM(d_model, name='smooth_forecasting_6_lstm', activation='relu'),
                                '3':  tf.keras.layers.LSTM(d_model, name='smooth_forecasting_3_lstm', activation='relu')}

                self.smooth_layers = {'8': tf.keras.layers.Dense(horizon-8+1, name='smooth_forecasting_8'),
                                '6': tf.keras.layers.Dense(horizon-6+1, name='smooth_forecasting_6'),
                                '3': tf.keras.layers.Dense(horizon-3+1, name='smooth_forecasting_3')}


                self.one_step_forecast_lstm_layers = {'0': tf.keras.layers.LSTM(d_model, name='next_step_forecasting_lstm', activation='relu'),
                                str(horizon//2):  tf.keras.layers.LSTM(d_model, name='mid_step_forecasting_lstm', activation='relu'),
                                str(horizon-1):  tf.keras.layers.LSTM(d_model, name='last_step_forecasting_lstm', activation='relu')}

                self.one_step_forecast = {'0': tf.keras.layers.Dense(1, name='next_step_forecasting'),
                                    str(horizon//2): tf.keras.layers.Dense(1, name='mid_step_forecasting'),
                                    str(horizon-1): tf.keras.layers.Dense(1, name='last_step_forecasting')}
                self.global_second_level_layers = [7]
                """
                    Last Level tasks
                """
                self.prediction_level_lstm_layer = tf.keras.layers.LSTM(d_model, return_sequences=False)
                self.prediction_output = tf.keras.layers.Dense(horizon, name='prediction')

            def get_mode(self, x, axis=1):
                dt = x.dtype
                # Shift input in case it has negative values
                m = tf.math.reduce_min(x)
                x2 = x - m
                # minlength should not be necessary but may fail without it
                # (reported here https://github.com/tensorflow/probability/issues/962)
                c = tfp.stats.count_integers(x2, axis=axis, dtype=dt,
                                             minlength=tf.math.reduce_max(x2) + 1)
                # Find the values with largest counts
                idx = tf.math.argmax(c, axis=0, output_type=dt)
                # Get the modes by shifting by the subtracted minimum
                modes = idx + m
                # Get the number of counts
                counts = tf.math.reduce_max(c, axis=0)

                return modes

            def gap_filling_task(self, inputs):

                #Gap filling task
                batch_indexes = tf.tile(tf.range(tf.shape(inputs)[0])[:, tf.newaxis, tf.newaxis], (1, self.window, 1))
                head_indexes = tf.tile(tf.range(self.window)[tf.newaxis, :, tf.newaxis], (tf.shape(inputs)[0], 1, 1))
                feat_index = tf.random.uniform((tf.shape(inputs)[0],window,1), minval=0, maxval=3, dtype=tf.int32)

                idx = tf.squeeze(tf.stack(values=[batch_indexes, head_indexes, feat_index], axis=-1))
                idx = tf.transpose(idx, perm=(1,2,0))    

                gap_index =  tf.reshape(tf.transpose(tf.random.shuffle(idx), perm= (2, 0, 1))[:, :1, :], (-1, 3))

                x_gap = tf.identity(inputs)
                x_gap_updated = tf.tensor_scatter_nd_update(x_gap, indices = gap_index, updates = -tf.ones(gap_index.shape[0])*100)
                x_gap_embedding = self.global_embedding_1(x_gap_updated)
                x_gap_embedding = self.global_embedding_2(x_gap_embedding)
                x_gap_embedding = self.decode_cnn_1(x_gap_embedding)
                x_gap_embedding = self.decode_cnn_2(x_gap_embedding)
                #gap_state = self.first_level_lstm_layer(x_gap_embedding)
                gap_filling_output = self.gap_filling_output(x_gap_embedding)
                #gap_filling_output = self.gap_filling_reshape(gap_filling_output_flat)
                #gap_filling_true = tf.reshape(tf.gather_nd(x_gap, gap_index), (tf.shape(x_gap)[0], 1))
                gap_loss = tf.keras.losses.MeanSquaredError()(gap_filling_output, x_gap)

                self.add_metric(gap_loss, name='gap_loss', aggregation='mean')
                

                return gap_filling_output, gap_loss

            def noise_reduction_task(self, inputs):

                input_copy = tf.identity(inputs)
                x_noise = self.gaussian_noise(input_copy)

                x_noise_embedding = self.global_embedding_1(x_noise)
                x_noise_embedding = self.global_embedding_2(x_noise_embedding)
                x_noise_embedding = self.decode_cnn_1(x_noise_embedding)
                x_noise_embedding = self.decode_cnn_2(x_noise_embedding)
                #noise_state = self.first_level_lstm_layer(x_noise_embedding)
                #noise_state = self.noise_reduction_lstm_layer(noise_state)
                noise_output = self.noise_reduction_output(x_noise_embedding)
                #noise_output = self.noise_reduction_reshape(noise_output_flat)

                noise_loss = tf.keras.losses.MeanSquaredError()(noise_output, input_copy)

                self.add_metric(noise_loss, name='noise_loss', aggregation='mean')
                

                return noise_output, noise_loss

            def smooth_forecast_task(self, second_level_state, labels, smooth=6):

                smooth_state = self.smooth_lstm_layers[str(smooth)](second_level_state)
                smooth_forecasting = self.smooth_layers[str(smooth)](smooth_state)
                labels_smooth = tf.math.reduce_mean(tf.signal.frame(labels, smooth, 1, axis=1), axis=2)

                smooth_loss = tf.keras.losses.MeanSquaredError()(smooth_forecasting, labels_smooth)

                self.add_metric(smooth_loss, name=f'smooth_{smooth}_loss', aggregation='mean')
                

                return smooth_forecasting, smooth_state, smooth_loss

            def one_step_forecast_task(self,second_level_state, labels, step=0):
                step_state = self.one_step_forecast_lstm_layers[str(step)](second_level_state)
                step_forecast = self.one_step_forecast[str(step)](step_state)
                step_loss = tf.keras.losses.MeanSquaredError()(step_forecast, labels[:, step, :])

                self.add_metric(step_loss, name=f'step_{step}_loss', aggregation='mean')
                
                
                return step_forecast, step_state, step_loss

            def day_of_week_forecasting(self, second_level_state, labels_extra):
                week_forecast = self.day_of_week_forecasting(second_level_state[:, -1, :])
                week_loss = tf.keras.losses.MeanSquaredError()(week_forecast, self.get_mode(labels_extra[:, :, 0], axis=1))

                self.add_metric(week_loss, name=f'week_forecasting_loss', aggregation='mean')

            def day_of_week_sequence(self, first_level_state, inputs_extra):
                # TODO: ONEHOT
                week_forecast = self.day_of_week_sequence(first_level_state[:, -1, :])
                week_loss = tf.keras.losses.MeanSquaredError()(week_forecast, self.get_mode(inputs_extra[:, :, 0], axis=1))

                self.add_metric(week_loss, name=f'week_sequence_loss', aggregation='mean') 

            def quarter_forecasting(self, second_level_state, labels_extra):
                quarter_forecast = self.quarter_forecasting(second_level_state[:, -1, :])
                quarter_loss = tf.keras.losses.MeanSquaredError()(quarter_forecast, self.get_mode(labels_extra[:, :, 2], axis=1))

                self.add_metric(quarter_loss, name=f'quarter_forecasting_loss', aggregation='mean')

            def quarter_sequence(self, first_level_state, inputs_extra):
                quarter_forecast = self.quarter_sequence(first_level_state[:, -1, :])
                quarter_loss = tf.keras.losses.MeanSquaredError()(quarter_forecast, self.get_mode(inputs_extra[:, :, 2], axis=1))

                self.add_metric(quarter_loss, name=f'quarter_sequence_loss', aggregation='mean') 


            def month_forecasting(self, second_level_state, labels_extra):
                month_forecast = self.month_forecasting(second_level_state[:, -1, :])
                month_loss = tf.keras.losses.MeanSquaredError()(month_forecast, self.get_mode(labels_extra[:, :, 1], axis=1))

                self.add_metric(month_loss, name=f'month_forecasting_loss', aggregation='mean')


            def month_sequence(self, first_level_state, inputs_extra):
                month_forecast = self.month_sequence(first_level_state[:, -1, :])
                month_loss = tf.keras.losses.MeanSquaredError()(month_forecast, self.get_mode(inputs_extra[:, :, 1], axis=1))

                self.add_metric(month_loss, name=f'month_sequence_loss', aggregation='mean')                

            def swap_task(self, inputs):

                x_swap = tf.identity(inputs)

                idx = tf.tile([tf.range(0, window)], (tf.shape(x_swap)[0], 1))

                swap_index = tf.random.uniform((tf.shape(x_swap)[0], 1), 0, len(self.swap_combinations), dtype=tf.int64)

                swap_index_onehot = tf.one_hot(swap_index, len(self.swap_combinations))

                swap = tf.gather(tf.convert_to_tensor(self.swap_combinations), swap_index)

                swap_numpy = swap.numpy()
                idx_numpy = idx.numpy()
                idx_numpy = np.concatenate((idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 0].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 1].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 2].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 3].squeeze()]), axis=1)

                idx = tf.convert_to_tensor(idx_numpy)
                x_swap = tf.gather(x_swap, idx, axis=1, batch_dims=1)

                x_swap_embedding = self.embedding(x_swap)
                swap_state = self.first_level_lstm_layer(x_swap_embedding)
                swap_state = self.swap_lstm_layer(swap_state)
                swap_output = self.swap_output(swap_state)

                swap_loss = tf.keras.losses.CategoricalCrossentropy()(tf.squeeze(swap_index_onehot), swap_output)

                self.add_metric(swap_loss, name='swap_loss', aggregation='mean')

                return swap_output, swap_state

            def train_step(self, data):
                inputs, labels = data
                gradients = []
                
                with tf.GradientTape(persistent=True) as tape:
                    prediction, losses = self(inputs, True)
                    loss = self.compiled_loss(labels, prediction, regularization_losses=self.losses)

                self.compiled_metrics.update_state(labels, prediction)
                
                model_layers = np.array(self.layers)

                # First level loss
                first_level_trainable_weights = [w for layer in model_layers[model.global_first_level_layers] for w in layer.trainable_weights]    
                main_grad_first_level = tape.gradient(loss, first_level_trainable_weights)
                self.main_grad_first_level_average = update_smooth_grad(main_grad_first_level, self.main_grad_first_level_average, grad_smooth_alpha)

                first_level_gradients = []
                for l in losses[:2]:

                    aux_grad = tape.gradient(l, first_level_trainable_weights)
                    aux_grad = combined_grads(main_grad_first_level, self.main_grad_first_level_average, aux_grad, mode, overall, lamb1)

                    first_level_gradients.append(aux_grad)

                    gradients.extend(list(zip(aux_grad, first_level_trainable_weights)))

                first_level_aux_grad = combined_grads(first_level_gradients[0], None, first_level_gradients[1], 'Multitask', overall, 1)
                self.first_level_aux_grad_average = update_smooth_grad(first_level_aux_grad, self.first_level_aux_grad_average, grad_smooth_alpha)

                #Second level loss
                first_level_trainable_weights = [w for layer in model_layers[self.global_first_level_layers] for w in layer.trainable_weights]    

                for l in losses[2:]:
                    aux_grad = tape.gradient(l, first_level_trainable_weights)
                    aux_grad = combined_grads(first_level_aux_grad, self.first_level_aux_grad_average, aux_grad, mode, overall, lamb2)
                    gradients.extend(list(zip(aux_grad, first_level_trainable_weights)))

                second_level_trainable_weights = [w for layer in model_layers[self.global_second_level_layers] for w in layer.trainable_weights]    
                main_grad_second_level = tape.gradient(loss, second_level_trainable_weights)
                self.main_grad_second_level_average = update_smooth_grad(main_grad_second_level, self.main_grad_second_level_average, grad_smooth_alpha)

                for l in losses[2:]:
                    aux_grad = tape.gradient(l, second_level_trainable_weights)
                    aux_grad = combined_grads(main_grad_second_level, self.main_grad_second_level_average, aux_grad, mode, overall, lamb2)

                    gradients.extend(list(zip(aux_grad, second_level_trainable_weights)))

                # All losses
                rest_layers = sum([self.global_second_level_layers, self.global_first_level_layers], [])

                mask = np.ones(len(model_layers), bool)
                mask[rest_layers] = False
                all_level_trainable_weights = [w for layer in model_layers[mask] for w in layer.trainable_weights]    
                main_grad_rest_layers = tape.gradient(loss, all_level_trainable_weights)

                for l in losses:
                    grad = tape.gradient(l, all_level_trainable_weights)
                    gradients.extend(list(zip(grad, all_level_trainable_weights)))
                gradients.extend(list(zip(main_grad_rest_layers, all_level_trainable_weights)))
                #Apply all gradients
                self.optimizer.apply_gradients(gradients)

                return {m.name: m.result() for m in self.metrics}
            
            def call(self, inputs, training):
                labels, labels_extra = inputs[1]
                inputs, inputs_extra = inputs[0]

                x = self.global_embedding_1(inputs)
                x = self.global_embedding_2(x)
                #first_level_state = self.first_level_lstm_layer(x)

                # First level
                gap_filling_outputs, gap_loss = self.gap_filling_task(inputs)
                noise_reduction_outputs, noise_loss = self.noise_reduction_task(inputs)
                #swap_output, swap_state = self.swap_task(inputs)

                #Second level
                second_level_state = self.second_level_lstm_layer(x)
                """smooth_forecasting_8, smooth_state_8, smooth_loss_8 = self.smooth_forecast_task(second_level_state, labels, 8)
                smooth_forecasting_6, smooth_state_6, smooth_loss_6 = self.smooth_forecast_task(second_level_state, labels, 6)
                smooth_forecasting_3, smooth_state_3, smooth_loss_3 = self.smooth_forecast_task(second_level_state, labels, 3)
                """
                first_step_forecast, first_step_state, first_step_loss = self.one_step_forecast_task(second_level_state, labels, 0)
                mid_step_forecast, mid_step_state, mid_step_loss = self.one_step_forecast_task(second_level_state, labels, self.horizon//2)
                last_step_forecast, last_step_state, last_step_loss = self.one_step_forecast_task(second_level_state, labels, self.horizon-1)
                
                losses = [gap_loss, noise_loss, first_step_loss, mid_step_loss, last_step_loss]
                # smooth_loss_8, smooth_loss_6, smooth_loss_3
                
                """tasks_loss = 0
                for i, loss in enumerate(losses):
                    losses[i] = (0.5/self.loss_weigts[i]**2)*loss + tf.math.log(1+self.loss_weigts[i]**2)"""
                    
                #self.add_loss(tasks_loss)
                # Last level
                prediction_level_state = self.prediction_level_lstm_layer(second_level_state)
                prediction_state = tf.keras.layers.concatenate([prediction_level_state, last_step_state, mid_step_state, 
                                         first_step_state])
                
                prediction_output = self.prediction_output(prediction_state)
                
                if training:
                    return prediction_output, losses
                else:
                    return prediction_output


        model = MultitaskHLNet(d_model, n_features, window, horizon, batch_size, grad_smooth_alpha, lamb1, lamb2, mode, overall)

        loss = 'mse'
    elif model_type == 'multitaskhlnet_taskslstm_firstlevelcnn_gradsprojected_embeddingappended':
    #pred_index = -1

        class MultitaskHLNet(tf.keras.Model):

            def __init__(self, d_model, n_features, window, horizon, batch_size, grad_smooth_alpha, lamb1, lamb2, mode, overall):
                super(MultitaskHLNet, self).__init__()
                self.d_model = d_model
                self.n_features = n_features
                self.window = window
                self.horizon = horizon
                self.batch_size = batch_size
                self.main_grad_first_level_average = None
                self.main_grad_second_level_average = None
                self.first_level_aux_grad_average = None
                self.grad_smooth_alpha = grad_smooth_alpha
                self.lamb1 = lamb1
                self.lamb2 = lamb2
                self.mode = mode
                self.overall = overall
                
                self.loss_weigts = tf.Variable(tf.ones([8]))
                
                self.global_embedding_1 = tf.keras.layers.Conv1D(d_model//2, 3, activation='relu', input_shape=(window, 3), padding='causal')
                self.global_embedding_2 = tf.keras.layers.Conv1D(d_model, 3, activation='relu', padding='causal')
                
                self.global_first_level_layers = [0,1]
                """
                    First Level tasks - Sequence featurization
                """
                self.decode_cnn_1 = tf.keras.layers.Conv1DTranspose(d_model, 3, activation='relu', padding='same')
                self.decode_cnn_2 = tf.keras.layers.Conv1DTranspose(d_model//2, 3, activation='relu', padding='same')
                self.gap_filling_output = tf.keras.layers.Dense(n_features)

                self.gaussian_noise = tf.keras.layers.GaussianNoise(0.01)
                self.noise_reduction_output = tf.keras.layers.Dense(n_features)

                #self.swap_combinations = list(itertools.permutations(np.arange(0, 4), 4))
                #self.swap_lstm_layer = tf.keras.layers.LSTM(d_model, name='swap_lstm', activation='relu')
                #self.swap_output = tf.keras.layers.Dense(len(self.swap_combinations), name='swap_task', activation='softmax')

                
                """
                    Second Level tasks - Forecasting helpers
                """
                self.second_level_lstm_layer = tf.keras.layers.LSTM(d_model, return_sequences=True)

                self.smooth_lstm_layers = {'8': tf.keras.layers.LSTM(d_model, name='smooth_forecasting_8_lstm', activation='relu'),
                                '6':  tf.keras.layers.LSTM(d_model, name='smooth_forecasting_6_lstm', activation='relu'),
                                '3':  tf.keras.layers.LSTM(d_model, name='smooth_forecasting_3_lstm', activation='relu')}

                self.smooth_layers = {'8': tf.keras.layers.Dense(horizon-8+1, name='smooth_forecasting_8'),
                                '6': tf.keras.layers.Dense(horizon-6+1, name='smooth_forecasting_6'),
                                '3': tf.keras.layers.Dense(horizon-3+1, name='smooth_forecasting_3')}


                self.one_step_forecast_lstm_layers = {'0': tf.keras.layers.LSTM(d_model, name='next_step_forecasting_lstm', activation='relu'),
                                str(horizon//2):  tf.keras.layers.LSTM(d_model, name='mid_step_forecasting_lstm', activation='relu'),
                                str(horizon-1):  tf.keras.layers.LSTM(d_model, name='last_step_forecasting_lstm', activation='relu')}

                self.one_step_forecast = {'0': tf.keras.layers.Dense(1, name='next_step_forecasting'),
                                    str(horizon//2): tf.keras.layers.Dense(1, name='mid_step_forecasting'),
                                    str(horizon-1): tf.keras.layers.Dense(1, name='last_step_forecasting')}
                self.global_second_level_layers = [7]
                """
                    Last Level tasks
                """
                self.prediction_level_lstm_layer = tf.keras.layers.LSTM(d_model, return_sequences=False)
                self.prediction_output = tf.keras.layers.Dense(horizon, name='prediction')

            def get_mode(self, x, axis=1):
                dt = x.dtype
                # Shift input in case it has negative values
                m = tf.math.reduce_min(x)
                x2 = x - m
                # minlength should not be necessary but may fail without it
                # (reported here https://github.com/tensorflow/probability/issues/962)
                c = tfp.stats.count_integers(x2, axis=axis, dtype=dt,
                                             minlength=tf.math.reduce_max(x2) + 1)
                # Find the values with largest counts
                idx = tf.math.argmax(c, axis=0, output_type=dt)
                # Get the modes by shifting by the subtracted minimum
                modes = idx + m
                # Get the number of counts
                counts = tf.math.reduce_max(c, axis=0)

                return modes

            def gap_filling_task(self, inputs):

                #Gap filling task
                batch_indexes = tf.tile(tf.range(tf.shape(inputs)[0])[:, tf.newaxis, tf.newaxis], (1, self.window, 1))
                head_indexes = tf.tile(tf.range(self.window)[tf.newaxis, :, tf.newaxis], (tf.shape(inputs)[0], 1, 1))
                feat_index = tf.random.uniform((tf.shape(inputs)[0],window,1), minval=0, maxval=3, dtype=tf.int32)

                idx = tf.squeeze(tf.stack(values=[batch_indexes, head_indexes, feat_index], axis=-1))
                idx = tf.transpose(idx, perm=(1,2,0))    

                gap_index =  tf.reshape(tf.transpose(tf.random.shuffle(idx), perm= (2, 0, 1))[:, :1, :], (-1, 3))

                x_gap = tf.identity(inputs)
                x_gap_updated = tf.tensor_scatter_nd_update(x_gap, indices = gap_index, updates = -tf.ones(gap_index.shape[0])*100)
                x_gap_embedding = self.global_embedding_1(x_gap_updated)
                x_gap_embedding = self.global_embedding_2(x_gap_embedding)
                x_gap_embedding = self.decode_cnn_1(x_gap_embedding)
                x_gap_embedding = self.decode_cnn_2(x_gap_embedding)
                #gap_state = self.first_level_lstm_layer(x_gap_embedding)
                gap_filling_output = self.gap_filling_output(x_gap_embedding)
                #gap_filling_output = self.gap_filling_reshape(gap_filling_output_flat)
                #gap_filling_true = tf.reshape(tf.gather_nd(x_gap, gap_index), (tf.shape(x_gap)[0], 1))
                gap_loss = tf.keras.losses.MeanSquaredError()(gap_filling_output, x_gap)

                self.add_metric(gap_loss, name='gap_loss', aggregation='mean')
                

                return gap_filling_output, gap_loss, x_gap_embedding

            def noise_reduction_task(self, inputs):

                input_copy = tf.identity(inputs)
                x_noise = self.gaussian_noise(input_copy)

                x_noise_embedding = self.global_embedding_1(x_noise)
                x_noise_embedding = self.global_embedding_2(x_noise_embedding)
                x_noise_embedding = self.decode_cnn_1(x_noise_embedding)
                x_noise_embedding = self.decode_cnn_2(x_noise_embedding)
                #noise_state = self.first_level_lstm_layer(x_noise_embedding)
                #noise_state = self.noise_reduction_lstm_layer(noise_state)
                noise_output = self.noise_reduction_output(x_noise_embedding)
                #noise_output = self.noise_reduction_reshape(noise_output_flat)

                noise_loss = tf.keras.losses.MeanSquaredError()(noise_output, input_copy)

                self.add_metric(noise_loss, name='noise_loss', aggregation='mean')
                

                return noise_output, noise_loss, x_noise_embedding

            def smooth_forecast_task(self, second_level_state, labels, smooth=6):

                smooth_state = self.smooth_lstm_layers[str(smooth)](second_level_state)
                smooth_forecasting = self.smooth_layers[str(smooth)](smooth_state)
                labels_smooth = tf.math.reduce_mean(tf.signal.frame(labels, smooth, 1, axis=1), axis=2)

                smooth_loss = tf.keras.losses.MeanSquaredError()(smooth_forecasting, labels_smooth)

                self.add_metric(smooth_loss, name=f'smooth_{smooth}_loss', aggregation='mean')
                

                return smooth_forecasting, smooth_state, smooth_loss

            def one_step_forecast_task(self,second_level_state, labels, step=0):
                step_state = self.one_step_forecast_lstm_layers[str(step)](second_level_state)
                step_forecast = self.one_step_forecast[str(step)](step_state)
                step_loss = tf.keras.losses.MeanSquaredError()(step_forecast, labels[:, step, :])

                self.add_metric(step_loss, name=f'step_{step}_loss', aggregation='mean')
                
                
                return step_forecast, step_state, step_loss

            def day_of_week_forecasting(self, second_level_state, labels_extra):
                week_forecast = self.day_of_week_forecasting(second_level_state[:, -1, :])
                week_loss = tf.keras.losses.MeanSquaredError()(week_forecast, self.get_mode(labels_extra[:, :, 0], axis=1))

                self.add_metric(week_loss, name=f'week_forecasting_loss', aggregation='mean')

            def day_of_week_sequence(self, first_level_state, inputs_extra):
                # TODO: ONEHOT
                week_forecast = self.day_of_week_sequence(first_level_state[:, -1, :])
                week_loss = tf.keras.losses.MeanSquaredError()(week_forecast, self.get_mode(inputs_extra[:, :, 0], axis=1))

                self.add_metric(week_loss, name=f'week_sequence_loss', aggregation='mean') 

            def quarter_forecasting(self, second_level_state, labels_extra):
                quarter_forecast = self.quarter_forecasting(second_level_state[:, -1, :])
                quarter_loss = tf.keras.losses.MeanSquaredError()(quarter_forecast, self.get_mode(labels_extra[:, :, 2], axis=1))

                self.add_metric(quarter_loss, name=f'quarter_forecasting_loss', aggregation='mean')

            def quarter_sequence(self, first_level_state, inputs_extra):
                quarter_forecast = self.quarter_sequence(first_level_state[:, -1, :])
                quarter_loss = tf.keras.losses.MeanSquaredError()(quarter_forecast, self.get_mode(inputs_extra[:, :, 2], axis=1))

                self.add_metric(quarter_loss, name=f'quarter_sequence_loss', aggregation='mean') 


            def month_forecasting(self, second_level_state, labels_extra):
                month_forecast = self.month_forecasting(second_level_state[:, -1, :])
                month_loss = tf.keras.losses.MeanSquaredError()(month_forecast, self.get_mode(labels_extra[:, :, 1], axis=1))

                self.add_metric(month_loss, name=f'month_forecasting_loss', aggregation='mean')


            def month_sequence(self, first_level_state, inputs_extra):
                month_forecast = self.month_sequence(first_level_state[:, -1, :])
                month_loss = tf.keras.losses.MeanSquaredError()(month_forecast, self.get_mode(inputs_extra[:, :, 1], axis=1))

                self.add_metric(month_loss, name=f'month_sequence_loss', aggregation='mean')                

            def swap_task(self, inputs):

                x_swap = tf.identity(inputs)

                idx = tf.tile([tf.range(0, window)], (tf.shape(x_swap)[0], 1))

                swap_index = tf.random.uniform((tf.shape(x_swap)[0], 1), 0, len(self.swap_combinations), dtype=tf.int64)

                swap_index_onehot = tf.one_hot(swap_index, len(self.swap_combinations))

                swap = tf.gather(tf.convert_to_tensor(self.swap_combinations), swap_index)

                swap_numpy = swap.numpy()
                idx_numpy = idx.numpy()
                idx_numpy = np.concatenate((idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 0].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 1].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 2].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 3].squeeze()]), axis=1)

                idx = tf.convert_to_tensor(idx_numpy)
                x_swap = tf.gather(x_swap, idx, axis=1, batch_dims=1)

                x_swap_embedding = self.embedding(x_swap)
                swap_state = self.first_level_lstm_layer(x_swap_embedding)
                swap_state = self.swap_lstm_layer(swap_state)
                swap_output = self.swap_output(swap_state)

                swap_loss = tf.keras.losses.CategoricalCrossentropy()(tf.squeeze(swap_index_onehot), swap_output)

                self.add_metric(swap_loss, name='swap_loss', aggregation='mean')

                return swap_output, swap_state

            def train_step(self, data):
                inputs, labels = data
                gradients = []
                
                with tf.GradientTape(persistent=True) as tape:
                    prediction, losses = self(inputs, True)
                    loss = self.compiled_loss(labels, prediction, regularization_losses=self.losses)

                self.compiled_metrics.update_state(labels, prediction)
                
                model_layers = np.array(self.layers)

                # First level loss
                first_level_trainable_weights = [w for layer in model_layers[model.global_first_level_layers] for w in layer.trainable_weights]    
                main_grad_first_level = tape.gradient(loss, first_level_trainable_weights)
                self.main_grad_first_level_average = update_smooth_grad(main_grad_first_level, self.main_grad_first_level_average, grad_smooth_alpha)

                first_level_gradients = []
                for l in losses[:2]:

                    aux_grad = tape.gradient(l, first_level_trainable_weights)
                    aux_grad = combined_grads(main_grad_first_level, self.main_grad_first_level_average, aux_grad, mode, overall, lamb1)

                    first_level_gradients.append(aux_grad)

                    gradients.extend(list(zip(aux_grad, first_level_trainable_weights)))

                first_level_aux_grad = combined_grads(first_level_gradients[0], None, first_level_gradients[1], 'Multitask', overall, 1)
                self.first_level_aux_grad_average = update_smooth_grad(first_level_aux_grad, self.first_level_aux_grad_average, grad_smooth_alpha)

                #Second level loss
                first_level_trainable_weights = [w for layer in model_layers[self.global_first_level_layers] for w in layer.trainable_weights]    

                for l in losses[2:]:
                    aux_grad = tape.gradient(l, first_level_trainable_weights)
                    aux_grad = combined_grads(first_level_aux_grad, self.first_level_aux_grad_average, aux_grad, mode, overall, lamb2)
                    gradients.extend(list(zip(aux_grad, first_level_trainable_weights)))

                second_level_trainable_weights = [w for layer in model_layers[self.global_second_level_layers] for w in layer.trainable_weights]    
                main_grad_second_level = tape.gradient(loss, second_level_trainable_weights)
                self.main_grad_second_level_average = update_smooth_grad(main_grad_second_level, self.main_grad_second_level_average, grad_smooth_alpha)

                for l in losses[2:]:
                    aux_grad = tape.gradient(l, second_level_trainable_weights)
                    aux_grad = combined_grads(main_grad_second_level, self.main_grad_second_level_average, aux_grad, mode, overall, lamb2)

                    gradients.extend(list(zip(aux_grad, second_level_trainable_weights)))

                # All losses
                rest_layers = sum([self.global_second_level_layers, self.global_first_level_layers], [])

                mask = np.ones(len(model_layers), bool)
                mask[rest_layers] = False
                all_level_trainable_weights = [w for layer in model_layers[mask] for w in layer.trainable_weights]    
                main_grad_rest_layers = tape.gradient(loss, all_level_trainable_weights)

                for l in losses:
                    grad = tape.gradient(l, all_level_trainable_weights)
                    gradients.extend(list(zip(grad, all_level_trainable_weights)))
                    
                gradients.extend(list(zip(main_grad_rest_layers, all_level_trainable_weights)))
                #Apply all gradients
                self.optimizer.apply_gradients(gradients)

                return {m.name: m.result() for m in self.metrics}
            
            def call(self, inputs, training):
                labels, labels_extra = inputs[1]
                inputs, inputs_extra = inputs[0]

                x = self.global_embedding_1(inputs)
                x = self.global_embedding_2(x)
                #first_level_state = self.first_level_lstm_layer(x)

                # First level
                gap_filling_outputs, gap_loss, x_gap_embedding = self.gap_filling_task(inputs)
                noise_reduction_outputs, noise_loss, x_noise_embedding = self.noise_reduction_task(inputs)
                #swap_output, swap_state = self.swap_task(inputs)
                
                x = tf.keras.layers.concatenate([x, x_gap_embedding, x_noise_embedding])
                
                #Second level
                second_level_state = self.second_level_lstm_layer(x)
                smooth_forecasting_8, smooth_state_8, smooth_loss_8 = self.smooth_forecast_task(second_level_state, labels, 8)
                smooth_forecasting_6, smooth_state_6, smooth_loss_6 = self.smooth_forecast_task(second_level_state, labels, 6)
                smooth_forecasting_3, smooth_state_3, smooth_loss_3 = self.smooth_forecast_task(second_level_state, labels, 3)
                
                first_step_forecast, first_step_state, first_step_loss = self.one_step_forecast_task(second_level_state, labels, 0)
                mid_step_forecast, mid_step_state, mid_step_loss = self.one_step_forecast_task(second_level_state, labels, self.horizon//2)
                last_step_forecast, last_step_state, last_step_loss = self.one_step_forecast_task(second_level_state, labels, self.horizon-1)
                
                losses = [gap_loss, noise_loss, first_step_loss, mid_step_loss, last_step_loss, smooth_loss_8, smooth_loss_6, smooth_loss_3]
                # 
                
                """tasks_loss = 0
                for i, loss in enumerate(losses):
                    losses[i] = (0.5/self.loss_weigts[i]**2)*loss + tf.math.log(1+self.loss_weigts[i]**2)"""
                    
                #self.add_loss(tasks_loss)
                # Last level
                prediction_level_state = self.prediction_level_lstm_layer(second_level_state)
                prediction_state = tf.keras.layers.concatenate([prediction_level_state, last_step_state, mid_step_state, 
                                         first_step_state, smooth_state_3, smooth_state_6, smooth_state_8])
                
                prediction_output = self.prediction_output(prediction_state)
                
                if training:
                    return prediction_output, losses
                else:
                    return prediction_output


        model = MultitaskHLNet(d_model, n_features, window, horizon, batch_size, grad_smooth_alpha, lamb1, lamb2, mode, overall)

        loss = 'mse'
    elif model_type == 'multitaskhlnet_taskslstm_firstlevelcnn_gradsprojected_embeddingappended_isolated':
    #pred_index = -1

        class MultitaskHLNet(tf.keras.Model):

            def __init__(self, d_model, n_features, window, horizon, batch_size, grad_smooth_alpha, lamb1, lamb2, mode, overall):
                super(MultitaskHLNet, self).__init__()
                self.d_model = d_model
                self.n_features = n_features
                self.window = window
                self.horizon = horizon
                self.batch_size = batch_size
                self.main_grad_first_level_average = None
                self.main_grad_second_level_average = None
                self.first_level_aux_grad_average = None
                self.grad_smooth_alpha = grad_smooth_alpha
                self.lamb1 = lamb1
                self.lamb2 = lamb2
                self.mode = mode
                self.overall = overall
                
                self.loss_weigts = tf.Variable(tf.ones([8]))
                
                self.global_embedding_1 = tf.keras.layers.Conv1D(d_model//2, 3, activation='relu', input_shape=(window, 3), padding='causal')
                self.global_embedding_2 = tf.keras.layers.Conv1D(d_model, 3, activation='relu', padding='causal')
                
                self.global_first_level_layers = [0,1]
                """
                    First Level tasks - Sequence featurization
                """
                self.decode_cnn_1 = tf.keras.layers.Conv1DTranspose(d_model, 3, activation='relu', padding='same')
                self.decode_cnn_2 = tf.keras.layers.Conv1DTranspose(d_model//2, 3, activation='relu', padding='same')
                self.gap_filling_output = tf.keras.layers.Dense(n_features)

                self.gaussian_noise = tf.keras.layers.GaussianNoise(0.01)
                self.noise_reduction_output = tf.keras.layers.Dense(n_features)

                #self.swap_combinations = list(itertools.permutations(np.arange(0, 4), 4))
                #self.swap_lstm_layer = tf.keras.layers.LSTM(d_model, name='swap_lstm', activation='relu')
                #self.swap_output = tf.keras.layers.Dense(len(self.swap_combinations), name='swap_task', activation='softmax')

                
                """
                    Second Level tasks - Forecasting helpers
                """
                self.second_level_lstm_layer = tf.keras.layers.LSTM(d_model, return_sequences=True)

                self.smooth_lstm_layers = {'8': tf.keras.layers.LSTM(d_model, name='smooth_forecasting_8_lstm', activation='relu'),
                                '6':  tf.keras.layers.LSTM(d_model, name='smooth_forecasting_6_lstm', activation='relu'),
                                '3':  tf.keras.layers.LSTM(d_model, name='smooth_forecasting_3_lstm', activation='relu')}

                self.smooth_layers = {'8': tf.keras.layers.Dense(horizon-8+1, name='smooth_forecasting_8'),
                                '6': tf.keras.layers.Dense(horizon-6+1, name='smooth_forecasting_6'),
                                '3': tf.keras.layers.Dense(horizon-3+1, name='smooth_forecasting_3')}


                self.one_step_forecast_lstm_layers = {'0': tf.keras.layers.LSTM(d_model, name='next_step_forecasting_lstm', activation='relu'),
                                str(horizon//2):  tf.keras.layers.LSTM(d_model, name='mid_step_forecasting_lstm', activation='relu'),
                                str(horizon-1):  tf.keras.layers.LSTM(d_model, name='last_step_forecasting_lstm', activation='relu')}

                self.one_step_forecast = {'0': tf.keras.layers.Dense(1, name='next_step_forecasting'),
                                    str(horizon//2): tf.keras.layers.Dense(1, name='mid_step_forecasting'),
                                    str(horizon-1): tf.keras.layers.Dense(1, name='last_step_forecasting')}
                self.global_second_level_layers = [7]
                """
                    Last Level tasks
                """
                self.prediction_level_lstm_layer = tf.keras.layers.LSTM(d_model, return_sequences=False)
                self.prediction_output = tf.keras.layers.Dense(horizon, name='prediction')

            def get_mode(self, x, axis=1):
                dt = x.dtype
                # Shift input in case it has negative values
                m = tf.math.reduce_min(x)
                x2 = x - m
                # minlength should not be necessary but may fail without it
                # (reported here https://github.com/tensorflow/probability/issues/962)
                c = tfp.stats.count_integers(x2, axis=axis, dtype=dt,
                                             minlength=tf.math.reduce_max(x2) + 1)
                # Find the values with largest counts
                idx = tf.math.argmax(c, axis=0, output_type=dt)
                # Get the modes by shifting by the subtracted minimum
                modes = idx + m
                # Get the number of counts
                counts = tf.math.reduce_max(c, axis=0)

                return modes

            def gap_filling_task(self, inputs):

                #Gap filling task
                batch_indexes = tf.tile(tf.range(tf.shape(inputs)[0])[:, tf.newaxis, tf.newaxis], (1, self.window, 1))
                head_indexes = tf.tile(tf.range(self.window)[tf.newaxis, :, tf.newaxis], (tf.shape(inputs)[0], 1, 1))
                feat_index = tf.random.uniform((tf.shape(inputs)[0],window,1), minval=0, maxval=3, dtype=tf.int32)

                idx = tf.squeeze(tf.stack(values=[batch_indexes, head_indexes, feat_index], axis=-1))
                idx = tf.transpose(idx, perm=(1,2,0))    

                gap_index =  tf.reshape(tf.transpose(tf.random.shuffle(idx), perm= (2, 0, 1))[:, :1, :], (-1, 3))

                x_gap = tf.identity(inputs)
                x_gap_updated = tf.tensor_scatter_nd_update(x_gap, indices = gap_index, updates = -tf.ones(gap_index.shape[0])*100)
                x_gap_embedding = self.global_embedding_1(x_gap_updated)
                x_gap_embedding = self.global_embedding_2(x_gap_embedding)
                x_gap_embedding = self.decode_cnn_1(x_gap_embedding)
                x_gap_embedding = self.decode_cnn_2(x_gap_embedding)
                #gap_state = self.first_level_lstm_layer(x_gap_embedding)
                gap_filling_output = self.gap_filling_output(x_gap_embedding)
                #gap_filling_output = self.gap_filling_reshape(gap_filling_output_flat)
                #gap_filling_true = tf.reshape(tf.gather_nd(x_gap, gap_index), (tf.shape(x_gap)[0], 1))
                gap_loss = tf.keras.losses.MeanSquaredError()(gap_filling_output, x_gap)

                self.add_metric(gap_loss, name='gap_loss', aggregation='mean')
                

                return gap_filling_output, gap_loss, x_gap_embedding

            def noise_reduction_task(self, inputs):

                input_copy = tf.identity(inputs)
                x_noise = self.gaussian_noise(input_copy)

                x_noise_embedding = self.global_embedding_1(x_noise)
                x_noise_embedding = self.global_embedding_2(x_noise_embedding)
                x_noise_embedding = self.decode_cnn_1(x_noise_embedding)
                x_noise_embedding = self.decode_cnn_2(x_noise_embedding)
                #noise_state = self.first_level_lstm_layer(x_noise_embedding)
                #noise_state = self.noise_reduction_lstm_layer(noise_state)
                noise_output = self.noise_reduction_output(x_noise_embedding)
                #noise_output = self.noise_reduction_reshape(noise_output_flat)

                noise_loss = tf.keras.losses.MeanSquaredError()(noise_output, input_copy)

                self.add_metric(noise_loss, name='noise_loss', aggregation='mean')
                

                return noise_output, noise_loss, x_noise_embedding

            def smooth_forecast_task(self, second_level_state, labels, smooth=6):

                smooth_state = self.smooth_lstm_layers[str(smooth)](second_level_state)
                smooth_forecasting = self.smooth_layers[str(smooth)](smooth_state)
                labels_smooth = tf.math.reduce_mean(tf.signal.frame(labels, smooth, 1, axis=1), axis=2)

                smooth_loss = tf.keras.losses.MeanSquaredError()(smooth_forecasting, labels_smooth)

                self.add_metric(smooth_loss, name=f'smooth_{smooth}_loss', aggregation='mean')
                

                return smooth_forecasting, smooth_state, smooth_loss

            def one_step_forecast_task(self,second_level_state, labels, step=0):
                step_state = self.one_step_forecast_lstm_layers[str(step)](second_level_state)
                step_forecast = self.one_step_forecast[str(step)](step_state)
                step_loss = tf.keras.losses.MeanSquaredError()(step_forecast, labels[:, step, :])

                self.add_metric(step_loss, name=f'step_{step}_loss', aggregation='mean')
                
                
                return step_forecast, step_state, step_loss

            def day_of_week_forecasting(self, second_level_state, labels_extra):
                week_forecast = self.day_of_week_forecasting(second_level_state[:, -1, :])
                week_loss = tf.keras.losses.MeanSquaredError()(week_forecast, self.get_mode(labels_extra[:, :, 0], axis=1))

                self.add_metric(week_loss, name=f'week_forecasting_loss', aggregation='mean')

            def day_of_week_sequence(self, first_level_state, inputs_extra):
                # TODO: ONEHOT
                week_forecast = self.day_of_week_sequence(first_level_state[:, -1, :])
                week_loss = tf.keras.losses.MeanSquaredError()(week_forecast, self.get_mode(inputs_extra[:, :, 0], axis=1))

                self.add_metric(week_loss, name=f'week_sequence_loss', aggregation='mean') 

            def quarter_forecasting(self, second_level_state, labels_extra):
                quarter_forecast = self.quarter_forecasting(second_level_state[:, -1, :])
                quarter_loss = tf.keras.losses.MeanSquaredError()(quarter_forecast, self.get_mode(labels_extra[:, :, 2], axis=1))

                self.add_metric(quarter_loss, name=f'quarter_forecasting_loss', aggregation='mean')

            def quarter_sequence(self, first_level_state, inputs_extra):
                quarter_forecast = self.quarter_sequence(first_level_state[:, -1, :])
                quarter_loss = tf.keras.losses.MeanSquaredError()(quarter_forecast, self.get_mode(inputs_extra[:, :, 2], axis=1))

                self.add_metric(quarter_loss, name=f'quarter_sequence_loss', aggregation='mean') 


            def month_forecasting(self, second_level_state, labels_extra):
                month_forecast = self.month_forecasting(second_level_state[:, -1, :])
                month_loss = tf.keras.losses.MeanSquaredError()(month_forecast, self.get_mode(labels_extra[:, :, 1], axis=1))

                self.add_metric(month_loss, name=f'month_forecasting_loss', aggregation='mean')


            def month_sequence(self, first_level_state, inputs_extra):
                month_forecast = self.month_sequence(first_level_state[:, -1, :])
                month_loss = tf.keras.losses.MeanSquaredError()(month_forecast, self.get_mode(inputs_extra[:, :, 1], axis=1))

                self.add_metric(month_loss, name=f'month_sequence_loss', aggregation='mean')                

            def swap_task(self, inputs):

                x_swap = tf.identity(inputs)

                idx = tf.tile([tf.range(0, window)], (tf.shape(x_swap)[0], 1))

                swap_index = tf.random.uniform((tf.shape(x_swap)[0], 1), 0, len(self.swap_combinations), dtype=tf.int64)

                swap_index_onehot = tf.one_hot(swap_index, len(self.swap_combinations))

                swap = tf.gather(tf.convert_to_tensor(self.swap_combinations), swap_index)

                swap_numpy = swap.numpy()
                idx_numpy = idx.numpy()
                idx_numpy = np.concatenate((idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 0].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 1].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 2].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 3].squeeze()]), axis=1)

                idx = tf.convert_to_tensor(idx_numpy)
                x_swap = tf.gather(x_swap, idx, axis=1, batch_dims=1)

                x_swap_embedding = self.embedding(x_swap)
                swap_state = self.first_level_lstm_layer(x_swap_embedding)
                swap_state = self.swap_lstm_layer(swap_state)
                swap_output = self.swap_output(swap_state)

                swap_loss = tf.keras.losses.CategoricalCrossentropy()(tf.squeeze(swap_index_onehot), swap_output)

                self.add_metric(swap_loss, name='swap_loss', aggregation='mean')

                return swap_output, swap_state

            def train_step(self, data):
                inputs, labels = data
                gradients = []
                
                with tf.GradientTape(persistent=True) as tape:
                    prediction, losses = self(inputs, True)
                    loss = self.compiled_loss(labels, prediction, regularization_losses=self.losses)

                self.compiled_metrics.update_state(labels, prediction)
                
                model_layers = np.array(self.layers)

                # First level loss
                first_level_trainable_weights = [w for layer in model_layers[model.global_first_level_layers] for w in layer.trainable_weights]    
                main_grad_first_level = tape.gradient(loss, first_level_trainable_weights)
                self.main_grad_first_level_average = update_smooth_grad(main_grad_first_level, self.main_grad_first_level_average, grad_smooth_alpha)

                first_level_gradients = []
                for l in losses[:2]:

                    aux_grad = tape.gradient(l, first_level_trainable_weights)
                    aux_grad = combined_grads(main_grad_first_level, self.main_grad_first_level_average, aux_grad, mode, overall, lamb1)

                    first_level_gradients.append(aux_grad)

                    gradients.extend(list(zip(aux_grad, first_level_trainable_weights)))

                first_level_aux_grad = combined_grads(first_level_gradients[0], None, first_level_gradients[1], 'Multitask', overall, 1)
                self.first_level_aux_grad_average = update_smooth_grad(first_level_aux_grad, self.first_level_aux_grad_average, grad_smooth_alpha)

                #Second level loss
                first_level_trainable_weights = [w for layer in model_layers[self.global_first_level_layers] for w in layer.trainable_weights]    

                for l in losses[2:]:
                    aux_grad = tape.gradient(l, first_level_trainable_weights)
                    aux_grad = combined_grads(first_level_aux_grad, self.first_level_aux_grad_average, aux_grad, mode, overall, lamb2)
                    gradients.extend(list(zip(aux_grad, first_level_trainable_weights)))

                second_level_trainable_weights = [w for layer in model_layers[self.global_second_level_layers] for w in layer.trainable_weights]    
                main_grad_second_level = tape.gradient(loss, second_level_trainable_weights)
                self.main_grad_second_level_average = update_smooth_grad(main_grad_second_level, self.main_grad_second_level_average, grad_smooth_alpha)

                for l in losses[2:]:
                    aux_grad = tape.gradient(l, second_level_trainable_weights)
                    aux_grad = combined_grads(main_grad_second_level, self.main_grad_second_level_average, aux_grad, mode, overall, lamb2)

                    gradients.extend(list(zip(aux_grad, second_level_trainable_weights)))

                # All losses
                rest_layers = sum([self.global_second_level_layers, self.global_first_level_layers], [])

                mask = np.ones(len(model_layers), bool)
                mask[rest_layers] = False
                all_level_trainable_weights = [w for layer in model_layers[mask] for w in layer.trainable_weights]    
                main_grad_rest_layers = tape.gradient(loss, all_level_trainable_weights)

                for l in losses:
                    grad = tape.gradient(l, all_level_trainable_weights)
                    gradients.extend(list(zip(grad, all_level_trainable_weights)))
                    
                gradients.extend(list(zip(main_grad_rest_layers, all_level_trainable_weights)))
                #Apply all gradients
                self.optimizer.apply_gradients(gradients)

                return {m.name: m.result() for m in self.metrics}
            
            def call(self, inputs, training):
                labels, labels_extra = inputs[1]
                inputs, inputs_extra = inputs[0]

                x = self.global_embedding_1(inputs)
                x = self.global_embedding_2(x)
                #first_level_state = self.first_level_lstm_layer(x)

                # First level
                gap_filling_outputs, gap_loss, x_gap_embedding = self.gap_filling_task(inputs)
                noise_reduction_outputs, noise_loss, x_noise_embedding = self.noise_reduction_task(inputs)
                #swap_output, swap_state = self.swap_task(inputs)
                
                x = tf.stop_gradient(tf.identity(tf.keras.layers.concatenate([x, x_gap_embedding, x_noise_embedding])))
                
                #Second level
                second_level_state = self.second_level_lstm_layer(x)
                smooth_forecasting_8, smooth_state_8, smooth_loss_8 = self.smooth_forecast_task(second_level_state, labels, 8)
                smooth_forecasting_6, smooth_state_6, smooth_loss_6 = self.smooth_forecast_task(second_level_state, labels, 6)
                smooth_forecasting_3, smooth_state_3, smooth_loss_3 = self.smooth_forecast_task(second_level_state, labels, 3)
                
                first_step_forecast, first_step_state, first_step_loss = self.one_step_forecast_task(second_level_state, labels, 0)
                mid_step_forecast, mid_step_state, mid_step_loss = self.one_step_forecast_task(second_level_state, labels, self.horizon//2)
                last_step_forecast, last_step_state, last_step_loss = self.one_step_forecast_task(second_level_state, labels, self.horizon-1)
                
                losses = [gap_loss, noise_loss, first_step_loss, mid_step_loss, last_step_loss, smooth_loss_8, smooth_loss_6, smooth_loss_3]
                # 
                
                tasks_loss = 0
                for i, loss in enumerate(losses):
                    losses[i] = (0.5/self.loss_weigts[i]**2)*loss + tf.math.log(1+self.loss_weigts[i]**2)
                    
                #self.add_loss(tasks_loss)
                # Last level
                prediction_level_state = self.prediction_level_lstm_layer(tf.stop_gradient(tf.identity(second_level_state)))
                prediction_state = tf.keras.layers.concatenate([prediction_level_state, 
                                                                tf.stop_gradient(tf.identity(last_step_state)), 
                                                                tf.stop_gradient(tf.identity(mid_step_state)), 
                                                                tf.stop_gradient(tf.identity(first_step_state)), 
                                                                tf.stop_gradient(tf.identity(smooth_state_3))
                                                                , tf.stop_gradient(tf.identity(smooth_state_6)), 
                                                                tf.stop_gradient(tf.identity(smooth_state_8))])
                
                prediction_output = self.prediction_output(prediction_state)
                
                if training:
                    return prediction_output, losses
                else:
                    return prediction_output


        model = MultitaskHLNet(d_model, n_features, window, horizon, batch_size, grad_smooth_alpha, lamb1, lamb2, mode, overall)

        loss = 'mse'
    
    elif model_type == 'multitaskhlnet_taskslstm_firstlevelcnn_gradsprojected_embeddingappended_lossweighted':
    #pred_index = -1

        class MultitaskHLNet(tf.keras.Model):

            def __init__(self, d_model, n_features, window, horizon, batch_size, grad_smooth_alpha, lamb1, lamb2, mode, overall):
                super(MultitaskHLNet, self).__init__()
                self.d_model = d_model
                self.n_features = n_features
                self.window = window
                self.horizon = horizon
                self.batch_size = batch_size
                self.main_grad_first_level_average = None
                self.main_grad_second_level_average = None
                self.first_level_aux_grad_average = None
                self.grad_smooth_alpha = grad_smooth_alpha
                self.lamb1 = lamb1
                self.lamb2 = lamb2
                self.mode = mode
                self.overall = overall
                
                self.loss_weights = tf.Variable(tf.ones([8]))
                
                self.global_embedding_1 = tf.keras.layers.Conv1D(d_model//2, 3, activation='relu', input_shape=(window, 3), padding='causal')
                self.global_embedding_2 = tf.keras.layers.Conv1D(d_model, 3, activation='relu', padding='causal')
                
                self.global_first_level_layers = [1,2]
                """
                    First Level tasks - Sequence featurization
                """
                self.decode_cnn_1 = tf.keras.layers.Conv1DTranspose(d_model, 3, activation='relu', padding='same')
                self.decode_cnn_2 = tf.keras.layers.Conv1DTranspose(d_model//2, 3, activation='relu', padding='same')
                self.gap_filling_output = tf.keras.layers.Dense(n_features)

                self.gaussian_noise = tf.keras.layers.GaussianNoise(0.01)
                self.noise_reduction_output = tf.keras.layers.Dense(n_features)

                #self.swap_combinations = list(itertools.permutations(np.arange(0, 4), 4))
                #self.swap_lstm_layer = tf.keras.layers.LSTM(d_model, name='swap_lstm', activation='relu')
                #self.swap_output = tf.keras.layers.Dense(len(self.swap_combinations), name='swap_task', activation='softmax')

                
                """
                    Second Level tasks - Forecasting helpers
                """
                self.second_level_lstm_layer = tf.keras.layers.LSTM(d_model, return_sequences=True)

                self.smooth_lstm_layers = {'8': tf.keras.layers.LSTM(d_model, name='smooth_forecasting_8_lstm', activation='relu'),
                                '6':  tf.keras.layers.LSTM(d_model, name='smooth_forecasting_6_lstm', activation='relu'),
                                '3':  tf.keras.layers.LSTM(d_model, name='smooth_forecasting_3_lstm', activation='relu')}

                self.smooth_layers = {'8': tf.keras.layers.Dense(horizon-8+1, name='smooth_forecasting_8'),
                                '6': tf.keras.layers.Dense(horizon-6+1, name='smooth_forecasting_6'),
                                '3': tf.keras.layers.Dense(horizon-3+1, name='smooth_forecasting_3')}


                self.one_step_forecast_lstm_layers = {'0': tf.keras.layers.LSTM(d_model, name='next_step_forecasting_lstm', activation='relu'),
                                str(horizon//2):  tf.keras.layers.LSTM(d_model, name='mid_step_forecasting_lstm', activation='relu'),
                                str(horizon-1):  tf.keras.layers.LSTM(d_model, name='last_step_forecasting_lstm', activation='relu')}

                self.one_step_forecast = {'0': tf.keras.layers.Dense(1, name='next_step_forecasting'),
                                    str(horizon//2): tf.keras.layers.Dense(1, name='mid_step_forecasting'),
                                    str(horizon-1): tf.keras.layers.Dense(1, name='last_step_forecasting')}
                self.global_second_level_layers = [8]
                """
                    Last Level tasks
                """
                self.prediction_level_lstm_layer = tf.keras.layers.LSTM(d_model, return_sequences=False)
                self.prediction_output = tf.keras.layers.Dense(horizon, name='prediction')

            def get_mode(self, x, axis=1):
                dt = x.dtype
                # Shift input in case it has negative values
                m = tf.math.reduce_min(x)
                x2 = x - m
                # minlength should not be necessary but may fail without it
                # (reported here https://github.com/tensorflow/probability/issues/962)
                c = tfp.stats.count_integers(x2, axis=axis, dtype=dt,
                                             minlength=tf.math.reduce_max(x2) + 1)
                # Find the values with largest counts
                idx = tf.math.argmax(c, axis=0, output_type=dt)
                # Get the modes by shifting by the subtracted minimum
                modes = idx + m
                # Get the number of counts
                counts = tf.math.reduce_max(c, axis=0)

                return modes

            def gap_filling_task(self, inputs):

                #Gap filling task
                batch_indexes = tf.tile(tf.range(tf.shape(inputs)[0])[:, tf.newaxis, tf.newaxis], (1, self.window, 1))
                head_indexes = tf.tile(tf.range(self.window)[tf.newaxis, :, tf.newaxis], (tf.shape(inputs)[0], 1, 1))
                feat_index = tf.random.uniform((tf.shape(inputs)[0],window,1), minval=0, maxval=3, dtype=tf.int32)

                idx = tf.squeeze(tf.stack(values=[batch_indexes, head_indexes, feat_index], axis=-1))
                idx = tf.transpose(idx, perm=(1,2,0))    

                gap_index =  tf.reshape(tf.transpose(tf.random.shuffle(idx), perm= (2, 0, 1))[:, :1, :], (-1, 3))

                x_gap = tf.identity(inputs)
                x_gap_updated = tf.tensor_scatter_nd_update(x_gap, indices = gap_index, updates = -tf.ones(gap_index.shape[0])*100)
                x_gap_embedding = self.global_embedding_1(x_gap_updated)
                x_gap_embedding = self.global_embedding_2(x_gap_embedding)
                x_gap_embedding = self.decode_cnn_1(x_gap_embedding)
                x_gap_embedding = self.decode_cnn_2(x_gap_embedding)
                #gap_state = self.first_level_lstm_layer(x_gap_embedding)
                gap_filling_output = self.gap_filling_output(x_gap_embedding)
                #gap_filling_output = self.gap_filling_reshape(gap_filling_output_flat)
                #gap_filling_true = tf.reshape(tf.gather_nd(x_gap, gap_index), (tf.shape(x_gap)[0], 1))
                gap_loss = tf.keras.losses.MeanSquaredError()(gap_filling_output, x_gap)

                self.add_metric(gap_loss, name='gap_loss', aggregation='mean')
                

                return gap_filling_output, gap_loss, x_gap_embedding

            def noise_reduction_task(self, inputs):

                input_copy = tf.identity(inputs)
                x_noise = self.gaussian_noise(input_copy)

                x_noise_embedding = self.global_embedding_1(x_noise)
                x_noise_embedding = self.global_embedding_2(x_noise_embedding)
                x_noise_embedding = self.decode_cnn_1(x_noise_embedding)
                x_noise_embedding = self.decode_cnn_2(x_noise_embedding)
                #noise_state = self.first_level_lstm_layer(x_noise_embedding)
                #noise_state = self.noise_reduction_lstm_layer(noise_state)
                noise_output = self.noise_reduction_output(x_noise_embedding)
                #noise_output = self.noise_reduction_reshape(noise_output_flat)

                noise_loss = tf.keras.losses.MeanSquaredError()(noise_output, input_copy)

                self.add_metric(noise_loss, name='noise_loss', aggregation='mean')
                

                return noise_output, noise_loss, x_noise_embedding

            def smooth_forecast_task(self, second_level_state, labels, smooth=6):

                smooth_state = self.smooth_lstm_layers[str(smooth)](second_level_state)
                smooth_forecasting = self.smooth_layers[str(smooth)](smooth_state)
                labels_smooth = tf.math.reduce_mean(tf.signal.frame(labels, smooth, 1, axis=1), axis=2)

                smooth_loss = tf.keras.losses.MeanSquaredError()(smooth_forecasting, labels_smooth)

                self.add_metric(smooth_loss, name=f'smooth_{smooth}_loss', aggregation='mean')
                

                return smooth_forecasting, smooth_state, smooth_loss

            def one_step_forecast_task(self,second_level_state, labels, step=0):
                step_state = self.one_step_forecast_lstm_layers[str(step)](second_level_state)
                step_forecast = self.one_step_forecast[str(step)](step_state)
                step_loss = tf.keras.losses.MeanSquaredError()(step_forecast, labels[:, step, :])

                self.add_metric(step_loss, name=f'step_{step}_loss', aggregation='mean')
                
                
                return step_forecast, step_state, step_loss

            def day_of_week_forecasting(self, second_level_state, labels_extra):
                week_forecast = self.day_of_week_forecasting(second_level_state[:, -1, :])
                week_loss = tf.keras.losses.MeanSquaredError()(week_forecast, self.get_mode(labels_extra[:, :, 0], axis=1))

                self.add_metric(week_loss, name=f'week_forecasting_loss', aggregation='mean')

            def day_of_week_sequence(self, first_level_state, inputs_extra):
                # TODO: ONEHOT
                week_forecast = self.day_of_week_sequence(first_level_state[:, -1, :])
                week_loss = tf.keras.losses.MeanSquaredError()(week_forecast, self.get_mode(inputs_extra[:, :, 0], axis=1))

                self.add_metric(week_loss, name=f'week_sequence_loss', aggregation='mean') 

            def quarter_forecasting(self, second_level_state, labels_extra):
                quarter_forecast = self.quarter_forecasting(second_level_state[:, -1, :])
                quarter_loss = tf.keras.losses.MeanSquaredError()(quarter_forecast, self.get_mode(labels_extra[:, :, 2], axis=1))

                self.add_metric(quarter_loss, name=f'quarter_forecasting_loss', aggregation='mean')

            def quarter_sequence(self, first_level_state, inputs_extra):
                quarter_forecast = self.quarter_sequence(first_level_state[:, -1, :])
                quarter_loss = tf.keras.losses.MeanSquaredError()(quarter_forecast, self.get_mode(inputs_extra[:, :, 2], axis=1))

                self.add_metric(quarter_loss, name=f'quarter_sequence_loss', aggregation='mean') 


            def month_forecasting(self, second_level_state, labels_extra):
                month_forecast = self.month_forecasting(second_level_state[:, -1, :])
                month_loss = tf.keras.losses.MeanSquaredError()(month_forecast, self.get_mode(labels_extra[:, :, 1], axis=1))

                self.add_metric(month_loss, name=f'month_forecasting_loss', aggregation='mean')


            def month_sequence(self, first_level_state, inputs_extra):
                month_forecast = self.month_sequence(first_level_state[:, -1, :])
                month_loss = tf.keras.losses.MeanSquaredError()(month_forecast, self.get_mode(inputs_extra[:, :, 1], axis=1))

                self.add_metric(month_loss, name=f'month_sequence_loss', aggregation='mean')                

            def swap_task(self, inputs):

                x_swap = tf.identity(inputs)

                idx = tf.tile([tf.range(0, window)], (tf.shape(x_swap)[0], 1))

                swap_index = tf.random.uniform((tf.shape(x_swap)[0], 1), 0, len(self.swap_combinations), dtype=tf.int64)

                swap_index_onehot = tf.one_hot(swap_index, len(self.swap_combinations))

                swap = tf.gather(tf.convert_to_tensor(self.swap_combinations), swap_index)

                swap_numpy = swap.numpy()
                idx_numpy = idx.numpy()
                idx_numpy = np.concatenate((idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 0].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 1].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 2].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 3].squeeze()]), axis=1)

                idx = tf.convert_to_tensor(idx_numpy)
                x_swap = tf.gather(x_swap, idx, axis=1, batch_dims=1)

                x_swap_embedding = self.embedding(x_swap)
                swap_state = self.first_level_lstm_layer(x_swap_embedding)
                swap_state = self.swap_lstm_layer(swap_state)
                swap_output = self.swap_output(swap_state)

                swap_loss = tf.keras.losses.CategoricalCrossentropy()(tf.squeeze(swap_index_onehot), swap_output)

                self.add_metric(swap_loss, name='swap_loss', aggregation='mean')

                return swap_output, swap_state

            def train_step(self, data):
                inputs, labels = data
                gradients = []
                
                with tf.GradientTape(persistent=True) as tape:
                    prediction, losses = self(inputs, True)
                    loss = self.compiled_loss(labels, prediction, regularization_losses=self.losses)

                self.compiled_metrics.update_state(labels, prediction)
                
                model_layers = np.array(self.layers)

                # First level loss
                first_level_trainable_weights = [w for layer in model_layers[model.global_first_level_layers] for w in layer.trainable_weights]    
                main_grad_first_level = tape.gradient(loss, first_level_trainable_weights)
                self.main_grad_first_level_average = update_smooth_grad(main_grad_first_level, self.main_grad_first_level_average, grad_smooth_alpha)

                first_level_gradients = []
                for i, l in enumerate(losses[:2]):

                    aux_grad = tape.gradient(l, first_level_trainable_weights)
                    aux_grad = combined_grads(main_grad_first_level, self.main_grad_first_level_average, aux_grad, mode, overall, self.loss_weights[i])

                    first_level_gradients.append(aux_grad)

                    gradients.extend(list(zip(aux_grad, first_level_trainable_weights)))

                first_level_aux_grad = combined_grads(first_level_gradients[0], None, first_level_gradients[1], 'Multitask', overall, 1)
                self.first_level_aux_grad_average = update_smooth_grad(first_level_aux_grad, self.first_level_aux_grad_average, grad_smooth_alpha)

                #Second level loss
                first_level_trainable_weights = [w for layer in model_layers[self.global_first_level_layers] for w in layer.trainable_weights]    

                for l in losses[2:]:
                    aux_grad = tape.gradient(l, first_level_trainable_weights)
                    aux_grad = combined_grads(first_level_aux_grad, self.first_level_aux_grad_average, aux_grad, mode, overall, self.loss_weights[i+2])
                    gradients.extend(list(zip(aux_grad, first_level_trainable_weights)))

                second_level_trainable_weights = [w for layer in model_layers[self.global_second_level_layers] for w in layer.trainable_weights]    
                main_grad_second_level = tape.gradient(loss, second_level_trainable_weights)
                self.main_grad_second_level_average = update_smooth_grad(main_grad_second_level, self.main_grad_second_level_average, grad_smooth_alpha)

                for l in losses[2:]:
                    aux_grad = tape.gradient(l, second_level_trainable_weights)
                    aux_grad = combined_grads(main_grad_second_level, self.main_grad_second_level_average, aux_grad, mode, overall, self.loss_weights[i+2])

                    gradients.extend(list(zip(aux_grad, second_level_trainable_weights)))

                # All losses
                rest_layers = sum([self.global_second_level_layers, self.global_first_level_layers], [])

                mask = np.ones(len(model_layers), bool)
                mask[rest_layers] = False
                all_level_trainable_weights = [w for layer in model_layers[mask] for w in layer.trainable_weights]    
                main_grad_rest_layers = tape.gradient(loss, all_level_trainable_weights)

                for l in losses:
                    grad = tape.gradient(l, all_level_trainable_weights)
                    gradients.extend(list(zip(grad, all_level_trainable_weights)))
                    
                gradients.extend(list(zip(main_grad_rest_layers, all_level_trainable_weights)))
                #Apply all gradients
                self.optimizer.apply_gradients(gradients)

                return {m.name: m.result() for m in self.metrics}
            
            def call(self, inputs, training):
                labels, labels_extra = inputs[1]
                inputs, inputs_extra = inputs[0]

                x = self.global_embedding_1(inputs)
                x = self.global_embedding_2(x)
                #first_level_state = self.first_level_lstm_layer(x)

                # First level
                gap_filling_outputs, gap_loss, x_gap_embedding = self.gap_filling_task(inputs)
                noise_reduction_outputs, noise_loss, x_noise_embedding = self.noise_reduction_task(inputs)
                #swap_output, swap_state = self.swap_task(inputs)
                
                x = tf.keras.layers.concatenate([x, x_gap_embedding, x_noise_embedding])
                
                #Second level
                second_level_state = self.second_level_lstm_layer(x)
                smooth_forecasting_8, smooth_state_8, smooth_loss_8 = self.smooth_forecast_task(second_level_state, labels, 8)
                smooth_forecasting_6, smooth_state_6, smooth_loss_6 = self.smooth_forecast_task(second_level_state, labels, 6)
                smooth_forecasting_3, smooth_state_3, smooth_loss_3 = self.smooth_forecast_task(second_level_state, labels, 3)
                
                first_step_forecast, first_step_state, first_step_loss = self.one_step_forecast_task(second_level_state, labels, 0)
                mid_step_forecast, mid_step_state, mid_step_loss = self.one_step_forecast_task(second_level_state, labels, self.horizon//2)
                last_step_forecast, last_step_state, last_step_loss = self.one_step_forecast_task(second_level_state, labels, self.horizon-1)
                
                losses = [gap_loss, noise_loss, first_step_loss, mid_step_loss, last_step_loss, smooth_loss_8, smooth_loss_6, smooth_loss_3]
                # 
                
                """tasks_loss = 0
                for i, loss in enumerate(losses):
                    losses[i] = (0.5/self.loss_weigts[i]**2)*loss + tf.math.log(1+self.loss_weigts[i]**2)
                    """
                #self.add_loss(tasks_loss)
                # Last level
                prediction_level_state = self.prediction_level_lstm_layer(second_level_state)
                prediction_state = tf.keras.layers.concatenate([prediction_level_state, last_step_state, mid_step_state, 
                                         first_step_state, smooth_state_3, smooth_state_6, smooth_state_8])
                
                prediction_output = self.prediction_output(prediction_state)
                
                if training:
                    return prediction_output, losses
                else:
                    return prediction_output


        model = MultitaskHLNet(d_model, n_features, window, horizon, batch_size, grad_smooth_alpha, lamb1, lamb2, mode, overall)

        loss = 'mse'
    
    elif model_type == 'multitaskhlnet_taskslstm_firstlevelcnn_gradsprojected_embeddingappended_swapappended':
    #pred_index = -1

        class MultitaskHLNet(tf.keras.Model):

            def __init__(self, d_model, n_features, window, horizon, batch_size, grad_smooth_alpha, lamb1, lamb2, mode, overall):
                super(MultitaskHLNet, self).__init__()
                self.d_model = d_model
                self.n_features = n_features
                self.window = window
                self.horizon = horizon
                self.batch_size = batch_size
                self.main_grad_first_level_average = None
                self.main_grad_second_level_average = None
                self.first_level_aux_grad_average = None
                self.grad_smooth_alpha = grad_smooth_alpha
                self.lamb1 = lamb1
                self.lamb2 = lamb2
                self.mode = mode
                self.overall = overall
                
                self.loss_weigts = tf.Variable(tf.ones([8]))
                
                self.global_embedding_1 = tf.keras.layers.Conv1D(d_model//2, 3, activation='relu', input_shape=(window, 3), padding='causal')
                self.global_embedding_2 = tf.keras.layers.Conv1D(d_model, 3, activation='relu', padding='causal')
                
                self.global_first_level_layers = [0,1]
                """
                    First Level tasks - Sequence featurization
                """
                self.decode_cnn_1 = tf.keras.layers.Conv1DTranspose(d_model, 3, activation='relu', padding='same')
                self.decode_cnn_2 = tf.keras.layers.Conv1DTranspose(d_model//2, 3, activation='relu', padding='same')
                self.gap_filling_output = tf.keras.layers.Dense(n_features)

                self.gaussian_noise = tf.keras.layers.GaussianNoise(0.01)
                self.noise_reduction_output = tf.keras.layers.Dense(n_features)

                self.swap_combinations = list(itertools.permutations(np.arange(0, 4), 4))
                self.swap_lstm_layer = tf.keras.layers.LSTM(d_model, name='swap_lstm', activation='relu')
                self.swap_output = tf.keras.layers.Dense(len(self.swap_combinations), name='swap_task', activation='softmax')

                
                """
                    Second Level tasks - Forecasting helpers
                """
                self.second_level_lstm_layer = tf.keras.layers.LSTM(d_model, return_sequences=True)

                self.smooth_lstm_layers = {'8': tf.keras.layers.LSTM(d_model, name='smooth_forecasting_8_lstm', activation='relu'),
                                '6':  tf.keras.layers.LSTM(d_model, name='smooth_forecasting_6_lstm', activation='relu'),
                                '3':  tf.keras.layers.LSTM(d_model, name='smooth_forecasting_3_lstm', activation='relu')}

                self.smooth_layers = {'8': tf.keras.layers.Dense(horizon-8+1, name='smooth_forecasting_8'),
                                '6': tf.keras.layers.Dense(horizon-6+1, name='smooth_forecasting_6'),
                                '3': tf.keras.layers.Dense(horizon-3+1, name='smooth_forecasting_3')}


                self.one_step_forecast_lstm_layers = {'0': tf.keras.layers.LSTM(d_model, name='next_step_forecasting_lstm', activation='relu'),
                                str(horizon//2):  tf.keras.layers.LSTM(d_model, name='mid_step_forecasting_lstm', activation='relu'),
                                str(horizon-1):  tf.keras.layers.LSTM(d_model, name='last_step_forecasting_lstm', activation='relu')}

                self.one_step_forecast = {'0': tf.keras.layers.Dense(1, name='next_step_forecasting'),
                                    str(horizon//2): tf.keras.layers.Dense(1, name='mid_step_forecasting'),
                                    str(horizon-1): tf.keras.layers.Dense(1, name='last_step_forecasting')}
                self.global_second_level_layers = [9]
                """
                    Last Level tasks
                """
                self.prediction_level_lstm_layer = tf.keras.layers.LSTM(d_model, return_sequences=False)
                self.prediction_output = tf.keras.layers.Dense(horizon, name='prediction')

            def get_mode(self, x, axis=1):
                dt = x.dtype
                # Shift input in case it has negative values
                m = tf.math.reduce_min(x)
                x2 = x - m
                # minlength should not be necessary but may fail without it
                # (reported here https://github.com/tensorflow/probability/issues/962)
                c = tfp.stats.count_integers(x2, axis=axis, dtype=dt,
                                             minlength=tf.math.reduce_max(x2) + 1)
                # Find the values with largest counts
                idx = tf.math.argmax(c, axis=0, output_type=dt)
                # Get the modes by shifting by the subtracted minimum
                modes = idx + m
                # Get the number of counts
                counts = tf.math.reduce_max(c, axis=0)

                return modes

            def gap_filling_task(self, inputs):

                #Gap filling task
                batch_indexes = tf.tile(tf.range(tf.shape(inputs)[0])[:, tf.newaxis, tf.newaxis], (1, self.window, 1))
                head_indexes = tf.tile(tf.range(self.window)[tf.newaxis, :, tf.newaxis], (tf.shape(inputs)[0], 1, 1))
                feat_index = tf.random.uniform((tf.shape(inputs)[0],window,1), minval=0, maxval=3, dtype=tf.int32)

                idx = tf.squeeze(tf.stack(values=[batch_indexes, head_indexes, feat_index], axis=-1))
                idx = tf.transpose(idx, perm=(1,2,0))    

                gap_index =  tf.reshape(tf.transpose(tf.random.shuffle(idx), perm= (2, 0, 1))[:, :1, :], (-1, 3))

                x_gap = tf.identity(inputs)
                x_gap_updated = tf.tensor_scatter_nd_update(x_gap, indices = gap_index, updates = -tf.ones(gap_index.shape[0])*100)
                x_gap_embedding = self.global_embedding_1(x_gap_updated)
                x_gap_embedding = self.global_embedding_2(x_gap_embedding)
                x_gap_embedding = self.decode_cnn_1(x_gap_embedding)
                x_gap_embedding = self.decode_cnn_2(x_gap_embedding)
                #gap_state = self.first_level_lstm_layer(x_gap_embedding)
                gap_filling_output = self.gap_filling_output(x_gap_embedding)
                #gap_filling_output = self.gap_filling_reshape(gap_filling_output_flat)
                #gap_filling_true = tf.reshape(tf.gather_nd(x_gap, gap_index), (tf.shape(x_gap)[0], 1))
                gap_loss = tf.keras.losses.MeanSquaredError()(gap_filling_output, x_gap)

                self.add_metric(gap_loss, name='gap_loss', aggregation='mean')
                

                return gap_filling_output, gap_loss, x_gap_embedding

            def noise_reduction_task(self, inputs):

                input_copy = tf.identity(inputs)
                x_noise = self.gaussian_noise(input_copy)

                x_noise_embedding = self.global_embedding_1(x_noise)
                x_noise_embedding = self.global_embedding_2(x_noise_embedding)
                x_noise_embedding = self.decode_cnn_1(x_noise_embedding)
                x_noise_embedding = self.decode_cnn_2(x_noise_embedding)
                #noise_state = self.first_level_lstm_layer(x_noise_embedding)
                #noise_state = self.noise_reduction_lstm_layer(noise_state)
                noise_output = self.noise_reduction_output(x_noise_embedding)
                #noise_output = self.noise_reduction_reshape(noise_output_flat)

                noise_loss = tf.keras.losses.MeanSquaredError()(noise_output, input_copy)

                self.add_metric(noise_loss, name='noise_loss', aggregation='mean')
                

                return noise_output, noise_loss, x_noise_embedding

            def smooth_forecast_task(self, second_level_state, labels, smooth=6):

                smooth_state = self.smooth_lstm_layers[str(smooth)](second_level_state)
                smooth_forecasting = self.smooth_layers[str(smooth)](smooth_state)
                labels_smooth = tf.math.reduce_mean(tf.signal.frame(labels, smooth, 1, axis=1), axis=2)

                smooth_loss = tf.keras.losses.MeanSquaredError()(smooth_forecasting, labels_smooth)

                self.add_metric(smooth_loss, name=f'smooth_{smooth}_loss', aggregation='mean')
                

                return smooth_forecasting, smooth_state, smooth_loss

            def one_step_forecast_task(self,second_level_state, labels, step=0):
                step_state = self.one_step_forecast_lstm_layers[str(step)](second_level_state)
                step_forecast = self.one_step_forecast[str(step)](step_state)
                step_loss = tf.keras.losses.MeanSquaredError()(step_forecast, labels[:, step, :])

                self.add_metric(step_loss, name=f'step_{step}_loss', aggregation='mean')
                
                
                return step_forecast, step_state, step_loss

            def day_of_week_forecasting(self, second_level_state, labels_extra):
                week_forecast = self.day_of_week_forecasting(second_level_state[:, -1, :])
                week_loss = tf.keras.losses.MeanSquaredError()(week_forecast, self.get_mode(labels_extra[:, :, 0], axis=1))

                self.add_metric(week_loss, name=f'week_forecasting_loss', aggregation='mean')

            def day_of_week_sequence(self, first_level_state, inputs_extra):
                # TODO: ONEHOT
                week_forecast = self.day_of_week_sequence(first_level_state[:, -1, :])
                week_loss = tf.keras.losses.MeanSquaredError()(week_forecast, self.get_mode(inputs_extra[:, :, 0], axis=1))

                self.add_metric(week_loss, name=f'week_sequence_loss', aggregation='mean') 

            def quarter_forecasting(self, second_level_state, labels_extra):
                quarter_forecast = self.quarter_forecasting(second_level_state[:, -1, :])
                quarter_loss = tf.keras.losses.MeanSquaredError()(quarter_forecast, self.get_mode(labels_extra[:, :, 2], axis=1))

                self.add_metric(quarter_loss, name=f'quarter_forecasting_loss', aggregation='mean')

            def quarter_sequence(self, first_level_state, inputs_extra):
                quarter_forecast = self.quarter_sequence(first_level_state[:, -1, :])
                quarter_loss = tf.keras.losses.MeanSquaredError()(quarter_forecast, self.get_mode(inputs_extra[:, :, 2], axis=1))

                self.add_metric(quarter_loss, name=f'quarter_sequence_loss', aggregation='mean') 


            def month_forecasting(self, second_level_state, labels_extra):
                month_forecast = self.month_forecasting(second_level_state[:, -1, :])
                month_loss = tf.keras.losses.MeanSquaredError()(month_forecast, self.get_mode(labels_extra[:, :, 1], axis=1))

                self.add_metric(month_loss, name=f'month_forecasting_loss', aggregation='mean')


            def month_sequence(self, first_level_state, inputs_extra):
                month_forecast = self.month_sequence(first_level_state[:, -1, :])
                month_loss = tf.keras.losses.MeanSquaredError()(month_forecast, self.get_mode(inputs_extra[:, :, 1], axis=1))

                self.add_metric(month_loss, name=f'month_sequence_loss', aggregation='mean')                

            def swap_task(self, inputs):

                x_swap = tf.identity(inputs)

                idx = tf.tile([tf.range(0, window)], (tf.shape(x_swap)[0], 1))

                swap_index = tf.random.uniform((tf.shape(x_swap)[0], 1), 0, len(self.swap_combinations), dtype=tf.int64)

                swap_index_onehot = tf.one_hot(swap_index, len(self.swap_combinations))

                swap = tf.gather(tf.convert_to_tensor(self.swap_combinations), swap_index)

                swap_numpy = swap.numpy()
                idx_numpy = idx.numpy()
                idx_numpy = np.concatenate((idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 0].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 1].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 2].squeeze()],
                            idx_numpy.reshape(idx_numpy.shape[0], 4, -1)[np.arange(idx_numpy.shape[0]), swap_numpy[..., 3].squeeze()]), axis=1)

                idx = tf.convert_to_tensor(idx_numpy)
                x_swap = tf.gather(x_swap, idx, axis=1, batch_dims=1)

                x_swap_embedding = self.embedding(x_swap)
                swap_state = self.first_level_lstm_layer(x_swap_embedding)
                swap_state = self.swap_lstm_layer(swap_state)
                swap_output = self.swap_output(swap_state)

                swap_loss = tf.keras.losses.CategoricalCrossentropy()(tf.squeeze(swap_index_onehot), swap_output)

                self.add_metric(swap_loss, name='swap_loss', aggregation='mean')

                return swap_output, swap_state, swap_loss

            def train_step(self, data):
                inputs, labels = data
                gradients = []
                
                with tf.GradientTape(persistent=True) as tape:
                    prediction, losses = self(inputs, True)
                    loss = self.compiled_loss(labels, prediction, regularization_losses=self.losses)

                self.compiled_metrics.update_state(labels, prediction)
                
                model_layers = np.array(self.layers)

                # First level loss
                first_level_trainable_weights = [w for layer in model_layers[model.global_first_level_layers] for w in layer.trainable_weights]    
                main_grad_first_level = tape.gradient(loss, first_level_trainable_weights)
                self.main_grad_first_level_average = update_smooth_grad(main_grad_first_level, self.main_grad_first_level_average, grad_smooth_alpha)

                first_level_gradients = []
                for l in losses[:2]:

                    aux_grad = tape.gradient(l, first_level_trainable_weights)
                    aux_grad = combined_grads(main_grad_first_level, self.main_grad_first_level_average, aux_grad, mode, overall, lamb1)

                    first_level_gradients.append(aux_grad)

                    gradients.extend(list(zip(aux_grad, first_level_trainable_weights)))

                first_level_aux_grad = combined_grads(first_level_gradients[0], None, first_level_gradients[1], 'Multitask', overall, 1)
                self.first_level_aux_grad_average = update_smooth_grad(first_level_aux_grad, self.first_level_aux_grad_average, grad_smooth_alpha)

                #Second level loss
                first_level_trainable_weights = [w for layer in model_layers[self.global_first_level_layers] for w in layer.trainable_weights]    

                for l in losses[2:]:
                    aux_grad = tape.gradient(l, first_level_trainable_weights)
                    aux_grad = combined_grads(first_level_aux_grad, self.first_level_aux_grad_average, aux_grad, mode, overall, lamb2)
                    gradients.extend(list(zip(aux_grad, first_level_trainable_weights)))

                second_level_trainable_weights = [w for layer in model_layers[self.global_second_level_layers] for w in layer.trainable_weights]    
                main_grad_second_level = tape.gradient(loss, second_level_trainable_weights)
                self.main_grad_second_level_average = update_smooth_grad(main_grad_second_level, self.main_grad_second_level_average, grad_smooth_alpha)

                for l in losses[2:]:
                    aux_grad = tape.gradient(l, second_level_trainable_weights)
                    aux_grad = combined_grads(main_grad_second_level, self.main_grad_second_level_average, aux_grad, mode, overall, lamb2)

                    gradients.extend(list(zip(aux_grad, second_level_trainable_weights)))

                # All losses
                rest_layers = sum([self.global_second_level_layers, self.global_first_level_layers], [])

                mask = np.ones(len(model_layers), bool)
                mask[rest_layers] = False
                all_level_trainable_weights = [w for layer in model_layers[mask] for w in layer.trainable_weights]    
                main_grad_rest_layers = tape.gradient(loss, all_level_trainable_weights)

                for l in losses:
                    grad = tape.gradient(l, all_level_trainable_weights)
                    gradients.extend(list(zip(grad, all_level_trainable_weights)))
                    
                gradients.extend(list(zip(main_grad_rest_layers, all_level_trainable_weights)))
                #Apply all gradients
                self.optimizer.apply_gradients(gradients)

                return {m.name: m.result() for m in self.metrics}
            
            def call(self, inputs, training):
                labels, labels_extra = inputs[1]
                inputs, inputs_extra = inputs[0]

                x = self.global_embedding_1(inputs)
                x = self.global_embedding_2(x)
                #first_level_state = self.first_level_lstm_layer(x)

                # First level
                gap_filling_outputs, gap_loss, x_gap_embedding = self.gap_filling_task(inputs)
                noise_reduction_outputs, noise_loss, x_noise_embedding = self.noise_reduction_task(inputs)
                swap_output, swap_state, swap_loss = self.swap_task(inputs)
                
                x = tf.keras.layers.concatenate([x, x_gap_embedding, x_noise_embedding])
                
                #Second level
                second_level_state = self.second_level_lstm_layer(x)
                smooth_forecasting_8, smooth_state_8, smooth_loss_8 = self.smooth_forecast_task(second_level_state, labels, 8)
                smooth_forecasting_6, smooth_state_6, smooth_loss_6 = self.smooth_forecast_task(second_level_state, labels, 6)
                smooth_forecasting_3, smooth_state_3, smooth_loss_3 = self.smooth_forecast_task(second_level_state, labels, 3)
                
                first_step_forecast, first_step_state, first_step_loss = self.one_step_forecast_task(second_level_state, labels, 0)
                mid_step_forecast, mid_step_state, mid_step_loss = self.one_step_forecast_task(second_level_state, labels, self.horizon//2)
                last_step_forecast, last_step_state, last_step_loss = self.one_step_forecast_task(second_level_state, labels, self.horizon-1)
                
                losses = [gap_loss, noise_loss, swap_loss, first_step_loss, mid_step_loss, last_step_loss, smooth_loss_8, smooth_loss_6, smooth_loss_3]
                # smooth_loss_8, smooth_loss_6, smooth_loss_3
                
                """tasks_loss = 0
                for i, loss in enumerate(losses):
                    losses[i] = (0.5/self.loss_weigts[i]**2)*loss + tf.math.log(1+self.loss_weigts[i]**2)"""
                    
                #self.add_loss(tasks_loss)
                # Last level
                prediction_level_state = self.prediction_level_lstm_layer(second_level_state)
                prediction_state = tf.keras.layers.concatenate([prediction_level_state, last_step_state, mid_step_state, 
                                         first_step_state, smooth_state_3, smooth_state_6, smooth_state_8])
                
                prediction_output = self.prediction_output(prediction_state)
                
                if training:
                    return prediction_output, losses
                else:
                    return prediction_output


        model = MultitaskHLNet(d_model, n_features, window, horizon, batch_size, grad_smooth_alpha, lamb1, lamb2, mode, overall)

        loss = 'mse'
        
    return model, loss

In [4]:
class hierarchical_loss(tf.keras.losses.Loss):

    def __init__(self, base_criterion, gf, ge,  reduction=tf.keras.losses.Reduction.SUM):
        self.gf = gf
        self.ge = ge
        self.base_criterion = base_criterion
        self.reduction = reduction
        self.name = 'HierarchicalLoss'

    def __call__(self, y_true, y_pred, sample_weight=None):
        criterion = tf.keras.losses.get(self.base_criterion)

        y_h_true = tf.math.reduce_mean(tf.signal.frame(y_true, self.gf, self.ge, axis=1), axis=2)
        loss = criterion(y_pred, tf.squeeze(y_h_true))

        return loss

In [5]:
def split_window(window_size, horizon_size, label_indexes, return_input_labels, n_features, features):
    
    inputs = features[:, :window_size, :n_features]
    labels = features[:, window_size:, :n_features]

    labels = tf.stack(
        [labels[:, :, index] for index in label_indexes],
        axis=-1)

    inputs.set_shape([None, window_size, None])
    labels.set_shape([None, horizon_size, None])

    if features.shape[2]>n_features:
        inputs_extra = features[:, :window_size, n_features:]
        labels_extra = features[:, window_size:, n_features:]
        inputs_extra.set_shape([None, window_size, None])
        labels_extra.set_shape([None, horizon_size, None])
        
        inputs = (inputs, inputs_extra)
        labels_input = (labels, labels_extra)
         
    if return_input_labels:
        labels_input = labels if features.shape[2]<=n_features else labels_input
        inputs = (inputs, labels_input)
        
    return inputs, labels

In [6]:
def censored_vector(u, v, mode='Projection'):
    """Adjusts the auxiliary loss gradient

    Adjusts the auxiliary loss gradient before adding it to the primary loss
    gradient and using a gradient descent-based method

    Args:
    u: A tensorflow variable representing the auxiliary loss gradient
    v: A tensorflow variable representing the primary loss gradient
    mode: The method used for the adjustment:
      - Single task: the auxiliary loss gradient is ignored
      - Multitask: the auxiliary loss gradient is kept as it is
      - Unweighted cosine: cf. https://arxiv.org/abs/1812.02224
      - Weighted cosine: cf. https://arxiv.org/abs/1812.02224
      - Projection: cf. https://github.com/vivien000/auxiliary-learning
      - Parameter-wise: same as projection but at the level of each parameter

    Returns:
    A tensorflow variable representing the adjusted auxiliary loss gradient
    """
    if mode == 'Single task':
        return 0  
    if mode == 'Multitask':
        return u
    l_u, l_v = tf.norm(u), tf.norm(v)
    if l_u.numpy() == 0 or l_v.numpy() == 0:
        return u
    u_dot_v = tf.math.reduce_sum(u*v)
    if mode == 'Unweighted cosine':
        return u if u_dot_v > 0 else tf.zeros_like(u)
    if mode == 'Weighted cosine':
        return tf.math.maximum(u_dot_v, 0)*u/l_u/l_v
    if mode == 'Projection':
        return u - tf.math.minimum(u_dot_v, 0)*v/l_v/l_v
    if mode == 'Parameter-wise':
        return u*((tf.math.sign(u*v)+1)/2)

def combined_grads(primary_grad,
                   average_primary_grad,
                   auxiliary_grad,
                   mode,
                   overall=False,
                   lam=1):
    """Combines auxiliary loss gradients and primary loss gradients

    Combines a sequence of auxiliary loss gradients and a sequence of primary
    loss gradients before performing a gradient descent step

    Args:
    primary_grad: A list of tensorflow variables corresponding to the primary
    loss gradient for the network's Keras variables
    average_primary_grad: A list of tensorflow variables corresponding to
    exponential moving averages of the elements above
    auxiliary_grad: A list of tensorflow variables corresponding to the
    auxiliary loss gradient for the network's Keras variables
    mode: The method used for the adjustment:
      - Single task: the auxiliary loss gradient is ignored
      - Multitask: the auxiliary loss gradient is kept as it is
      - Unweighted cosine: cf. https://arxiv.org/abs/1812.02224
      - Weighted cosine: cf. https://arxiv.org/abs/1812.02224
      - Projection: cf. https://github.com/vivien000/auxiliary-learning
      - Parameter-wise: same as projection but at the level of each parameter
    overall: True if the transformation takes place at the level of the whole
    parameter vector, i.e. the concatenation of all the Keras variables of the
    network
    lambda: Float balancing the primary loss and the auxiliary loss

    Returns:
    A list of tensorflow variables combining the primary loss gradients and the
    auxiliary loss gradients and that can directly be used for the next gradient
    descent step
    """
    result = [0]*len(primary_grad)
    a = tf.constant([], dtype=tf.float32)
    aa = tf.constant([], dtype=tf.float32)
    b = tf.constant([], dtype=tf.float32)
    shapes = []
    for i in range(len(primary_grad)):
        if auxiliary_grad[i] is None or mode == 'Single task':
            result[i] = primary_grad[i]
        elif primary_grad[i] is None:
            result[i] = lam*auxiliary_grad[i]
        elif mode == 'Multitask':
            result[i] = primary_grad[i] + lam*auxiliary_grad[i]
        elif not overall:
            if average_primary_grad is None:
                result[i] = (primary_grad[i]
                             + lam*censored_vector(auxiliary_grad[i],
                                                   primary_grad[i],
                                                   mode))
            else:
                result[i] = (primary_grad[i]
                             + lam*censored_vector(auxiliary_grad[i],
                                                   average_primary_grad[i],
                                                   mode))
        else:
            a = tf.concat([a, tf.reshape(primary_grad[i], [-1])], axis=0)
            if average_primary_grad is not None:
                aa = tf.concat([aa, tf.reshape(average_primary_grad[i], [-1])], axis=0)
            b = tf.concat([b, tf.reshape(auxiliary_grad[i], [-1])], axis=0)
            shapes.append((primary_grad[i].shape,
                         np.product(primary_grad[i].shape.as_list()),
                         i))

        if len(shapes) > 0:
            if average_primary_grad is None:
                c = a + lam*censored_vector(b, a, mode)
            else:
                c = a + lam*censored_vector(b, aa, mode)
            start = 0
            for i in range(len(shapes)):
                shape, length, index = shapes[i]
                result[index] = tf.reshape(c[start:start+length], shape)
                start += length
            
    return result

def update_smooth_grad(main_grad, average_grad, alpha):
    
    if alpha != 1:
        if average_grad is None:
            average_grad = main_grad
        else:
            for i in range(len(average_grad)):
                if main_grad[i] is not None:
                    average_grad[i] = ((1 - alpha)*average_grad[i]
                                               + alpha*main_grad[i])
    return average_grad

In [7]:
dataset = "BERMEJALES"
target = dataset+"-O3-AT_IN"
windows = [24]
horizons = [24]

batch_size=32
lr = 3e-3
epochs = 100

d_models = [256]
lambs = [(0.2, 0.2), (0.5, 0.5), (1, 1), (2, 2)]
alphas = [1, 0.9]
modes = ['Projection', 'Weighted cosine'] #, 'Unweighted cosine'
overall = True

model_types = ['multitaskhlnet_taskslstm_firstlevelcnn_gradsprojected_embeddingappended_lossweighted', 'multitaskhlnet_taskslstm_firstlevelcnn_gradsprojected_embeddingappended_swapappended']
n_features = 3
test_year = 2015

df = pd.read_csv("../data/"+dataset.lower()+"_d.csv")
df = df.set_index(pd.to_datetime(df['FECHA_HORA']))

df = df[df.index.year.isin([2015, 2014, 2013])]

In [1]:
pred_index = None

for window, horizon, d_model, model_type, lamb, grad_smooth_alpha, mode in itertools.product(windows, horizons, d_models, model_types, lambs, alphas, modes):
    lamb1, lamb2 = lamb
    np.random.seed(123)
    tf.random.set_seed(123)
    random.seed(123)
    
    selected_columns = [dataset+"-O3-AT_IN", dataset+"-PM10-AT_IN", dataset+"-TMP Media-AT_IN"]
    df = df[selected_columns]
    
    #for test_year in df.index.year.unique():
    return_input_labels = 'multitaskhlnet' in model_type
    
    if return_input_labels:
        df['day_of_week'] = df.index.dayofweek
        df['month'] = df.index.month
        df['season'] = df.index.quarter
        selected_columns.extend(['season', 'month', 'day_of_week'])
    
    """
        Data preparation
    """ 
    scaler = MinMaxScaler(feature_range=(0, 1))
    scaler_o3 = MinMaxScaler(feature_range=(0, 1))

    train = df[df.index.year!=test_year]

    validation_row = int(len(train) * 0.9)
    valid = train.iloc[validation_row:, :]
    train = train.iloc[:validation_row, :]

    test =  df[df.index.year==test_year]

    test_dates = test.index

    scaler = scaler.fit(np.concatenate((train.values[:, :n_features], valid.values[:, :n_features])))
    data_train =  scaler.transform(train.values[:, :n_features])
    data_train = np.concatenate((data_train, train.values[:, n_features:]), axis=1)

    data_valid = scaler.transform(valid.values[:, :n_features])
    data_valid = np.concatenate((data_valid, valid.values[:, n_features:]), axis=1)

    data_test = scaler.transform(test.values[:, :n_features])
    data_test = np.concatenate((data_test, test.values[:, n_features:]), axis=1)
    scaled_o3 = scaler_o3.fit_transform(np.concatenate((train[[target]].values, valid[[target]].values)))

    df_train = pd.DataFrame(data = data_train, columns = selected_columns)
    df_valid = pd.DataFrame(data = data_valid, columns = selected_columns)
    df_test = pd.DataFrame(data = data_test, columns = selected_columns)

    window_splitter = partial(split_window, window, horizon, [0], return_input_labels, n_features)

    training_generator = tf.keras.preprocessing.timeseries_dataset_from_array(
          data=df_train.values,
          targets=None,
          sequence_length=window+horizon,
          sequence_stride=1,
          shuffle=True,
          batch_size=batch_size,
          seed = 123)
    training_generator = training_generator.map(window_splitter, num_parallel_calls=tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE)

    valid_generator = tf.keras.preprocessing.timeseries_dataset_from_array(
          data=df_valid.values,
          targets=None,
          sequence_length=window+horizon,
          sequence_stride=1,
          shuffle=True,
          batch_size=batch_size,
          seed = 123)
    valid_generator = valid_generator.map(window_splitter, num_parallel_calls=tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE)

    test_generator = tf.keras.preprocessing.timeseries_dataset_from_array(
          data=df_test.values,
          targets=None,
          sequence_length=window+horizon,
          sequence_stride=1,
          shuffle=False,
          batch_size=batch_size)

    test_generator = test_generator.map(window_splitter, num_parallel_calls=tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE)

    model, loss = get_model(model_type, window, horizon, d_model, n_features, batch_size, lamb1, lamb2, grad_smooth_alpha, mode, overall)

    radam = tfa.optimizers.RectifiedAdam()
    ranger = tfa.optimizers.Lookahead(radam, sync_period=6, slow_step_size=0.5)
    model.compile(optimizer=ranger, loss=loss,run_eagerly=True, metrics='mae')

    """
        Callbacks
    """
    base_path = f'../results/{dataset.lower()}/{model_type}/w{window}_h{horizon}_dmodel{d_model}_lamb1{lamb1}_lamb2{lamb2}_a{grad_smooth_alpha}_m{mode}'
    print(base_path)
    if not os.path.isdir(base_path):
        os.makedirs(base_path+'/checkpoints_'+str(test_year))

    early_stopping = tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True) # Restore brat weights
    checkpoint = tf.keras.callbacks.ModelCheckpoint(base_path+'/checkpoints_'+str(test_year)+'/cp-{epoch:04d}.ckpt', monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=True, mode='auto')

    """
        Training
    """
    start_time = time.time()
    
    main_grad_first_level_average, main_grad_second_level_average, first_level_aux_grad_average = None, None, None
    
    
    history = model.fit(training_generator,
                        validation_data=valid_generator,
                        use_multiprocessing=True,
                        callbacks=[checkpoint, early_stopping],
                        epochs=epochs,
                        workers=12)
    
    train_time = (time.time() - start_time)/60

    history_dict = history.history
    json.dump(history_dict, open(base_path+'/checkpoints_'+str(test_year)+'/history.json', 'w'))

    """
        Inference
    """

    #model.load_weights(checkpoint._get_most_recently_modified_file_matching_pattern(base_path+'/checkpoints_'+str(test_year)+'/cp-{epoch:04d}.ckpt'))

    reals_raw = df_test[target].values[window:]
    indexer = np.arange(horizon)[None, :] + np.arange(len(reals_raw)-horizon+1)[:, None]
    reals = reals_raw[indexer]
    predictions = model.predict(test_generator)

    if pred_index is not None:
        predictions = predictions[pred_index]

    real_scaled = scaler_o3.inverse_transform(reals)
    predictions_scaled = scaler_o3.inverse_transform(predictions)

    dates = pd.Series(test_dates.values[window:-horizon+1])
    columns = [f'real_{i}' for i in range(horizon)]
    columns.extend([f'pred_{i}' for i in range(horizon)])
    predictions_df = pd.DataFrame(data=np.concatenate([real_scaled, predictions_scaled], axis=1), columns= columns, index = dates)
    metrics_dict = evaluate(real_scaled, predictions_scaled)
    metrics = pd.DataFrame(metrics_dict, index = [test_year])
    metrics['time'] = train_time

    if os.path.exists(base_path+'/metrics.csv') and os.path.exists(base_path+'/predictions.csv'):
        metrics.to_csv(base_path+'/metrics.csv', mode='a', header=False)
        predictions_df.to_csv(base_path+'/predictions.csv', mode='a', header=False)
    else:
        metrics.to_csv(base_path+'/metrics.csv')
        predictions_df.to_csv(base_path+'/predictions.csv')


    print(metrics)

NameError: name 'itertools' is not defined

* Hacer que los pesos de las tareas se aprendan
* Añadir attention
* Hacer que los gradientes de un mismo nivel sean coherentes entre ellos
* Probar aislar niveles
* Aislar solo primer nivel

In [None]:
model.summary()