# Temporal Fusion Transformers for Multi-horizon Time Series Forecasting

## Introduction
This notebook demonstrates the use of the Temporal Fustion Transformer (TFT) for high-peformance mulit-horizon time series prediction, using a traffic forecasting example with data from the [UCI PEMS-SF Repository](https://archive.ics.uci.edu/ml/datasets/PEMS-SF). 

We also show how to use TFT for two interpretability cases:
1. Analyzing variable importance weights to identify signficant features for the prediction problem.
2. Visualizing persistent temporal patterns learnt by the TFT using temporal self-attention weights.

A third use case is also presented in our companion notebook "Temporal Fusion Transfomers for Regime Identification in Time Series Data".

### Reference Paper
> **Bryan Lim, Sercan Arik, Nicolas Loeff and Tomas Pfister**. "Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting". *Submitted*, 2019.

#### Abstract
Multi-horizon forecasting problems often contain a complex mix of inputs -- including static (i.e. time-invariant) covariates, known future inputs, and other exogenous time series that are only observed historically -- without any prior information on how they interact with the target. While several deep learning models have been proposed for multi-step prediction, they typically comprise black-box models which do not account for the full range of inputs present in common scenarios. In this paper, we introduce the Temporal Fusion Transformer (TFT) -- a novel attention-based architecture which combines high-performance multi-horizon forecasting with interpretable insights into temporal dynamics. To learn temporal relationships at different scales, the TFT utilizes recurrent layers for local processing and interpretable self-attention layers for learning long-term dependencies. The TFT also uses specialized components for the judicious selection of relevant features and a series of gating layers to suppress unnecessary components, enabling high performance in a wide range of regimes. On a variety of real-world datasets, we demonstrate significant performance improvements over existing benchmarks, and showcase three practical interpretability use-cases of TFT.

## Preliminary Setup

### Package Installation

In [1]:
# Uses pip3 to install necessary packages
!pip3 install pyunpack wget patool plotly cufflinks --user

# Resets the IPython kernel to import the installed package.
import IPython
app = IPython.Application.instance()
app.kernel.do_shutdown(True)



{'status': 'ok', 'restart': True}

### Code Download
First, we download the [Google Research GitHub Repository](https://github.com/google-research/google-research.git) to the current directory.

In [1]:
import os
from git import Repo

# Current working directory
repo_dir = os.getcwd() + '/repo'

if not os.path.exists(repo_dir):
    os.makedirs(repo_dir)

# Clones github repository
if not os.listdir(repo_dir):
    git_url = "https://github.com/google-research/google-research.git"
    Repo.clone_from(git_url, repo_dir)

# Sets current directory
tft_dir = os.path.join(repo_dir, 'tft')
os.chdir(tft_dir)

# Suppress warnings in cells to improve readability
import warnings  
warnings.filterwarnings('ignore') 

 ### Data Download
 Next, we download data for the default traffic experiment using the scripts provided. As the raw data is supplied in 10-min intervals, the scripts also aggregate them into hourly buckets as described in the paper.

In [3]:
import pandas as pd
from script_download_data import main as download_data

In [6]:
# Download parameters
expt_name = 'traffic'                                  # Name of default experiment
output_folder = os.path.join(os.getcwd(), 'outputs')   # Root folder to save experiment outputs
force_download = False                                 # Skips download if data is already present

if not os.path.exists(output_folder):
    os.makedirs(output_folder)

# Downloads main csv file if not present
csv_file = os.path.join(output_folder,'data','traffic', 'hourly_data.csv')
if not os.path.exists(csv_file):
    download_data(expt_name, force_download=True, output_folder=output_folder)

# Load the downloaded data
df = pd.read_csv(csv_file, index_col=0)

## Data Definitions

Now that the csv file has been saved, we need to define certain data manipulation methods to:
1. Supply the TFT with information on the input type (e.g. static, observed or known inputs) along with the numerical type (e.g. real-valued or categorical). This information is used by the TFT to batch the data for training.
2. Specify how to split the dataset into train, validation and test sets.
3. Implement any pre-training data transformations to perform (e.g. standardization).
4. Implement any post-processing steps required by input (e.g. converting to original scale).

This requires the definition of a new class which inherits from ``GenericDataFormatter`` and implements all abstract functions.


In [7]:
from data_formatters.base import GenericDataFormatter, DataTypes, InputTypes

# View avialable inputs and data types.
print("Available data types:")
for option in DataTypes:
    print(option)

print()
print("Avaialbe input types:")
for option in InputTypes:
    print(option)

Available data types:
DataTypes.REAL_VALUED
DataTypes.CATEGORICAL
DataTypes.DATE

Avaialbe input types:
InputTypes.TARGET
InputTypes.OBSERVED_INPUT
InputTypes.KNOWN_INPUT
InputTypes.STATIC_INPUT
InputTypes.ID
InputTypes.TIME


In [8]:
from libs import utils        # Load TFT helper functions
import sklearn.preprocessing  # Used for data standardization

# Implement formatting functions
class TrafficFormatter(GenericDataFormatter):
    """Defines and formats data for the traffic dataset.

    This also performs z-score normalization across the entire dataset, hence
    re-uses most of the same functions as volatility.

    Attributes:
    column_definition: Defines input and data type of column used in the
      experiment.
    identifiers: Entity identifiers used in experiments.
    """
    
    # This defines the types used by each column
    _column_definition = [
      ('id', DataTypes.REAL_VALUED, InputTypes.ID),   
      ('hours_from_start', DataTypes.REAL_VALUED, InputTypes.TIME),
      ('values', DataTypes.REAL_VALUED, InputTypes.TARGET),
      ('time_on_day', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
      ('day_of_week', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
      ('hours_from_start', DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT),
      ('categorical_id', DataTypes.CATEGORICAL, InputTypes.STATIC_INPUT),
    ]

    def split_data(self, df, valid_boundary=151, test_boundary=166):
        """Splits data frame into training-validation-test data frames.

        This also calibrates scaling object, and transforms data for each split.

        Args:
          df: Source data frame to split.
          valid_boundary: Starting year for validation data
          test_boundary: Starting year for test data

        Returns:
          Tuple of transformed (train, valid, test) data.
        """

        print('Formatting train-valid-test splits.')

        index = df['sensor_day']
        train = df.loc[index < valid_boundary]
        valid = df.loc[(index >= valid_boundary - 7) & (index < test_boundary)]
        test = df.loc[index >= test_boundary - 7]

        self.set_scalers(train)

        return (self.transform_inputs(data) for data in [train, valid, test])

    def set_scalers(self, df):
        """Calibrates scalers using the data supplied.

        Args:
          df: Data to use to calibrate scalers.
        """
        print('Setting scalers with training data...')

        column_definitions = self.get_column_definition()
        id_column = utils.get_single_col_by_input_type(InputTypes.ID,
                                                       column_definitions)
        target_column = utils.get_single_col_by_input_type(InputTypes.TARGET,
                                                           column_definitions)

        # Extract identifiers in case required
        self.identifiers = list(df[id_column].unique())

        # Format real scalers
        real_inputs = utils.extract_cols_from_data_type(
            DataTypes.REAL_VALUED, column_definitions,
            {InputTypes.ID, InputTypes.TIME})

        data = df[real_inputs].values
        self._real_scalers = sklearn.preprocessing.StandardScaler().fit(data)
        self._target_scaler = sklearn.preprocessing.StandardScaler().fit(
            df[[target_column]].values)  # used for predictions

        # Format categorical scalers
        categorical_inputs = utils.extract_cols_from_data_type(
            DataTypes.CATEGORICAL, column_definitions,
            {InputTypes.ID, InputTypes.TIME})

        categorical_scalers = {}
        num_classes = []
        for col in categorical_inputs:
            # Set all to str so that we don't have mixed integer/string columns
            srs = df[col].apply(str)
            categorical_scalers[col] = sklearn.preprocessing.LabelEncoder().fit(
              srs.values)
            num_classes.append(srs.nunique())

        # Set categorical scaler outputs
        self._cat_scalers = categorical_scalers
        self._num_classes_per_cat_input = num_classes

    def transform_inputs(self, df):
        """Performs feature transformations.

        This includes both feature engineering, preprocessing and normalisation.

        Args:
          df: Data frame to transform.

        Returns:
          Transformed data frame.

        """
        output = df.copy()

        if self._real_scalers is None and self._cat_scalers is None:
            raise ValueError('Scalers have not been set!')

        column_definitions = self.get_column_definition()

        real_inputs = utils.extract_cols_from_data_type(
            DataTypes.REAL_VALUED, column_definitions,
            {InputTypes.ID, InputTypes.TIME})
        categorical_inputs = utils.extract_cols_from_data_type(
            DataTypes.CATEGORICAL, column_definitions,
            {InputTypes.ID, InputTypes.TIME})

        # Format real inputs
        output[real_inputs] = self._real_scalers.transform(df[real_inputs].values)

        # Format categorical inputs
        for col in categorical_inputs:
            string_df = df[col].apply(str)
            output[col] = self._cat_scalers[col].transform(string_df)

        return output

    def format_predictions(self, predictions):
        """Reverts any normalisation to give predictions in original scale.

        Args:
          predictions: Dataframe of model predictions.

        Returns:
          Data frame of unnormalised predictions.
        """
        output = predictions.copy()

        column_names = predictions.columns

        for col in column_names:
            if col not in {'forecast_time', 'identifier'}:
                output[col] = self._target_scaler.inverse_transform(predictions[col])

        return output
    
    
    def get_fixed_params(self):
        """Returns fixed model parameters for experiments."""

        fixed_params = {
            'total_time_steps': 8*24,     # Total width of the Temporal Fusion Decoder
            'num_encoder_steps': 7*24,    # Length of LSTM decoder (ie. # historical inputs)
            'num_epochs': 100,            # Max number of epochs for training 
            'early_stopping_patience': 5, # Early stopping threshold for # iterations with no loss improvement
            'multiprocessing_workers': 5  # Number of multi-processing workers
        }

        return fixed_params

## Training and Evaluating the TFT

Using the data formatting definitions in ``TrafficFormatter``, we next walk through the procedure for training the TFT.

First, we get all data-related parameters from the data formatter, and define TFT model parameters.

In [9]:
# Create a data formatter 
data_formatter = TrafficFormatter()

# Split data 
train, valid, test = data_formatter.split_data(df)

data_params = data_formatter.get_experiment_params()

# Model parameters for calibration
model_params = {'dropout_rate': 0.3,      # Dropout discard rate
                'hidden_layer_size': 320, # Internal state size of TFT
                'learning_rate': 0.001,   # ADAM initial learning rate
                'minibatch_size': 128,    # Minibatch size for training
                'max_gradient_norm': 100.,# Max norm for gradient clipping
                'num_heads': 4,           # Number of heads for multi-head attention
                'stack_size': 1           # Number of stacks (default 1 for interpretability)
               }

# Folder to save network weights during training.
model_folder = os.path.join(output_folder, 'saved_models', 'traffic', 'fixed')
model_params['model_folder'] = model_folder

model_params.update(data_params)

Formatting train-valid-test splits.
Setting scalers with training data...


Next, we train a TFT model by subsampling the data in the training and validation sets.

In [10]:
import tensorflow as tf
from libs.tft_model import TemporalFusionTransformer

# Specify GPU usage
tf_config = utils.get_default_tensorflow_config(tf_device="gpu", gpu_id=0)

W1218 17:44:29.930665 140065355179776 deprecation_wrapper.py:119] From /home/bryanlim/research/prod_code/notebooks/tft/libs/utils.py:139: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.



Selecting GPU ID=0


In [11]:
tf.reset_default_graph()
with tf.Graph().as_default(), tf.Session(config=tf_config) as sess:

    tf.keras.backend.set_session(sess)
    
    # Create a TFT model
    model = TemporalFusionTransformer(model_params, 
                                    use_cudnn=True) # Run model on GPU using CuDNNLSTM cells

    # Sample data into minibatches for training
    if not model.training_data_cached():
        model.cache_batched_data(train, "train", num_samples=450000)
        model.cache_batched_data(valid, "valid", num_samples=50000)

    # Train and save model
    model.fit()
    model.save(model_folder)

W1218 17:44:36.911679 140065355179776 deprecation_wrapper.py:119] From /home/bryanlim/research/prod_code/notebooks/tft/libs/tft_model.py:1024: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.

W1218 17:44:36.923585 140065355179776 deprecation.py:506] From /usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/initializers.py:119: calling RandomUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W1218 17:44:37.020132 140065355179776 deprecation.py:506] From /usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of

Resetting temp folder...
*** TemporalFusionTransformer params ***
# known_categorical_inputs = [0]
# num_epochs = 100
# max_gradient_norm = 100.0
# static_input_loc = [4]
# input_obs_loc = [0]
# hidden_layer_size = 320
# column_definition = [('id', <DataTypes.REAL_VALUED: 0>, <InputTypes.ID: 4>), ('hours_from_start', <DataTypes.REAL_VALUED: 0>, <InputTypes.TIME: 5>), ('values', <DataTypes.REAL_VALUED: 0>, <InputTypes.TARGET: 0>), ('time_on_day', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('day_of_week', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('hours_from_start', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('categorical_id', <DataTypes.CATEGORICAL: 1>, <InputTypes.STATIC_INPUT: 3>)]
# early_stopping_patience = 5
# output_size = 1
# input_size = 5
# learning_rate = 0.001
# num_encoder_steps = 168
# model_folder = /home/bryanlim/research/prod_code/notebooks/tft/outputs/saved_models/traffic/fixed
# multiprocessing_workers = 5
# known_regular_

W1218 17:44:43.666501 140065355179776 deprecation_wrapper.py:119] From /home/bryanlim/research/prod_code/notebooks/tft/libs/tft_model.py:927: The name tf.keras.layers.CuDNNLSTM is deprecated. Please use tf.compat.v1.keras.layers.CuDNNLSTM instead.



Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 192, 5)]     0                                            
__________________________________________________________________________________________________
tf_op_layer_TemporalFusionTrans [(None, 192, 1)]     0           input_1[0][0]                    
__________________________________________________________________________________________________
tf_op_layer_TemporalFusionTrans [(None, 192)]        0           tf_op_layer_TemporalFusionTransfo
__________________________________________________________________________________________________
sequential (Sequential)         (None, 192, 320)     308160      tf_op_layer_TemporalFusionTransfo
______________________________________________________________________________________________

W1218 18:06:33.863079 140065355179776 deprecation.py:323] From /usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Loading model from /home/bryanlim/research/prod_code/notebooks/tft/outputs/saved_models/traffic/fixed/tmp/TemporalFusionTransformer.check


W1219 08:23:44.980646 140065355179776 deprecation_wrapper.py:119] From /home/bryanlim/research/prod_code/notebooks/tft/libs/tft_model.py:1334: The name tf.keras.backend.get_session is deprecated. Please use tf.compat.v1.keras.backend.get_session instead.

W1219 08:23:44.982270 140065355179776 deprecation_wrapper.py:119] From /home/bryanlim/research/prod_code/notebooks/tft/libs/utils.py:160: The name tf.get_collection is deprecated. Please use tf.compat.v1.get_collection instead.

W1219 08:23:44.983052 140065355179776 deprecation_wrapper.py:119] From /home/bryanlim/research/prod_code/notebooks/tft/libs/utils.py:160: The name tf.GraphKeys is deprecated. Please use tf.compat.v1.GraphKeys instead.

W1219 08:23:44.984141 140065355179776 deprecation_wrapper.py:119] From /home/bryanlim/research/prod_code/notebooks/tft/libs/utils.py:161: The name tf.train.Saver is deprecated. Please use tf.compat.v1.train.Saver instead.



Model saved to: /home/bryanlim/research/prod_code/notebooks/tft/outputs/saved_models/traffic/fixed/TemporalFusionTransformer.ckpt


To evaluate model performance, we reload the serialised model and compute P50 and P90 losses.

In [12]:
tf.reset_default_graph()
with tf.Graph().as_default(), tf.Session(config=tf_config) as sess:

    tf.keras.backend.set_session(sess)
    
    # Create a new model & load weights
    model = TemporalFusionTransformer(model_params, 
                                    use_cudnn=True)
    model.load(model_folder)
    
    # Make forecasts
    output_map = model.predict(test, return_targets=True)

    targets = data_formatter.format_predictions(output_map["targets"])

    # Format predictions
    p50_forecast = data_formatter.format_predictions(output_map["p50"])
    p90_forecast = data_formatter.format_predictions(output_map["p90"])

    def extract_numerical_data(data):
        """Strips out forecast time and identifier columns."""
        return data[[
          col for col in data.columns
          if col not in {"forecast_time", "identifier"}
        ]]

    # Compute Losses
    p50_loss = utils.numpy_normalised_quantile_loss(
        extract_numerical_data(targets), extract_numerical_data(p50_forecast),
        0.5)
    p90_loss = utils.numpy_normalised_quantile_loss(
        extract_numerical_data(targets), extract_numerical_data(p90_forecast),
        0.9)

Resetting temp folder...
*** TemporalFusionTransformer params ***
# known_categorical_inputs = [0]
# num_epochs = 100
# max_gradient_norm = 100.0
# static_input_loc = [4]
# input_obs_loc = [0]
# hidden_layer_size = 320
# column_definition = [('id', <DataTypes.REAL_VALUED: 0>, <InputTypes.ID: 4>), ('hours_from_start', <DataTypes.REAL_VALUED: 0>, <InputTypes.TIME: 5>), ('values', <DataTypes.REAL_VALUED: 0>, <InputTypes.TARGET: 0>), ('time_on_day', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('day_of_week', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('hours_from_start', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('categorical_id', <DataTypes.CATEGORICAL: 1>, <InputTypes.STATIC_INPUT: 3>)]
# early_stopping_patience = 5
# output_size = 1
# input_size = 5
# learning_rate = 0.001
# num_encoder_steps = 168
# model_folder = /home/bryanlim/research/prod_code/notebooks/tft/outputs/saved_models/traffic/fixed
# multiprocessing_workers = 5
# known_regular_

W1219 08:23:59.846418 140065355179776 deprecation_wrapper.py:119] From /home/bryanlim/research/prod_code/notebooks/tft/libs/utils.py:186: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.




[[ 0.00491535  0.00358888 -0.03971872 ...  0.14401527 -0.02203974
   0.11571892]
 [-0.06665877 -0.11486395 -0.24659614 ... -0.16529417  0.05967648
  -0.23826502]
 [-0.02768596 -0.00410851 -0.00946631 ... -0.14548478 -0.0671318
   0.03113735]
 ...
 [ 0.12403987 -0.0016388   0.2350335  ... -0.13526227  0.09376924
   0.02481275]
 [ 0.05744901  0.05871652  0.07034711 ...  0.03795624  0.14012192
   0.05822404]
 [ 0.14277846  0.05741519  0.05315429 ...  0.00322583  0.20357557
  -0.07062716]]
tensor_name:  TemporalFusionTransformer/time_distributed_51/bias
[-0.47612965 -0.09316115  0.39905387  0.38569704 -0.3335706  -0.0127484
 -0.23193046  0.00601276  0.3818157  -0.6157066  -0.21581744  0.06233747
 -0.2336487  -0.21165335 -0.37570047 -0.05146493 -0.14636041  1.0509906
 -0.04836453 -0.32447472 -0.35439545 -0.7203101  -0.23912019  0.2840183
  0.02775555  0.44096544  0.08269168  0.0620349   0.0734289   0.10778423
  0.04806028  0.21030554  0.1054052   0.4557979  -0.02592745 -0.21446438
 -0.2229

W1219 08:24:00.199779 140065355179776 deprecation.py:323] From /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.


Done.


In [13]:
print("Normalised quantile losses: P50={}, P90={}".format(p50_loss.mean(), p90_loss.mean()))

Normalised quantile losses: P50=0.09363400457506142, P90=0.0696119659978696


## Interpretability Use Cases
The relationships learnt by the TFT can also be studied using the trained model, through:

1. Analyzing the variable selection weights to identify significant features for the prediction problem.
2. Visualizing distributions of self-attention weights to determine the presence of any persistent temporal relationships.

In the remainder of this section, we demonstrate two interpretability use cases to showcase the above.

### Generate Weights for Interpretability
First, we generate all necessary variable selection and attention weights required for analysis.

In [32]:
# Store outputs in maps
counts = 0
interpretability_weights = {k: None for k in ['decoder_self_attn', 
                                              'static_flags', 'historical_flags', 'future_flags']}

tf.reset_default_graph()
with tf.Graph().as_default(), tf.Session(config=tf_config) as sess:

    tf.keras.backend.set_session(sess)
    
    # Create a new model & load weights
    model = TemporalFusionTransformer(model_params, 
                                    use_cudnn=True)
    model.load(model_folder)
    for identifier, sliced in test.groupby('id'):
        
        print("Getting attention weights for {}".format(identifier))
        weights = model.get_attention(sliced)
        
        for k in interpretability_weights:
            w = weights[k]
            
            # Average attentin across heads if necessary
            if k == 'decoder_self_attn':
                w = w.mean(axis=0)
            
                # Store a single matrix for weights to reduce memory footprint
                batch_size, _, _ = w.shape                 
                counts += batch_size
            
            if interpretability_weights[k] is None:
                interpretability_weights[k] = w.sum(axis=0)
            else:
                interpretability_weights[k] += w.sum(axis=0)

interpretability_weight = {k: interpretability_weights[k]/counts for k in interpretability_weights}

print('Done.')

Resetting temp folder...
*** TemporalFusionTransformer params ***
# known_categorical_inputs = [0]
# num_epochs = 100
# max_gradient_norm = 100.0
# static_input_loc = [4]
# input_obs_loc = [0]
# hidden_layer_size = 320
# column_definition = [('id', <DataTypes.REAL_VALUED: 0>, <InputTypes.ID: 4>), ('hours_from_start', <DataTypes.REAL_VALUED: 0>, <InputTypes.TIME: 5>), ('values', <DataTypes.REAL_VALUED: 0>, <InputTypes.TARGET: 0>), ('time_on_day', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('day_of_week', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('hours_from_start', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('categorical_id', <DataTypes.CATEGORICAL: 1>, <InputTypes.STATIC_INPUT: 3>)]
# early_stopping_patience = 5
# output_size = 1
# input_size = 5
# learning_rate = 0.001
# num_encoder_steps = 168
# model_folder = /home/bryanlim/research/prod_code/notebooks/tft/outputs/saved_models/traffic/fixed
# multiprocessing_workers = 5
# known_regular_

### Use Case 1: Analyzing Variable Importance 

Next, we analyze the distribution of variable selection weights on the input layer -- using this to quantify the relative importance of a given feature for the prediction problem in general. This is split into variable importance for static covariates, time-varying historical inputs variables and known future inputs as shown below.

In [39]:
import numpy as np
def get_range(static_gate, axis=None):
    """Returns the mean, 10th, 50th and 90th percentile of variable importance weights."""
    return {'Mean': static_gate.mean(axis=axis), 
               '10%': np.quantile(static_gate, 0.1, axis=axis),
               '50%': np.quantile(static_gate, 0.5, axis=axis),
               '90%': np.quantile(static_gate, 0.9, axis=axis)}

#### Static Variable Importance


In [58]:
def flatten(x):
    static_attn = x
    static_attn = static_attn.reshape([-1, static_attn.shape[-1]])
    return static_attn

static_attn = flatten(interpretability_weights['static_flags'])
m = get_range(static_attn, axis=0)
pd.DataFrame({k: pd.Series(m[k], index=['ID']) for k in m})

Unnamed: 0,10%,50%,90%,Mean
ID,1.0,1.0,1.0,1.0


#### Temporal Variable Importance -- Past Inputs

In [59]:
x = flatten(interpretability_weights['historical_flags'])
m = get_range(x, axis=0)
pd.DataFrame({k: pd.Series(m[k], index=['Hour of Day', 'Day of Week', 'Time Index', 'Target']) for k in m})

Unnamed: 0,10%,50%,90%,Mean
Hour of Day,0.321299,0.321317,0.321358,0.321323
Day of Week,0.07715,0.077151,0.077153,0.077151
Time Index,0.214888,0.214898,0.214916,0.2149
Target,0.386603,0.386621,0.38665,0.386625


#### Temporal Variable Importance -- Future Inputs

In [60]:
x = flatten(interpretability_weights['future_flags'])
m = get_range(x, axis=0)
pd.DataFrame({k: pd.Series(m[k], index=['Hour of Day', 'Day of Week', 'Time Index']) for k in m})

Unnamed: 0,10%,50%,90%,Mean
Hour of Day,0.680779,0.680784,0.680789,0.680784
Day of Week,0.208171,0.208177,0.208183,0.208177
Time Index,0.111037,0.111038,0.111039,0.111038


### Use Case 2: Visualizing Persistent Temporal Patterns
Lastly, we analyse the distribution of self-attention weights across various horizons to see if any persistent temporal patterns exist within the dataset. This allows us to identify any seasonal patterns or lagged relationships in the dataset, based on which past time steps are consistently important for predictions at a given horizon. 

We visualize this using the average attention pattern for various prediction horizons, as shown below:

#### Mean Attention Weights for Various Prediction Horizons

In [1]:
# Plotting libraries & Functions
import plotly.offline
from plotly.offline import download_plotlyjs, init_notebook_mode, plot
import plotly.graph_objs as go
import cufflinks as cf
from IPython.display import HTML

# Loads plotly charts
def iplot(fig, s='plot.html'):
    filename = os.path.join(output_folder, s)
    plotly.offline.plot(fig, filename=filename, auto_open=False)
    return HTML(filename)    

def plotly_chart(df, title=None, kind='scatter', x_label=None, y_label=None, secondary_y=None, fill=None,
                shape=None, subplots=False):
    
    fig = df.iplot(asFigure=True, title=title, kind=kind, xTitle=x_label, yTitle=y_label, secondary_y=secondary_y,
                  fill=fill, subplots=subplots, shape=shape)

    return iplot(fig)

In [72]:
self_attn = interpretability_weights['decoder_self_attn']

means = pd.DataFrame({"horizon={}".format(k): self_attn[model.num_encoder_steps+k-1, :] 
                      for k in [1, 5, 10, 15, 20]})
means.index -= model.num_encoder_steps

plotly_chart(means,
             x_label="Positiion Index (n)",
             y_label="Mean Attention Weight",
             title="Average Attention Pattern at Various Prediction Horizons")

Copyright 2019 The Google Research Authors.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at:

In [None]:
 http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.