In [1]:
import tensorflow as tf
import numpy as np


In [2]:
time_steps, channels = 4, 2

data = tf.math.round(tf.random.uniform((time_steps, channels), minval=0, maxval=9))
data

<tf.Tensor: shape=(4, 2), dtype=float32, numpy=
array([[7., 4.],
       [7., 6.],
       [2., 8.],
       [1., 3.]], dtype=float32)>

In [3]:
lower_diag = tf.linalg.band_part(tf.ones((time_steps, time_steps)), -1, 0)
lower_diag

<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]], dtype=float32)>

In [4]:
lower_diag / tf.reduce_sum(lower_diag, axis=1, keepdims=True)


<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
array([[1.        , 0.        , 0.        , 0.        ],
       [0.5       , 0.5       , 0.        , 0.        ],
       [0.33333334, 0.33333334, 0.33333334, 0.        ],
       [0.25      , 0.25      , 0.25      , 0.25      ]], dtype=float32)>

In [5]:
tf.reduce_sum(lower_diag, axis=1, keepdims=True)

<tf.Tensor: shape=(4, 1), dtype=float32, numpy=
array([[1.],
       [2.],
       [3.],
       [4.]], dtype=float32)>

In [6]:
tf.keras.activations.softmax(tf.where(lower_diag == 1, lower_diag, float('-inf')), axis=1)

<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
array([[1.        , 0.        , 0.        , 0.        ],
       [0.5       , 0.5       , 0.        , 0.        ],
       [0.33333334, 0.33333334, 0.33333334, 0.        ],
       [0.25      , 0.25      , 0.25      , 0.25      ]], dtype=float32)>

In [7]:
tf.convert_to_tensor(
    [
        [
            [1, 1, 1, 1],
            [2, 2, 2, 2],
            [3, 3, 3, 3]
        ],
        [
            [4, 4, 4, 4],
            [5, 5, 5, 5],
            [6, 6, 6, 6]
        ]
    ]
)

<tf.Tensor: shape=(2, 3, 4), dtype=int32, numpy=
array([[[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]],

       [[4, 4, 4, 4],
        [5, 5, 5, 5],
        [6, 6, 6, 6]]])>

In [16]:
def head_mul(data, transform):
    '''
    head = (batch, F)
    transform = (heads, HF, F)
    result = (batch, heads, HF)
    '''
    assert (tf.rank(data) == 2)
    assert (tf.rank(transform) == 3)
    assert (data.shape[1] == transform.shape[2])

    B, F = data.shape
    return tf.squeeze(transform @ tf.reshape(data, (B, 1, F, 1)))


def head_mul_simple(data, transform):
    assert (tf.rank(data) == 2)
    assert (tf.rank(transform) == 3)
    assert (data.shape[1] == transform.shape[2])

    B, F = data.shape
    H, HF, _ = transform.shape

    result = np.zeros((B, H, HF))

    for b in range(B):
        for h in range(H):
            for hf in range(HF):
                result[b, h, hf] = tf.tensordot(data[b, :],  transform[h, hf, :], 1).numpy()

    return tf.convert_to_tensor(result, dtype=tf.float32)



In [19]:
batch, features = 5, 4
heads, head_features = 3, 2

data = tf.random.normal((batch, features))
transform = tf.random.normal((heads, head_features, features))

result = head_mul(data, transform)
tf.debugging.assert_near(result, head_mul_simple(data, transform))

result

<tf.Tensor: shape=(5, 3, 2), dtype=float32, numpy=
array([[[-1.3783665 , -0.24565655],
        [-2.4135888 ,  2.6722853 ],
        [ 3.0694003 ,  0.18428665]],

       [[-0.8693649 , -3.979703  ],
        [ 0.48059523,  2.8107116 ],
        [-0.10500497, -0.44696766]],

       [[ 0.22860071, -1.7451897 ],
        [-4.4174647 , -1.5649848 ],
        [ 2.5846443 , -0.6335654 ]],

       [[-1.823693  ,  0.27538848],
        [ 0.6374525 ,  3.9810205 ],
        [ 1.1162341 ,  0.6864896 ]],

       [[-0.26496747, -2.0837076 ],
        [ 0.04458502,  0.51541185],
        [-0.37498823, -0.21276651]]], dtype=float32)>