This notebook describes the temporal fusion transformers [@lim2021temporal] architecture, and ports it over to keras 3 while making some punctual improvements.

The original repository is: https://github.com/google-research/google-research/tree/master/tft.

In [2]:
from __future__ import annotations

import numpy as np
import keras_core as keras
from keras_core import layers
from fastcore import docments
from nbdev.showdoc import show_doc


Using TensorFlow backend


# Architecture

## Concepts

* **time distributed**: 
  * applies same layer to each of the timesteps in the data
    * in other words, a layer with the exact same weights
  * indices:
    * index 0: batch
    * index 1: time
    * indices 2...: data
  * More info: https://www.tensorflow.org/api_docs/python/tf/keras/layers/TimeDistributed

# Gated residual network

## Linear layer

* dedicated implementation to better control use of time distribution

In [3]:
def linear_layer(size:int, # Output size
                 activation:str|callable|None=None, # Activation function
                 use_time_distributed:bool=False, # Apply the layer across all timesteps?
                 use_bias:bool=True # Include bias in the layer?
)->keras.src.layers.core.dense.Dense: # Dense layer
    "Linear layer."

    linear = keras.layers.Dense(size, activation=activation, use_bias=use_bias)
    if use_time_distributed:
        linear = keras.layers.TimeDistributed(linear)
    return linear

## Dense layer

* dedicated implementation to better control use of time distribution

In [4]:
def dense_layer(
    size:int, # Output size
    activation:str|callable|None=None, # Activation function
    use_time_distributed:bool=False, # Apply the layer across all timesteps?
    use_bias:bool=True # Include bias in the layer?
)->keras.src.layers.core.dense.Dense: # Dense layer
    "Dense layer"

    dense = layers.Dense(size, activation=activation, use_bias=use_bias)
    if use_time_distributed:
        dense = layers.TimeDistributed(dense)
    return dense

Example usage of dense layer:

In [5]:
#| code-fold: show

batch_size = 3
n_timesteps = 5
n_features = 100
layer_size = 16

# input dimensions: batches / timesteps / features
x = np.random.randn(batch_size*n_timesteps*n_features).reshape([batch_size, n_timesteps, n_features]) 

# dense layer
dense = dense_layer(size=layer_size, use_time_distributed=True)

# output dimensions: batches / timesteps / layer size
assert dense(x).shape == [batch_size, n_timesteps, layer_size]

Now showing that the time-distributed layer applies the same weights at all timesteps:

In [6]:
#| code-fold: show

x = np.ones((1, n_timesteps, n_features))
timesteps_equal = []
for i in range(n_timesteps-1):
    timesteps_equal.append((np.array_equal(dense(x)[0,0,:], dense(x)[0,i+1,:])))

assert np.all(timesteps_equal)

In [7]:
#| output: asis
show_doc(dense_layer)


---

### dense_layer

>      dense_layer (size:int, activation:Union[str,<built-
>                   infunctioncallable>,NoneType]=None,
>                   use_time_distributed:bool=False, use_bias:bool=True)

Dense layer

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| size | int |  | Output size |
| activation | str \| callable \| None | None | Activation function |
| use_time_distributed | bool | False | Apply the layer across all timesteps? |
| use_bias | bool | True | Include bias in the layer? |
| **Returns** | **keras.src.layers.core.dense.Dense** |  | **Dense layer** |

## Gated linear unit (GLU)

* Introduced by @dauphin2017language
* The GLU is part of the Gated Residual Network (GRN) block
* Using input $\gamma \in \mathbb{R}^{d_{\text{model}}}$ and the subscript $\omega$ to index weights, $\text{GLU}_{\omega}(\gamma) = \sigma(W_{4, \omega} \gamma + b_{4, \omega}) \odot (W_{5, \omega} \gamma + b_{5, \omega})$
* As can be seen above, the result could be very close to zero through the Hadamard multipliciation, which in practice means that the network would not be affected by that data (ie, it would be gated out)
* *"GLUs reduce the vanishing gradient problem for deep architectures by providing a linear path for gradients while retaining non-linear capabilities"*
* *"provide flexibility to suppress any parts of the architecture that are not required for a given dataset"*


In [8]:
#| output: false

def apply_gating_layer(
    x, # Input tensors (batch first)
    hidden_layer_size:int, # Dimension of the GLU
    dropout_rate:float|None=None, # Dropout rate
    use_time_distributed:bool=True, # Apply the GLU across all timesteps?
    activation:str|callable=None # Activation function
): # Tuple of (GLU output tensors, gated_layer)
    "Gated Linear Unit (GLU) layer"
    
    if dropout_rate is not None:
        x = layers.Dropout(dropout_rate)(x)

    activation_layer = layers.Dense(
                hidden_layer_size,
                activation=activation
            )

    gated_layer = layers.Dense(
                hidden_layer_size,
                activation='sigmoid'
            )

    if use_time_distributed:
        activation_layer = layers.TimeDistributed(activation_layer)(x)
        gated_layer = layers.TimeDistributed(gated_layer)(x)
    else:
        activation_layer = activation_layer(x)
        gated_layer = gated_layer(x)

    return layers.Multiply()([activation_layer, gated_layer]), gated_layer

Example usage of GLU:

In [9]:
#| code-fold: show

batch_size = 3
n_timesteps = 5
n_features = 100
hidden_layer_size = 16

# input dimensions: batches / timesteps / features
x = np.random.randn(batch_size*n_timesteps*n_features).reshape([batch_size, n_timesteps, n_features]) 

# output dimensions: batches / timesteps / hidden_layer_size
assert apply_gating_layer(x=x, hidden_layer_size=hidden_layer_size)[0].shape == [batch_size, n_timesteps, hidden_layer_size]

In [10]:
#| output: asis
show_doc(apply_gating_layer)


---

### apply_gating_layer

>      apply_gating_layer (x, hidden_layer_size:int,
>                          dropout_rate:Optional[float]=None,
>                          use_time_distributed:bool=True,
>                          activation:Union[str,<built-
>                          infunctioncallable>]=None)

Gated Linear Unit (GLU) layer

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| x |  |  | Input tensors (batch first) |
| hidden_layer_size | int |  | Dimension of the GLU |
| dropout_rate | float \| None | None | Dropout rate |
| use_time_distributed | bool | True | Apply the GLU across all timesteps? |
| activation | str \| callable | None | Activation function |

## Skip connection

Adds inputs to layer, ie "skip connection", and then implements layer normalisation [@ba2016layer].

In [11]:
def add_and_norm(
    x_list # List of input tensors (of the same dimension) for skip connection
    ):
    "Adds tensors with same dimensions and then normalises layer"
    tmp = layers.Add()(x_list)
    return layers.LayerNormalization()(tmp)


Example usage of skip connections + layer normalization:

In [12]:
#| code-fold: show

batch_size = 3
n_timesteps = 5
n_features = 100

# input dimensions: batches / timesteps / features
x1 = np.random.randn(batch_size*n_timesteps*n_features).reshape([batch_size, n_timesteps, n_features]) 
x2 = np.random.randn(batch_size*n_timesteps*n_features).reshape([batch_size, n_timesteps, n_features]) 

# output dimensions: batches / timesteps / features
x1x2 = add_and_norm(x_list=[x1, x2])
assert x1.shape == x1x2.shape

Mean values (normalised should be around 0):

In [13]:
x1.mean(axis=-1), x2.mean(axis=-1), x1x2.numpy().mean(axis=-1)

(array([[ 0.0878034 ,  0.15579257,  0.02132437, -0.01665834,  0.09834406],
        [-0.08403851, -0.09840068, -0.08497837,  0.08315674,  0.04145419],
        [ 0.03267195, -0.15613629, -0.06508664, -0.07177194, -0.00536668]]),
 array([[ 0.08824437,  0.07657852, -0.11315156, -0.14589392, -0.10032865],
        [ 0.00629648, -0.23405682,  0.13411307, -0.04240066,  0.02123513],
        [-0.17501666, -0.04459926,  0.08562279,  0.02482085, -0.12458415]]),
 array([[-2.80141830e-08,  9.53674295e-09,  1.78813930e-09,
          8.34465030e-09, -5.96046434e-10],
        [ 9.53674295e-09,  1.54972071e-08, -2.38418574e-09,
         -4.76837148e-09, -5.96046457e-09],
        [ 3.03983683e-08,  1.23679635e-08, -1.19209291e-08,
         -3.31550822e-08, -9.83476678e-09]], dtype=float32))

Standard deviation (normalised should be around 1):

In [14]:
x1.std(axis=-1), x2.std(axis=-1), x1x2.numpy().std(axis=-1)

(array([[0.94964034, 1.05664469, 0.92412482, 0.96766552, 1.0508922 ],
        [0.99617178, 1.09495702, 1.04056727, 0.90764699, 1.05762491],
        [0.99494172, 0.9119524 , 1.03807059, 0.94555555, 1.07363401]]),
 array([[1.01199589, 1.0563091 , 1.07754409, 1.01300685, 0.9101479 ],
        [1.09136647, 1.02580908, 0.99869893, 0.98858294, 1.0134959 ],
        [0.99310298, 1.01879404, 1.01631997, 1.12171694, 1.09240828]]),
 array([[0.9997355 , 0.99979895, 0.9997791 , 0.9997775 , 0.99976087],
        [0.9997847 , 0.99976426, 0.99978065, 0.99963933, 0.9997341 ],
        [0.99973655, 0.9997171 , 0.9997928 , 0.9997747 , 0.9998052 ]],
       dtype=float32))

In [15]:
#| output: asis

show_doc(add_and_norm)

---

### add_and_norm

>      add_and_norm (x_list)

Adds tensors with same dimensions and then normalises layer

|    | **Details** |
| -- | ----------- |
| x_list | List of input tensors (of the same dimension) for skip connection |

## Gated residual network (GRN)

* The GRN is a key building block of the TFT
    * Helps keep information only from relevant input variables
    * Also keeps the model as simple as possible by only applying non-linearities when relevant
* $\text{GRN}_{\omega}(a, c)$:
    * *1st step*: $\eta_{2} = \text{ELU}(W_{2, \omega} a + b_{2, \omega} + W_{3, \omega} c)$, (where the additional context $c$ might be zero),
    * *2nd step*: $\eta_{1} = W_{1, \omega} \eta_{2} + b_{1, w}$,
    * *3rd step*: $\text{LayerNorm}(a + \text{GLU}_{\omega}(\eta_{1}))$
* $\text{ELU}(\cdot)$ is the Exponential Linear Unit activation function (@clevert2015fast)
    * Unlike ReLUs, ELUs allow for negative values, which pushes unit activations closer to zero at a lower computation complexity, and producing more accurate results
* $\text{LayerNorm}(\cdot)$ is the layer normalisation (@ba2016layer)

In [16]:
def gated_residual_network(
    x, # Network inputs
    hidden_layer_size:int, # Dimension of the GRN
    output_size:int|None=None, # Size of output layer (if None, same as `hidden_layer_size`)
    dropout_rate:float|None=None, # Dropout rate
    use_time_distributed:bool=True, # Apply the GLU across all timesteps?
    additional_context=None, # Additional context vector to use if relevant
    return_gate:bool=False #Whether to return GLU gate for diagnostic purposes
):
    "Applies the gated residual network (GRN) as defined in the paper"
    
    # Setup skip connection
    if output_size is None:
        output_size = hidden_layer_size
        skip = x
    else:
        linear = keras.layers.Dense(output_size)
        if use_time_distributed:
            linear = keras.layers.TimeDistributed(linear)
        skip = linear(x)

    # 1st step: eta2
    hidden = linear_layer(
        size=hidden_layer_size, # W2
        activation=None,
        use_time_distributed=use_time_distributed,
        use_bias=True # b2
    )(x)

    # "For instances without a context vector, the GRN simply treates the context input as zero - ie, $c = 0$ in Eq. 4"
    if additional_context is not None: # if c is != 0...
        hidden += linear_layer(
            size=hidden_layer_size, # W3
            activation=None,
            use_time_distributed=use_time_distributed,
            use_bias=False # no bias for additional context, since there already is bias from the "main" calculation of eta2
        )(additional_context)

    hidden = keras.layers.Activation('elu')(hidden)

    # 2nd step: eta1
    hidden = linear_layer(
        size=hidden_layer_size, # W1
        activation=None,
        use_time_distributed=use_time_distributed,
        use_bias=True # b1
    )(hidden)

    # 3rd step: concluding the GRN calculation
    gating_layer, gate = apply_gating_layer(
        x=hidden,
        hidden_layer_size=output_size,
        dropout_rate=dropout_rate,
        use_time_distributed=use_time_distributed,
        activation=None
    )

    GRN = add_and_norm([skip, gating_layer])

    if return_gate:
        return GRN, gate
    else:
        return GRN

Example usage of GRN:

In [17]:
#| code-fold: show

batch_size = 3
n_timesteps = 5
n_features = 100
hidden_layer_size = 16
output_size = 17

# input dimensions: batches / timesteps / features
x = np.random.randn(batch_size*n_timesteps*n_features).reshape([batch_size, n_timesteps, n_features]) 

grn = gated_residual_network(
    x=x,
    hidden_layer_size=hidden_layer_size,
    output_size=output_size,
    dropout_rate=0,
    use_time_distributed=True,
    additional_context=None,
    return_gate=False
)

# output dimensions: batches / timesteps / hidden_layer_size
assert grn.shape == [batch_size, n_timesteps, output_size]

In [18]:
#| output: asis

show_doc(gated_residual_network)

---

### gated_residual_network

>      gated_residual_network (x, hidden_layer_size:int,
>                              output_size:Optional[int]=None,
>                              dropout_rate:Optional[float]=None,
>                              use_time_distributed:bool=True,
>                              additional_context=None, return_gate:bool=False)

Applies the gated residual network (GRN) as defined in the paper

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| x |  |  | Network inputs |
| hidden_layer_size | int |  | Dimension of the GRN |
| output_size | int \| None | None | Size of output layer (if None, same as `hidden_layer_size`) |
| dropout_rate | float \| None | None | Dropout rate |
| use_time_distributed | bool | True | Apply the GLU across all timesteps? |
| additional_context | NoneType | None | Additional context vector to use if relevant |
| return_gate | bool | False | Whether to return GLU gate for diagnostic purposes |

# Attention components

* Attention mechanisms use relationships between keys $K \in \mathbf{R}^{N \times d_{attention}}$ and queries $Q \in \mathbf{R}^{N \times d_{attention}}$ to scale a vector of values $V \in \mathbf{R}^{N \times d_V}$: $\text{Attention}(Q, K, V) = A(Q, K) V$
    * $N$ is the number of timesteps going into the attention layer (the number of lags plus the number of periods to be forecasted)
    * $A(\cdot)$ is a normalisation function
        * After @vaswani2017attention, the canonical choice for $A(\cdot)$ is the scaled dot-product: $A(Q, K) = \text{Softmax}(\frac{Q K^{T}}{\sqrt{d_{attention}}} )$
    
* The TFT uses a modified attention head to enhance the explainability of the model
* Specifically, the transformer block (multi-head attention) is modified to:
    * share values in each head, and
    * employ additive aggregation of all heads
* More formally, compare the interpretable multi-head attention (used in this paper) with the canonical multi-head attention:
    * $\text{InterpretableMultiHead}(Q, K, V) = \tilde{H} W_{H}$, with:
        * $\begin{aligned}\tilde{H} &= \tilde{A}(Q, K) V W_V \\
        &= \{\frac{1}{m_H} \sum^{m_{H}}_{h=1} A(Q W^{(h)}_Q, K W^{(h)}_K) \} V W_V \\
        &= \frac{1}{m_H} \sum^{m_{H}}_{h=1} \text{Attention}(Q W^{(h)}_Q, K W^{(h)}_K, V W_V)
        \end{aligned}$
    * $\text{MultiHead}(Q, K, V) = [H_1, \dots, H_{m_H}] W_H$, with:
        * $H_h = \text{Attention}(Q W^{(h)}_Q, K W^{(h)}_K, V W_V^{(h)}) $

## Decoder mask for self-attention layer

In [19]:
def get_decoder_mask(
    self_attention_inputs # Inputs to the self-attention layer
):
    "Determines shape of decoder mask"
    len_s = keras.ops.shape(self_attention_inputs)[1] # length of inputs
    bs = keras.ops.shape(self_attention_inputs)[0] # batch shape
    mask = keras.ops.cumsum(keras.ops.eye(len_s), 1) #keras.backend.cumsum(np.eye(len_s, bs))

    ### warning: I had to manually implement some batch-wise shape here 
    ### because the new keras `eye` function does not have a batch_size arg.
    ### inspired by: https://github.com/tensorflow/tensorflow/blob/v2.14.0/tensorflow/python/ops/linalg_ops_impl.py#L30
    ### <hack>
    mask = keras.ops.expand_dims(mask, axis=0)    
    mask = keras.ops.tile(mask, (bs, 1, 1))
    ### </hack>

    return mask


Example usage of the decoder mask:

In [20]:
dec = get_decoder_mask(grn)

assert dec.shape == (batch_size, n_timesteps, n_timesteps)

Note that it produces an upper-triangular matrix of ones:

In [21]:
dec[0]

<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[1., 1., 1., 1., 1.],
       [0., 1., 1., 1., 1.],
       [0., 0., 1., 1., 1.],
       [0., 0., 0., 1., 1.],
       [0., 0., 0., 0., 1.]], dtype=float32)>

In [22]:
show_doc(get_decoder_mask)

---

### get_decoder_mask

>      get_decoder_mask (self_attention_inputs)

Determines shape of decoder mask

|    | **Details** |
| -- | ----------- |
| self_attention_inputs | Inputs to the self-attention layer |

## Scaled dot product attention layer

* This is the same as Eq. (1) of @vaswani2017attention 
    * except that in this case the dimension of the value vector is the same $d_{\text{model}}$ as for the query and key vectors
* As discussed in the paper, additive attention outperforms dot product attention for larger $d_{\text{model}}$ values, so the attention is scaled back to smaller values

In [69]:
class ScaledDotProductAttention():
    def __init__(
        self,
        training:bool=True, # Whether the layer is being trained or used in inference
        attention_dropout:float=0.0 # Will be ignored if `training=False`
    ):
        self.training = training
        self.dropout = keras.layers.Dropout(rate=attention_dropout)
        self.activation = keras.layers.Activation('softmax')

    def __call__(
        self,
        q, # Queries, tensor of shape (?, time, D_model)
        k, # Keys, tensor of shape (?, time, D_model)
        v, # Values, tensor of shape (?, time, D_model)
        mask # Masking if required (sets Softmax to very large value), tensor of shape (?, time, time)
    ):
        # returns Tuple (layer outputs, attention weights)
        scale = keras.ops.sqrt(keras.ops.cast(keras.ops.shape(k)[-1], dtype='float32'))
        attention = keras.ops.einsum("btd,btd->bt", q, k) / scale
        if mask is not None:
            mmask = keras.layers.Lambda(lambda x: (-1e9) * (1. - keras.ops.cast(x, 'float32')))(mask)
            attention = keras.layers.Add()([attention, mmask])
        attention = self.activation(attention)
        if self.training:
            attention = self.dropout(attention)
        output = keras.ops.einsum("bt,btd->bt", attention, v)
        return output, attention

Below is an example of how the `ScaledDotProductAttention` layer works:

In [61]:
#| code-fold: show

batch_size = 3
n_timesteps = 5
n_features = 13

# input dimensions: batches / timesteps / features
x_btf = np.random.randn(batch_size*n_timesteps*n_features).reshape([batch_size, n_timesteps, n_features]) 

# using the same vector for q, k and v just to simplify
q=keras.ops.cast(x_btf, 'float32')
k=keras.ops.cast(x_btf, 'float32')
v=keras.ops.cast(x_btf, 'float32')

output, attention = ScaledDotProductAttention()(q=q, k=k, v=v, mask=None)
output, attention # both have shape (batch_size, n_timesteps)

attention 1 shape:  (3, 5)
attention 2 shape:  (3, 5)
attention 3 shape:  (3, 5)


(<tf.Tensor: shape=(3, 5), dtype=float32, numpy=
 array([[-0.06796159, -0.14928389,  0.35569158, -0.46120387,  0.30438477],
        [ 0.29471743, -2.9135132 ,  0.03538444,  0.09945958, -0.01185272],
        [ 0.01331887,  0.29213437, -0.04500379,  0.00928352,  0.3467226 ]],
       dtype=float32)>,
 <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
 array([[0.14248843, 0.06488948, 0.14144121, 0.5226409 , 0.12854004],
        [0.17114735, 0.7915294 , 0.01030869, 0.02321733, 0.00379721],
        [0.34609836, 0.46975043, 0.1107455 , 0.00888185, 0.0645239 ]],
       dtype=float32)>)

In [62]:
show_doc(ScaledDotProductAttention)

---

### ScaledDotProductAttention

>      ScaledDotProductAttention (training:bool=True,
>                                 attention_dropout:float=0.0)

Initialize self.  See help(type(self)) for accurate signature.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| training | bool | True | Whether the layer is being trained or used in inference |
| attention_dropout | float | 0.0 | Will be ignored if `training=False` |

In [73]:
activ = keras.layers.Activation('softmax')
attention, activ(attention)

(<tf.Tensor: shape=(3, 5), dtype=float32, numpy=
 array([[0.14248843, 0.06488948, 0.14144121, 0.5226409 , 0.12854004],
        [0.17114735, 0.7915294 , 0.01030869, 0.02321733, 0.00379721],
        [0.34609836, 0.46975043, 0.1107455 , 0.00888185, 0.0645239 ]],
       dtype=float32)>,
 <tf.Tensor: shape=(3, 5), dtype=float32, numpy=
 array([[0.18611768, 0.17222127, 0.18592288, 0.2721985 , 0.18353964],
        [0.18452014, 0.34314075, 0.15710588, 0.15914705, 0.15608622],
        [0.22777675, 0.25775722, 0.18001015, 0.16257665, 0.17187914]],
       dtype=float32)>)

### Softmax

A small detour to illustrate the softmax function. 

The $i^{\text{th}}$ element of $\text{Softmax}(x)$, with $x \in \mathbf{R}^K$ is:

$$
\text{Softmax}(x)_i = \frac{e^{x_i}}{\sum_{j=1}^K e^{x_j}}
$$

For example, see the values below for an input vector $x$ ($K=5$ in this example):

In [63]:
#| code-fold: show

x = np.array([-np.Inf, -1., 0., 1., 3.])
keras.layers.Activation('softmax')(x)
print("x = ", x)
print("exp(x) = ", np.exp(x))
print("denominator (sum of exp(x_j), j=1,...,K) = ", sum(np.exp(x)))
print("softmax(x) = ", np.exp(x) / sum(np.exp(x)))
print("sum of softmax(x)_j, j=1,...,K = ", sum(np.exp(x) / sum(np.exp(x))))

x =  [-inf  -1.   0.   1.   3.]
exp(x) =  [ 0.          0.36787944  1.          2.71828183 20.08553692]
denominator (sum of exp(x_j), j=1,...,K) =  24.171698192818155
softmax(x) =  [0.         0.01521943 0.0413707  0.11245721 0.83095266]
sum of softmax(x)_j, j=1,...,K =  1.0


As can be seen above, the softmax function really makes the largest numbers stand out from the rest.

Note also that $-\infty$ results in 0.

## Interpretable Multi-head attention

* When values are shared in each head and then are aggregated additively, each head still lcan learn different temporal patterns (from their own unique queries and keys), but with the same input values.
    * In other words, they can be interpreted as an ensemble over the attention weights
    * the paper doesn't mention this explicitly, but the ensemble is equally-weighted - maybe there is some performance to be gained by having some way to weight the different attention heads 🤔, such as having a linear layer combining them... will explore in the future

In [64]:
class InterpretableMultiHeadAttention():
    def __init__(
        self,
        n_head:int,
        d_model:int,
        training:bool=True, # Whether the layer is being trained or used in inference
        dropout:float=0.0 # Will be ignored if `training=False`
    ):
        self.n_head = n_head
        self.d_k = self.d_v = d_k = d_v = d_model # // n_head - the original model divides by number of heads
        self.training = training
        self.dropout = dropout

        # using the same value layer facilitates interpretability
        vs_layer = keras.layers.Dense(d_v, use_bias=False)

        # creates list of queries, keys and values across heads
        self.qs_layers = self._build_layers(d_k, n_head)
        self.ks_layers = self._build_layers(d_k, n_head)
        self.vs_layers = [vs_layer for _ in range(n_head)]

        self.attention = ScaledDotProductAttention()
        self.w_o = keras.layers.Dense(d_v, use_bias=False) # W_v in Eqs. (14)-(16), output weight matrix to project internal state to the original TFT

    def __call__(
        self,
        q, # Queries, tensor of shape (?, time, D_model)
        k, # Keys, tensor of shape (?, time, D_model)
        v, # Values, tensor of shape (?, time, D_model)
        mask=None # Masking if required (sets Softmax to very large value), tensor of shape (?, time, time)
    ):
        heads = []
        attns = []
        for i in range(self.n_head):
            qs = self.qs_layers[i](q)
            ks = self.ks_layers[i](q)
            vs = self.vs_layers[i](v)
           
            head, attn = self.attention(qs, ks, vs, mask)
            if self.training:
                head = keras.layers.Dropout(self.dropout)(head)
            heads.append(head)
            attns.append(attn)

        outputs = keras.ops.mean(head, axis=0) if self.n_head > 1 else head # H_tilde
        outputs = self.w_o(outputs)
        if self.training:
            outputs = keras.layers.Dropout(self.dropout)(outputs)

        return outputs, attn

    def _build_layers(self, d:int, n_head:int):
            return [keras.layers.Dense(d) for _ in range(n_head)]

InterpretableMultiHeadAttention(n_head=8, d_model=32)

<__main__.InterpretableMultiHeadAttention at 0x23463a89ca0>

In [65]:
imha = InterpretableMultiHeadAttention(n_head=8, d_model=16)

In [66]:
grn.shape # B, T, F

TensorShape([3, 5, 17])

In [67]:
mask = get_decoder_mask(grn)
mask.shape # B, T, T

TensorShape([3, 5, 5])

In [68]:
imha(grn, grn, grn, mask)

(3, 5, 16) (3, 5, 16) (3, 5, 16)
attention 1 shape:  (3, 5)
attention 2 shape:  (3, 5, 5)
attention 3 shape:  (3, 5, 5)


InvalidArgumentError: {{function_node __wrapped__Einsum_N_2_device_/job:localhost/replica:0/task:0/device:CPU:0}} Expected input 0 to have rank 2 but got: 3 [Op:Einsum] name: 

# References {.unnumbered}