<a href="https://colab.research.google.com/github/mhamedLmarbouh/Attention-is-all-you-need/blob/main/Transformers_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Attention Is all you need**

A simplistic implementation of Attention is all you need with readability as it's main goal

https://arxiv.org/pdf/1706.03762.pdf

In [2]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torchtext import data, datasets

import numpy as  np
import random
import tqdm

SEED = 47

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
device

device(type='cuda')

In [26]:
# Readable implementation but less efficient 
class ScaledDotProduct(nn.Module):
    def __init__(self, embedding_size, attention_dropout=0.2):
        super(ScaledDotProduct, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.k = embedding_size

    def forward(self, keys, queries, mask=None):
        
        attention = torch.matmul(queries, keys.transpose(2,3)) # (batch, seq_len, seq_len)
        
        attention = attention / math.sqrt(self.k)

        if mask is not None:
            # we fill the masked tokens with -inf so when the softmax is applied
            # the probability will be 0
            attention = torch.masked_fill(mask == 0, float('-inf'))

        attention = self.dropout(attention)
        attention = F.softmax(attention, dim=2) # row wise self-attention probabilities
        return attention

**A less readable but more efficient implementation of** ```ScaledDotProduct```



In [None]:
class BMMScaledDotProduct(nn.Module):
    def __init__(self, embed_dim, attention_dropout=0.2):
        super(BMMScaledDotProduct, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.embed_dim = embed_dim

    def forward(self, keys, queries, mask=None):
        batch, h, seq_len, embed_dim = keys.size()
        keys_ = keys.view(batch*h, seq_len, embed_dim) 
        queries_ = queries.view(batch*h, seq_len, embed_dim) 
        attention = torch.bmm(queries_, keys_.transpose(1,2))
        attention.resize_(batch, h, seq_len, embed_dim)
                
        attention = attention / math.sqrt(self.embed_dim)

        if mask is not None:
            # we fill the masked tokens with -inf so when the softmax is applied
            # the probability will be 0
            attention = torch.masked_fill(mask == 0, float('-inf'))

        attention = self.dropout(attention)
        attention = F.softmax(attention, dim=2) # row wise self-attention probabilities
        return attention

* Speed comparaison between the two implementation of `ScaledDotProduct`

In [63]:
attention = ScaledDotProduct(128)
bmm_attention = BMMScaledDotProduct(128)
#batch, h, seq_len, embed_dim
queries = torch.randn((32, 8, 512, 128))
keys = torch.randn((32, 8, 512, 128))

In [34]:
%%timeit
res = attention(keys, queries)

1 loop, best of 3: 2.24 s per loop


In [64]:
%%timeit
res = bmm_attention(keys, queries)

1 loop, best of 3: 538 ms per loop


torch.Size([256, 512, 128])

1. embed_dim: Embedding dimention/Hidden size usually refered to as **k** </br>

2. num_heads: number of attention heads defines the number of times to 
calculate the set (Key, Query, values) kinda like the number of filters in a CNN. In practice we accelerate the calculation, by doing it in one matrix multiplication

In [None]:
class SelfAttention(nn.Module):

    def __init__(self, embed_dim, heads, attention_dropout):
        super(SelfAttention, self).__init__()
        self.embed_dim = embed_dim # refered to as k in the paper
        self.num_heads = heads # generally refered to as h

        # Linear layer weights (embed_dim, embed_dim*num_heads)
        self.calculate_h_keys = nn.Linear(self.embed_dim, 
                                          self.embed_dim * self.num_heads,
                                          bias=False)
        self.calculate_h_queries = nn.Linear(self.embed_dim, 
                                             self.embed_dim * self.num_heads,
                                             bias=False)
        self.calculate_h_values = nn.Linear(self.embed_dim, 
                                            self.embed_dim * self.num_heads,
                                            bias=False)
        
        self.scaled_dot_producte = ScaledDotProduct(embed_dim, attention_dropout=attention_dropout)

        self.reduce_h_dimension = nn.Linear(self.num_heads*self.embed_dim, 
                                            self.embed_dim, bias=False)

    def forward(self, X):
        batch, seq_len, embed_dim = X.size()

        assert embed_dim == self.embed_dim, f"Input embedding dim ({embed_dim}) should match layer embedding dim ({self.embed_dim})"


        # We calculating h keys for each value in the sequence 
        # the same goes for the queries and values
        keys = self.calculate_h_keys(X) # shape (batch, seq_len, embed_dim*h)
        queries = self.calculate_h_queries(X) # shape (batch, seq_len, embed_dim*h)
        values = self.calculate_h_values(X) # shape (batch, seq_len, embed_dim*h)

        # We separate the embed_dim*h:  (batch, seq_len, embed_dim*h) => (batch, seq_len, h, embed_dim)
        # We transpose axis 1 and 2: (batch, seq_len, h, embed_dim) => (batch, h, seq_len, embed_dim)
        keys = keys.view(batch, seq_len, self.num_heads, self.embed_dim).transpose(1,2)
        queries = queries.view(batch, seq_len, self.num_heads, self.embed_dim).transpose(1,2)
        values = values.view(batch, seq_len, self.num_heads, self.embed_dim).transpose(1,2)

        attention = self.scaled_dot_producte(keys=keys, queries=queries) # (batch, h, seq_len, seq_len)

        # apply the self-attention to the values
        out = torch.matmul(attention, values) # (batch, h, seq_len, embed_dim)
        out = out.transpose(1,2).contiguous().view(batch, seq_len, self.num_heads*self.embed_dim) # (batch, seq_len, embed_dim*h)

        return self.reduce_h_dimension(out) # (batch, seq_len, embed_dim)

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, heads):
        super(TransformerBlock, self).__init__()

        self.embed_dim = embed_dim
        self.self_attention = SelfAttention(embed_dim=self.embed_dim, heads=heads)

        self.norm1 = nn.LayerNorm(normalized_shape=self.embed_dim)
        self.norm2 = nn.LayerNorm(normalized_shape=self.embed_dim)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(self.embed_dim, self.embed_dim*4),
            nn.ReLU(),
            nn.Linear(4*self.embed_dim, self.embed_dim)
        )
    
    def forward(self, X):
        X_with_attention = self.self_attention(X)
        X = self.norm1(X_with_attention+X)

        X_prime = self.feed_forward(X)
        X = self.norm2(X_prime+X)

        return X

In [None]:
class Transformer(nn.Module):
 
    def __init__(self, seq_len, embed_dim, heads, depth, num_tokens, num_classes):
        super(Transformer, self).__init__()
        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.heads = heads
        self.depth = depth
        self.num_tokens = num_tokens
        self.num_classes = num_classes

        self.token_embed = nn.Embedding(self.num_tokens, self.embed_dim)
        self.pos_embed = nn.Embedding(self.seq_len, self.embed_dim)

        transformers_blocks = []
        for i in range(self.depth):
            transformers_blocks.append(TransformerBlock(self.embed_dim, self.heads))
        self.apply_transform = nn.Sequential(*transformers_blocks)

        self.to_classes = nn.Linear(self.embed_dim, self.num_classes)

        self.do = nn.Dropout()
    def forward(self, X):

        tokens = self.token_embed(X)
        batch, seq_len, _ = tokens.size()
        assert seq_len == self.seq_len, f"sequence length expected to be {self.seq_len} but got {seq_len}"
        position = self.pos_embed(torch.arange(seq_len, device=device))
        position = position.expand(batch, seq_len, self.embed_dim)

        X = tokens + position
        X = self.do(X)
        X = self.apply_transform(X)
        X = torch.max(X, 1).values

        X = self.to_classes(X)

        return F.log_softmax(X, dim=1)

# Test the implementation

In [None]:
VOCAB_SIZE = 10_000
BATCH_SIZE = 256
SEQ_SIZE = 128


EMBED_SIZE = 128

HEADS = 3
DEPTH = 4
NUM_CLS = 2 # positive/negative

EPOCHS = 12
LOGS_FOLDER = './logs'

In [None]:
TEXT = data.ReversibleField(tokenize="spacy", batch_first=True, 
                  lower=True, include_lengths=True, 
                  fix_length=MAX_LEN)
LABEL = data.LabelField(sequential=False)


In [None]:
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

In [None]:
train_data, val_data = train_data.split(split_ratio=0.8)

In [None]:
TEXT.build_vocab(train_data, max_size=VOCAB_SIZE - 2)
LABEL.build_vocab(train_data)

In [None]:
train_iterator, val_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, val_data, test_data), 
    batch_size = BATCH_SIZE, 
    device = device)

In [None]:
torch.cuda.empty_cache()
model = Transformer(seq_len=SEQ_SIZE, embed_dim=EMBED_SIZE, heads=HEADS, depth=DEPTH, num_tokens=VOCAB_SIZE, num_classes=NUM_CLS)
model.cuda()

In [None]:
opt = torch.optim.Adam(params=model.parameters())

In [None]:
tbw = SummaryWriter(LOGS_FOLDER)

seen = 0
for epoch in range(EPOCHS):  
    model.train(True)  
    for batch in tqdm.tqdm_notebook(train_iterator):
        opt.zero_grad()

        X = batch.text[0]
        Y = batch.label

        if X.size(1) > MAX_LEN:
            X = X[:, :MAX_LEN]
        Y_hat = model(X)
        loss = F.nll_loss(Y_hat, Y)
        
        loss.backward()
        opt.step()
        seen += int(X.size(0))
        tbw.add_scalar('classification/train-loss', float(loss.item()), seen)
    
    with torch.no_grad():
        total = 0.0
        correcte = 0.0
        model.train(False)
        for batch in tqdm.tqdm_notebook(val_iterator):
            X = batch.text[0]
            Y = batch.label
            if X.size(1) > MAX_LEN:
                X = X[:, :MAX_LEN]

            Y_hat = model(X)
            val_loss = F.nll_loss(Y_hat, Y)
            Y_hat = Y_hat.argmax(dim=1)

            total += float(X.size(0))
            correcte += float((Y == Y_hat).sum().item())

        acc = correcte / total
        print(f'-- "{epoch}: validation"  accuracy {acc:.3}')
        tbw.add_scalar('classification/val-loss', float(val_loss.item()), epoch)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  if __name__ == '__main__':


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))


-- "0: validation"  accuracy 0.611


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))


-- "1: validation"  accuracy 0.676


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))


-- "2: validation"  accuracy 0.722


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))


-- "3: validation"  accuracy 0.726


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))


-- "4: validation"  accuracy 0.755


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))


-- "5: validation"  accuracy 0.749


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))


-- "6: validation"  accuracy 0.776


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))


-- "7: validation"  accuracy 0.782


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))


-- "8: validation"  accuracy 0.788


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))


-- "9: validation"  accuracy 0.8


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))


-- "10: validation"  accuracy 0.793


HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))


-- "11: validation"  accuracy 0.811


In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir LOGS_FOLDER

In [65]:
def reverse(encoded_sentence, TEXT):
    sent = [TEXT.vocab.itos[idx] for idx in encoded_sentence]
    return ' '.join(sent)

In [None]:
model.train(False)

In [None]:
total = 0
correcte = 0
model.train(False)
for batch in tqdm.tqdm_notebook(test_iterator):
    X = batch.text[0]
    Y = batch.label
    if X.size(1) > MAX_LEN:
        X = X[:, :MAX_LEN]

    Y_hat = model(X)
    Y_hat = Y_hat.argmax(dim=1)

    total += float(X.size(0))
    correcte += float((Y == Y_hat).sum().item())

acc = correcte / total