In [94]:
from tensorflow.keras.layers import MultiHeadAttention
import tensorflow as tf

### 参考torch.nn.MultiheadAttention

In [95]:
layer = MultiHeadAttention(num_heads=5,  # Number of attention heads.
                           key_dim=200,  # Size of each attention head for query and key.
                           value_dim=50,  # Size of each attention head for value.
                           dropout=0.1)  # Dropout probability.

mask = tf.range(4)[None, :]  < tf.constant([2, 2, 2, 2, 3, 3, 3, 3])[:, None]
mask = tf.expand_dims(mask, 0)
mask = tf.repeat(mask, 10, axis=0)

query = tf.random.stateless_uniform(shape=(10, 8, 40), seed=(1, 1))  # query.shape=(B, T, dim_q)
value = tf.random.stateless_uniform(shape=(10, 4, 30), seed=(1, 1))  # value.shape=(B, S, dim_v)
key = tf.random.stateless_uniform(shape=(10, 4, 20), seed=(1, 1))  # key.shape=(B, S, dim_k)

attention_output, attention_scores = layer(query=query,
                                           value=value,
                                           # if not given, will use value for both key and value, which is the most common case.
                                           key=key,
                                           # 默认return_attention_scores=False
                                           return_attention_scores=True,
                                           # attention_mask.shape=(B, T, S)
                                           # 1 indicates attention and 0 indicates no attention.(与pytorch含义相反)
                                           attention_mask=mask,
                                           # 训练模式还是评估模式(默认training=False)
                                           training=True)
# attention_output.shape=(B, T, E)
# T is for target sequence shapes
# E is the query input last dimension
print(attention_output.shape)

# attention_scores.shape=(B, num_heads, T, S)
print(attention_scores.shape)

(10, 8, 40)
(10, 5, 8, 4)


In [96]:
print(mask[0, :, :])
print(attention_scores[0, 0, :, :])

tf.Tensor(
[[ True  True False False]
 [ True  True False False]
 [ True  True False False]
 [ True  True False False]
 [ True  True  True False]
 [ True  True  True False]
 [ True  True  True False]
 [ True  True  True False]], shape=(8, 4), dtype=bool)
tf.Tensor(
[[0.49942783 0.50057214 0.         0.        ]
 [0.49945465 0.5005454  0.         0.        ]
 [0.5003021  0.49969792 0.         0.        ]
 [0.50053257 0.49946746 0.         0.        ]
 [0.3331651  0.33336708 0.33346784 0.        ]
 [0.33350733 0.33324602 0.33324662 0.        ]
 [0.33371982 0.33332738 0.33295283 0.        ]
 [0.3335802  0.33322385 0.33319595 0.        ]], shape=(8, 4), dtype=float32)


In [97]:
for i in layer.weights:
    print(i.name, i.shape)


multi_head_attention_25/query/kernel:0 (40, 5, 200)
multi_head_attention_25/query/bias:0 (5, 200)
multi_head_attention_25/key/kernel:0 (20, 5, 200)
multi_head_attention_25/key/bias:0 (5, 200)
multi_head_attention_25/value/kernel:0 (30, 5, 50)
multi_head_attention_25/value/bias:0 (5, 50)
multi_head_attention_25/attention_output/kernel:0 (5, 50, 40)
multi_head_attention_25/attention_output/bias:0 (40,)
