In [1]:
import tensorflow as tf
import numpy as np

In [8]:
def head_mul(data, transform):
    '''
    data: batch, F, FD
    transform: heads, FD, HFD
    result: (batch, heads, F, HFD)

    '''
    assert (tf.rank(data) == 3)
    assert (tf.rank(transform) == 3)
    assert (data.shape[2] == transform.shape[1])

    return tf.expand_dims(data, 1) @ transform


def head_mul_simple(data, transform):
    '''
    data: batch, F, FD
    transform: heads, FD, HFD
    result: (batch, heads, F, HFD)
    '''
    assert (tf.rank(data) == 3)
    assert (tf.rank(transform) == 3)
    assert (data.shape[2] == transform.shape[1])

    B, F, FD = data.shape
    H, _, HFD = transform.shape

    result = np.zeros((B, H, F, HFD))

    for b in range(B):
        for h in range(H):
            for f in range(F):
                for fd in range(HFD):
                    result[b, h, f, fd] = tf.tensordot(data[b, f, :],  transform[h, :, fd], 1).numpy()

    return tf.convert_to_tensor(result, dtype=tf.float32)


In [9]:
batch, F, FD = 8, 5, 4
heads, HFD = 3, 2

data = tf.random.normal((batch, F, FD))
transform = tf.random.normal((heads,  FD, HFD))

result =  head_mul(data, transform)

tf.debugging.assert_near(result, head_mul_simple(data, transform))


In [10]:
batch, F, FD = 8, 5, 4
heads, HFD = 3, 2

data = tf.random.normal((batch, F, FD))
key_transform = tf.random.normal((heads,  FD, HFD))
value_transform = tf.random.normal((heads,  FD, HFD))
query_transform = tf.random.normal((heads,  FD, HFD))

key = head_mul(data, key_transform)
value = head_mul(data, value_transform)
query = head_mul(data, query_transform)

key.shape

TensorShape([8, 3, 5, 2])

In [11]:
def attention(query, key):
    '''
    query: (batch, heads, F, FD)
    key: (batch, heads, F, FD)
    result: (batch, heads, F (query), F (key))
    '''
    assert (tf.rank(query) == 4)
    assert (query.shape == key.shape)

    return query @ tf.transpose(key, (0, 1, 3, 2))


def attention_simple(query, key):
    '''
    query: (batch, heads, F, FD)
    key: (batch, heads, F, FD)
    result: (batch, heads, F (query), F (key))
    '''
    assert (tf.rank(query) == 4)
    assert (query.shape == key.shape)

    B, H, F, _ = query.shape
    result = np.zeros((B, H, F, F))

    for b in range(B):
        for h in range(H):
            for q in range(F):
                for k in range(F):
                    result[b, h, q, k] =  tf.tensordot(query[b, h, q, :], key[b, h, k, :], 1).numpy()
    
    return tf.convert_to_tensor(result, dtype=tf.float32)

attention_simple(query, key)[0, 0, :, :]

<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[ 5.8921366 , -8.357531  ,  6.0643506 ,  5.5240784 ,  7.334341  ],
       [-5.014019  ,  8.240324  , -4.897439  , -4.9543586 , -7.053102  ],
       [ 2.7879987 , -2.0333219 ,  3.3175197 ,  2.1821284 ,  2.0881238 ],
       [-0.5884373 ,  2.2399423 , -0.27792072, -0.86745894, -1.7435443 ],
       [ 1.438715  , -1.8663764 ,  1.5214187 ,  1.3096718 ,  1.6654413 ]],
      dtype=float32)>

In [12]:
att = attention(query, key)

tf.debugging.assert_near(att, attention_simple(query, key))

print('attention')
print(att[0, 0, :, :].numpy())

lower_diag_mask = tf.linalg.band_part(tf.ones((F, F), dtype=tf.bool), -1, 0)

print('\nlower diagional mask')
print(lower_diag_mask.numpy())

att = tf.where(lower_diag_mask, att, float('-inf'))

print('\nmasked attention')
print(att[0, 0, :, :].numpy())

att = tf.keras.activations.softmax(att, axis = 3)

print('\nattention normalised with softmax')
print(att[0, 0, :, :].numpy())


print('\nsum accross rows (should be 1)')
print(tf.reduce_sum(att[0, 0, :, :], 1).numpy())

attention
[[ 5.8921366  -8.357531    6.0643506   5.5240784   7.334341  ]
 [-5.014019    8.240324   -4.897439   -4.9543586  -7.053102  ]
 [ 2.7879987  -2.0333219   3.3175197   2.1821284   2.0881238 ]
 [-0.5884373   2.2399423  -0.27792072 -0.86745894 -1.7435443 ]
 [ 1.438715   -1.8663764   1.5214187   1.3096718   1.6654413 ]]

lower diagional mask
[[ True False False False False]
 [ True  True False False False]
 [ True  True  True False False]
 [ True  True  True  True False]
 [ True  True  True  True  True]]

masked attention
[[ 5.8921366         -inf        -inf        -inf        -inf]
 [-5.014019    8.240324          -inf        -inf        -inf]
 [ 2.7879987  -2.0333219   3.3175197         -inf        -inf]
 [-0.5884373   2.2399423  -0.27792072 -0.86745894        -inf]
 [ 1.438715   -1.8663764   1.5214187   1.3096718   1.6654413 ]]

attention normalised with softmax
[[1.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
 [1.7527145e-06 9.9999821e-01 0.0000000e+00 

In [13]:
print('att')
print(att[0, 0, :, :].numpy())

print('\nvalue')
print(value[0, 0, :, :].numpy())

res = att @ value

print('\nres = att @ value')
print(res[0, 0, :, :].numpy())

print('\nfeature 0 dims for all heads')
print(res[0, :, 0, :].numpy())

res = tf.transpose(res, (0, 2, 1, 3))
B, F, H, FD =  res.shape
res = tf.reshape(res, (B, F, H*FD))

print('\nfeature 0 dims for all heads stacked')
print(res[0, 0, :].numpy())

att
[[1.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
 [1.7527145e-06 9.9999821e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00]
 [3.6952525e-01 2.9769482e-03 6.2749779e-01 0.0000000e+00 0.0000000e+00]
 [4.9903493e-02 8.4426850e-01 6.8074830e-02 3.7753161e-02 0.0000000e+00]
 [2.3494373e-01 8.6214626e-03 2.5520056e-01 2.0650052e-01 2.9473373e-01]]

value
[[-0.15064985  0.65298265]
 [-1.7147651  -3.2876022 ]
 [-0.8068951   1.0620738 ]
 [-0.07752123  2.2077055 ]
 [ 0.94131666  3.1736104 ]]

res = att @ value
[[-0.15064985  0.65298265]
 [-1.7147622  -3.2875953 ]
 [-0.56709856  0.89795554]
 [-1.5130961  -2.5873847 ]
 [ 0.00533149  1.7873744 ]]

feature 0 dims for all heads
[[-0.15064985  0.65298265]
 [-0.16928649 -0.9094292 ]
 [ 1.8130186   2.3197675 ]]

feature 0 dims for all heads stacked
[-0.15064985  0.65298265 -0.16928649 -0.9094292   1.8130186   2.3197675 ]
