## Generalized matrix product Attension adapted to use TensorFlow from 
### https://machinelearningmastery.com/the-attention-mechanism-from-scratch/

In [1]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import tensorflow as tf

In [None]:
word_0 = tf.convert_to_tensor([1, 0, 1], dtype=tf.float32)[tf.newaxis, :]

In [4]:
word_0

<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[1., 0., 1.]], dtype=float32)>

In [5]:
word_1 = tf.convert_to_tensor([0, 1, 0], dtype=tf.float32)[tf.newaxis, :]
word_2 = tf.convert_to_tensor([1, 1, 0], dtype=tf.float32)[tf.newaxis, :]
word_3 = tf.convert_to_tensor([0, 0, 1], dtype=tf.float32)[tf.newaxis, :]

In [6]:
words = tf.stack([word_0, word_1, word_2, word_3], axis=1)

In [7]:
words

<tf.Tensor: shape=(1, 4, 3), dtype=float32, numpy=
array([[[1., 0., 1.],
        [0., 1., 0.],
        [1., 1., 0.],
        [0., 0., 1.]]], dtype=float32)>

In [8]:
words = tf.concat([words, words], axis=0)

In [9]:
words

<tf.Tensor: shape=(2, 4, 3), dtype=float32, numpy=
array([[[1., 0., 1.],
        [0., 1., 0.],
        [1., 1., 0.],
        [0., 0., 1.]],

       [[1., 0., 1.],
        [0., 1., 0.],
        [1., 1., 0.],
        [0., 0., 1.]]], dtype=float32)>

In [10]:
tf.random.set_seed(42)

In [11]:
W_Q = tf.random.normal((3, 3))
W_K = tf.random.normal((3, 3))
W_V = tf.random.normal((3, 3))

In [12]:
W_Q

<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[ 0.3274685, -0.8426258,  0.3194337],
       [-1.4075519, -2.3880599, -1.0392479],
       [-0.5573232,  0.539707 ,  1.6994323]], dtype=float32)>

In [13]:
words @ W_Q

<tf.Tensor: shape=(2, 4, 3), dtype=float32, numpy=
array([[[-0.2298547, -0.3029188,  2.018866 ],
        [-1.4075519, -2.3880599, -1.0392479],
        [-1.0800834, -3.2306857, -0.7198142],
        [-0.5573232,  0.539707 ,  1.6994323]],

       [[-0.2298547, -0.3029188,  2.018866 ],
        [-1.4075519, -2.3880599, -1.0392479],
        [-1.0800834, -3.2306857, -0.7198142],
        [-0.5573232,  0.539707 ,  1.6994323]]], dtype=float32)>

In [14]:
tf.tensordot(words[0][0], W_Q, 1)

<tf.Tensor: shape=(3,), dtype=float32, numpy=array([-0.2298547, -0.3029188,  2.018866 ], dtype=float32)>

In [15]:
tf.tensordot(words[0][1], W_Q, 1)

<tf.Tensor: shape=(3,), dtype=float32, numpy=array([-1.4075519, -2.3880599, -1.0392479], dtype=float32)>

In [16]:
Q = words @ W_Q
K = words @ W_K
V = words @ W_V

In [17]:
K.shape

TensorShape([2, 4, 3])

In [18]:
scores = Q @ tf.transpose(K, perm=(0, 2, 1))

In [19]:
scores

<tf.Tensor: shape=(2, 4, 4), dtype=float32, numpy=
array([[[ 0.22580178,  1.3982916 ,  2.403096  , -0.7790025 ],
        [ 2.4536927 ,  0.5462186 ,  2.0905945 ,  0.909317  ],
        [ 3.0849564 ,  1.1585747 ,  3.5767365 ,  0.666795  ],
        [-0.40546206,  0.7859355 ,  0.91695386, -0.5364804 ]],

       [[ 0.22580178,  1.3982916 ,  2.403096  , -0.7790025 ],
        [ 2.4536927 ,  0.5462186 ,  2.0905945 ,  0.909317  ],
        [ 3.0849564 ,  1.1585747 ,  3.5767365 ,  0.666795  ],
        [-0.40546206,  0.7859355 ,  0.91695386, -0.5364804 ]]],
      dtype=float32)>

In [20]:
Q[0][0]

<tf.Tensor: shape=(3,), dtype=float32, numpy=array([-0.2298547, -0.3029188,  2.018866 ], dtype=float32)>

In [21]:
K[0]

<tf.Tensor: shape=(4, 3), dtype=float32, numpy=
array([[-0.24659589, -0.8622878 , -0.04561105],
       [-0.00519627, -0.49453196,  0.6178192 ],
       [ 0.07902831, -1.3554357 ,  0.99594223],
       [-0.33082047, -0.00138408, -0.4237341 ]], dtype=float32)>

In [22]:
tf.transpose(K, perm=(0, 2, 1))

<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[-0.24659589, -0.00519627,  0.07902831, -0.33082047],
        [-0.8622878 , -0.49453196, -1.3554357 , -0.00138408],
        [-0.04561105,  0.6178192 ,  0.99594223, -0.4237341 ]],

       [[-0.24659589, -0.00519627,  0.07902831, -0.33082047],
        [-0.8622878 , -0.49453196, -1.3554357 , -0.00138408],
        [-0.04561105,  0.6178192 ,  0.99594223, -0.4237341 ]]],
      dtype=float32)>

In [23]:
tf.tensordot(Q[0][0], tf.transpose(K, perm=(0, 2, 1))[0], 1)

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([ 0.22580181,  1.3982916 ,  2.403096  , -0.7790025 ], dtype=float32)>

In [24]:
weights = tf.nn.softmax(scores, axis=1)

In [25]:
weights

<tf.Tensor: shape=(2, 4, 4), dtype=float32, numpy=
array([[[0.035387  , 0.36291677, 0.19261877, 0.08382176],
        [0.32840103, 0.15479483, 0.1409227 , 0.45350763],
        [0.61738896, 0.28556126, 0.6228797 , 0.35584316],
        [0.01882302, 0.19672711, 0.0435788 , 0.10682744]],

       [[0.035387  , 0.36291677, 0.19261877, 0.08382176],
        [0.32840103, 0.15479483, 0.1409227 , 0.45350763],
        [0.61738896, 0.28556126, 0.6228797 , 0.35584316],
        [0.01882302, 0.19672711, 0.0435788 , 0.10682744]]], dtype=float32)>

In [26]:
tf.reduce_sum(weights, axis=1)

<tf.Tensor: shape=(2, 4), dtype=float32, numpy=
array([[1.        , 0.99999994, 0.99999994, 0.99999994],
       [1.        , 0.99999994, 0.99999994, 0.99999994]], dtype=float32)>

In [27]:
attention = weights @ V

In [28]:
attention

<tf.Tensor: shape=(2, 4, 3), dtype=float32, numpy=
array([[[-0.9646175 ,  0.5957003 ,  0.11186798],
        [-0.4880892 ,  1.7590883 ,  1.0904744 ],
        [-1.8241206 ,  2.273948  ,  2.4636662 ],
        [-0.37432897,  0.44493032, -0.01641299]],

       [[-0.9646175 ,  0.5957003 ,  0.11186798],
        [-0.4880892 ,  1.7590883 ,  1.0904744 ],
        [-1.8241206 ,  2.273948  ,  2.4636662 ],
        [-0.37432897,  0.44493032, -0.01641299]]], dtype=float32)>

In [29]:
attention.shape

TensorShape([2, 4, 3])