This notebook describes the temporal fusion transformers [@lim2021temporal] architecture, and ports it over to keras 3 while making some punctual improvements, including making the notation closer to the paper math.

The original repository is: https://github.com/google-research/google-research/tree/master/tft.

In [None]:
from __future__ import annotations

import numpy as np
import keras_core as keras
from keras_core import layers
from fastcore import docments
from nbdev.showdoc import show_doc


# Introduction

The main characteristics of TFT that make it interesting for nowcasting or forecasting purposes are:

- **multi-horizon forecasting**: the ability to output, at each point in time $t$, a sequence of forecasts for $t+h, h > 1$
- **quantile prediction**: each forecast is accompanied by a quantile band that communicates the amount of uncertainty around a prediction
- **flexible use of different types of inputs**: static inputs (akin to fixed effects), historical input and known future input (eg, important holidays, years that are known to have major sports events such as Olympic games, etc)
- **interpretability**: the model learns to select variables from the space of all input variables to retain only those that are globally meaningful, to assign attention to different parts of the time series, and to identify events of significance


# Notation

* unique entities: $i \in (1, \dots\, I)$
* time periods $t \in [0, T_i]$
  * $k \geq 1$ lags
  * $h \geq 1$ forecasting period
* set of entity-level static covariates: $s_i \in \mathbf{R}^{m_s}$
* set of temporal inputs: $\chi_{i, t} \in \mathbf{R}^{m_\chi}$
  * $\chi_{i,t} = [z_{i,t}, x_{i,t}]$
    * $z_{i,t} \in \mathbf{R}^{m_z}$ are observed inputs
    * $x_{i,t} \in \mathbf{R}^{m_z}$ are a priori known inputs (eg, years that have major sports events)
    * $m_\chi = m_z + m_x$
* target scalars: $y_{i,t}$
  * $\hat{y}_{i,t,q} = f_q(y_{i,t-k:t}, z_{i,t-k:t}, x_{i,t-k:t+h}, s_i)$
* hidden unit size (common across all the TFT architecture): $d_{\text{model}}$
* transformed input of $j$-th variable at time $t$: $\xi_t^{(j)} \in \mathbf{R}^{d_{\text{model}}}$
  * $\Xi_t = [\xi_t^{(1)}, \dots, \xi_t^{(m_\chi)}]$

# Architecture

## Concepts

* **time distributed**: 
  * applies same layer to each of the time steps in the data
    * in other words, a layer with the exact same weights
  * indices:
    * index 0: batch
    * index 1: time
    * indices 2...: data
  * More info: https://www.tensorflow.org/api_docs/python/tf/keras/layers/TimeDistributed

## Components

### Linear layer

> dedicated implementation to better control use of time distribution on vanilla linear layer

$$
\mathbb{Y} = \phi(\mathbf{W} x + \mathbf{b}),
$$ {#eq-dense}

where $x$ is the input to `linear_layer()(x)`, $\mathbb{Y}$ is the output of `linear_layer()(x)`, $\phi$ is an activation function (or no activation function is `activation` is `None`), $\mathbf{W} \in \mathbf{R}^{(d_{\text{size}} \times d_{\text{inputs}})}$ is a matrix of weights and $\mathbf{b} \in \mathbf{R}^{d_{size}}$ is a vector of biases. Importantly, $\mathbf{W}$ and $\mathbf{b}$ are indexed with $_{\omega}$ to denote weight-sharing when the layer is time-distributed.

In [63]:
def linear_layer(size:int, # Output size
                 activation:str|callable|None=None, # Activation function
                 use_time_distributed:bool=False, # Apply the layer across all time steps?
                 use_bias:bool=True # Include bias in the layer?
)->layers.Dense: # Dense layer
    "Linear layer."

    linear = layers.Dense(size, activation=activation, use_bias=use_bias)
    if use_time_distributed:
        linear = layers.TimeDistributed(linear)
    return linear

In [64]:
#| output: asis
#| echo: false
show_doc(linear_layer, title_level=4)

---

#### linear_layer



Linear layer.

#### Example usage

In [65]:
#| code-fold: show

batch_size = 2
n_timesteps = 5
n_features = 100
layer_size = 8

# input dimensions: batches / timesteps / features
x = np.ones(batch_size*n_timesteps*n_features).reshape([batch_size, n_timesteps, n_features]) 

# dense layer
linear_td_true = linear_layer(size=layer_size, use_time_distributed=True)
linear_td_false = linear_layer(size=layer_size, use_time_distributed=False)

# output dimensions: batches / timesteps / layer size
assert linear_td_true(x).shape == [batch_size, n_timesteps, layer_size]
assert linear_td_false(x).shape == [batch_size, n_timesteps, layer_size]

Note that the time-distributed linear layer results in the same weights being applied to each time step:

In [66]:
#| code-fold: show

x = np.ones((1, n_timesteps, n_features))
timesteps_equal = []
for i in range(n_timesteps-1):
    timesteps_equal.append((np.array_equal(linear_td_true(x)[0,0,:], linear_td_true(x)[0,i+1,:])))

assert np.all(timesteps_equal)

linear_td_true(x)

<tf.Tensor: shape=(1, 5, 8), dtype=float32, numpy=
array([[[-0.27810854,  1.139082  , -1.7507869 ,  2.0346196 ,
          1.0295256 ,  1.7452098 ,  0.4395818 ,  0.54309237],
        [-0.27810854,  1.139082  , -1.7507869 ,  2.0346196 ,
          1.0295256 ,  1.7452098 ,  0.4395818 ,  0.54309237],
        [-0.27810854,  1.139082  , -1.7507869 ,  2.0346196 ,
          1.0295256 ,  1.7452098 ,  0.4395818 ,  0.54309237],
        [-0.27810854,  1.139082  , -1.7507869 ,  2.0346196 ,
          1.0295256 ,  1.7452098 ,  0.4395818 ,  0.54309237],
        [-0.27810854,  1.139082  , -1.7507869 ,  2.0346196 ,
          1.0295256 ,  1.7452098 ,  0.4395818 ,  0.54309237]]],
      dtype=float32)>

### Skip connection

> Adds inputs to layer and then implements layer normalisation

$$
\text{LayerNorm}(a + b),
$$ {#eq-skip}

for $a$ and $b$ tensors of the same dimension and $\text{LayerNorm}(\cdot)$ being the layer normalisation (@ba2016layer), ie subtracting $\mu^l$ and dividing by $\sigma^l$ defined as:

$$
\mu^l = \frac{1}{H} \sum_{i=1}^H n_i^l \quad \sigma^l = \sqrt{\frac{1}{H} \sum_{i=1}^H (n_i^l - \mu^l)^2},
$$ {#eq-layernorm}

with $H$ denoting the number of $n$ hidden units in a layer $l$.

* Adding a layer's inputs to its outputs is also called "skip connection"
* The layer is then normalised [@ba2016layer] to avoid having the numbers grow too big, which is detrimental for gradient transmission
  * Layer normalisation uses the same computation both during training and inference times, and is particularly suitable for time series

In [67]:
def add_and_norm(
    x_list # List of input tensors (of the same dimension) for skip connection
    ):
    "Adds tensors with same dimensions and then normalises layer"
    tmp = layers.Add()(x_list)
    return layers.LayerNormalization()(tmp)

In [68]:
#| output: asis
#| echo: false
show_doc(add_and_norm, title_level=4)

---

#### add_and_norm

>      add_and_norm (x_list)

Adds tensors with same dimensions and then normalises layer

|    | **Details** |
| -- | ----------- |
| x_list | List of input tensors (of the same dimension) for skip connection |

#### Example usage

In [None]:
#| code-fold: show

batch_size = 3
n_timesteps = 5
n_features = 100

# input dimensions: batches / timesteps / features
x1 = np.random.randn(batch_size*n_timesteps*n_features).reshape([batch_size, n_timesteps, n_features]) 
x2 = np.random.randn(batch_size*n_timesteps*n_features).reshape([batch_size, n_timesteps, n_features]) 

# output dimensions: batches / timesteps / features
x1x2 = add_and_norm(x_list=[x1, x2])
assert x1.shape == x1x2.shape

Mean values of the normalised layer should be around 0:

In [69]:
print("mean of x1 units at each batch X time step: ", x1.mean(axis=-1), "\n")
print("mean of x2 units at each batch X time step: ", x2.mean(axis=-1), "\n")
print("mean of sum of x1 and x2:"), x1x2.numpy().mean(axis=-1)

mean of x1 units at each batch X time step:  [[ 0.0891     -0.00987941 -0.10854553 -0.00655235  0.07156264]
 [-0.05450076 -0.06417484  0.07080917 -0.13651035  0.08508652]
 [ 0.0993336  -0.09902981 -0.02505386  0.04154724 -0.11811789]] 

mean of x2 units at each batch X time step:  [[ 0.00124595 -0.08214826  0.01206193 -0.01212819  0.11724552]
 [ 0.03407577 -0.17638668  0.11328357 -0.20265001  0.00658767]
 [-0.00163865  0.04814776 -0.04867956 -0.08390731 -0.02609097]] 

mean of sum of x1 and x2:


(None,
 array([[-1.22189521e-08,  1.10268594e-08,  2.08616258e-09,
          1.81794171e-08,  4.76837148e-09],
        [ 9.53674295e-09,  4.61936001e-09, -3.79979603e-09,
          1.07288365e-08,  1.78813941e-08],
        [-2.02655794e-08, -2.20537189e-08,  7.00354574e-09,
         -8.34465030e-09, -1.49011609e-10]], dtype=float32))

Standard deviation (for the normalised output it should be around 1):

In [70]:
print("std of x1 units at each batch X time step: ", x1.std(axis=-1), "\n")
print("std of x2 units at each batch X time step: ", x2.std(axis=-1), "\n")
print("std of normalised sum of x1 and x2:"), x1x2.numpy().std(axis=-1)

std of x1 units at each batch X time step:  [[0.94563709 0.96309248 0.97245555 1.01093958 0.8727528 ]
 [1.03880269 1.08403902 0.94582936 0.96116156 1.036372  ]
 [0.95825912 1.02954164 0.98003212 1.01333929 0.95075193]] 

std of x2 units at each batch X time step:  [[0.92853231 0.89098436 0.92638462 1.168471   1.06304328]
 [0.99139127 0.89403001 1.06474063 1.01128404 1.0370899 ]
 [1.16739561 0.96358718 0.97307712 1.02691728 0.88659434]] 

std of normalised sum of x1 and x2:


(None,
 array([[0.9997464 , 0.99971324, 0.9997391 , 0.99975777, 0.99971575],
        [0.99971294, 0.99970776, 0.9997577 , 0.99974173, 0.9997703 ],
        [0.99977344, 0.99974054, 0.9997541 , 0.99974614, 0.9996765 ]],
       dtype=float32))

### Gated linear unit (GLU)

> Linear layer that learns how much to gate vs let pass through

Using input $\gamma \in \mathbb{R}^{d_{\text{model}}}$ and the subscript $\omega$ to index weights, 

$$
\text{GLU}_{\omega}(\gamma) = \sigma(W_{4, \omega} \gamma + b_{4, \omega}) \odot (W_{5, \omega} \gamma + b_{5, \omega}).
$$ {#eq-GLU}

* Introduced by @dauphin2017language
* The intuition is to train two versions of @eq-dense in the same data, but one of them having a sigmoid activation (which outputs values between zero and one), then multiplying each hidden unit
* The result could be zero or very close to zero through the Hadamard multipliciation, which in practice means that the network would not be affected by that data (ie, it would be gated out)
  * The first term (with the sigmoid) is the gate that determines what percentage of the linear layer passes through
* *"GLUs reduce the vanishing gradient problem for deep architectures by providing a linear path for gradients while retaining non-linear capabilities"*
* *"provide flexibility to suppress any parts of the architecture that are not required for a given dataset"*
* The GLU is part of @sec-GRN

In [109]:
def apply_gating_layer(
    x, # Input tensors (batch first)
    hidden_layer_size:int, # Dimension of the GLU
    dropout_rate:float|None=None, # Dropout rate
    use_time_distributed:bool=True, # Apply the GLU across all time steps?
    activation:str|callable=None # Activation function
)->(keras.KerasTensor, keras.KerasTensor): # $\text{GLU}(\gamma)$, $\sigma(W \gamma + b)$, both with dimension (batch_size, num_time_steps, hidden_layer_size)
    "Gated Linear Unit (GLU) layer"
    
    if dropout_rate is not None:
        x = layers.Dropout(dropout_rate)(x)

    activation_layer = linear_layer(
        size=hidden_layer_size,
        activation=activation,
        use_time_distributed=use_time_distributed
    )(x)
    
    gate_layer = linear_layer(
        size=hidden_layer_size,
        activation='sigmoid',
        use_time_distributed=use_time_distributed
    )(x)

    return layers.Multiply()([activation_layer, gate_layer]), gate_layer

In [110]:
#| output: asis
#| echo: false

show_doc(apply_gating_layer, title_level=4)

---

#### apply_gating_layer



#### Example usage

In [73]:
#| code-fold: show

batch_size = 3
n_timesteps = 5
n_features = 100
hidden_layer_size = 16

# input dimensions: batches / timesteps / features
x = np.random.randn(batch_size*n_timesteps*n_features).reshape([batch_size, n_timesteps, n_features]) 

# output dimensions: batches / timesteps / hidden_layer_size
assert apply_gating_layer(x=x, hidden_layer_size=hidden_layer_size)[0].shape == [batch_size, n_timesteps, hidden_layer_size]

In [74]:
#| code-fold: show

[i.shape for i in apply_gating_layer(x=x, hidden_layer_size=hidden_layer_size)]

[TensorShape([3, 5, 16]), TensorShape([3, 5, 16])]

### Gated residual network (GRN) {#sec-GRN}

$$
\text{GRN}_{\omega}(a, c) = \text{LayerNorm}(a + \text{GLU}_{\omega}(W_{1, \omega} \text{ELU}(W_{2, \omega} a + b_{2, \omega} + W_{3, \omega} c) + b_{1, w}))
$$ {#eq-GRN}

* Breaking down $\text{GRN}_{\omega}(a, c)$:
    * *1st step*: $\eta_{2} = \text{ELU}(W_{2, \omega} a + b_{2, \omega} + W_{3, \omega} c)$ (where the additional context $c$ might be zero) as in @eq-dense but adapted for the added context if any and with $\text{ELU}(\cdot)$ as the activation function,
    * *2nd step*: $\eta_{1} = W_{1, \omega} \eta_{2} + b_{1, w}$ as in @eq-dense,
    * *3rd step*: $\text{LayerNorm}(a + \text{GLU}_{\omega}(\eta_{1}))$ as in @eq-skip and @eq-GLU
* $\text{ELU}(\cdot)$ is the Exponential Linear Unit activation function (@clevert2015fast)
    * Unlike ReLUs, ELUs allow for negative values, which pushes unit activations closer to zero at a lower computation complexity, and producing more accurate results
* The GRN is a key building block of the TFT
    * Helps keep information only from relevant input variables
    * Also keeps the model as simple as possible by only applying non-linearities when relevant

In [75]:
def gated_residual_network(
    x, # Network inputs
    hidden_layer_size:int, # Dimension of the GRN
    output_size:int|None=None, # Size of output layer (if None, same as `hidden_layer_size`)
    dropout_rate:float|None=None, # Dropout rate
    use_time_distributed:bool=True, # Apply the GRN across all time steps?
    additional_context=None, # Additional context vector to use if relevant
    return_gate:bool=False #Whether to return GRN gate for diagnostic purposes
):
    "Applies the gated residual network (GRN) as defined in the paper"
    
    # Setup skip connection
    if output_size is None:
        output_size = hidden_layer_size
        skip = x
    else:
        linear = keras.layers.Dense(output_size)
        if use_time_distributed:
            linear = keras.layers.TimeDistributed(linear)
        skip = linear(x)

    # 1st step: eta2
    hidden = linear_layer(
        size=hidden_layer_size, # W2
        activation=None,
        use_time_distributed=use_time_distributed,
        use_bias=True # b2
    )(x)

    # "For instances without a context vector, the GRN simply treates the context input as zero - ie, $c = 0$ in Eq. 4"
    if additional_context is not None: # if c is != 0...
        hidden += linear_layer(
            size=hidden_layer_size, # W3
            activation=None,
            use_time_distributed=use_time_distributed,
            use_bias=False # no bias for additional context, since there already is bias from the "main" calculation of eta2
        )(additional_context)

    hidden = keras.layers.Activation('elu')(hidden)

    # 2nd step: eta1
    hidden = linear_layer(
        size=hidden_layer_size, # W1
        activation=None,
        use_time_distributed=use_time_distributed,
        use_bias=True # b1
    )(hidden)

    # 3rd step: concluding the GRN calculation
    gating_layer, gate = apply_gating_layer(
        x=hidden,
        hidden_layer_size=output_size,
        dropout_rate=dropout_rate,
        use_time_distributed=use_time_distributed,
        activation=None
    )

    GRN = add_and_norm([skip, gating_layer])

    if return_gate:
        return GRN, gate
    else:
        return GRN

In [76]:
#| output: asis
#| echo: false

show_doc(gated_residual_network, title_level=4)

---

#### gated_residual_network

>      gated_residual_network (x, hidden_layer_size:int,
>                              output_size:int|None=None,
>                              dropout_rate:float|None=None,
>                              use_time_distributed:bool=True,
>                              additional_context=None, return_gate:bool=False)

Applies the gated residual network (GRN) as defined in the paper

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| x |  |  | Network inputs |
| hidden_layer_size | int |  | Dimension of the GRN |
| output_size | int \| None | None | Size of output layer (if None, same as `hidden_layer_size`) |
| dropout_rate | float \| None | None | Dropout rate |
| use_time_distributed | bool | True | Apply the GRN across all time steps? |
| additional_context | NoneType | None | Additional context vector to use if relevant |
| return_gate | bool | False | Whether to return GRN gate for diagnostic purposes |

#### Example usage

In [77]:
#| code-fold: show

batch_size = 3
n_timesteps = 5
n_features = 100
hidden_layer_size = 16
output_size = 17

# input dimensions: batches / timesteps / features
x = np.random.randn(batch_size*n_timesteps*n_features).reshape([batch_size, n_timesteps, n_features]) 

grn, gate = gated_residual_network(
    x=x,
    hidden_layer_size=hidden_layer_size,
    output_size=output_size,
    dropout_rate=0,
    use_time_distributed=True,
    additional_context=None,
    return_gate=True
)

# output dimensions: batches / timesteps / hidden_layer_size
assert grn.shape == [batch_size, n_timesteps, output_size]

In [79]:
#| code-fold: show

gate[0] # first batch

<tf.Tensor: shape=(5, 17), dtype=float32, numpy=
array([[0.46408975, 0.51650125, 0.3576629 , 0.32816413, 0.28808403,
        0.25894374, 0.80485815, 0.6035425 , 0.18347855, 0.20668997,
        0.40690467, 0.77098   , 0.5859963 , 0.77419156, 0.5004397 ,
        0.6925783 , 0.8398583 ],
       [0.5860845 , 0.4493843 , 0.35765216, 0.4549006 , 0.4134464 ,
        0.4788437 , 0.7112705 , 0.42405283, 0.8090347 , 0.63934696,
        0.12561484, 0.19756034, 0.29987946, 0.82742363, 0.40238598,
        0.19599892, 0.2811166 ],
       [0.49277538, 0.47769853, 0.51799566, 0.5151028 , 0.49813136,
        0.37231475, 0.45661026, 0.51294667, 0.3605565 , 0.3776353 ,
        0.6286485 , 0.7145511 , 0.4172555 , 0.62817055, 0.37960175,
        0.67153347, 0.7933012 ],
       [0.6383022 , 0.67520565, 0.36656207, 0.6027976 , 0.52373886,
        0.41002965, 0.12543976, 0.6682144 , 0.41265413, 0.43105012,
        0.5055955 , 0.6043989 , 0.4804591 , 0.7942763 , 0.5494791 ,
        0.27742058, 0.6879719 ],
   

### Variable selection networks

$$
\sum_{j=1}^{m_{\chi}} \upsilon_{\chi_t}^{(j)} \tilde{\xi}_t^{(j)},
$$ {#eq-VSN}

with $j$ indexing the input variable, $\upsilon_{\chi_t}^{(j)}$ standing for variable $j$'s selection weight, $m$ being the number of features and $\tilde{\xi}_t^{(j)}$ defined as:

$$
\tilde{\xi}_t^{(j)} = \text{GRN}(\xi_t^{(j)}).
$$ {#eq-embed}

* In the paper, they are represented in the bottom right of Fig. 2
* Note there are separate variable selection networks for different input groups:
  * `static_variable_selection`
    * does not have static context as input, as it already contains static information
  * `temporal_variable_selection`
    * used for both historical and known future inputs
    * includes static contexts

In [80]:
def static_variable_selection( 
    embedding, # Embedded static inputs, $\xi_t$
    hidden_layer_size, # Dimension of the GRN
    dropout_rate # Dropout rate
)->(keras.KerasTensor, keras.KerasTensor): # $\tilde{\xi}_t$ with dimension (batch_size, ), $\upsilon_{\chi t}$ with dimension (batch_size, num_static_vars, 1)
    "Filter contribution of different static variables"

    # Add temporal features
    _, num_static, _ = embedding.get_shape().as_list() # (embeddings are $\xi_t^(1, \dots, \m_{\chi})$)
    flattened = layers.Flatten()(embedding) # $\Xi_t$, with dimensions (batch_size, num_entities)

    # Nonlinear transformation with the GRN
    mlp_outputs = gated_residual_network(
        x=flattened, # Network inputs
        hidden_layer_size=hidden_layer_size, # Dimension of the GRN
        output_size=num_static, # Size of output layer (if None, same as `hidden_layer_size`)
        dropout_rate=dropout_rate, # Dropout rate
        use_time_distributed=False, # Apply the GRN across all time steps?
        additional_context=None, # Additional context vector to use if relevant
    ) 
    sparse_weights = layers.Activation('softmax')(mlp_outputs)
    sparse_weights = keras.ops.expand_dims(sparse_weights, axis=-1) # $\upsilon_{\chi t}$
    # it's the sparse weights above that determine how much each variable will be influencing the model

    transformed_embeddings = []
    for i in range(num_static):
        transformed_embeddings.append(gated_residual_network(
            x=embedding[:, i:i+1, :], # Selects instances of the same static variable across all batches and individuals
            hidden_layer_size=hidden_layer_size, # Dimension of the GRN
            output_size=hidden_layer_size, # Size of output layer (if None, same as `hidden_layer_size`)
            dropout_rate=dropout_rate, # Dropout rate
            use_time_distributed=False, # Does not make sense to apply the GRN across all time steps because static variables do not have a time dimension
        ))
    transformed_embedding = keras.ops.concatenate(transformed_embeddings, axis=1) # $\tilde{\xi_t^(1, \dots, \m_{\chi})}$

    combined = layers.Multiply()(
        [sparse_weights, transformed_embedding]
    )
    
    static_vec = keras.ops.sum(combined, axis=1)

    return static_vec, sparse_weights

In [81]:
#| output: asis
#| echo: false

show_doc(static_variable_selection, title_level=4)

---

#### static_variable_selection

>      static_variable_selection (embedding, hidden_layer_size, dropout_rate)

Filter contribution of different static variables

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| embedding |  | Embedded static inputs, $\xi_t$ |
| hidden_layer_size |  | Dimension of the GRN |
| dropout_rate |  | Dropout rate |
| **Returns** | **(keras.KerasTensor, keras.KerasTensor)** | **$\tilde{\xi}_t$ with dimension (batch_size, ), $\upsilon_{\chi t}$ with dimension (batch_size, num_static_vars, 1)** |

In [82]:
def temporal_variable_selection(#__lstm_combine_and_mask(
        embedding, # Embedded temporal inputs, $\xi_t^(v), v \in (1, \dots, m_\chi)$
        context, # Static context variable selection, $c_s$
        hidden_layer_size, # Dimension of the GRN, $d_{\text{model}}$
        dropout_rate # Dropout rate
    )->(keras.KerasTensor, keras.KerasTensor, ): # $\tilde{\xi}_t$ with dimension (batch_size, ), $\upsilon_{\chi t}$ with dimension (batch_size, num_static_vars, 1):
        "Filter contribution of different temporal variables"

        # Add temporal features
        _, time_steps, embedding_dim, num_inputs = embedding.get_shape().as_list()

        flattened = keras.ops.reshape(
            embedding,
            [1, time_steps, embedding_dim * num_inputs]
        )
        expanded_static_context_c_s = keras.ops.expand_dims(
            context,
            axis=1
        )

        # Variable selection weights \upsilon
        mlp_outputs, upsilon = gated_residual_network(
            x=flattened,
            hidden_layer=hidden_layer_size,
            output_size=num_inputs,
            dropout_rate=dropout_rate,
            use_time_distributed=True,
            additional_context=expanded_static_context_c_s,
            return_gate=True
        )
        sparse_weights = keras.layers.Activation('softmax')(mlp_outputs)
        sparse_weights = keras.ops.expand_dims(sparse_weights, axis=2)

        # Nonlinear processing and application of weights
        transformed_embeddings = []
        for i in range(num_inputs):
            transformed_embeddings.append(
                gated_residual_network(
                    embedding[Ellipsis, i],
                    hidde_layer=hidden_layer_size,
                    dropout_rate=dropout_rate,
                    use_time_distributed=True
                )
            )
        transformed_embeddings = keras.ops.stack(transformed_embeddings, axis=-1)

        combined = layers.Multiply()([
            sparse_weights, transformed_embeddings
        ])
        temporal_ctx = keras.ops.sum(combined, axis=-1)

        return temporal_ctx, sparse_weights, upsilon

In [None]:
#| output: asis
#| echo: false

show_doc(temporal_variable_selection, title_level=4)

#### Example usage

In [86]:
#| code-fold: show

batch_size = 3
n_timesteps = 5
n_static_vars = 8
n_temporal_vars = 11
n_indiv = 7
hidden_layer_size = 16

# input dimensions: batches / timesteps / features
static_vars = np.random.randn(batch_size*n_static_vars*n_indiv).reshape([batch_size, n_static_vars, n_indiv]) 

# transformed variables (\xi_t^{(1)}, \xi_t^{(2)})
xi1 = linear_layer(
        size=hidden_layer_size,
        activation=None,
        use_time_distributed=True
        )(static_vars)

In [88]:
#| code-fold: show

static_selected_vars, static_selection_weights = static_variable_selection(xi1, hidden_layer_size=hidden_layer_size, dropout_rate=0.)

static_selected_vars.shape, static_selection_weights.shape

(TensorShape([3, 16]), TensorShape([3, 8, 1]))

### Attention components

* Attention mechanisms use relationships between keys $K \in \mathbf{R}^{N \times d_{attention}}$ and queries $Q \in \mathbf{R}^{N \times d_{attention}}$ to scale a vector of values $V \in \mathbf{R}^{N \times d_V}$: $\text{Attention}(Q, K, V) = A(Q, K) V$
    * $N$ is the number of timesteps going into the attention layer (the number of lags plus the number of periods to be forecasted)
    * $A(\cdot)$ is a normalisation function
        * After @vaswani2017attention, the canonical choice for $A(\cdot)$ is the scaled dot-product: $A(Q, K) = \text{Softmax}(\frac{Q K^{T}}{\sqrt{d_{attention}}} )$
    
* The TFT uses a modified attention head to enhance the explainability of the model
* Specifically, the transformer block (multi-head attention) is modified to:
    * share values in each head, and
    * employ additive aggregation of all heads
* More formally, compare the interpretable multi-head attention (used in this paper) with the canonical multi-head attention:
    * $\text{InterpretableMultiHead}(Q, K, V) = \tilde{H} W_{H}$, with:
        * $\begin{aligned}\tilde{H} &= \tilde{A}(Q, K) V W_V \\
        &= \{\frac{1}{m_H} \sum^{m_{H}}_{h=1} A(Q W^{(h)}_Q, K W^{(h)}_K) \} V W_V \\
        &= \frac{1}{m_H} \sum^{m_{H}}_{h=1} \text{Attention}(Q W^{(h)}_Q, K W^{(h)}_K, V W_V)
        \end{aligned}$
    * $\text{MultiHead}(Q, K, V) = [H_1, \dots, H_{m_H}] W_H$, with:
        * $H_h = \text{Attention}(Q W^{(h)}_Q, K W^{(h)}_K, V W_V^{(h)}) $

### Decoder mask for self-attention layer

In [89]:
def get_decoder_mask(
    self_attention_inputs # Inputs to the self-attention layer
):
    "Determines shape of decoder mask"
    len_s = keras.ops.shape(self_attention_inputs)[1] # length of inputs
    bs = keras.ops.shape(self_attention_inputs)[0] # batch shape
    mask = keras.ops.cumsum(keras.ops.eye(len_s), 1) #keras.backend.cumsum(np.eye(len_s, bs))

    ### warning: I had to manually implement some batch-wise shape here 
    ### because the new keras `eye` function does not have a batch_size arg.
    ### inspired by: https://github.com/tensorflow/tensorflow/blob/v2.14.0/tensorflow/python/ops/linalg_ops_impl.py#L30
    ### <hack>
    mask = keras.ops.expand_dims(mask, axis=0)    
    mask = keras.ops.tile(mask, (bs, 1, 1))
    ### </hack>

    return mask


In [90]:
#| output: asis
#| echo: false

show_doc(get_decoder_mask, title_level=4)

---

#### get_decoder_mask

>      get_decoder_mask (self_attention_inputs)

Determines shape of decoder mask

|    | **Details** |
| -- | ----------- |
| self_attention_inputs | Inputs to the self-attention layer |

#### Example usage

In [91]:
#| code-fold: show

dec = get_decoder_mask(grn)

assert dec.shape == (batch_size, n_timesteps, n_timesteps)

Note that it produces an upper-triangular matrix of ones:

In [92]:
#| code-fold: show

dec[0]

<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[1., 1., 1., 1., 1.],
       [0., 1., 1., 1., 1.],
       [0., 0., 1., 1., 1.],
       [0., 0., 0., 1., 1.],
       [0., 0., 0., 0., 1.]], dtype=float32)>

### Scaled dot product attention layer

* This is the same as Eq. (1) of @vaswani2017attention 
    * except that in this case the dimension of the value vector is the same $d_{\text{model}}$ as for the query and key vectors
* As discussed in the paper, additive attention outperforms dot product attention for larger $d_{\text{model}}$ values, so the attention is scaled back to smaller values

In [93]:
class ScaledDotProductAttention():
    def __init__(
        self,
        training:bool=True, # Whether the layer is being trained or used in inference
        attention_dropout:float=0.0 # Will be ignored if `training=False`
    ):
        self.training = training
        self.dropout = keras.layers.Dropout(rate=attention_dropout)
        self.activation = keras.layers.Activation('softmax')

    def __call__(
        self,
        q, # Queries, tensor of shape (?, time, D_model)
        k, # Keys, tensor of shape (?, time, D_model)
        v, # Values, tensor of shape (?, time, D_model)
        mask # Masking if required (sets Softmax to very large value), tensor of shape (?, time, time)
    ):
        # returns Tuple (layer outputs, attention weights)
        scale = keras.ops.sqrt(keras.ops.cast(keras.ops.shape(k)[-1], dtype='float32'))
        attention = keras.ops.einsum("bij,bjk->bik", q, keras.ops.transpose(k, axes=(0, 2, 1))) / scale
        if mask is not None:
            mmask = keras.layers.Lambda(lambda x: (-1e9) * (1. - keras.ops.cast(x, 'float32')))(mask)
            attention = keras.layers.Add()([attention, mmask])
        attention = self.activation(attention)
        if self.training:
            attention = self.dropout(attention)
        output = keras.ops.einsum("btt,btd->bt", attention, v)
        return output, attention

In [94]:
#| output: asis
#| echo: false

show_doc(ScaledDotProductAttention, title_level=4)

---

#### ScaledDotProductAttention

>      ScaledDotProductAttention (training:bool=True,
>                                 attention_dropout:float=0.0)

Initialize self.  See help(type(self)) for accurate signature.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| training | bool | True | Whether the layer is being trained or used in inference |
| attention_dropout | float | 0.0 | Will be ignored if `training=False` |

#### Example usage
Below is an example of how the `ScaledDotProductAttention` layer works:

In [95]:
#| code-fold: show

batch_size = 3
n_timesteps = 5
n_features = 13

# input dimensions: batches / timesteps / features
x_btf = np.random.randn(batch_size*n_timesteps*n_features).reshape([batch_size, n_timesteps, n_features]) 

# using the same vector for q, k and v just to simplify
q=keras.ops.cast(x_btf, 'float32')
k=keras.ops.cast(x_btf, 'float32')
v=keras.ops.cast(x_btf, 'float32')

Testing without masking:

In [96]:
#| code-fold: show

output, attention = ScaledDotProductAttention()(q=q, k=k, v=v, mask=None)
output, attention # both have shape (batch_size, n_timesteps)

(<tf.Tensor: shape=(3, 5), dtype=float32, numpy=
 array([[ 1.0461457 ,  1.0315393 , -0.34412605, -0.45953444,  2.8253996 ],
        [-1.744076  ,  4.6008143 , -1.4990771 , -5.1597986 ,  3.4189878 ],
        [-1.9091827 ,  0.29079467, -0.11268348,  0.05857382,  1.3043103 ]],
       dtype=float32)>,
 <tf.Tensor: shape=(3, 5, 5), dtype=float32, numpy=
 array([[[3.86870772e-01, 1.28998399e-01, 2.67757922e-01, 7.61292577e-02,
          1.40243649e-01],
         [3.64807099e-02, 7.74610400e-01, 3.52141671e-02, 1.08375907e-01,
          4.53188010e-02],
         [8.30628444e-03, 3.86280683e-03, 9.78443503e-01, 2.39209924e-03,
          6.99525373e-03],
         [1.49935763e-02, 7.54757524e-02, 1.51868779e-02, 8.91941488e-01,
          2.40225298e-03],
         [1.44322505e-02, 1.64911319e-02, 2.32054479e-02, 1.25520967e-03,
          9.44615960e-01]],
 
        [[8.43910575e-01, 3.76040526e-02, 2.52377056e-02, 5.36176935e-02,
          3.96299511e-02],
         [1.51728827e-03, 9.95497584e-01

... and with masking:

In [97]:
#| code-fold: show

output, attention = ScaledDotProductAttention()(q=q, k=k, v=v, mask=get_decoder_mask(q))
output, attention # both have shape (batch_size, n_timesteps)

(<tf.Tensor: shape=(3, 5), dtype=float32, numpy=
 array([[ 1.0461457 ,  1.0705955 , -0.3483654 , -0.51382303,  2.9910564 ],
        [-1.744076  ,  4.607806  , -1.6109525 , -5.775865  ,  3.6350102 ],
        [-1.9091827 ,  0.29375416, -0.134687  ,  0.06258623,  1.6628706 ]],
       dtype=float32)>,
 <tf.Tensor: shape=(3, 5, 5), dtype=float32, numpy=
 array([[[3.86870772e-01, 1.28998399e-01, 2.67757922e-01, 7.61292577e-02,
          1.40243649e-01],
         [0.00000000e+00, 8.03938687e-01, 3.65474448e-02, 1.12479225e-01,
          4.70346585e-02],
         [0.00000000e+00, 0.00000000e+00, 9.90497112e-01, 2.42156792e-03,
          7.08142901e-03],
         [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 9.97313917e-01,
          2.68605119e-03],
         [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
          1.00000000e+00]],
 
        [[8.43910575e-01, 3.76040526e-02, 2.52377056e-02, 5.36176935e-02,
          3.96299511e-02],
         [0.00000000e+00, 9.97010350e-01

### Softmax

A small detour to illustrate the softmax function. 

The $i^{\text{th}}$ element of $\text{Softmax}(x)$, with $x \in \mathbf{R}^K$ is:

$$
\text{Softmax}(x)_i = \frac{e^{x_i}}{\sum_{j=1}^K e^{x_j}}
$$

For example, see the values below for an input vector $x$ ($K=5$ in this example):

In [98]:
#| code-fold: show

x = np.array([-np.Inf, -1., 0., 1., 3.])
keras.layers.Activation('softmax')(x)
print("x = ", x)
print("exp(x) = ", np.exp(x))
print("denominator (sum of exp(x_j), j=1,...,K) = ", sum(np.exp(x)))
print("softmax(x) = ", np.exp(x) / sum(np.exp(x)))
print("sum of softmax(x)_j, j=1,...,K = ", sum(np.exp(x) / sum(np.exp(x))))

x =  [-inf  -1.   0.   1.   3.]
exp(x) =  [ 0.          0.36787944  1.          2.71828183 20.08553692]
denominator (sum of exp(x_j), j=1,...,K) =  24.171698192818155
softmax(x) =  [0.         0.01521943 0.0413707  0.11245721 0.83095266]
sum of softmax(x)_j, j=1,...,K =  1.0


As can be seen above, the softmax function really makes the largest numbers stand out from the rest.

Note also that $-\infty$ results in 0.

### Interpretable Multi-head attention

* When values are shared in each head and then are aggregated additively, each head still lcan learn different temporal patterns (from their own unique queries and keys), but with the same input values.
    * In other words, they can be interpreted as an ensemble over the attention weights
    * the paper doesn't mention this explicitly, but the ensemble is equally-weighted - maybe there is some performance to be gained by having some way to weight the different attention heads 🤔, such as having a linear layer combining them... will explore in the future

In [99]:
class InterpretableMultiHeadAttention():
    def __init__(
        self,
        n_head:int,
        d_model:int,
        training:bool=True, # Whether the layer is being trained or used in inference
        dropout:float=0.0 # Will be ignored if `training=False`
    ):
        self.n_head = n_head
        self.d_k = self.d_v = d_k = d_v = d_model // n_head # the original model divides by number of heads
        self.training = training
        self.dropout = dropout

        # using the same value layer facilitates interpretability
        vs_layer = keras.layers.Dense(d_v, use_bias=False, name="Shared value")

        # creates list of queries, keys and values across heads
        self.qs_layers = self._build_layers(d_k, n_head)
        self.ks_layers = self._build_layers(d_k, n_head)
        self.vs_layers = [vs_layer for _ in range(n_head)]

        self.attention = ScaledDotProductAttention()
        self.w_o = keras.layers.Dense(d_v, use_bias=False, name="W_v") # W_v in Eqs. (14)-(16), output weight matrix to project internal state to the original TFT

    def __call__(
        self,
        q, # Queries, tensor of shape (?, time, D_model)
        k, # Keys, tensor of shape (?, time, D_model)
        v, # Values, tensor of shape (?, time, D_model)
        mask=None # Masking if required (sets Softmax to very large value), tensor of shape (?, time, time)
    ):
        heads = []
        attns = []
        for i in range(self.n_head):
            qs = self.qs_layers[i](q)
            ks = self.ks_layers[i](q)
            vs = self.vs_layers[i](v)
           
            head, attn = self.attention(qs, ks, vs, mask)
            if self.training:
                head = keras.layers.Dropout(self.dropout)(head)
            heads.append(head)
            attns.append(attn)

        outputs = keras.ops.mean(heads, axis=0) if self.n_head > 1 else head # H_tilde
        outputs = self.w_o(outputs)
        if self.training:
            outputs = keras.layers.Dropout(self.dropout)(outputs)

        return outputs, attn

    def _build_layers(self, d:int, n_head:int):
            return [keras.layers.Dense(d) for _ in range(n_head)]

In [100]:
#| output: asis
#| echo: false

show_doc(InterpretableMultiHeadAttention, title_level=4)

---

#### InterpretableMultiHeadAttention

>      InterpretableMultiHeadAttention (n_head:int, d_model:int,
>                                       training:bool=True, dropout:float=0.0)

Initialize self.  See help(type(self)) for accurate signature.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| n_head | int |  |  |
| d_model | int |  |  |
| training | bool | True | Whether the layer is being trained or used in inference |
| dropout | float | 0.0 | Will be ignored if `training=False` |

#### Example usage

In [101]:
#| code-fold: show

imha = InterpretableMultiHeadAttention(n_head=8, d_model=16)

In [102]:
#| code-fold: show

grn.shape # B, T, F

TensorShape([3, 5, 17])

In [103]:
#| code-fold: show

mask = get_decoder_mask(grn)
mask.shape # B, T, T

TensorShape([3, 5, 5])

In [104]:
#| code-fold: show

imha(grn, grn, grn, mask)

(<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
 array([[ 0.53399706, -0.34997872],
        [-1.0885253 ,  1.4489189 ],
        [ 1.0975373 , -1.2146299 ]], dtype=float32)>,
 <tf.Tensor: shape=(3, 5, 5), dtype=float32, numpy=
 array([[[2.15666671e-03, 9.89674389e-01, 8.07477813e-03, 9.13843178e-05,
          2.73770206e-06],
         [0.00000000e+00, 6.24240518e-01, 2.27672309e-01, 9.62720364e-02,
          5.18151261e-02],
         [0.00000000e+00, 0.00000000e+00, 7.74442792e-01, 1.72618762e-01,
          5.29384129e-02],
         [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 5.69353878e-01,
          4.30646122e-01],
         [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
          1.00000000e+00]],
 
        [[3.52688372e-01, 2.05759197e-01, 1.31885797e-01, 1.62114590e-01,
          1.47552103e-01],
         [0.00000000e+00, 1.14922486e-01, 1.40755579e-01, 1.44064263e-01,
          6.00257635e-01],
         [0.00000000e+00, 0.00000000e+00, 3.83377731e-01, 2.123

## Putting it all together: TFT

In [None]:
class TemporalFusionTransformer():
    def __init__(
        self,
        # Data params
        time_steps:int,
        input_size:int,
        output_size:int,
        category_counts:int,
        n_workers:int, # Number of multiprocessing workers

        # TFT params
        input_obs_loc,
        static_input_loc,
        known_regular_input_idx,
        known_categorical_input_idx,
        column_definition,

        # Network params
        quantile:list=[0.1, 0.5, 0.9], # List of quantiles the model should forecast
        hidden_layer_size:int=30, # Size of hidden layer
        dropout_rate:float=0.0, # Dropout ratio (between 0.0, inclusive, and less than 1.0)
        num_encoder_steps:int=4,
        num_stacks:int=4,
        num_heads:int=4,
        
        # Training params
        max_gradient_norm:float=1.0, # 
        learning_rate:float=0.001,
        minibatch_size:int=64,
        num_epochs:int=100,
        early_stopping_patience:int=5,
        use_gpu:bool=True
    ):
        self.time_steps = time_steps
        self.input_size = input_size
        self.output_size = output_size # Number of periods to be forecasted
        self.category_counts = category_counts
        self.n_workers = n_workers # Number of multiprocessing workers
        
        self.input_obs_loc = input_obs_loc
        self.static_input_loc = static_input_loc
        self.known_regular_input_idx = known_regular_input_idx
        self.known_categorical_input_idx = known_categorical_input_idx
        self.column_definition = column_definition

        self.quantile = quantile # List of quantiles the model should forecast
        self.hidden_layer_size = hidden_layer_size # Size of hidden layer
        self.dropout_rate = dropout_rate # Dropout ratio (between 0.0, inclusive, and less than 1.0)
        self.num_encoder_steps = num_encoder_steps
        self.num_stacks = num_stacks
        self.num_heads = num_heads
        
        self.max_gradient_norm = max_gradient_norm
        self.learning_rate = learning_rate
        self.minibatch_size = minibatch_size
        self.num_epochs = num_epochs
        self.early_stopping_patience = early_stopping_patience
        self.use_gpu = use_gpu

        self.model = self.build_model()

    def __get_tft_embeddings(
        self,
        all_inputs # Input tensor of dimensions (batch, time steps, num variables)
    ):
        # Transform raw inputs to embeddings
        # For continuous variables: linear transformation
        # For categorical variables: embeddings
        
        num_categorical_variables = len(self.category_counts)
        num_regular_variables = self.input_size - num_categorical_variables

        embedding_sizes = [
            self.hidden_layer_size
            for i, size in enumerate(self.category_counts)
        ]

        embeddings = [
            keras.Sequential([
                layers.InputLayer([self.time_steps]),
                layers.Embedding(
                    self.category_counts[i],
                    embedding_sizes[i],
                    input_length=self.time_steps,
                    dtype='float32'
                )
            ])
            for i in range(num_categorical_variables)
        ]

        regular_inputs, categorical_inputs = \
            all_inputs[:, :, :num_regular_variables], \
            all_inputs[:, :, num_regular_variables:]

        embedded_inputs = [
            embeddings[i](categorical_inputs[Ellipsis, i])
            for i in range(num_categorical_variables)
        ]

        # static inputs
        if self._static_input_loc:
            st_inp_dense = [
                layers.Dense(self.hidden_layer_size)(
                    regular_inputs[:, 0, i:i + 1]
                )
                for i in range(num_regular_variables)
                if i in self._static_input_loc
            ]
            st_inp_embed = [
                embedded_inputs[i][:, 0, :]
                for i in range(num_categorical_variables)
                if  i + num_regular_variables in self._static_input_loc
            ]
            static_inputs = st_inp_dense + st_inp_embed
        else:
            static_inputs = None

        # Targets
        past_inputs = keras.ops.stack([
            linear_layer(
                size=self.hidden_layer_size,
                activation=None,
                use_time_distributed=True
            )(regular_inputs[Ellipsis, i:i + 1])
        ], axis=-1)

        # past inputs: observed but not known a priori
        wired_embeddings = [
            embeddings[i](categorical_inputs[:,:,i])
            for i in range(num_categorical_variables)
            if i not in self._known_categorical_input_idx \
                and i + num_regular_variables not in self._input_obs_loc    
        ]
        unknown_inputs = [
            linear_layer(
                size=self.hidden_layer_size,
                activation=None,
                use_time_distributed=True
            )(regular_inputs[Ellipsis, i:i + 1])
            for i in range(regular_inputs.shape[-1])
            if i not in self._known_categorical_input_idx \
                and i + num_regular_variables not in self._input_obs_loc    
        ]
        if wired_embeddings + unknown_inputs:
            unknown_inputs = keras.ops.stack(wired_embeddings + unknown_inputs, axis=-1)
        else:
            unkown_inputs = None

        # a priori known inputs
        known_regular_inputs = [
            linear_layer(
                size=self.hidden_layer_size,
                activation=None,
                use_time_distributed=True
            )(regular_inputs[Ellipsis, i:i + 1])
            for i in self._known_regular_input_idx
            if i not in self._static_input_loc
        ]
        known_categorical_inputs = [
            embedded_inputs[i]
            for i in self._known_categorical_input_idx
            if i + num_regular_variables not in self._static_input_loc
        ]
        known_combined_layer = keras.ops.stack(
            known_regular_inputs + known_categorical_inputs,
            axis=-1
        )

        return unknown_inputs, known_combined_layer, past_inputs, static_inputs

    def _build_base_graph(self):
        # Build the graph, defining the layers of the TFT
        

        ### <TFTInputs>
        all_inputs = layers.Input(
            shape=(self.time_steps, self.input_size) # Argument `shape` does not include batch size
        )
        unknown_inputs, known_combined_layer, past_inputs, static_inputs \
            = self.__get_tft_embeddings(all_inputs)
        ### </TFTInputs>

        # first we isolate the known future inputs and observed past inputs
        if unknown_inputs is not None:
            historical_inputs = keras.ops.concatenate([
                unknown_inputs[:, :self.num_encoder_steps, :],
                known_combined_layer[:, :self.num_encoder_steps, :],
                past_inputs[:, :self.num_encoder_steps, :]
            ], axis=1)
        else:
            historical_inputs = keras.ops.concatenate([
                known_combined_layer[:, :self.num_encoder_steps, :],
                past_inputs[:, :self.num_encoder_steps, :]
            ])
        
        # and then we isolate the known future inputs
        future_inputs = known_combined_layer[:, :self.num_encoder_steps, :]

        # static vars
        static_encoder, static_weights = static_variable_selection(static_inputs)

        # Static covariate encoders
        # These integrate static features into the network through encoding of context vectors
        # that condition the time-varying dynamics
        self.static_context_variable_selection = gated_residual_network( # c_s
            x=static_encoder, # Network inputs
            hidden_layer_size=self.hidden_layer_size, # Dimension of the GRN
            output_size=self.hidden_layer_size, # Size of output layer (if None, same as `hidden_layer_size`)
            dropout_rate=self.dropout_rate, # Dropout rate
            use_time_distributed=False, # Apply the GRN across all time steps?
        )
        self.static_context_enrichment = gated_residual_network( # c_3
            x=static_encoder, # Network inputs
            hidden_layer_size=self.hidden_layer_size, # Dimension of the GRN
            output_size=self.hidden_layer_size, # Size of output layer (if None, same as `hidden_layer_size`)
            dropout_rate=self.dropout_rate, # Dropout rate
            use_time_distributed=False, # Apply the GRN across all time steps?
        )
        self.static_context_state_h = gated_residual_network( # c_h
            x=static_encoder, # Network inputs
            hidden_layer_size=self.hidden_layer_size, # Dimension of the GRN
            output_size=self.hidden_layer_size, # Size of output layer (if None, same as `hidden_layer_size`)
            dropout_rate=self.dropout_rate, # Dropout rate
            use_time_distributed=False, # Apply the GRN across all time steps?
        )
        self.static_context_state_c = gated_residual_network( # c_c
            x=static_encoder, # Network inputs
            hidden_layer_size=self.hidden_layer_size, # Dimension of the GRN
            output_size=self.hidden_layer_size, # Size of output layer (if None, same as `hidden_layer_size`)
            dropout_rate=self.dropout_rate, # Dropout rate
            use_time_distributed=False, # Apply the GRN across all time steps?
        )

        historical_features, historical_flags, _ = temporal_variable_selection(
            embedding=historical_inputs,
            context=self.static_context_variable_selection,
            hidden_layer_size=self.hidden_layer_size,
            dropout_rate=self.dropout_rate)
        future_features, future_flags, _ = temporal_variable_selection(
            embedding=future_inputs,
            context=self.static_context_variable_selection,
            hidden_layer_size=self.hidden_layer_size,
            dropout_rate=self.dropout_rate)

        # Locality enhancement (Section 4.5.1 in paper) with seq-to-seq layer

        # LSTM layers: LSTM Encoder for encoding past inputs
        history_lstm, state_h, state_c = layers.LSTM(
            units=self.hidden_layer_size,
            return_sequences=True,
            return_state=True,
            stateful=False,
            activation='tanh',
            recurrent_activation='sigmoid',
            recurrent_dropout=0,
            unroll=False,
            use_bias=True)(
                inputs=historical_features,
                initial_state=[
                    self.static_context_state_h, # short-term state
                    self.static_context_state_c  # long-term state
                ]
            )

        # LSTM layers: LSTM Decoder for decoding future inputs
        future_lstm = layers.LSTM(
            units=self.hidden_layer_size,
            return_sequences=True,
            return_state=False,
            stateful=False,
            activation='tanh',
            recurrent_activation='sigmoid',
            recurrent_dropout=0,
            unroll=False,
            use_bias=True)(
                inputs=future_features,
                initial_stage=[
                    state_h, # short-term state
                    state_c  # long-term state
                ]
            )

        lstm_layer = keras.ops.concatenate([history_lstm, future_lstm], axis=1)

        # Apply gated skip connection (Gate followed by Add & Norm)

        input_embeddings = keras.ops.concatenate(
            [historical_features, future_features],
            axis=1
        )
        lstm_layer, _ = apply_gating_layer(
            x=lstm_layer, # Input tensors (batch first)
            hidden_layer_size=self.hidden_layer_size, # Dimension of the GLU
            dropout_rate=self.dropout_rate, # Dropout rate
            use_time_distributed=True, # Apply the GLU across all time steps?
            activation=None # Activation function
        )
        temporal_feature_layer = add_and_norm([lstm_layer, input_embeddings])

        # Temporal Fusion Decoder (TFT, Purple box in Fig. 2)
        # contains three steps
        # TFT 1st step: Static enrichment
        #   - enhances the temporal features with static metadata (Eq. 18)
        
        expanded_static_context_c_e = keras.ops.expand_dims(
            self.static_context_enrichment,
            axis=1
        )
        enriched, _ = gated_residual_network( # $\theta(t, n) = \text{GRN}_{\theta}(\tilde{\theta}(t, n), c_e)
            x=temporal_feature_layer,
            hidden_layer_size=self.hidden_layer_size,
            output_size=self.hidden_layer_size,
            dropout_rate=self.dropout_rate,
            use_time_distributed=True,
            additional_context=expanded_static_context_c_e,
            return_gate=True
        )

        # TFT 2nd step: Temporal self-attention

        self_attention_layer = InterpretableMultiHeadAttention(
            n_head=self.num_heads,
            d_model=self.hidden_layer_size,
            dropout=self.dropout_rate # Will be ignored if `training=False`
        )
        mask = get_decoder_mask(enriched)
        post_attn, self_attention = self_attention_layer( # $B(t) = \text{IMHA}(\Theta(t), \Theta(t), \Theta(t))$
            q=enriched,
            k=enriched,
            v=enriched,
            mask=mask
        )
        post_attn, _ = apply_gating_layer( # $\text{GLU}_{\delta}(\beta(t, n))$
            x=post_attn, # Input tensors (batch first)
            hidden_layer_size=self.hidden_layer_size, # Dimension of the GLU
            dropout_rate=self.dropout_rate, # Dropout rate
            use_time_distributed=True, # Apply the GLU across all time steps?
            activation=None # Activation function
        )
        post_attn = add_and_norm([post_attn, enriched]) # \delta(t, n) = \text{LayerNorm}(\theta(t, n) + $\text{GLU}_{\delta}(\beta(t, n)))$

        # TFT 3rd step: Position-wise feed-forward
        decoder = gated_residual_network(
            x=post_attn,
            hidden_layer_size=self.hidden_layer_size,
            output_size=self.hidden_layer_size,
            dropout_rate=self.dropout_rate,
            use_time_distributed=True,
            additional_context=None,
            return_gate=False
        )

        # final skip connection
        decoder, _ = apply_gating_layer(
            x=decoder, # Input tensors (batch first)
            hidden_layer_size=self.hidden_layer_size, # Dimension of the GLU
            dropout_rate=self.dropout_rate, # Dropout rate
            activation=None # Activation function
        )
        transformer_layer = add_and_norm([decoder, temporal_feature_layer])

        # the function also returns the attention components
        # for explainability analyses
        attention_components = {
            "temporal_attention_weights": self_attention,
            "variable_selection_weights_static_inputs": static_weights[Ellipsis, 0],
            "variable_selection_weights_past_inputs": historical_flags[Ellipsis, 0, :],
            "variable_selection_weights_future_inputs": future_flags[Ellipsis, 0, :]
        }

        return transformer_layer, all_inputs, attention_components

    def build_model(self):
        # Build model and define training losses

        transformer_layer, all_inputs, self._attention_components = self._build_base_graph()
        outputs = keras.layers.TimeDistributed(
            keras.layers.Dense(self.output_size * len(self.quantiles))
        )(transformer_layer[Ellipsis, self.num_encoder_steps:, :])
        model = keras.Model(inputs=all_inputs, outputs=outputs)

# Using the TFT model

## Data

In this example, we will use a simple inflation panel dataset.

In [54]:
from gingado.utils import list_all_dataflows, load_SDMX_data

In [55]:
dflows = list_all_dataflows()
dflows

2023-10-31 08:54:00,309 pandasdmx.reader.sdmxml - DEBUG: Truncate sub-microsecond time in <Prepared>
2023-10-31 08:54:19,798 pandasdmx.reader.sdmxml - DEBUG: Truncate sub-microsecond time in <Prepared>
2023-10-31 08:54:39,258 pandasdmx.reader.sdmxml - DEBUG: Truncate sub-microsecond time in <Prepared>
2023-10-31 08:54:40,137 pandasdmx.reader.sdmxml - DEBUG: Truncate sub-microsecond time in <Prepared>
2023-10-31 08:54:43,320 pandasdmx.reader.sdmxml - DEBUG: Truncate sub-microsecond time in <Prepared>


ABS_XML  ABORIGINAL_POP_PROJ                 Projected population, Aboriginal and Torres St...
         ABORIGINAL_POP_PROJ_REMOTE          Projected population, Aboriginal and Torres St...
         ABS_ABORIGINAL_POPPROJ_INDREGION    Projected population, Aboriginal and Torres St...
         ABS_ACLD_LFSTATUS                   Australian Census Longitudinal Dataset (ACLD):...
         ABS_ACLD_TENURE                     Australian Census Longitudinal Dataset (ACLD):...
                                                                   ...                        
UNSD     DF_UNData_UNFCC                                                       SDMX_GHG_UNDATA
WB       DF_WITS_Tariff_TRAINS                                WITS - UNCTAD TRAINS Tariff Data
         DF_WITS_TradeStats_Development                             WITS TradeStats Devlopment
         DF_WITS_TradeStats_Tariff                                      WITS TradeStats Tariff
         DF_WITS_TradeStats_Trade                 

## Creating a TFT model

In [None]:
tft = TemporalFusionTransformer(
    time_steps=12,
    input_size=20,
    output_size=4,
    category_counts=5,
    n_workers=2, # Number of multiprocessing workers

    # TFT params
    input_obs_loc=24,
    static_input_loc=24,
    known_regular_input_idx=24,
    known_categorical_input_idx=24,
    column_definition=None,
)

In [None]:
tft.time_steps

# References {.unnumbered}