## Attention Is All You Need

Welcome!

In this project, we will replicate the famous paper - [Attention Is All You Need](https://arxiv.org/pdf/1706.03762), from where the world of LLMs started.

In [1]:
import torch ## torch let's us create tensors and also provides helper functions
import torch.nn as nn ## torch.nn gives us nn.module() and nn.Linear()
import torch.nn.functional as F # This gives us the softmax()

### Encoder

First, we need the attention class

This is the step which takes the query, key and value matrices, and calculates the scaled dot product attention (SDPA).

$$Attention(Q, K, V) = Softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Note that, the KV cache technique is used at the time of inference for converting the inference time complexity from quadratic to linear, at the cost of added space complexity. It has nothing to do with the model's architecture.

In [2]:
class Attention(nn.Module):

    def __init__(self, d_k, d_v, d_model=2,
                 row_dim=0,
                 col_dim=1):
        """
        Calculates the attention scores for the given query, key and value vectors

        input:
           d_k: dimension of query and key vectors (its important that they have same dimension due to the compatibility of the dot product)
           d_v: dimension of value vector
           d_model: dimension of input vectors of model
           row_dim: axis for rows
           col_dim: axis for columns
        """

        super().__init__()

        self.W_q = nn.Linear(in_features=d_model, out_features=d_k, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_k, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_v, bias=False)

        self.row_dim = row_dim
        self.col_dim = col_dim


    # The only change from SelfAttention and attention is that
    # we expect 3 sets of encodings to be passed in...
    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        # ...and we pass those sets of encodings to the various weight matrices.
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)

        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

We also need the Multi-head attention

At one layer, we can use multiple attention heads in parallel. This encourages the model to learn different and independent attention scores, leading to more rich context info.

In [3]:
class MultiHeadAttention(nn.Module):

    def __init__(self, d_k, d_v,
                 d_model=2,
                 row_dim=0,
                 col_dim=1,
                 num_heads=1):
        """
        Calculates multiple attention heads for the given query, key and value vectors

        input:
           d_k: dimension of query and key vectors (its important that they have same dimension due to the compatibility of the dot product)
           d_v: dimension of value vector
           d_model: dimension of input and output vectors of model
        """

        super().__init__()

        ## create a bunch of attention heads
        self.heads = nn.ModuleList(
            [Attention(d_k, d_v, d_model, row_dim, col_dim)
             for _ in range(num_heads)]
        )

        # We want to make sure the output has dimension d_model
        self.out = nn.Linear(in_features=num_heads*d_v,
                             out_features=d_model,
                             bias=False)

        self.col_dim = col_dim

    def forward(self,
                encodings_for_q,
                encodings_for_k,
                encodings_for_v,
                mask=None):

        ## run the data through all of the attention heads
        return self.out(torch.cat([head(encodings_for_q,
                               encodings_for_k,
                               encodings_for_v,
                               mask)
                          for head in self.heads], dim=self.col_dim))

Encoder

This is where we combine the above blocks and create the encoder layer.

We can have multiple encoder layers connected in series. Fo this, we will first make an encoder layer and then repeat that several times to make an encoder.

In [4]:
class EncoderLayer(nn.Module):
  def __init__(self, d_k=64, d_v=64, d_model=512, row_dim=0, col_dim=1, num_heads=8, p_drop=0.1):

    """
    This creates a single encoder layer, with multi-head attention and feed forward layers.
    """

    super().__init__()

    self.multihead = MultiHeadAttention(d_k, d_v, d_model, row_dim, col_dim, num_heads)
    self.norm1 = nn.LayerNorm(d_model)
    self.ff1 = nn.Linear(in_features = d_model, out_features = d_model, bias = True)
    self.relu = nn.ReLU()
    self.ff2 = nn.Linear(in_features=d_model, out_features = d_model, bias = False)
    self.norm2 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(p=p_drop)

  def forward(self, encodings, mask = None):
    x = self.multihead(encodings, encodings, encodings, mask)
    x = self.norm1(x + encodings)
    x = self.dropout(x)

    y = self.ff1(x)
    y = self.relu(y)
    y = self.ff2(y)
    y = self.norm2(y + x)
    y = self.dropout(y)

    return y

In [5]:
class Encoder(nn.Module):
  def __init__(self, d_k=64, d_v=64, d_model=512, row_dim=0, col_dim=1, num_heads=8, p_drop=0.1, N = 6):
    """
    This repeats the encoder layer N times.
    """
    super().__init__()

    self.layers = nn.ModuleList([EncoderLayer(d_k, d_v, d_model, row_dim, col_dim, num_heads, p_drop)
                                 for _ in range(N)])
    self.norm = nn.LayerNorm(d_model)

  def forward(self, encodings, mask = None):
    for layer in self.layers:
      encodings = layer(encodings, mask)

    return self.norm(encodings)

Decoder

Now the decoder. This is same as Encoder, but an additional 'masked' multi-head attention and a feed-forward is added to it.

In [6]:
class DecoderLayer(nn.Module):
  def __init__(self, d_k=64, d_v=64, d_model=512, row_dim=0, col_dim=1, num_heads=8, p_drop=0.1):

    super().__init__()

    self.masked_multihead = MultiHeadAttention(d_k, d_v, d_model, row_dim, col_dim, num_heads)
    self.norm1 = nn.LayerNorm(d_model)
    self.multihead = MultiHeadAttention(d_k, d_v, d_model, row_dim, col_dim, num_heads)
    self.norm2 = nn.LayerNorm(d_model)
    self.ff1 = nn.Linear(in_features = d_model, out_features=d_model, bias = True)
    self.relu = nn.ReLU()
    self.ff2 = nn.Linear(in_features=d_model, out_features=d_model)
    self.norm3 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(p=p_drop)
    self.row_dim = row_dim
    self.col_dim = col_dim
    self.num_heads = num_heads

  def forward(self, encoder_encodings, decoder_encodings, mask = None):
    x = self.masked_multihead(decoder_encodings, decoder_encodings, decoder_encodings, mask)
    x = self.norm1(x + decoder_encodings)
    x = self.dropout(x)

    y = self.multihead(x, encoder_encodings, encoder_encodings, None)
    y = self.norm2(y + x)
    y = self.dropout(y)

    z = self.ff1(y)
    z = self.relu(z)
    z = self.ff2(z)
    z = self.norm3(z + y)
    z = self.dropout(z)

    return z

In [7]:
class Decoder(nn.Module):
  def __init__(self, d_k=64, d_v=64, d_model=512, row_dim=0, col_dim=1, num_heads=8, p_drop=0.1, N = 6):

    super().__init__()

    self.layers = nn.ModuleList([DecoderLayer(d_k, d_v, d_model, row_dim, col_dim, num_heads, p_drop)
                                 for _ in range(N)])
    self.norm = nn.LayerNorm(d_model)

  def forward(self, encoder_encodings, decoder_encodings, mask = None):
    for layer in self.layers:
      decoder_encodings = layer(encoder_encodings, decoder_encodings, mask)

    return self.norm(decoder_encodings)

That's the complete code for Encoder-Decoder architecture for attention. The only thing left is the embedding layer and the positional encoding. We can add the sinusoidal vectors to input vectors. For the input embedding, that can be generated using a feed-forward network from one-hot encoded vocabulary to a vector of size d_model.

Let's test the above framework

In [8]:
my_encoder = Encoder()
my_decoder = Decoder()

In [9]:
x = torch.randn((25, 512))
x.shape

torch.Size([25, 512])

In [10]:
y = my_encoder(x)

In [11]:
y.shape

torch.Size([25, 512])

In [12]:
w = torch.randn((7, 512))
w.shape

torch.Size([7, 512])

In [13]:
# prepare the mask
mask = torch.tril(torch.ones(7, 7))
mask = mask == 0
mask

tensor([[False,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True],
        [False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False]])

In [14]:
z = my_decoder(y, w, mask)

In [15]:
z.shape

torch.Size([7, 512])

In [16]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [17]:
from torchinfo import summary

In [18]:
summary(my_encoder)

Layer (type:depth-idx)                        Param #
Encoder                                       --
├─ModuleList: 1-1                             --
│    └─EncoderLayer: 2-1                      --
│    │    └─MultiHeadAttention: 3-1           1,048,576
│    │    └─LayerNorm: 3-2                    1,024
│    │    └─Linear: 3-3                       262,656
│    │    └─ReLU: 3-4                         --
│    │    └─Linear: 3-5                       262,144
│    │    └─LayerNorm: 3-6                    1,024
│    │    └─Dropout: 3-7                      --
│    └─EncoderLayer: 2-2                      --
│    │    └─MultiHeadAttention: 3-8           1,048,576
│    │    └─LayerNorm: 3-9                    1,024
│    │    └─Linear: 3-10                      262,656
│    │    └─ReLU: 3-11                        --
│    │    └─Linear: 3-12                      262,144
│    │    └─LayerNorm: 3-13                   1,024
│    │    └─Dropout: 3-14                     --
│    └─EncoderLaye

In [19]:
summary(my_decoder)

Layer (type:depth-idx)                        Param #
Decoder                                       --
├─ModuleList: 1-1                             --
│    └─DecoderLayer: 2-1                      --
│    │    └─MultiHeadAttention: 3-1           1,048,576
│    │    └─LayerNorm: 3-2                    1,024
│    │    └─MultiHeadAttention: 3-3           1,048,576
│    │    └─LayerNorm: 3-4                    1,024
│    │    └─Linear: 3-5                       262,656
│    │    └─ReLU: 3-6                         --
│    │    └─Linear: 3-7                       262,656
│    │    └─LayerNorm: 3-8                    1,024
│    │    └─Dropout: 3-9                      --
│    └─DecoderLayer: 2-2                      --
│    │    └─MultiHeadAttention: 3-10          1,048,576
│    │    └─LayerNorm: 3-11                   1,024
│    │    └─MultiHeadAttention: 3-12          1,048,576
│    │    └─LayerNorm: 3-13                   1,024
│    │    └─Linear: 3-14                      262,656
│    │

Positional Encodings

In [20]:
def pos_encodings(token_len, d_model):
  pos_enc = torch.zeros((token_len, d_model), dtype = torch.float32)
  for pos in range(token_len):
    for i in range(d_model):
      if i % 2 == 0:
        pos_enc[pos, i] = torch.sin(torch.tensor(pos / (10000 ** (i / d_model))))
      else:
        pos_enc[pos, i] = torch.cos(torch.tensor(pos / (10000 ** ((i - 1) / d_model))))

  return pos_enc

In [21]:
x = torch.randn((24, 512))
x.shape

torch.Size([24, 512])

In [22]:
y = x + pos_encodings(24, 512)
y.shape

torch.Size([24, 512])

In [23]:
from tqdm import tqdm
from time import time

In [24]:
mask = torch.tril(torch.ones(10, 10))
mask = mask == 0

def testing(d_k=64, d_v=64, d_model=512, row_dim=0, col_dim=1, num_heads=8, p_drop=0.1, N = 6):
  start = time()
  my_encoder = Encoder(d_k, d_v, d_model, row_dim, col_dim, num_heads, p_drop, N)
  my_decoder = Decoder(d_k, d_v, d_model, row_dim, col_dim, num_heads, p_drop, N)

  x = torch.randn((5, d_model))
  x = x + pos_encodings(5, d_model)

  y = my_encoder(x)
  w = torch.randn((10, d_model))
  w = w + pos_encodings(10, d_model)

  z = my_decoder(y, w, mask)

  return time() - start

In [25]:
testing()

1.3299205303192139

In [26]:
d_k = [3, 4, 12]
d_v = [5, 10, 13]
d_model = [8, 9, 17]
num_heads = [1, 2, 6, 11]
N = [1, 5, 12]

params = {'d_k': d_k,
          'd_v': d_v,
          'd_model': d_model,
          'row_dim': [0],
          'col_dim': [1],
          'num_heads': num_heads,
          'N': N}

In [27]:
len(d_k) * len(d_v) * len(d_model) * len(num_heads) * len(N)

324

In [28]:
from itertools import product

keys   = list(params.keys())
values = [params[k] for k in keys]

total_time = 0

for combo in tqdm(product(*values), desc = 'Testing architecture for different values'):

    arg_dict = dict(zip(keys, combo))

    result = testing(**arg_dict)
    total_time += result

print(f"Total time -> {total_time} seconds")

Testing architecture for different values: 324it [00:24, 13.20it/s]

Total time -> 23.192572832107544 seconds



