In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Attention

In [2]:
?Attention

[0;31mInit signature:[0m [0mAttention[0m[0;34m([0m[0muse_scale[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Dot-product attention layer, a.k.a. Luong-style attention.

Inputs are `query` tensor of shape `[batch_size, Tq, dim]`, `value` tensor of
shape `[batch_size, Tv, dim]` and `key` tensor of shape
`[batch_size, Tv, dim]`. The calculation follows the steps:

1. Calculate scores with shape `[batch_size, Tq, Tv]` as a `query`-`key` dot
   product: `scores = tf.matmul(query, key, transpose_b=True)`.
2. Use scores to calculate a distribution with shape
   `[batch_size, Tq, Tv]`: `distribution = tf.nn.softmax(scores)`.
3. Use `distribution` to create a linear combination of `value` with
   shape `[batch_size, Tq, dim]`:
   `return tf.matmul(distribution, value)`.

Args:
  use_scale: If `True`, will create a scalar variable to scale the attention
    scores.
  causal: Boolean. Set to `True` f

### Cross-Attention in Keras

In [4]:
a = Attention()

In [28]:
v = tf.constant([
    [
        [1, 2],
        [3, 4],
    ]
], dtype=tf.float32)
v.shape

TensorShape([1, 2, 2])

In [53]:
k = tf.constant([
    [
        [1, 0],
        [0, 1],
    ]
], dtype=tf.float32)
k.shape

TensorShape([1, 2, 2])

In [54]:
q = tf.constant([
    [
        [1, 0],
        [0, 1],
        [1, 1],
        [0, 0]
    ]
], dtype=tf.float32)
q.shape

TensorShape([1, 4, 2])

In [55]:
out = a([q, v, k])
out.shape

TensorShape([1, 4, 2])

In [56]:
out[0].numpy()

array([[1.5378828, 2.5378828],
       [2.462117 , 3.462117 ],
       [2.       , 3.       ],
       [2.       , 3.       ]], dtype=float32)

### Cross-Attention in Tensorflow

In [70]:
similar_logit = [qm @ tf.transpose(km) for qm, km in zip(q, k)]
similar_weight = tf.math.softmax(similar_logit)
out = similar_weight @ v
out.numpy()

array([[[1.5378828, 2.5378828],
        [2.462117 , 3.462117 ],
        [2.       , 3.       ],
        [2.       , 3.       ]]], dtype=float32)

### Self-Attention in Keras

In [75]:
a([k, k, k]).numpy()

array([[[0.73105854, 0.26894143],
        [0.26894143, 0.73105854]]], dtype=float32)

### Self-Attention in Tensorflow

In [76]:
similar_logit = [km @ tf.transpose(km) for km, km in zip(k, k)]
similar_weight = tf.math.softmax(similar_logit)
out = similar_weight @ k
out.numpy()

array([[[0.73105854, 0.26894143],
        [0.26894143, 0.73105854]]], dtype=float32)

### Multi-head Attention

In [80]:
from tensorflow.keras.layers import MultiHeadAttention
?MultiHeadAttention

[0;31mInit signature:[0m
[0mMultiHeadAttention[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mnum_heads[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mkey_dim[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mvalue_dim[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdropout[0m[0;34m=[0m[0;36m0.0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0muse_bias[0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0moutput_shape[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mattention_axes[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mkernel_initializer[0m[0;34m=[0m[0;34m'glorot_uniform'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbias_initializer[0m[0;34m=[0m[0;34m'zeros'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mkernel_regularizer[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbias_regularizer[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    