# Aiayn - Attention is all you Need
![alt text](images/aiayn/aiayn.png "Architecture of AIAYN")

## Positional Encoding
Since our model contains no recurrence and no convolution, in order for the model to make use of the
order of the sequence, we must inject some information about the relative or absolute position of the
tokens in the sequence. To this end, we add "positional encodings" to the input embeddings at the
bottoms of the encoder and decoder stacks. The positional encodings have the same dimension dmodel
as the embeddings, so that the two can be summed. There are many choices of positional encodings,
learned and fixed [8].
In this work, we use sine and cosine functions of different frequencies:
P E(pos,2i) = sin(pos/100002i/dmodel)
P E(pos,2i+1) = cos(pos/100002i/dmodel)

where pos is the position and i is the dimension. That is, each dimension of the positional encoding
corresponds to a sinusoid. The wavelengths form a geometric progression from 2π to 10000 · 2π. We
chose this function because we hypothesized it would allow the model to easily learn to attend by
relative positions, since for any fixed offset k, P Epos+k can be represented as a linear function of
P Epos.

In [3]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn

def positional_encoding(length: int, depth: int) -> torch.Tensor:
    depth = depth // 2

    positions = torch.arange(length, dtype=torch.float32).unsqueeze(1)  # (seq, 1)
    depths = torch.arange(depth, dtype=torch.float32).unsqueeze(0) / depth  # (1, depth)

    angle_rates = 1 / (10000 ** depths)  # (1, depth)
    angle_rads = positions * angle_rates  # (pos, depth)

    pos_encoding = torch.cat((torch.sin(angle_rads), torch.cos(angle_rads)), dim=-1)  # (pos, depth*2)

    return pos_encoding

pos_encoding = positional_encoding(length=2048, depth=512)

# TODO: Tricky look at the code again
class PositionalEmbedding(nn.Module):
    def __init__(self, vocab_size: int, d_model: int):
        super(PositionalEmbedding, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_encoding = positional_encoding(length=2048, depth=d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        length = x.size(1)
        x = self.embedding(x)
        # This factor sets the relative scale of the embedding and positional encoding.
        x *= torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        x = x + self.pos_encoding[:length, :].unsqueeze(0).to(x.device)
        return x
    
    def compute_mask(self, *args, **kwargs):
        # TODO: Implement this when needed
        pass

class CommonAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, **kwargs):
        super(CommonAttention, self).__init__()
        self.mha = nn.MultiheadAttention(embed_dim, num_heads, **kwargs)
        self.layernorm = nn.LayerNorm(embed_dim)
        self.add = nn.Identity()  # Add is done through simple addition in PyTorch

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> torch.Tensor:
        attn_output, _ = self.mha(query, key, value, **kwargs)
        output = self.add(query, attn_output)
        output = self.layernorm(output)
        return output
    
class CrossAttention(CommonAttention):
    def __init__(self, embed_dim: int, num_heads: int, **kwargs):
        super(CrossAttention, self).__init__(embed_dim, num_heads, **kwargs)
        self.last_attn_scores = None

    def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
        attn_output, attn_scores = self.mha(
            query=x,
            key=context,
            value=context,
            need_weights=True
        )

        # Cache the attention scores for plotting later.
        self.last_attn_scores = attn_scores

        x = x + attn_output  # Addition operation
        x = self.layernorm(x)

        return x
    
class GlobalSelfAttention(CommonAttention):
    def __init__(self, embed_dim: int, num_heads: int, **kwargs):
        super(GlobalSelfAttention, self).__init__(embed_dim, num_heads, **kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        attn_output, _ = self.mha(
            query=x,
            key=x,
            value=x
        )
        x = x + attn_output  # Addition operation
        x = self.layernorm(x)
        return x
    
class CausalSelfAttention(CommonAttention):
    def __init__(self, embed_dim: int, num_heads: int, **kwargs):
        super(CausalSelfAttention, self).__init__(embed_dim, num_heads, **kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Create a causal mask, not sure if this is correct
        seq_len = x.size(1)
        attn_mask = torch.tril(torch.ones((seq_len, seq_len), device=x.device)).unsqueeze(0).unsqueeze(0)
        attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.0))

        attn_output, _ = self.mha(
            query=x,
            key=x,
            value=x,
            attn_mask=attn_mask
        )
        x = x + attn_output  # Addition operation
        x = self.layernorm(x)
        return x
    