https://keras.io/api/layers/attention_layers/multi_head_attention/

In [1]:
import tensorflow as tf

In [2]:
from tensorflow import keras

In [43]:
target = tf.keras.Input(
    shape = (6, 16),
    )

source = tf.keras.Input(
    shape = (8, 16),
    )

In [44]:
multi_head_attention_layer = tf.keras.layers.MultiHeadAttention(
    num_heads = 4,
    key_dim = 10,
    )

In [45]:
output_tensor, weights = multi_head_attention_layer(
    target,
    source,
    return_attention_scores = True,
    )

In [47]:
multi_head_attention_layer

<keras.layers.attention.multi_head_attention.MultiHeadAttention at 0x24863409000>

In [48]:
model = tf.keras.Model(
    inputs = [
        target,
        source,
    ],
    outputs = [
        output_tensor,
        weights
    ]
    )

In [49]:
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_5 (InputLayer)           [(None, 6, 16)]      0           []                               
                                                                                                  
 input_6 (InputLayer)           [(None, 8, 16)]      0           []                               
                                                                                                  
 multi_head_attention_1 (MultiH  ((None, 6, 16),     2696        ['input_5[0][0]',                
 eadAttention)                   (None, 4, 6, 8))                 'input_6[0][0]']                
                                                                                                  
Total params: 2,696
Trainable params: 2,696
Non-trainable params: 0
________________________

# test with dat

In [50]:
import numpy as np

In [51]:
model.inputs

[<KerasTensor: shape=(None, 6, 16) dtype=float32 (created by layer 'input_5')>,
 <KerasTensor: shape=(None, 8, 16) dtype=float32 (created by layer 'input_6')>]

In [63]:
target_data = np.random.rand(
    3, 6, 16
    )
target_data.shape

(3, 6, 16)

In [64]:
source_data = np.random.rand(
    3, 8, 16
    )
source_data.shape

(3, 8, 16)

In [65]:
x = [
    target_data,
    source_data,
]

In [66]:
y = model.predict(x)



In [68]:
y[0].shape

(3, 6, 16)

In [69]:
y[1].shape

(3, 4, 6, 8)

In [71]:
y[1][0]

array([[[0.12772563, 0.12616478, 0.13036045, 0.12200306, 0.126748  ,
         0.12362163, 0.12640627, 0.11697017],
        [0.12910034, 0.12442674, 0.13401498, 0.12428208, 0.12229104,
         0.12355483, 0.12636365, 0.11596631],
        [0.12704283, 0.12630887, 0.13037072, 0.12476128, 0.12713385,
         0.12290147, 0.12376611, 0.11771494],
        [0.12777916, 0.12544581, 0.12852891, 0.12460414, 0.12698376,
         0.12330268, 0.12477107, 0.11858448],
        [0.12791145, 0.12790315, 0.127364  , 0.12279636, 0.12804875,
         0.12131897, 0.12502044, 0.11963688],
        [0.12939617, 0.1254393 , 0.12879187, 0.12205814, 0.12641588,
         0.12347578, 0.12615529, 0.1182676 ]],

       [[0.12477323, 0.12493065, 0.12220493, 0.12577006, 0.12733994,
         0.12542522, 0.12597567, 0.12358028],
        [0.12615266, 0.12568286, 0.12069038, 0.12493178, 0.1294339 ,
         0.12497088, 0.12619652, 0.12194103],
        [0.12298717, 0.12503068, 0.11959112, 0.12537137, 0.12853411,
         

In [74]:
multi_head_attention_layer = tf.keras.layers.MultiHeadAttention(
    num_heads = 2,
    key_dim = 2,
    attention_axes = (2,3),
    )

In [80]:
input_layer = tf.keras.layers.Input(
    shape = (5, 3, 4, 16),
    )

In [82]:
output_layer = multi_head_attention_layer(
    input_layer,
    input_layer,
    )

In [83]:
model = tf.keras.Model(
    inputs = input_layer,
    outputs = output_layer
    )

In [84]:
input_data = np.random.rand(
    100, 5, 3, 4, 16
    )
input_data.shape

(100, 5, 3, 4, 16)

In [85]:
y = model.predict(input_data)



# end