# Attention Is All You Need

https://arxiv.org/pdf/1706.03762.pdf

## 1. Overview of the architecture

<div>
<img src="https://github.com/Clemthefou445/E-BART/blob/main/transformer.png" width="350"/>
</div>

The transformer follows a **sequence-to-sequence architecture with attention**.<br />
It is composed of an **encoder** block, also called a Transformer block (left) and a **decoder** block (right).<br /><br />
(All Large Language Models (LLMs) use these Transformer encoder or/and decoder blocks for pre-training.
Once we got a Pre-trained LLM, we will fine-tune it for a specific task.)
<br /><br />
- **Multi-Head Attention** :
    - Takes 3 same inputs: **Values, Keys & Queries**
- **Skip connections** (Residual connections)

- The decoder block is composed of a Transformer block + a prior **Masked Multi-Head Attention**. The masking is used to enforce the decoder to learn a meaningful mapping instead of a simple one-to-one mapping.
- Both the encoder and decoder can be repeated n times (Nx)
- The transformer network is permutationally invariant. If the order of the words is changed in an input sentence, the output will be exactly the same. That is why we apply a **positional encoding** before the encoder step.
- The transformer does all its operations in parallel, in contrast to other sequence models likes RNNs, GLUs or LSTMs.



## 2. Attention Mechanism

- Embedding input is n dimensional
- Split it in h parts (all n/h dimensional) $\rightarrow$ h heads (multi-headed)
<br />
- The **Scaled Dot-Product Attention** (SDPA) follows the equation:

$$Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$$

- Dividing by $\sqrt{d_k}$ is for numerical stability
- Concatenating allows us to have the same dimension again as the embedding input

<div>
<img src="attention_mechanism.png" width="550"/>
</div>

## 3. Technical implementation

In [1]:
import torch
import torch.nn as nn

In [25]:

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.queries = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, values, keys, query, mask):
        # Get number of training examples
        N = query.shape[0]

        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values = self.values(values)  # (N, value_len, embed_size)
        keys = self.keys(keys)  # (N, key_len, embed_size)
        queries = self.queries(query)  # (N, query_len, embed_size)

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        # Einsum does matrix mult. for query*keys for each training example
        # with every other training example, don't be confused by einsum
        # it's just how I like doing matrix multiplication & bmm

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim),
        # keys shape: (N, key_len, heads, heads_dim)
        # energy: (N, heads, query_len, key_len)

        # Mask padded indices so their weights become 0
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # Normalize energy values similarly to seq2seq + attention
        # so that they sum to 1. Also divide by scaling factor for
        # better stability
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        # attention shape: (N, heads, query_len, key_len)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # out after matrix multiply: (N, query_len, heads, head_dim), then
        # we reshape and flatten the last two dimensions.

        out = self.fc_out(out)
        # Linear layer doesn't modify the shape, final shape will be
        # (N, query_len, embed_size)

        return out


In [26]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        # Add skip connection, run through normalization and finally dropout
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out


In [27]:
class Encoder(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        embed_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length,
    ):

        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion,
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        out = self.dropout(
            (self.word_embedding(x) + self.position_embedding(positions))
        )

        # In the Encoder the query, key, value are all the same, it's in the
        # decoder this will change. This might look a bit odd in this case.
        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out



In [28]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.norm = nn.LayerNorm(embed_size)
        self.attention = SelfAttention(embed_size, heads=heads)
        self.transformer_block = TransformerBlock(
            embed_size, heads, dropout, forward_expansion
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, query, src_mask)
        return out


In [29]:
class Decoder(nn.Module):
    def __init__(
        self,
        trg_vocab_size,
        embed_size,
        num_layers,
        heads,
        forward_expansion,
        dropout,
        device,
        max_length,
    ):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
                for _ in range(num_layers)
            ]
        )
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)

        return out


In [30]:
class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        trg_pad_idx,
        embed_size=512,
        num_layers=6,
        forward_expansion=4,
        heads=8,
        dropout=0,
        device="cpu",
        max_length=100,
    ):

        super(Transformer, self).__init__()

        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length,
        )

        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length,
        )

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # (N, 1, 1, src_len)
        return src_mask.to(self.device)

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )

        return trg_mask.to(self.device)

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out


In [33]:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device :",device)

    x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(
        device
    )
    trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

    src_pad_idx = 0
    trg_pad_idx = 0
    src_vocab_size = 10
    trg_vocab_size = 10
    model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(
        device
    )
    out = model(x, trg[:, :-1])
    print(out.shape)

Device : cpu
torch.Size([2, 7, 10])


In [36]:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device :",device)
    
    src_pad_idx = 0
    trg_pad_idx = 0
    src_vocab_size = 10
    trg_vocab_size = 10

    model = Transformer(src_vocab_size, trg_vocab_size, src_vocab_size, trg_vocab_size).to(device)
    
    # Generate random sample data
    max_seq_length = 100
    
    src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
    trg_data = torch.randint(1, trg_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

Device : cpu


In [38]:
import torch.optim as optim

In [43]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

model.train()

for epoch in range(100):
    optimizer.zero_grad()
    output = model(src_data, trg_data[:, :-1])
    loss = criterion(output.contiguous().view(-1, trg_vocab_size), trg_data[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

Epoch: 1, Loss: 2.40474009513855
Epoch: 2, Loss: 3.303234100341797
Epoch: 3, Loss: 2.8076043128967285
Epoch: 4, Loss: 2.489821672439575
Epoch: 5, Loss: 2.4566433429718018
Epoch: 6, Loss: 2.456268310546875
Epoch: 7, Loss: 2.2751410007476807
Epoch: 8, Loss: 2.2004480361938477
Epoch: 9, Loss: 2.2295851707458496
Epoch: 10, Loss: 2.2655255794525146
Epoch: 11, Loss: 2.284621477127075
Epoch: 12, Loss: 2.2826709747314453
Epoch: 13, Loss: 2.2617950439453125
Epoch: 14, Loss: 2.2333261966705322
Epoch: 15, Loss: 2.209059476852417
Epoch: 16, Loss: 2.196061134338379
Epoch: 17, Loss: 2.1966922283172607
Epoch: 18, Loss: 2.208064556121826
Epoch: 19, Loss: 2.2212185859680176
Epoch: 20, Loss: 2.226254940032959
Epoch: 21, Loss: 2.2208449840545654
Epoch: 22, Loss: 2.2102911472320557
Epoch: 23, Loss: 2.200763702392578
Epoch: 24, Loss: 2.195277214050293
Epoch: 25, Loss: 2.194127082824707
Epoch: 26, Loss: 2.1963071823120117
Epoch: 27, Loss: 2.200092315673828
Epoch: 28, Loss: 2.20339298248291
Epoch: 29, Loss: 

In [None]:
  3 2. 2 2. 2 2. 2. 2.2 2.2 2. 2.2 2.2 2.2 2. 2. 2.1 2. 2.2 2. 2.2 2.2 2. 2. 2. 2.1 2. 2.20339298248291
E2.2045345306396484
 2.2030699253082275
 2.1998441219329834
 2.1962740421295166
 2.193582534790039
 2.1924350261688232
 2.1928277015686035
 2.1941514015197754
 2.19547438621521
 2.196024179458618
 2.195533514022827
 2.1942436695098877
 2.192678451538086
 2.1913933753967285
 2.190768241882324
 2.190857172012329
 2.191364288330078
 2.191812753677368
 2.1918089389801025
 2.1912412643432617
 2.1903069019317627
 2.1893622875213623
 2.188716411590576
 2.1884663105010986
 2.188469886779785
 2.1884515285491943
 2.1881794929504395
 2.187596559524536
 2.1868247985839844
 2.186051845550537
 2.1854019165039062
 2.1848814487457275
Epoch: 61, Loss: 2.1844048500061035
Epoch: 62, Loss: 2.1838538646698
Epoch: 63, Loss: 2.183134078979492
Epoch: 64, Loss: 2.182218074798584
Epoch: 65, Loss: 2.181161880493164
Epoch: 66, Loss: 2.1800527572631836
Epoch: 67, Loss: 2.178938150405884
Epoch: 68, Loss: 2.1777894496917725
Epoch: 69, Loss: 2.176527738571167
Epoch: 70, Loss: 2.1750638484954834
Epoch: 71, Loss: 2.1733386516571045
Epoch: 72, Loss: 2.171377658843994
Epoch: 73, Loss: 2.169236183166504
Epoch: 74, Loss: 2.1669399738311768
Epoch: 75, Loss: 2.1644351482391357
Epoch: 76, Loss: 2.1616265773773193
Epoch: 77, Loss: 2.158419370651245
Epoch: 78, Loss: 2.1547579765319824
Epoch: 79, Loss: 2.1505491733551025
Epoch: 80, Loss: 2.1456782817840576
Epoch: 81, Loss: 2.1398048400878906
Epoch: 82, Loss: 2.1325161457061768
Epoch: 83, Loss: 2.12361741065979
Epoch: 84, Loss: 2.113696575164795
Epoch: 85, Loss: 2.103590250015259
Epoch: 86, Loss: 2.094468593597412
Epoch: 87, Loss: 2.1004445552825928
Epoch: 88, Loss: 2.2671828269958496
Epoch: 89, Loss: 2.194218873977661
Epoch: 90, Loss: 2.2461740970611572
Epoch: 91, Loss: 2.152465343475342
Epoch: 92, Loss: 2.1159706115722656
Epoch: 93, Loss: 2.1534335613250732
Epoch: 94, Loss: 2.1752917766571045
Epoch: 95, Loss: 2.144052743911743
Epoch: 96, Loss: 2.1259284019470215
Epoch: 97, Loss: 2.1351141929626465
Epoch: 98, Loss: 2.1453800201416016
Epoch: 99, Loss: 2.139561176300049
Epoch: 100, Loss: 2.120004415512085