# Aiayn - Attention is all you Need
![alt text](images/aiayn/aiayn.png "Architecture of AIAYN")

## Positional Encoding
Since our model contains no recurrence and no convolution, in order for the model to make use of the
order of the sequence, we must inject some information about the relative or absolute position of the
tokens in the sequence. To this end, we add "positional encodings" to the input embeddings at the
bottoms of the encoder and decoder stacks. The positional encodings have the same dimension dmodel
as the embeddings, so that the two can be summed. There are many choices of positional encodings,
learned and fixed [8].
In this work, we use sine and cosine functions of different frequencies:
P E(pos,2i) = sin(pos/100002i/dmodel)
P E(pos,2i+1) = cos(pos/100002i/dmodel)

where pos is the position and i is the dimension. That is, each dimension of the positional encoding
corresponds to a sinusoid. The wavelengths form a geometric progression from 2π to 10000 · 2π. We
chose this function because we hypothesized it would allow the model to easily learn to attend by
relative positions, since for any fixed offset k, P Epos+k can be represented as a linear function of
P Epos.

In [23]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn

def positional_encoding(length: int, depth: int) -> torch.Tensor:
    depth = depth // 2
    positions = torch.arange(length, dtype=torch.float32).unsqueeze(1)  # (seq, 1)
    depths = torch.arange(depth, dtype=torch.float32).unsqueeze(0) / depth  # (1, depth)

    angle_rates = 1 / (10000 ** depths)  # (1, depth)
    angle_rads = positions * angle_rates  # (pos, depth)

    pos_encoding = torch.cat((torch.sin(angle_rads), torch.cos(angle_rads)), dim=-1)  # (pos, depth*2)
    return pos_encoding

class PositionalEmbedding(nn.Module):
    def __init__(self, vocab_size: int, d_model: int):
        super(PositionalEmbedding, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_encoding = positional_encoding(length=2048, depth=d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        length = x.size(1)
        x = self.embedding(x)
        x *= torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        x = x + self.pos_encoding[:length, :].unsqueeze(0).to(x.device)
        return x
    
    def compute_mask(self, *args, **kwargs):
        # Implement this if needed
        # TODO: Implement this
        pass

class CommonAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, **kwargs):
        super(CommonAttention, self).__init__()
        self.mha = nn.MultiheadAttention(embed_dim, num_heads, **kwargs)
        self.layernorm = nn.LayerNorm(embed_dim)
        self.add = nn.Identity()

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> torch.Tensor:
        attn_output, _ = self.mha(query, key, value, **kwargs)
        output = self.add(query + attn_output)
        output = self.layernorm(output)
        return output
    
class CrossAttention(CommonAttention):
    def __init__(self, embed_dim: int, num_heads: int, **kwargs):
        super(CrossAttention, self).__init__(embed_dim, num_heads, **kwargs)
        self.last_attn_scores = None

    def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
        attn_output, attn_scores = self.mha(query=x, key=context, value=context, need_weights=True)
        self.last_attn_scores = attn_scores
        x = x + attn_output
        x = self.layernorm(x)
        return x
    
class GlobalSelfAttention(CommonAttention):
    def __init__(self, embed_dim: int, num_heads: int, **kwargs):
        super(GlobalSelfAttention, self).__init__(embed_dim, num_heads, **kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        attn_output, _ = self.mha(query=x, key=x, value=x)
        x = x + attn_output
        x = self.layernorm(x)
        return x
    
class CausalSelfAttention(CommonAttention):
    def __init__(self, embed_dim: int, num_heads: int, **kwargs):
        self.num_heads = num_heads
        super(CausalSelfAttention, self).__init__(embed_dim, num_heads, **kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = x.size()
        attn_mask = torch.tril(torch.ones((seq_len, seq_len), device=x.device)).unsqueeze(0).unsqueeze(0)
        attn_mask = attn_mask.expand(batch_size * self.num_heads, seq_len, seq_len)
        attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.0))
        attn_output, _ = self.mha(query=x, key=x, value=x, attn_mask=attn_mask)
        x = x + attn_output
        x = self.layernorm(x)
        return x

class FeedForward(nn.Module):
    def __init__(self, d_model: int, dff: int, dropout_rate: float = 0.1):
        super(FeedForward, self).__init__()
        self.seq = nn.Sequential(
            nn.Linear(d_model, dff),
            nn.ReLU(),
            nn.Linear(dff, d_model),
            nn.Dropout(dropout_rate)
        )
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.seq(x)
        x = self.layer_norm(x)
        return x

class EncoderLayer(nn.Module):
    def __init__(self, *, d_model: int, num_heads: int, dff: int, dropout_rate: float = 0.1):
        super(EncoderLayer, self).__init__()
        self.self_attention = GlobalSelfAttention(embed_dim=d_model, num_heads=num_heads, dropout=dropout_rate)
        self.ffn = FeedForward(d_model, dff, dropout_rate)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.self_attention(x)
        x = self.ffn(x)
        return x
    
class Encoder(nn.Module):
    def __init__(self, *, num_layers: int, d_model: int, num_heads: int,
                 dff: int, vocab_size: int, dropout_rate: float = 0.1):
        super(Encoder, self).__init__()
        self.d_model = d_model
        self.num_layers = num_layers
        self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size, d_model=d_model)
        self.enc_layers = nn.ModuleList([
            EncoderLayer(d_model=d_model, num_heads=num_heads, dff=dff, dropout_rate=dropout_rate)
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pos_embedding(x)
        x = self.dropout(x)
        for i in range(self.num_layers):
            x = self.enc_layers[i](x)
        return x
    
class DecoderLayer(nn.Module):
    def __init__(self, *, d_model: int, num_heads: int, dff: int, dropout_rate: float = 0.1):
        super(DecoderLayer, self).__init__()
        self.causal_self_attention = CausalSelfAttention(embed_dim=d_model, num_heads=num_heads, dropout=dropout_rate)
        self.cross_attention = CrossAttention(embed_dim=d_model, num_heads=num_heads, dropout=dropout_rate)
        self.ffn = FeedForward(d_model, dff, dropout_rate)

    def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
        x = self.causal_self_attention(x)
        x = self.cross_attention(x, context)
        self.last_attn_scores = self.cross_attention.last_attn_scores
        x = self.ffn(x)
        return x
    
class Decoder(nn.Module):
    def __init__(self, *, num_layers: int, d_model: int, num_heads: int, dff: int, vocab_size: int, dropout_rate: float = 0.1):
        super(Decoder, self).__init__()
        self.d_model = d_model
        self.num_layers = num_layers
        self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size, d_model=d_model)
        self.dropout = nn.Dropout(dropout_rate)
        self.dec_layers = nn.ModuleList([
            DecoderLayer(d_model=d_model, num_heads=num_heads, dff=dff, dropout_rate=dropout_rate)
            for _ in range(num_layers)
        ])
        self.last_attn_scores = None

    def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
        x = self.pos_embedding(x)
        x = self.dropout(x)
        for i in range(self.num_layers):
            x = self.dec_layers[i](x, context)
        self.last_attn_scores = self.dec_layers[-1].last_attn_scores
        return x
    
class Transformer(nn.Module):
    def __init__(self, *, num_layers: int, d_model: int, num_heads: int, dff: int,
                 input_vocab_size: int, target_vocab_size: int, dropout_rate: float = 0.1):
        super(Transformer, self).__init__()
        self.encoder = Encoder(num_layers=num_layers, d_model=d_model, num_heads=num_heads,
                               dff=dff, vocab_size=input_vocab_size, dropout_rate=dropout_rate)
        self.decoder = Decoder(num_layers=num_layers, d_model=d_model, num_heads=num_heads,
                               dff=dff, vocab_size=target_vocab_size, dropout_rate=dropout_rate)
        self.final_layer = nn.Linear(d_model, target_vocab_size)

    def forward(self, inputs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        context, x = inputs
        context = self.encoder(context)
        x = self.decoder(x, context)
        logits = self.final_layer(x)
        return logits

In [24]:
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer

# Load data
data = pd.read_csv('spa.csv', delimiter='\t', header=None, names=['input', 'output', 'metadata'])

# Tokenizer (using BERT tokenizer for simplicity)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize data
data['input_ids'] = data['input'].apply(lambda x: tokenizer.encode(x, add_special_tokens=True))
data['output_ids'] = data['output'].apply(lambda x: tokenizer.encode(x, add_special_tokens=True))

class TranslationDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        input_ids = torch.tensor(self.df.iloc[idx]['input_ids'])
        output_ids = torch.tensor(self.df.iloc[idx]['output_ids'])
        return input_ids, output_ids

dataset = TranslationDataset(data)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=lambda x: x)

In [25]:
import torch.optim as optim

model = Transformer(
    num_layers=2,
    d_model=512,
    num_heads=8,
    dff=2048,
    input_vocab_size=len(tokenizer.vocab),
    target_vocab_size=len(tokenizer.vocab),
    dropout_rate=0.1
)

criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [26]:
def train_step(batch, model, criterion, optimizer):
    model.train()
    inputs, targets = zip(*batch)
    inputs = nn.utils.rnn.pad_sequence(inputs, batch_first=True, padding_value=tokenizer.pad_token_id)
    targets = nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=tokenizer.pad_token_id)
    
    inputs, targets = inputs.to(device), targets.to(device)
    
    optimizer.zero_grad()
    output = model((inputs, targets[:, :-1]))  # inputs and target sequences shifted for teacher forcing
    loss = criterion(output.view(-1, output.size(-1)), targets[:, 1:].contiguous().view(-1))
    
    loss.backward()
    optimizer.step()
    
    return loss.item()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(10):  # 10 epochs for example
    total_loss = 0
    for batch in dataloader:
        loss = train_step(batch, model, criterion, optimizer)
        total_loss += loss
    print(f'Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}')

RuntimeError: expand(torch.cuda.FloatTensor{[1, 1, 29, 29]}, size=[256, 29, 29]): the number of sizes provided (3) must be greater or equal to the number of dimensions in the tensor (4)