In [1]:
import tensorflow as tf
import numpy as np

x = tf.keras.Input(shape=[1, 3])
layer  = tf.keras.layers.MultiHeadAttention(num_heads=1, key_dim=2, use_bias=False)

output_tensor = layer(x, x)
print(output_tensor.shape)

(None, 1, 3)


In [2]:
weights = layer.get_weights()
print(len(weights))

4


In [3]:
print(weights[0].shape)
print(weights[1].shape)
print(weights[2].shape)
print(weights[3].shape)

(3, 1, 2)
(3, 1, 2)
(3, 1, 2)
(1, 2, 3)


In [4]:
q = np.array([[[ 0.4,  0.3 ]],
              [[-0.1, -0.1]],
              [[ 0.2, -0.1]]])
k = np.array([[[ 0.1,  0.2 ]],
              [[-0.3, -0.4]],
              [[-0.1,  0.2]]])
v = np.array([[[-0.2,  0.1 ]],
              [[-0.4,  0.2]],
              [[ 0.4, -0.6]]])
o = np.array([[[ 0.1, -0.1,  0.6 ],
               [ 0.9,  0.3,  0.1 ]]])

In [5]:
layer.set_weights([q, k, v, o])

In [6]:
weights = layer.get_weights()
print(weights[0])
print(weights[1])
print(weights[2])
print(weights[3])

[[[ 0.4  0.3]]

 [[-0.1 -0.1]]

 [[ 0.2 -0.1]]]
[[[ 0.1  0.2]]

 [[-0.3 -0.4]]

 [[-0.1  0.2]]]
[[[-0.2  0.1]]

 [[-0.4  0.2]]

 [[ 0.4 -0.6]]]
[[[ 0.1 -0.1  0.6]
  [ 0.9  0.3  0.1]]]


In [7]:
data = np.array([1., 3., 2.])
data = data.reshape((1, 1, 3))
print(data.shape)
print(data)

(1, 1, 3)
[[[1. 3. 2.]]]


In [8]:
output_tensor, weights = layer(data, data, return_attention_scores=True)
print(output_tensor.shape)
print(weights.shape)

(1, 1, 3)
(1, 1, 1, 1)


In [9]:
print(output_tensor)
print(weights)

tf.Tensor([[[-0.51      -0.09      -0.4100001]]], shape=(1, 1, 3), dtype=float32)
tf.Tensor([[[[1.]]]], shape=(1, 1, 1, 1), dtype=float32)


## verify

In [10]:
q = np.array([[[ 0.4,  0.3 ]],
              [[-0.1, -0.1]],
              [[ 0.2, -0.1]]])
k = np.array([[[ 0.1,  0.2 ]],
              [[-0.3, -0.4]],
              [[-0.1,  0.2]]])
v = np.array([[[-0.2,  0.1 ]],
              [[-0.4,  0.2]],
              [[ 0.4, -0.6]]])

q = q.reshape((3, 2))
k = k.reshape((3, 2))
v = v.reshape((3, 2))

print(q)
print(k)
print(v)

[[ 0.4  0.3]
 [-0.1 -0.1]
 [ 0.2 -0.1]]
[[ 0.1  0.2]
 [-0.3 -0.4]
 [-0.1  0.2]]
[[-0.2  0.1]
 [-0.4  0.2]
 [ 0.4 -0.6]]


In [11]:
data = np.array([1., 3., 2.])
data = data.reshape((1, 3))
print(data.shape)
print(data)

(1, 3)
[[1. 3. 2.]]


In [12]:
q_value = np.dot(data, q)
k_value = np.dot(data, k)
v_value = np.dot(data, v)
 
print(q_value)
print(k_value)
print(v_value)

[[ 0.5 -0.2]]
[[-1.  -0.6]]
[[-0.6 -0.5]]


In [29]:
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

alpha = softmax([np.dot(k_value, q_value)])
print(alpha)

[1.]


In [30]:
context_vector = alpha*v_value
print(context_vector)

[-0.6 -0.5]


In [31]:
o = np.array([[[ 0.1, -0.1,  0.6 ],
               [ 0.9,  0.3,  0.1 ]]])
o = o.reshape((2, 3))
print(o)

[[ 0.1 -0.1  0.6]
 [ 0.9  0.3  0.1]]


In [32]:
output = np.dot(context_vector, o)
print(output)

[-0.51 -0.09 -0.41]


In [6]:
import tensorflow as tf
import numpy as np

data   = np.array([1., 3., 2.]).reshape((1, 1, 3))
layer  = tf.keras.layers.MultiHeadAttention(num_heads=1, key_dim=2)
output = layer(query=data, value=data, key=data)
print(output)

tf.Tensor([[[ 2.5561461 -0.7888286  0.6175448]]], shape=(1, 1, 3), dtype=float32)


tf.Tensor([[[ 1.1402409 -1.725717  -1.1650776]]], shape=(1, 1, 3), dtype=float32)
