In [1]:
import tensorflow as tf
import numpy as np


In [2]:
time_steps, channels = 4, 2

data = tf.math.round(tf.random.uniform((time_steps, channels), minval=0, maxval=9))
data

<tf.Tensor: shape=(4, 2), dtype=float32, numpy=
array([[4., 1.],
       [1., 3.],
       [5., 1.],
       [4., 8.]], dtype=float32)>

In [3]:
lower_diag = tf.linalg.band_part(tf.ones((time_steps, time_steps)), -1, 0)
lower_diag

<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]], dtype=float32)>

In [4]:
lower_diag / tf.reduce_sum(lower_diag, axis=1, keepdims=True)


<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
array([[1.        , 0.        , 0.        , 0.        ],
       [0.5       , 0.5       , 0.        , 0.        ],
       [0.33333334, 0.33333334, 0.33333334, 0.        ],
       [0.25      , 0.25      , 0.25      , 0.25      ]], dtype=float32)>

In [5]:
tf.reduce_sum(lower_diag, axis=1, keepdims=True)

<tf.Tensor: shape=(4, 1), dtype=float32, numpy=
array([[1.],
       [2.],
       [3.],
       [4.]], dtype=float32)>

In [6]:
tf.keras.activations.softmax(tf.where(lower_diag == 1, lower_diag, float('-inf')), axis=1)

<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
array([[1.        , 0.        , 0.        , 0.        ],
       [0.5       , 0.5       , 0.        , 0.        ],
       [0.33333334, 0.33333334, 0.33333334, 0.        ],
       [0.25      , 0.25      , 0.25      , 0.25      ]], dtype=float32)>

In [7]:
tf.convert_to_tensor(
    [
        [
            [1, 1, 1, 1],
            [2, 2, 2, 2],
            [3, 3, 3, 3]
        ],
        [
            [4, 4, 4, 4],
            [5, 5, 5, 5],
            [6, 6, 6, 6]
        ]
    ]
)

<tf.Tensor: shape=(2, 3, 4), dtype=int32, numpy=
array([[[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]],

       [[4, 4, 4, 4],
        [5, 5, 5, 5],
        [6, 6, 6, 6]]])>

In [8]:
def head_mul(data, transform):
    '''
    data: batch, F, FD
    transform: heads, FD, HFD
    result: (batch, heads, F, HFD)

    '''
    assert (tf.rank(data) == 3)
    assert (tf.rank(transform) == 3)
    assert (data.shape[2] == transform.shape[1])

    return tf.expand_dims(data, 1) @ transform


def head_mul_simple(data, transform):
    '''
    data: batch, F, FD
    transform: heads, FD, HFD
    result: (batch, heads, F, HFD)
    '''
    assert (tf.rank(data) == 3)
    assert (tf.rank(transform) == 3)
    assert (data.shape[2] == transform.shape[1])

    B, F, FD = data.shape
    H, _, HFD = transform.shape

    result = np.zeros((B, H, F, HFD))

    for b in range(B):
        for h in range(H):
            for f in range(F):
                for fd in range(HFD):
                    result[b, h, f, fd] = tf.tensordot(data[b, f, :],  transform[h, :, fd], 1).numpy()

    return tf.convert_to_tensor(result, dtype=tf.float32)


In [9]:
batch, F, FD = 8, 5, 4
heads, HFD = 3, 2

data = tf.random.normal((batch, F, FD))
transform = tf.random.normal((heads,  FD, HFD))

result =  head_mul(data, transform)

tf.debugging.assert_near(result, head_mul_simple(data, transform))


In [10]:
batch, F, FD = 8, 5, 4
heads, HFD = 3, 2

data = tf.random.normal((batch, F, FD))
key_transform = tf.random.normal((heads,  FD, HFD))
value_transform = tf.random.normal((heads,  FD, HFD))
query_transform = tf.random.normal((heads,  FD, HFD))

key = head_mul(data, key_transform)
value = head_mul(data, value_transform)
query = head_mul(data, query_transform)



In [11]:
def attention(query, key):
    '''
    query: (batch, heads, F, FD)
    key: (batch, heads, F, FD)
    result: (batch, heads, F (query), F (key))
    '''
    assert (tf.rank(query) == 4)
    assert (query.shape == key.shape)

    return query @ tf.transpose(key, (0, 1, 3, 2))


def attention_simple(query, key):
    '''
    query: (batch, heads, F, FD)
    key: (batch, heads, F, FD)
    result: (batch, heads, F (query), F (key))
    '''
    assert (tf.rank(query) == 4)
    assert (query.shape == key.shape)

    B, H, F, _ = query.shape
    result = np.zeros((B, H, F, F))

    for b in range(B):
        for h in range(H):
            for q in range(F):
                for k in range(F):
                    result[b, h, q, k] =  tf.tensordot(query[b, h, q, :], key[b, h, k, :], 1).numpy()
    
    return tf.convert_to_tensor(result, dtype=tf.float32)



In [12]:
att = attention(query, key)

tf.debugging.assert_near(att, attention_simple(query, key))

print(att[0, 0, :, :].numpy())

lower_diag_mask = tf.linalg.band_part(tf.ones((F, F), dtype=tf.bool), -1, 0)

print(lower_diag_mask.numpy())

att = tf.where(lower_diag_mask, att, float('-inf'))

print(att[0, 0, :, :].numpy())

att = tf.keras.activations.softmax(att, axis = 3)

print(att[0, 0, :, :].numpy())

[[ -1.0842866    0.20300126   0.08299351   0.91089964   2.7099497 ]
 [ -3.238945    -4.5375476  -12.838415     3.9008617   -4.7610292 ]
 [ -8.057059    -9.966704   -28.576384     9.400688    -8.54254   ]
 [  1.7756287    1.2854252    3.9799676   -1.862774    -0.39434886]
 [ -2.5942352   -3.2000687   -9.17811      3.0247874   -2.7279596 ]]
[[ True False False False False]
 [ True  True False False False]
 [ True  True  True False False]
 [ True  True  True  True False]
 [ True  True  True  True  True]]
[[ -1.0842866        -inf        -inf        -inf        -inf]
 [ -3.238945   -4.5375476        -inf        -inf        -inf]
 [ -8.057059   -9.966704  -28.576384         -inf        -inf]
 [  1.7756287   1.2854252   3.9799676  -1.862774         -inf]
 [ -2.5942352  -3.2000687  -9.17811     3.0247874  -2.7279596]]
[[1.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
 [7.8559971e-01 2.1440029e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00]
 [8.7097931e-01 1.2902074e-01 1