## Transformers from scratch mini-project
I was curious how transformers worked internally so I wrote one in tensorflow based on [this explanation](https://peterbloem.nl/blog/transformers). In order to check correctness of my implementation I ran this NLP pipeline from [Tensorflow](https://github.com/tensorflow/text/blob/master/docs/tutorials/transformer.ipynb) on my local GPU, and substituted my MultiHeadAttention layer for the one in the tutorial.

In [1]:
import tensorflow as tf
from tensorflow import keras

tf.keras.utils.set_random_seed(42)  # sets seeds for base-python, numpy and tf
tf.config.experimental.enable_op_determinism()

In [22]:
x = tf.random.uniform((1,9,4))
x

<tf.Tensor: shape=(1, 9, 4), dtype=float32, numpy=
array([[[0.803156  , 0.49777734, 0.37054038, 0.9118674 ],
        [0.637642  , 0.18209696, 0.63791955, 0.27701473],
        [0.04227114, 0.84219384, 0.90637195, 0.222556  ],
        [0.9198462 , 0.68789077, 0.42705178, 0.878158  ],
        [0.6943959 , 0.46567595, 0.52925766, 0.33019018],
        [0.12754858, 0.16153514, 0.5085137 , 0.44301772],
        [0.35205877, 0.8969147 , 0.24940813, 0.76328313],
        [0.85935795, 0.08480155, 0.20418596, 0.28848922],
        [0.65142167, 0.7106751 , 0.8695041 , 0.23745108]]], dtype=float32)>

In [23]:
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.heads = num_heads
        self.dim = d_model
        assert self.dim % self.heads == 0
        self.depth = self.dim // self.heads
      
    def build(self, batch_input_shape):
        dim = batch_input_shape[-1]
        self.Wq = self.add_weight(name="Wq", shape=[dim, dim], initializer="glorot_normal")
        self.Wk = self.add_weight(name="Wk", shape=[dim, dim], initializer="glorot_normal")
        self.Wv = self.add_weight(name="Wv", shape=[dim, dim], initializer="glorot_normal")
        self.Wo = self.add_weight(name="Wo", shape=[dim, dim], initializer="glorot_normal")

    def split_into_heads(self, M, batch_size): 
        M = tf.reshape(M, (batch_size, -1, self.heads, self.depth))
        return tf.transpose(M, perm=[0, 2, 1, 3])

    def call(self, v, k, q, mask=None):
        batch = tf.shape(q)[0]
        seq_len = tf.shape(q)[-1]

        q = tf.transpose(q, perm=[0, 2, 1])
        k = tf.transpose(k, perm=[0, 2, 1]) 
        v = tf.transpose(v, perm=[0, 2, 1]) 

        Q = self.Wq @ q # (dim, dim) * (batch, dim, seq_len)
        Q = tf.transpose(Q, perm=[0, 2, 1]) #transpose to (batch, seq_len, dim)
        Q = self.split_into_heads(Q, batch) #split then transpose -> (batch, heads, seq_len, dim // heads)
        
        K = self.Wk @ k
        K = tf.transpose(K, perm=[0, 2, 1])
        K = self.split_into_heads(K, batch)
        
        V = self.Wv @ v
        V = tf.transpose(V, perm=[0, 2, 1])
        V = self.split_into_heads(V, batch)
        
        d = tf.cast(q.shape[1], tf.float32)

        scaled_attention_logits = tf.matmul(Q, K, transpose_b=True) / tf.sqrt(d)
        
        if mask is not None:
            scaled_attention_logits += (mask * -1e9)
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
        unrolled = attention_weights @ V

        unrolled = tf.transpose(unrolled, [0, 2, 1, 3])
        concatenated = tf.reshape(unrolled, [batch, -1, self.dim])
        return concatenated @ self.Wo, attention_weights

In [25]:
m1 = MultiHeadAttention(4, 2)
m1(x,x,x)[0]

<tf.Tensor: shape=(1, 9, 4), dtype=float32, numpy=
array([[[0.35597828, 0.0591852 , 0.06663257, 0.3346224 ],
        [0.36085227, 0.03652559, 0.07470929, 0.34728718],
        [0.36287993, 0.03878205, 0.07710293, 0.34061083],
        [0.35561445, 0.06234346, 0.06531179, 0.33334184],
        [0.35947052, 0.04296065, 0.07254462, 0.34330994],
        [0.36129224, 0.03525817, 0.07694381, 0.34424722],
        [0.35690615, 0.0572939 , 0.06944344, 0.3314358 ],
        [0.35885248, 0.03900723, 0.07306662, 0.34727803],
        [0.36076185, 0.04316544, 0.0729384 , 0.34268153]]], dtype=float32)>