## Attention/Transformer Implementation
1. Core Implementation:
   - Implement QKV attention mechanism
   - Create basic transformer architecture
   - Handle attention masking
   - Implement attention-free transformers and explain what they do 

2. Follow-up Questions:
   - How would you optimize memory usage?
   - What are the tradeoffs in different attention mechanisms?
   - How would you handle different sequence lengths?

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        # ensure d_model is divisible by num_heads
        assert d_model % num_heads == 0

        self.d_model = d_model  # embedding dimension
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # linear layers for q,k,v projections and output
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

        # scaling factor for dot product attention
        self.scale = math.sqrt(self.head_dim)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)

        # project and reshape q,k,v to split into heads
        q = self.q_proj(q).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(k).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(v).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale

        # apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # attention weights
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        # compute weighted values
        out = torch.matmul(attn, v)

        # reshape and project output
        out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        out = self.out_proj(out)

        return out, attn

def scaled_dot_product_attention(q, k, v, mask=None):
    # compute attention scores
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # compute attention weights and weighted sum
    attn = F.softmax(scores, dim=-1)
    return torch.matmul(attn, v), attn

In [None]:
class AttentionFreeTransformer(nn.Module):
    def __init__(self, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        # main model dimensions
        self.d_model = d_model
        self.nhead = nhead

        # linear projections instead of attention
        self.input_proj = nn.Linear(d_model, d_model)
        self.pos_proj = nn.Linear(d_model, d_model)

        # feed forward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_feedforward, d_model)
        )

        # layer norm and dropout
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

        # positional encoding
        self.pos_encoding = self._create_positional_encoding()

    def _create_positional_encoding(self, max_len=5000):
        # create sinusoidal positional encodings
        pe = torch.zeros(max_len, self.d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2) * (-math.log(10000.0) / self.d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)

    def forward(self, x):
        # x shape: (batch_size, seq_len, d_model)
        b, l, d = x.shape

        # add positional encoding
        pos = self.pos_encoding[:, :l, :].to(x.device)
        x = x + pos

        # first sublayer: linear projection instead of attention
        residual = x
        x = self.input_proj(x)
        x = self.pos_proj(x)
        x = self.dropout(x)
        x = self.norm1(x + residual)

        # second sublayer: feedforward network
        residual = x
        x = self.ffn(x)
        x = self.norm2(x + residual)

        return x

class SimpleAttentionFreeTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=512):
        super().__init__()
        # embedding layer converts tokens to vectors
        self.embedding = nn.Embedding(vocab_size, d_model)

        # main transformer layer
        self.transformer = AttentionFreeTransformer(d_model=d_model)

        # output projection
        self.out_proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        # convert tokens to embeddings
        x = self.embedding(x)

        # pass through transformer
        x = self.transformer(x)

        # project to vocabulary size
        x = self.out_proj(x)

        return x