# Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting
- AI Hub Notebook: [AI Hub - Interpretable Multi-Horizon Time Series Forecasting with TFT](https://aihub.cloud.google.com/p/products%2F9f39ad8d-ad81-4fd9-8238-5186d36db2ec)
- Arxiv Paper: [Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting](https://arxiv.org/pdf/1912.09363.pdf)
- UCI ML Repository Data: [PEMS-SF Dataset](https://archive.ics.uci.edu/ml/datasets/PEMS-SF)

**FAQs**  
**What is multi-horizon forecasting?**  
Multi-horizon forecasting(MHF) often contains a `complex mix of inputs` - including `static covariates`, known `future inputs` and other `exogenous time series` that are only observed in the past - without any prior information on how they interact witht the target.

**What is new in this paper?**  
There are many DL architectures published for multi-horizon forecasting problems but most of them are `black box` models. This paper brings insights into how the full range of inputs present in practical scenarios

**What is a `Temporal Fusion Transformer(TFT)`?**  
A TFT is a novel attention based architecture which combines high performance multi-horizon forecasting with interpretable insights into temporal dynamics.

**What is powers the TFT architecture?**    
To learn temporal relationship at different scales, TFT uses `Recurrent Neural Network(RNN)` for local processing and interpretable self-attention layers for long-term dependencies.
It utilizes specialized components to select relevant features and a series of gating layers to suppress unnecessary components, enabling high performance in a wide range of scenarios.

**What are the key DL concepts used?**
- Recurrent Neural Networks (RNN)
- Long Short Term Memory Networks (LSTM)
- Attention Based Models
- Transformer Based Models

**What challenges this paper specifically addresses that is not considered before?**
- Consideration of Exogenous inputs are known into the future
- Not neglecting static covariates
- Designing networks with suitable inductive biases 
- Considering heterogeneity of forecasting inputs

**Does this paper comes with source code?**  
Yes, the authors implemented on 3 of real-world datasets and demonstrated the significance of this architecture

## 1. Introduction
MHF, the prediction of variables of interest at multiple future time steps. It is a very important problem in time series ML. Most of the time series problems solves one time step ahead. For example, 
- Given a series of observations of event collected over a period of time, `what could be the next event`
- Given a list of closing day stock price for a given stock, `what is the price at the end of next day?`
- From the observations of monthly purchasing power of a household, `what could be the purchase amount for the next month?`
- Given unit of electricity consumed at hourly basis, `what is the power consumption by 10AM?`

In case of MHF, the above highlighted questions transform to multiple time steps
- what could be the next event $\Longrightarrow$ `what could be the next 4 subsequent events`
- what is the price at the end of next day? $\Longrightarrow$ `what is the end of day price for tomorrow, day-after-tomorrow and the the day after`
- what could be the purchase amount for the next month? $\Longrightarrow$ `what could be the purchase amount for the next 6 months?`
- what is the power consumption in the next 1 hour? $\Longrightarrow$ `what is the power consumption by 10AM, 11AM, Noon and 1PM?`

#### Data Source of MHF
MHFs have access to variety of data sources,heterogeneity of the data sources with little information about their interactions makes MHF a challenging problem, 
- Information about the future (e.g. upcoming holidays)
- exogenous time series(e.g. historical customer footprint)
- static meta data (e.g. geo-location of the store)

#### Usual Challenges 

- Common challenges or problems with autoregressing models is assume all exogenous inputs are known into future
- These variables are simply concatenated with time-dependent features at each step
- Most architectures are `black box`, where forecasts are controlled by complex nonlinear interactions between many parameters
- Trustworthiness of the model is questioned due to the opaque nature

Further, commonly used explainability methods(LIME and SHAP) for DNNs are not well suited for applying to temporal data.
- LIME and SHAP do not consider the time ordering of input features
- LIME, surrogate models are independently constructed for each data-point
- SHAP features are considered independently for neighboring time steps
Such post-hoc approaches might lead to poor explanation quality as dependencies between time steps are typically significant in time series.

- Attention based architectures has the inherent interpretability for sequential data (e.g. language or speech). 
- Usually language or speech datasets are univariabe but temporal datas are multivariate, applying an Attention based model on such datasets is a novelty but heterogeneity of data is still a challenge
- 

### 1.1 Temporal Fusion Transformers(TFT)
The TFT proposed in this paper is an attention based DNN architecture for MHF to achieve high performance while enabling interpretability.
Novel ideas incorporated considering full range of potential inputs and temporal relationships are
1. Static covariate encoders which encode context vectors for use in other parts of the network
2. Gating mechanisms througout and sample dependent variable selection to minimize the contributions of irrelevant inputs
3. A sequence to sequence layer to locally process known and observed inputs
4. a temporal self-attention decoder to learn any long term dependencies present in the dataset - This facilitates the interpretability by identifying,  
    a. Globally important variables for the prediction problem  
    b. Persistent temporal patterns  
    c. Significant events  
    
While conventional method assumes target alone to be fed into prediction recursion loop and ignores numerous useful time-varying inputs for the 2nd time step onwards, TFT explicitly accounts for the diversity of inputs. This is done by naturally handling static covariates and time-varying inputs

#### Time Series Interpretability with Attention
Attention mechanisms are used in 
- Translation[17]
- Image Classification[22]
- Tabular Learning[23]  

To identify saliance of input for each instance using the magnitude of attention weights. With interpretability motivations, time series researches[7, 12, 24] were conducted  using LSTM[25] and Transformer based architectures. However it is done, without giving importance of static covariates 

TFT alleviates th static covariates problem with separate encoder-decoder attention at each step on top of the self-attention to determine the contribution of temporal inputs

Post-hoc interpretability methods are applied on pre-trained black-box models and often based on distilling into a surrogate interpretable model or decomposing into feature attributions. They are not designed to take into account the time ordering of inputs, limiting their use for complex time series data.

**Feature Selection Methods**
- Inherently interpretable modeling approaches build components for feature seection directly into the architecture
- For time series forecasting, they are based on explicitly quantifying time-dependent variable contributions
- Interpretable Multi-Variable LSTMs[27] partitions the hidden state such that each variable contributes uniquely to its own memory segment and weights memory segments to determine variable contributions
- By computing single contribution coefficient based on attention weights, temporal importance and variable selections schemes are identified

TFTs is designed to analyze global temporal relationships with input data and allow users to interpret global behviors of the model on the whole dataset. Specifically in the identification of any persistent patterns(e.g. seasonality or lag effects) and regimes present.

### 1.2 Related Work
- Traditional Multi Horizon Forecasting Methods[18, 19]
- Iterated approaches using autoregressive models[9, 6, 12]
    - They are one step ahead prediction models with multi-stpe predictions obtained by recursively feeding predictions into future inputs
    - LSTM Networks like Deep AR[9] uses stacked LSTM to generate parameters of a predefined linear state-space model with predictive distributions produced via Kalman Filter
    - Further convolutional layers for local processing and a sparse attention mechanism to increase the receptive field during forecasting
    - These methods assumes target alone to be fed recursively into future inputs
- Direct methods based on sequence-to-sequence models[10, 11]
    - Direct methods explicity generate forecasts for multiple predefined horizons at each time step relying seq2seq architecture
    - LSTM encoders to summarize past inputs and a variety of methods to generate future predictions.
    - MQRNN[10] uses LSTM or Convolutional encoders to generate context vectors to feed into an MLP for each horizon
    - A multi-model attention mechanism is used with LSTM encoders to context vectors for a bi-directionsal LSTM decoder
    - Yet interpretability remains challenging
    
**Others**  
- Post-hoc explanations methods, [LIME, SHAP, RL-LIM],[15, 16, 26]
- Inherently interpretable models[27, 24]
- Methods combining temporal importance and variable selection[24]

## 2 Multi-Horizon Forecasting

### 2.1 Inputs and Targets
In a given time series dataset at each time step $\mathcal{t} \in [0, T_i]$
- $I$ is the unique number of entities.
- Each entity $i$ is associated with a set of static covariates. i.e. $\mathbf{s}_i \in \mathbb{R}^{m_s}$
- Inputs $\mathcal{X}_{i, t} \in \mathbb{R}^{m_x}$
- Targets $\mathcal{y}_{i,t} \in \mathbb{R}$

Time dependent inputs are divided into 2 categories  
 
$$\large{\mathcal{X}_{i,t} = [\mathcal{z}^T_{i,t}, \mathcal{x}^T_{i,t}]^T}$$  

**Observed Inputs**  
Measured at each step and are unknown beforehand  
$$\mathcal{z}^T_{i,t} \in \mathbb{R}^{m_z}$$
**Known Inputs**  
Inputs that are pre-determined (e.g. day of week at time $t$)
$$\mathcal{x}^T_{i,t} \in \mathbb{R}^{m_x}$$

#### Quantile Forecasting
The provision for prediction intervals can be useful for optimizing decisions and risk managemnt by yielding and indication of likely best and worst case values of the target. Adoption of quantile regression to MHF setting [10, 50, 90] percentiles at each time step. Quantile forecast takes the form

$$\large{\hat{\mathcal{y_i}}(q, t, \tau) = f(q, \tau, y_{i, t-k:t}, z_{i, t-k:t}, x_{i,t-k:t + \tau}, s_i)}  \tag{1}$$

where,
- $\hat{\mathcal{y_{i, t+\tau}}}(q, t, \tau)$ is the predicted $q^{th}$ sample quantile of $\tau$ step ahead forecast at time $t$
- $f_q(\cdots)$ is the prediction model
- $k$ is the finite look back window of all past observations 
- $t$ is the start time of the start time, target and known inputs are available only till time $t$.. i.e
$$y_{i, t-k:t} = \{y_{i,t-k}, \cdots, y_{i,t}\}$$
- All known inputs across the entire range is 
$$x_{i,t-k:t+\tau} = \{x_{i, t-k}, \cdots, x_{i, t + \tau}\}$$

Output forecasts for $\tau_{max}$ time steps - i.e. $\tau \in \{1, \cdots, \tau_{max}\}$



## 3. Model Architecture
TFT Architecture inputs,
- Static metadata
- Time varying past inputs
- Time varying a priori known future inputs

1. Variable slections  is used for judicioous selection of the most salient features based on the input.  
2. Gated Residual Network(GRN) blocks enable efficient information flow with skip connections and gating layers.  
3. Time dependent processing is based on LSTMs for local processing and multi-head attention for integrating information from any time step

Building blocks of a TFT,  
1. **Gating Mechanisms**:  
    - GMs to skip over any unused components of the architecture, 
    - Provide adaptive depth and network complexity to accomodate a wide range of datasets and scenarios.


2. **Variable Selection Networks**:  
    - VSNs to select relevant input variables at each time step


3. **Static Covariate Encoders**:  
    - SCEs to integrate static features into the network, through encoding of context vectors to condition temporal dynamics.


4. **Temporal Processing**:  
    - TPs to learn both long and short-term temporal relationships from both observed and known time-varying inpts
    - A seq2seq layer is employed for local processing, 
    - Long term dependencies are captured using a novel interpretable multi-head attention block


5. **Prediction Intervals**:
    - PIs via quantile forecasts to determine the range of likely target values at each prediction horizon


### 3.1 Gating Mechanism
- The relationship between exogenous inputs and targets are often unknown, makes it difficult to anticipate the relevance of a variable
- It is also difficult to determine the extent of required non-linear processing, and a simpler model might work better (e.g. small or noisy data)
- To make the model flexible, non-linear processing is applied on need bases.

The paper proposes Gated Residual Network(GRN) as a building block of TFT. The GRN takes primary input $a$ and an optional context vector $c$ and yields:

$$\large{GRN_{\omega}(a, c) = LayerNorm(a + GLU_{\omega}(\eta_1)} \tag{2}$$
$$\eta_1 = \mathbf{W}_{1,\omega} \eta_2 + \mathcal{b}_{1,\omega} \tag{3}$$
$$\eta_2 = ELU(\mathbf{W}_{s,\omega} a + \mathbf{W}_{3,\omega} c + \mathcal{b}_{2,\omega} )  \tag{4}$$

where,
- ELU is the Exponential Linear Unit activation function[28]
- $\eta_1 \in \mathbb{R}^{d_model}$ - Intermediate Layer
- $\eta_2 \in \mathbb{R}^{d_model}$ - Intermediate Layer
- $LayerForm$ is standard layer normalization[29]
- $\omega$ is an index to denote weight sharing

ELU activation results in linear behavior by rendering 2 properties
- identity function $\mathbf{W}_{s,\omega} a + \mathbf{W}_{3,\omega} c + \mathcal{b}_{2,\omega} \gg 0$
- constant generator $\mathbf{W}_{s,\omega} a + \mathbf{W}_{3,\omega} c + \mathcal{b}_{2,\omega} \ll 0$



**Gated Linear Units(GLU)**  
GLUs provide the flexibility to suppress any parts of architecture that are not required for a given dataset. 

Let $\mathcal{\gamma} \in \mathbb{R}^{d_model}$ be the input, then

$$GLU_w(\mathcal{\gamma}) = \sigma(\mathbf{W}_{b, w} \gamma + b_{4, w} \odot (\mathbf{W}_{5,\omega} + b_{5, \omega}) \tag{5}$$

where,
- $\sigma(\cdot)$ is the sigmoid activation functions
- $\mathbf{W} \in \mathbf{R}^{d_model x d_model}$ is the weight
- $b_{(\cdot)} \in \mathbf{R}^{d_model}$ is the bias
- $odot$ is element wise Hadamard product 
- $d_{model}$ is the hidden state size

1. GLU allows TFT to control the extent to which the GRN contributes to the original input $a$. 
2. It potentially skips over th layer entirely if necessary as athe GLU ouputs could be all close to 0 in order to suppress the nonlinear contribution
3. For instances without context vector, GRN simply treats the context to zero. i.e. $\mathcal{C} = 0$
4. During training, dropout is applied before the gating layer and layer normalization - i.e to $\eta_1$ in Eq.3

### 3.2 Variable Selection Networks
Relevance of a specific variable from the dataset and its contributions are unknown while training a deep neural network. TFT is designed to provide instance-wise bariable selection through
- The use of variable selection networks applied to both static covariates and time dependent covariates to pick the most salient variables from dataset
- Variable selection also allows TFT to remove any unnecessary noisy inputs that may impact the performance negatively

This is accomplished using entity embeddings[31] for categorical variabls as feature representations and linear transformations for continuous variables.
- Entity embeddings transform a variable into $d_{model}$ dimensional vector that matches the dimensions in the subsequent layers for skip connections
- All static, past and future inputs make use of separate variable selection networks
- Variable selection network for past inputs are presented without losing generality

Let,
- $\xi^{(j)}_t \in \mathbb{R}^{d_model}$ denote the transformed input of the j-th variable at time t
- $\Xi_t  = [\xi ^{(1)^T}_t, \cdots, \xi^{(m_{\mathcal{x}})}_t]^T$ being the flattened vector of all past inputs at time $t$

Variable selection weights are generated by feeding $\Xi_t$ and an external context vector $\large\mathcal{c}_s$ through a GRN and then a Softmax Layer

$$\large{\mathcal{V_{Xt}} = Softmax(GRN(\Xi_t, \large\mathcal{c}_s))}\tag{6}$$

where,
- $\mathcal{V_{Xt}} \in \mathbb{R}^{m_x}$ is a vector of variable selections weights 
- $\large\mathcal{c}_s$ is obtained from a statice covariate encoder
- For static variables, the context vector $\large\mathcal{c}_s$ is omitted - given it already has access to static information

At each time step an additional layer of non layer of non-linear processing is employed by feeding $\xi^{(j)}_t$ through its own GRN

$$\large{\tilde{\xi}^{(j)}_t =  GRN_{\tilde{\xi}(j)}(\xi^{}_t)}\tag{7}$$

Where,
- $\tilde{\xi}^{(j)}_t$ is the processed feature vector for variable $j$
- Note, each variable has its own GRN with weights shared across all time steps $t$
- Processes features are then weighted by their variable selection weights and combines

$$\large{\tilde{\xi_t}} = \sum^{\mathcal{m_x}}_{j=1} \mathcal{V^{(j)}_{Xt}} \tilde{\xi_t}^{(j)}  \tag{8}$$

Where,
- $\mathcal{V^{(j)}_{Xt}}$ is the $j^{th}$ element of vector $\mathcal{V_{Xt}}$

### 3.3 Static Covariate Encoders
TFT uses separate GRN encoders to integrate information from static metadata to produce 4 different context vectors
$$\large[\mathcal{c}_s, \large\mathcal{c}_e, \large\mathcal{c}_c, \large\mathcal{c}_h]$$
These context vectors are wired into various locations in the Temporal Fusion Decoder(TFD).
- $\large\mathcal{c}_s$ is temporal variable selection
- $(\large\mathcal{c}_c, \large\mathcal{c}_h)$ Local processing of Temporal Features
- $\large\mathcal{c}_e$ Enriching of temporal features with static information

Let, $\zeta$ be the output of the static variable selection network, context for temporal varialble selection would be encoded according to 
$$\large\mathcal{c}_s = GRN_{\mathcal{c}_s}(\zeta)$$

### 3.4 Interpretable Multi-Head Attention
TFT employs a `self-attention mechanism` to learn long-term relasionships across different time steps, which we modify from `multi-head attention` in transformer based architectures[17, 12] to enhance explainability.

**The Q, K and V**
- Attention mechanisms `scales values` $\large{V} \in \mathbb{R}^{Nxd_v}$ relationship between keys(K) and queries(Q)
- $\large{K} \in \mathbb{R}^{Nxd_{attn}}$ is the Key
- $\large{Q} \in \mathbb{R}^{Nxd_{attn}}$ is the Query

$$Attention(\large{Q, K, V}) = A(\large{Q,K})V\tag{9}$$

Where,
- $\large A()$ is the normalization function - A common choice is scaled dot-product attention

$$\large A({Q,K}) = Softmax(\large\frac{QK^T}{\sqrt{d_{attn}} })\tag{10}$$

Multi Head Attention is proposed in employing different heads for different representation subspaces
$$MultiHead{(Q,K,V)}) = \large [H_1, \cdots, H_{m_H}]W_H\tag{11}$$
$$H_h = Attention(QW^{(h)}_Q, KW^{(h)}_K, VW^{(h)}_V) \tag{12}$$

Where,
- $W^{(h)}_K \in \mathbb{R}^{d_{model} x d_{attn}}$ is head specific weights for keys
- $W^{(h)}_Q \in \mathbb{R}^{d_{model} x d_{attn}}$ is head specific weights for queries
- $W^{(h)}_V \in \mathbb{R}^{d_{model} x d_{V}}$ is head specific weights for values

$W_H \in \mathbb{R}^{(m_h.d_V)xd_{model}}$ linearly combines outputs contatenated from all heads $H_h$

Multi-head attention to share values in each head, and employ `additive aggregation` of all heads
$$InterpretableMultiHead(Q, K, V) = \tilde H \tilde{W}_H \tag{13}$$
$$\tilde H = \tilde A(Q, K)V W_V \tag{14}$$

$$\tilde H =  \huge\{ \normalsize 1/H  \sum^{m_H}_{h=1} A(QW^{(h)}_Q, KW^{(h)}_K)  \huge\}\normalsize VW_V \tag{15}$$
$$\tilde H =  1/H  \sum^{m_H}_{h=1} A(QW^{(h)}_Q, KW^{(h)}_K, VW_V)\tag{16}$$

Where,
- $W_v \in \mathbb{R}^{d_{model} x d_V}$ are value weights shared across all heads
- $W_H \in \mathbb{R}^{d_{attn} x d_{model}}$ is used for final linear mapping

1. Through this, each head can learn different temporal patterns, while attending to a common set of input features.
2. These features can be interpretted as a simple ensemble over attention weights into combined matrix Eq.14.
3. Compared to Eq.10, Eq.14 yields an increased representation capacity in an efficient way