### Transformers are complex. They have several parts that need to be implemented. Lets go step by step

In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import math
import torchtext

### This first part will be creating our embeddings. Remember again the lesson on data representation. We need to create these embeddings for the transformer to encode.

### We need to turn the words into numbers. We will never forget data representation. Input will be words or word fragments, also known as "tokens". Each token goes through and activation function and multiplies it by a weight. We use the same embedding network for each phrase.

In [None]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embed_dims):
        super(Embedding, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_dims)
    def forward(self, x):
        out = self.embed(x)
        return out

### Positional Encoding tells the transformer about the meaning and position of words in the input. We use a series of sine and cosine values. We use the y-values on the sine and cosine curves to find the corrosponding x-axis coordinate for each word.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dims, max_seq_len=512):
        super(PositionalEncoding, self).__init__()
        
        # Compute positional encodings for up to max_seq_len positions
        position = torch.arange(0, max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dims, 2) * -(math.log(10000.0) / embed_dims))
        pos_enc = torch.zeros(max_seq_len, embed_dims)
        pos_enc[:, 0::2] = torch.sin(position * div_term)
        pos_enc[:, 1::2] = torch.cos(position * div_term)
        pos_enc = pos_enc.unsqueeze(0)
        
        # Register the positional encodings as a buffer so they can be
        # moved to the same device as the model's parameters
        self.register_buffer('pos_enc', pos_enc)

    def forward(self, x):
        # Add the positional encodings to the input embeddings
        x = x + self.pos_enc[:, :x.size(1)]
        return x

### Self Attention works by comparing word similarity for every word in the sentence. In the transformer, we calculate the similarity between the query and keys. Larger similarities indicate stronger similarity. We then use a softmax function to determine what percentage of each word should be used to encode the query.

In [None]:
class MultiheadAttention(nn.Module):
    def __init__(self, embed_dims, nhead, dropout=0.1):
        super(MultiheadAttention, self).__init__()
        
        assert embed_dims % nhead == 0, "embed_dims must be divisible by nhead"
        self.nhead = nhead
        self.head_dim = embed_dims // nhead
        self.embed_dims = embed_dims
        
        # Linear transformations for query, key, and value for each head
        self.q_linear = nn.Linear(embed_dims, embed_dims)
        self.k_linear = nn.Linear(embed_dims, embed_dims)
        self.v_linear = nn.Linear(embed_dims, embed_dims)
        
        # Final linear transformation after concatenating the heads
        self.out_linear = nn.Linear(embed_dims, embed_dims)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # Linearly transform query, key, and value for each head
        query = self.q_linear(query).view(batch_size, -1, self.nhead, self.head_dim)
        key = self.k_linear(key).view(batch_size, -1, self.nhead, self.head_dim)
        value = self.v_linear(value).view(batch_size, -1, self.nhead, self.head_dim)
        
        # Transpose to make dimensions compatible for batch-wise matrix multiplication
        query = query.permute(0, 2, 1, 3)
        key = key.permute(0, 2, 1, 3)
        value = value.permute(0, 2, 1, 3)
        
        # Compute scaled dot-product attention for each head
        scale_factor = torch.sqrt(torch.tensor(self.head_dim, dtype=query.dtype))
        scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / scale_factor
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attention_weights = torch.nn.functional.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention to values
        context = torch.matmul(attention_weights, value)
        
        # Reshape and concatenate the outputs from different heads
        context = context.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.embed_dims)
        
        # Apply a final linear layer
        output = self.out_linear(context)
        
        return output, attention_weights