<img src="https://github.com/martin-fabbri/colab-notebooks/raw/master/bert/images/attention-zoom-in.png" width=1000px alt="Big Picture"/>

In [1]:
import tensorflow as tf

## 1. Inputs

In [2]:
x = [
    [1.0, 0.0, 1.0, 0.0],  # input 1
    [0.0, 2.0, 0.0, 2.0],  # input 2
    [1.0, 1.0, 1.0, 1.0],  # input 3
]

x = tf.constant(x)

## 2. Initialize Queries, Keys, and Values Weights

In [3]:
w_key = [
  [0.0, 0.0, 1.0],
  [1.0, 1.0, 0.0],
  [0.0, 1.0, 0.0],
  [1.0, 1.0, 0.0]
]
w_query = [
  [1.0, 0.0, 1.0],
  [1.0, 0.0, 0.0],
  [0.0, 0.0, 1.0],
  [0.0, 1.0, 1.0]
]
w_value = [
  [0.0, 2.0, 0.0],
  [0.0, 3.0, 0.0],
  [1.0, 0.0, 3.0],
  [1.0, 1.0, 0.0]
]
w_key = tf.constant(w_key)
w_query = tf.constant(w_query)
w_value = tf.constant(w_value)

In [4]:
keys = tf.linalg.matmul(x, w_key)
keys

<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[0., 1., 1.],
       [4., 4., 0.],
       [2., 3., 1.]], dtype=float32)>

In [5]:
queries = tf.linalg.matmul(x, w_query)
queries

<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[1., 0., 2.],
       [2., 2., 2.],
       [2., 1., 3.]], dtype=float32)>

In [6]:
values = tf.linalg.matmul(x, w_value)
values

<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[1., 2., 3.],
       [2., 8., 0.],
       [2., 6., 3.]], dtype=float32)>

<img src="https://github.com/martin-fabbri/colab-notebooks/raw/master/bert/images/attention-nn.png" alt="self-attention block" width=800px>

## 2. Calculate attention scores

<img src="https://github.com/martin-fabbri/colab-notebooks/raw/master/bert/images/multi-head-attention.png" alt="multihead-attention" width="700px">

In [15]:
attention_scores = tf.matmul(queries, keys, transpose_b=True)
attention_scores

<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[ 2.,  4.,  4.],
       [ 4., 16., 12.],
       [ 4., 12., 10.]], dtype=float32)>

### 3. Softmax

In [16]:
attention_scores_softmax = tf.nn.softmax(attention_scores)
attention_scores_softmax

<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[6.3378938e-02, 4.6831051e-01, 4.6831051e-01],
       [6.0336647e-06, 9.8200780e-01, 1.7986100e-02],
       [2.9538720e-04, 8.8053685e-01, 1.1916770e-01]], dtype=float32)>

## 4. Multiply scores with values

In [17]:
weighted_values = values[:, None] * tf.transpose(attention_scores_softmax)[:,:,None]
weighted_values

<tf.Tensor: shape=(3, 3, 3), dtype=float32, numpy=
array([[[6.3378938e-02, 1.2675788e-01, 1.9013682e-01],
        [6.0336647e-06, 1.2067329e-05, 1.8100995e-05],
        [2.9538720e-04, 5.9077441e-04, 8.8616158e-04]],

       [[9.3662101e-01, 3.7464840e+00, 0.0000000e+00],
        [1.9640156e+00, 7.8560624e+00, 0.0000000e+00],
        [1.7610737e+00, 7.0442948e+00, 0.0000000e+00]],

       [[9.3662101e-01, 2.8098631e+00, 1.4049315e+00],
        [3.5972200e-02, 1.0791660e-01, 5.3958301e-02],
        [2.3833540e-01, 7.1500623e-01, 3.5750312e-01]]], dtype=float32)>

In [19]:
outputs = tf.reduce_sum(weighted_values, axis=0)
outputs

<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[1.936621 , 6.683105 , 1.5950683],
       [1.9999939, 7.963991 , 0.0539764],
       [1.9997045, 7.759892 , 0.3583893]], dtype=float32)>