In [1]:
from typing import *

import jax
import jax.numpy as jnp
import jax.random as random
import flax.linen as nn

from flax.linen.attention import make_causal_mask

In [2]:
def dot_product_attention_weights(q: jnp.ndarray,
                                  k: jnp.ndarray,
                                  mask: Optional[jnp.ndarray] = None,
                                  dtype = jnp.float32):
  """Computes dot-product attention weights given query and key.

  Args:
    q: queries for calculating attention with shape of [num_q,  qk_depth].
    k: keys for calculating attention with shape of      [num_kv, qk_depth].
    mask: mask for the attention weights. This should be of shape [num_q, num_kv].
      This can be used for incorporating causal masks.
      Attention weights are masked out if their corresponding mask value is `False`.
  """
  assert q.ndim == k.ndim == 2, 'q, k must have rank 2.'
  assert q.shape[-1] == k.shape[-1], 'q, k depths must match.'

  depth = q.shape[-1]
  q = q / jnp.sqrt(depth).astype(dtype)
  attn_weights = jnp.einsum('...qd,...kd->...qk', q, k, precision=None)

  # apply attention mask
  if mask is not None:
    big_neg = jnp.finfo(dtype).min  # -3.4028235e+38 for jnp.float32
    attn_weights = jnp.where(mask, attn_weights, big_neg)

  attn_weights = jax.nn.softmax(attn_weights).astype(dtype)

  return attn_weights


In [3]:
def layernorm(x):
    x = x - jnp.mean(x, axis=-1, keepdims=True)
    stddev = jnp.sqrt(jnp.mean(x * x, axis=-1, keepdims=True))
    return x/stddev

class DecoderLayer(nn.Module):
  """Simple implementation for single head decoder layer, the QKV projections 
  are omitted for simplicity."""
  embed_dim: int
  ffn_dim: int
  
  @nn.compact
  def __call__(
    self,
    x: jnp.ndarray,
    attention_mask: jnp.ndarray,
    encoder_kv: jnp.ndarray,
    # encoder_attention_mask: jnp.ndarray,
  ):
    assert x.ndim == 2 and attention_mask.ndim == 1
    assert encoder_kv.ndim == 2 # and encoder_attention_mask.ndim == 1
    assert x.shape[0] == attention_mask.shape[0] and x.shape[1] == embed_dim
    # assert encoder_kv.shape[0] == encoder_attention_mask.shape[0] and encoder_kv.shape[1] == embed_dim

    # 1. self attention part
    residual = x
    q, k, v = x,x,x
    casual_mask = make_causal_mask(attention_mask)[0] == 1.0
    attn_weights = dot_product_attention_weights(q, k, casual_mask)
    x = attn_weights @ v
    # omit a dropout
    x = x + residual
    x = layernorm(x)

    # 2. cross attention part
    residual = x
    q = x
    k, v = encoder_kv, encoder_kv
    attn_weights = dot_product_attention_weights(q, k)
    x = attn_weights @ v
    # omit a dropout
    x = x + residual
    x = layernorm(x)

    # 3. FFN part
    residual = x
    fc1_weight = self.param("fc1", nn.initializers.normal(), (self.embed_dim, self.ffn_dim))
    x = nn.gelu(x @ fc1_weight)
    # omit a dropout
    fc2_weight = self.param("fc2", nn.initializers.normal(), (self.ffn_dim, self.embed_dim))
    x = x @ fc2_weight
    x = x + residual
    x = layernorm(x)

    return x

In [4]:
seq_len = 5
embed_dim = 8
ffn_dim = 64

encoder_kv = random.normal(key=random.PRNGKey(0), shape=(seq_len, embed_dim))
# encoder_attention_mask = jnp.ones((seq_len, ), dtype=jnp.int32)

# input with seq_len, aka, decoding input at t+1
input_feature_2 = random.normal(key=random.PRNGKey(1), shape=(seq_len, embed_dim))
attention_mask_2 = jnp.ones((seq_len,), dtype=jnp.int32)

# input with seq_len-1, aka, decoding input at t
input_feature_1 = input_feature_2[..., :-1, :]
attention_mask_1 = jnp.ones((seq_len-1,), dtype=jnp.int32)

decoder_layer = DecoderLayer(embed_dim=embed_dim, ffn_dim=ffn_dim)
params = decoder_layer.init(
    random.PRNGKey(2),
    x=jnp.zeros_like(input_feature_2),
    attention_mask=jnp.ones_like(attention_mask_2),
    encoder_kv=jnp.zeros_like(input_feature_2),
    # encoder_attention_mask=jnp.ones_like(attention_mask_2),
)



In [5]:
output_1 = decoder_layer.apply(
  params,
  x=input_feature_1,
  attention_mask=attention_mask_1,
  encoder_kv=encoder_kv,
  # encoder_attention_mask=encoder_attention_mask
)
output_2 = decoder_layer.apply(
  params, 
  x=input_feature_2, 
  attention_mask=attention_mask_2, 
  encoder_kv=encoder_kv, 
#   encoder_attention_mask=encoder_attention_mask
)

In [6]:
def pad_to_shape(a, shape):
  shape_to_pad = list(a.shape)
  shape = list(shape)
  pad_width = [(0, s2 - s1) for s1, s2 in zip(shape_to_pad, shape)]
  return jnp.pad(a, pad_width=pad_width)

In [7]:
print("output shape at t  :", output_1.shape)
print("output shape at t+1:", output_2.shape)
print("diff of output t+1 and t:\n", output_2 - pad_to_shape(output_1, output_2.shape))

output shape at t  : (4, 8)
output shape at t+1: (5, 8)
diff of output t+1 and t:
 [[ 0.          0.          0.          0.          0.          0.
   0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.        ]
 [-0.2758741  -1.9854703  -1.0166407   0.7927936   0.82887286  0.33593562
   0.0909228   1.2294604 ]]


As you can notice that for a single decoder layer, the output at decoding step $t$ and $t+1$ is **incremental**

`dot_product_attention_weight` implements the most basic $$W_{\mathrm{attn}} = \mathrm{softmax}(\mathrm{mask} \odot \frac{QK^\top}{\sqrt{d_k}})$$

In the case of **lower triangle** causal mask is presented, the final $W_{\mathrm{attn}}$ will be of the following form
$$
W_{\mathrm{attn}} = 
\begin{pmatrix}
w_{1,1} &         &        &           &   0    \\
w_{2,1} & w_{2,2} &        &           &        \\
w_{3,1} & w_{3,2} & \ddots &           &        \\
\vdots  & \vdots  & \ddots & \ddots    &        \\
w_{n,1} & w_{n,2} & \ldots & w_{n,m-1} &w_{n,m}
\end{pmatrix}
$$
where $n$ is the number of query entries and $m$ is the number of key and value entries. `num_q` and `num_kv` in code, respectively.

In practice, `num_q` and `num_kv` is the number of decoding step $t$.
$$
W_{\mathrm{attn}}^t = 
\begin{pmatrix}
w_{1,1} &        &           &   0    \\
w_{2,1} & \ddots &           &        \\
\vdots  & \ddots & \ddots    &        \\
w_{t,1} & \ldots & w_{t,t-1} &w_{t,t}
\end{pmatrix}
$$
and 
$$
\newcommand{\horzbar}{\rule[.5ex]{2.5ex}{0.5pt}}
\begin{equation}
W_{\mathrm{attn}}^{t+1} = \begin{pmatrix}
W_{\mathrm{attn}}^t             & 0 \\
\horzbar \, w_{t+1, 1..t} \, \horzbar & w_{t+1, t+1}
\end{pmatrix}
\end{equation}
$$



For $V^t$ and $V^{t+1}$, the value at decoding step $t$ and $t+1$, respectively.
$$
\newcommand{\horzbar}{\rule[.5ex]{2.5ex}{0.5pt}}
V^{t+1} = 
\begin{pmatrix}
V^t \\
\horzbar\, v^{t+1}\,\horzbar
\end{pmatrix}
$$

Then the attention output for 
$$
Y^{t+1} = W_{\mathrm{attn}}^{t+1} V^{t+1} = 
\begin{pmatrix}
W_{\mathrm{attn}}^t V^t  \\
w^{t+1} V^{t+1}
\end{pmatrix} \overset{?}{=} \begin{pmatrix}
Y^{t}  \\
w^{t+1} V^{t+1}
\end{pmatrix}
$$

Notice the recursive structure $W_{\mathrm{attn}}^{t+1} V^{t+1}$ and $W_{\mathrm{attn}}^t V^t$, this is where the attention cache comes from. Let's call this form as the **incremental invariant**, the **invariant** for short.

### Why and when self attention can be cached?

To be able to use cache mechanism, the $\overset{?}{=}$ must hold for the whole encoder layer because there are computation before and after the attention.

Fortunately, it holds because the layer normalization and FFN applies for the embedding/depth dimension, that is, they applies to each position separately and identically. So when a new position is appended from $t \rightarrow t+1$, the output positions for $1..t$ will not be affected by the new $v^{t+1}$, because of the 0 column in Eq. 1.

The proceeding discussion shows that when a lower triangle mask is presented, you can leverage the caching mechanism for the self attention in the first decoder layer.

### Why and when cross attention can be cached?

As for the cross attention, the $K$ and $V$ is **static** because they are outputed by the encoder, which will not be updated when deocding step advances from $t \rightarrow t+1$. Since the self attention preserves the **invariant** 
$$
Y^{t+1} = \begin{bmatrix} Y^t \\ y^{t+1} \end{bmatrix} = \mathrm{SelfAttn}(X^{t+1}) = \mathrm{SelfAttn}(\begin{bmatrix} X^t \\ x^{t+1} \end{bmatrix})
$$
which makes the input to the cross attention is the same form as the embedding, a new embedding is appended to the value $V$, a new row is appended to the $Y$ similarly. Then the (NB., $\frac{\cdot}{\sqrt{d_k}}$ is omitted)
$$
\begin{align*}
Z^{t+1} &= \mathrm{CrossAttn}(Y^{t+1}) \\
        &= \mathrm{softmax}(Y^{t+1} K^\top) V \\
        &= \mathrm{softmax}(\begin{bmatrix} Y^t \\ y^{t+1} \end{bmatrix} K^\top) V \\
        &= \begin{bmatrix} \mathrm{softmax}(Y^t K^\top) \\ \mathrm{softmax}(y^{t+1} K^\top)  \end{bmatrix} V \\
        &= \begin{bmatrix} \mathrm{softmax}(Y^t K^\top) V \\ \mathrm{softmax}(y^{t+1} K^\top) V \end{bmatrix} \\
        &= \begin{bmatrix} Z^t V \\ z^{t+1} V \end{bmatrix}
\end{align*}
$$
So the **incremental invariant** is also preserved for cross attention if the mask is not presented. It will also hold when the mask is static.

It can easily show that the FFN part preserve the **invariant**. Finally, the whole (first) decoder layer preserves the **invariant**, this can also be verified by the preceeding program output. Then you can easily see that the **invariant** will also hold for the later decoder layers. That is, all attention calculuation can be cached if Eq. 1 holds. This implies a special form of the mask for the self attention.

### Final note

This shows that the fc1 and fc2 in the `DecoderLayer` implementation can be made of incremental with a proper cache. Then all the GEMM can be reduced to GEMV. That is, we can optimize the decoding process to have a lower order of computational complexity. But it is not seen in the current [transformers decoding](https://github.com/huggingface/transformers) implementation, only attention cache is implemented. 