In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Attention

In [2]:
?Attention

[0;31mInit signature:[0m
[0mAttention[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0muse_scale[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mscore_mode[0m[0;34m=[0m[0;34m'dot'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdropout[0m[0;34m=[0m[0;36m0.0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mseed[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0;34m**[0m[0mkwargs[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Dot-product attention layer, a.k.a. Luong-style attention.

Inputs are a list with 2 or 3 elements:
1. A `query` tensor of shape `(batch_size, Tq, dim)`.
2. A `value` tensor of shape `(batch_size, Tv, dim)`.
3. A optional `key` tensor of shape `(batch_size, Tv, dim)`. If none
    supplied, `value` will be used as a `key`.

The calculation follows the steps:
1. Calculate attention scores using `query` and `key` with shape
    `(batch_size, Tq, Tv)`.
2. Use sc

### Cross-Attention in Keras

In [3]:
a = Attention()

In [4]:
v = tf.constant([
    [
        [1, 2],
        [3, 4],
    ]
], dtype=tf.float32)
v.shape

I0000 00:00:1728170802.137145   86957 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1728170802.190515   86957 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1728170802.191631   86957 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1728170802.194559   86957 cuda_executor.cc:1015] successful NUMA node read from SysFS ha

TensorShape([1, 2, 2])

In [5]:
k = tf.constant([
    [
        [1, 0],
        [0, 1],
    ]
], dtype=tf.float32)
k.shape

TensorShape([1, 2, 2])

In [6]:
q = tf.constant([
    [
        [1, 0],
        [0, 1],
        [1, 1],
        [0, 0]
    ]
], dtype=tf.float32)
q.shape

TensorShape([1, 4, 2])

In [7]:
out = a([q, v, k])
out.shape

TensorShape([1, 4, 2])

In [8]:
out[0].numpy()

array([[1.5378828, 2.5378828],
       [2.462117 , 3.462117 ],
       [2.       , 3.       ],
       [2.       , 3.       ]], dtype=float32)

### Cross-Attention in Tensorflow

In [9]:
similar_logit = [qm @ tf.transpose(km) for qm, km in zip(q, k)]
similar_weight = tf.math.softmax(similar_logit)
out = similar_weight @ v
out.numpy()

array([[[1.5378828, 2.5378828],
        [2.462117 , 3.462117 ],
        [2.       , 3.       ],
        [2.       , 3.       ]]], dtype=float32)

### Self-Attention in Keras

In [10]:
a([k, k, k]).numpy()

array([[[0.73105854, 0.26894143],
        [0.26894143, 0.73105854]]], dtype=float32)

### Self-Attention in Tensorflow

In [11]:
similar_logit = [km @ tf.transpose(km) for km, km in zip(k, k)]
similar_weight = tf.math.softmax(similar_logit)
out = similar_weight @ k
out.numpy()

array([[[0.73105854, 0.26894143],
        [0.26894143, 0.73105854]]], dtype=float32)

### Multi-head Attention

In [12]:
from tensorflow.keras.layers import MultiHeadAttention
?MultiHeadAttention

[0;31mInit signature:[0m
[0mMultiHeadAttention[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mnum_heads[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mkey_dim[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mvalue_dim[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdropout[0m[0;34m=[0m[0;36m0.0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0muse_bias[0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0moutput_shape[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mattention_axes[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mkernel_initializer[0m[0;34m=[0m[0;34m'glorot_uniform'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbias_initializer[0m[0;34m=[0m[0;34m'zeros'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mkernel_regularizer[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbias_regularizer[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    