In [1]:
import keras
import tensorflow as tf

2022-06-26 01:50:30.122068: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2022-06-26 01:50:30.124795: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-06-26 01:50:30.124805: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [2]:
inp = keras.Input(shape=(100,))
emb = keras.layers.Embedding(100,16)(inp)
emb

2022-06-26 01:50:31.261727: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2022-06-26 01:50:31.261756: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
2022-06-26 01:50:31.261775: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (b76fc4b7beb9): /proc/driver/nvidia/version does not exist
2022-06-26 01:50:31.262023: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


<KerasTensor: shape=(None, 100, 16) dtype=float32 (created by layer 'embedding')>

In [3]:
mha = keras.layers.MultiHeadAttention(2, 64)
mha(emb, emb)

<KerasTensor: shape=(None, 100, 16) dtype=float32 (created by layer 'multi_head_attention')>

Let's check out the einsum equations used for transforming the inputs.

In [4]:
print(mha._dot_product_equation)
print(mha._combine_equation)
print(mha._output_dense.get_config()['equation'])

aecd,abcd->acbe
acbe,aecd->abcd
abcd,cde->abe


So what's going into these einsums?

In [5]:
q = mha._query_dense(emb)
k = mha._key_dense(emb)
v = mha._value_dense(emb)

In [6]:
q

<KerasTensor: shape=(None, 100, 2, 64) dtype=float32 (created by layer 'query')>

They're all tensors of shape `(None, 100, 2, 16)`, corresponding to `(batch, sequence_len, num_heads, hidden_dim)`

Our input shape is `(None, 100, 16)`.

The first thing that happens is a call to `tf.einsum('aecd,abcd->acbe', key, query)`. So let's follow the shapes here.

a = None
b = 100
c = 2
d = 16
e = 100

So the output shape should be `(None, 2, 100, 100)`.

Then, there is a call to `tf.einsum('acbe,aecd->abcd', attention, value)`. Again, let's follow the shapes.

a = None
b = 100
c = 2
d = 16
e = 100

So the output shape there should be `(None, 100, 2, 16)`.

Then there is a final einsum computation: 'abcd,cde->abe' which brings us back to `(None, 100, 16)`. Let's prove to ourselves that this works.

In [7]:
att = tf.einsum('aecd,abcd->acbe', k, q)
int = tf.einsum('acbe,aecd->abcd', att, v)
out = mha._output_dense(int)
out

<KerasTensor: shape=(None, 100, 16) dtype=float32 (created by layer 'attention_output')>

Ok, this works. So what happens when we project the sequence length of the key and value tensors to some constant before the einsums?

In [8]:
k_transformed = keras.layers.EinsumDense('bsnh,se->benh', output_shape=(32, None, None))(k)
v_transformed = keras.layers.EinsumDense('bsnh,sf->bfnh', output_shape=(32, None, None))(v)

In [9]:
k_transformed

<KerasTensor: shape=(None, 32, 2, 64) dtype=float32 (created by layer 'einsum_dense')>

In [10]:
att = tf.einsum('aecd,abcd->acbe', k_transformed, q)
int = tf.einsum('acbe,aecd->abcd', att, v_transformed)
out = mha._output_dense(int)
out

<KerasTensor: shape=(None, 100, 16) dtype=float32 (created by layer 'attention_output')>

Seems to work, let's try it out! But first let's get a time benchmark.

In [11]:
inp = keras.Input(shape=(2000,))
emb = keras.layers.Embedding(100,256)(inp)
mha = keras.layers.MultiHeadAttention(8, 128)(emb, emb)
model = keras.Model(inp, mha)
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 2000)]       0           []                               
                                                                                                  
 embedding_1 (Embedding)        (None, 2000, 256)    25600       ['input_2[0][0]']                
                                                                                                  
 multi_head_attention_1 (MultiH  (None, 2000, 256)   1051904     ['embedding_1[0][0]',            
 eadAttention)                                                    'embedding_1[0][0]']            
                                                                                                  
Total params: 1,077,504
Trainable params: 1,077,504
Non-trainable params: 0
__________________

In [12]:
model.predict([[1]*2000]*32)



array([[[-0.00235837, -0.00011801, -0.00334762, ...,  0.00549526,
         -0.00023075,  0.00329405],
        [-0.00235837, -0.00011801, -0.00334762, ...,  0.00549526,
         -0.00023075,  0.00329405],
        [-0.00235837, -0.00011801, -0.00334762, ...,  0.00549526,
         -0.00023075,  0.00329405],
        ...,
        [-0.00235837, -0.00011801, -0.00334762, ...,  0.00549526,
         -0.00023075,  0.00329405],
        [-0.00235837, -0.00011801, -0.00334762, ...,  0.00549526,
         -0.00023075,  0.00329405],
        [-0.00235837, -0.00011801, -0.00334762, ...,  0.00549526,
         -0.00023075,  0.00329405]],

       [[-0.00235837, -0.00011801, -0.00334762, ...,  0.00549526,
         -0.00023075,  0.00329405],
        [-0.00235837, -0.00011801, -0.00334762, ...,  0.00549526,
         -0.00023075,  0.00329405],
        [-0.00235837, -0.00011801, -0.00334762, ...,  0.00549526,
         -0.00023075,  0.00329405],
        ...,
        [-0.00235837, -0.00011801, -0.00334762, ...,  

Ok, so it takes about 5-6 seconds to predict on a batch size of 32 with sequence length 2000. Let's try the linformer layer.

In [13]:
from linformer.attention import LinearMultiHeadAttention

lmha = LinearMultiHeadAttention(8, 128, projection_dim=64)(emb, emb)
fast_model = keras.Model(inp, lmha)
fast_model.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 2000)]       0           []                               
                                                                                                  
 embedding_1 (Embedding)        (None, 2000, 256)    25600       ['input_2[0][0]']                
                                                                                                  
 linear_multi_head_attention (L  (None, 2000, 256)   1179904     ['embedding_1[0][0]',            
 inearMultiHeadAttention)                                         'embedding_1[0][0]']            
                                                                                                  
Total params: 1,205,504
Trainable params: 1,205,504
Non-trainable params: 0
________________

In [14]:
fast_model.predict([[1]*2000]*25)



array([[[ 0.00037992,  0.00090282, -0.00024732, ...,  0.00072621,
         -0.00168245, -0.00192442],
        [ 0.00037992,  0.00090282, -0.00024732, ...,  0.00072621,
         -0.00168245, -0.00192442],
        [ 0.00037992,  0.00090282, -0.00024732, ...,  0.00072621,
         -0.00168245, -0.00192442],
        ...,
        [ 0.00037992,  0.00090282, -0.00024732, ...,  0.00072621,
         -0.00168245, -0.00192442],
        [ 0.00037992,  0.00090282, -0.00024732, ...,  0.00072621,
         -0.00168245, -0.00192442],
        [ 0.00037992,  0.00090282, -0.00024732, ...,  0.00072621,
         -0.00168245, -0.00192442]],

       [[ 0.00037992,  0.00090282, -0.00024732, ...,  0.00072621,
         -0.00168245, -0.00192442],
        [ 0.00037992,  0.00090282, -0.00024732, ...,  0.00072621,
         -0.00168245, -0.00192442],
        [ 0.00037992,  0.00090282, -0.00024732, ...,  0.00072621,
         -0.00168245, -0.00192442],
        ...,
        [ 0.00037992,  0.00090282, -0.00024732, ...,  

This reports a predict time of <1s for our expected max batch size of 25!

Theoretically this should scale linearly, so we can look at how long it takes to predict on a batch size of 1 and extrapolate (more or less) from there.

In [15]:
fast_model.predict([[1]*2000])



array([[[ 0.00037992,  0.00090282, -0.00024732, ...,  0.00072621,
         -0.00168245, -0.00192442],
        [ 0.00037992,  0.00090282, -0.00024732, ...,  0.00072621,
         -0.00168245, -0.00192442],
        [ 0.00037992,  0.00090282, -0.00024732, ...,  0.00072621,
         -0.00168245, -0.00192442],
        ...,
        [ 0.00037992,  0.00090282, -0.00024732, ...,  0.00072621,
         -0.00168245, -0.00192442],
        [ 0.00037992,  0.00090282, -0.00024732, ...,  0.00072621,
         -0.00168245, -0.00192442],
        [ 0.00037992,  0.00090282, -0.00024732, ...,  0.00072621,
         -0.00168245, -0.00192442]]], dtype=float32)