# Attention mechanisms and positional encoding

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

## Hands-on positional encoding

In [None]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_seq_length=512):
    super(PositionalEncoding, self).__init__()
    self.d_model = d_model
    self.max_seq_length = max_seq_length
    pe = torch.zeros(max_seq_length, d_model)
    position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2, dtype = torch.float) * -(math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0)
    self.register_buffer('pe', pe)
    
  def forward(self, x):
    x = x + self.pe[:, :x.size(1)]
    return x

## Implementing multi-headed self-attention

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model
    self.head_dim = d_model // num_heads
    self.query_linear = nn.Linear(d_model, d_model)
    self.key_linear = nn.Linear(d_model, d_model)
    self.value_linear = nn.Linear(d_model, d_model)
    self.output_linear = nn.Linear(d_model, d_model)

  def split_heads(self, x, batch_size):
    x = x.view(batch_size, -1, self.num_heads, self.head_dim)
    return x.permute(0, 2, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.head_dim)

  def compute_attention(self, query, key, mask=None):
    scores = torch.matmul(query, key.permute(1, 2, 0))
    if mask is not None:
      scores = scores.masked_fill(mask == 0, float("-1e9"))
    attention_weights = F.softmax(scores, dim=-1)
    return attention_weights
  
  def forward(self, query, key, value, mask=None):
    batch_size = query.size(0)
    query = self.split_heads(self.query_linear(query), batch_size)
    key = self.split_heads(self.key_linear(key), batch_size)
    value = self.split_heads(self.value_linear(value), batch_size)
    attention_weights = self.compute_attention(query, key, mask)
    output = torch.matmul(attention_weights, value)
    output = output.view(batch_size, self.num_heads, -1, self.head_dim).permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.d_model)
    return self.output_linear(output)

# Building an encoder transformer

## Post-attention feed-forward layer

In [None]:
class FeedForwardSubLayer(nn.Module):
  def __init__(self, d_model, d_ff):
    super(FeedForwardSubLayer, self).__init__()
    self.fc1 = nn.Linear(d_model, d_ff)
    self.fc2 = nn.Linear(d_ff, d_model)
    self.relu = nn.ReLU()
  
  def forward(self, x):
    return self.fc2(self.relu(self.fc1(x)))

## Time for an encoder layer

In [None]:
class EncoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout):
    super(EncoderLayer, self).__init__()
    self.self_attn = MultiHeadAttention(d_model, num_heads)
    self.feed_forward = FeedForwardSubLayer(d_model, d_ff)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)
  
  def forward(self, x, mask):
    attn_output = self.self_attn(x, x, x, mask)
    x = self.norm1(x + self.dropout(attn_output))
    ff_output = self.feed_forward(x)
    x = self.norm2(x + self.dropout(ff_output))
    return x

## Encoder transformer body and head

In [None]:
class TransformerEncoder(nn.Module):
  def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length):
    super(TransformerEncoder, self).__init__()
    self.embedding = nn.Embedding(vocab_size, d_model)
    self.positional_encoding = PositionalEncoding(d_model, max_sequence_length)
    # Define a stack of multiple encoder layers
    self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

  # Complete the forward pass method
  def forward(self, x, mask):
    x = self.embedding(x)
    x = self.positional_encoding(x)
    for layer in self.layers:
      x = layer(x, mask)
    return x

class ClassifierHead(nn.Module):
  def __init__(self, d_model, num_classes):
    super(ClassifierHead, self).__init__()
    # Add linear layer for multiple-class classification
    self.fc = nn.Linear(d_model, num_classes)

  def forward(self, x):
    logits = self.fc(x[:, 0, :])
    # Obtain log class probabilities upon raw outputs
    return F.log_softmax(logits, dim=-1)

## Testing the encoder transformer

In [None]:
num_classes = 3
vocab_size = 10000
batch_size = 8
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
sequence_length = 256
dropout = 0.1

input_sequence = torch.randint(0, vocab_size, (batch_size, sequence_length))
mask = torch.randint(0, 2, (sequence_length, sequence_length))

# Instantiate the encoder transformer's body and head
encoder = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length=sequence_length)
classifier = ClassifierHead(d_model, num_classes)

# Complete the forward pass 
output = encoder(input_sequence, mask)
classification = classifier(output)
print("Classification outputs for a batch of ", batch_size, "sequences:")
print(classification)

# Building a encoder transformer

## Building a decoder body and head

In [None]:
class TransformerDecoder(nn.Module):
  def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length):
    super(TransformerDecoder, self).__init__()
    self.embedding = nn.Embedding(vocab_size, d_model)
    self.positional_encoding = PositionalEncoding(d_model, max_sequence_length)
    self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

    # Add a linear layer (head) for next-word prediction
    self.fc = nn.Linear(d_model, vocab_size)

  def forward(self, x, self_mask):
    x = self.embedding(x)
    x = self.positional_encoding(x)
    for layer in self.layers:
        x = layer(x, self_mask)

    # Apply the forward pass through the model head
    x = self.fc(x)
    return F.log_softmax(x, dim=-1)

## Testing the decoder transformer

In [None]:
input_sequence = torch.randint(0, vocab_size, (batch_size, sequence_length))

# Create a triangular attention mask for causal attention
self_attention_mask = (1 - torch.triu(torch.ones(1, sequence_length, sequence_length), diagonal=1)).bool()  # Upper triangular mask

# Instantiate the decoder transformer
decoder = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length=sequence_length)

output = decoder(input_sequence, self_attention_mask)
print(output.shape)
print(output)

# Building a encoder-decoder transformer

## Incorporating cross-attention in an encoder

In [None]:
class DecoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout):
    super(DecoderLayer, self).__init__()
    
    # Initialize the causal (masked) self-attention and cross-attention
    self.self_attn = MultiHeadAttention(d_model, num_heads)
    self.cross_attn = MultiHeadAttention(d_model, num_heads)
    self.feed_forward = FeedForwardSubLayer(d_model, d_ff)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.norm3 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, causal_mask, encoder_output, cross_mask):
    # Pass the necessary arguments to the causal self-attention and cross-attention
    self_attn_output = self.self_attn(x, x, x, causal_mask)
    x = self.norm1(x + self.dropout(self_attn_output))
    cross_attn_output = self.cross_attn(x, encoder_output, encoder_output, cross_mask)
    x = self.norm2(x + self.dropout(cross_attn_output))
    ff_output = self.feed_forward(x)
    x = self.norm3(x + self.dropout(ff_output))
    return x

## Trying out an encoder-decoder transformer

In [None]:
# Create a batch of random input sequences
input_sequence = torch.randint(0, vocab_size, (batch_size, sequence_length))
padding_mask = torch.randint(0, 2, (sequence_length, sequence_length))
causal_mask = torch.triu(torch.ones(sequence_length, sequence_length), diagonal=1)

# Instantiate the two transformer bodies
encoder = TransformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length=sequence_length)
decoder = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length=sequence_length)

# Pass the necessary masks as arguments to the encoder and the decoder
encoder_output = encoder(input_sequence, padding_mask)
decoder_output = decoder(input_sequence, causal_mask, encoder_output, padding_mask)
print("Batch's output shape: ", decoder_output.shape)

## Transformer assembly bottom-up

In [None]:
# Initialize positional encoding layer and stack of EncoderLayer modules
class TransformerEncoder(nn.Module):
  def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len, dropout):
    super(TransformerEncoder, self).__init__()
    self.embedding = nn.Embedding(vocab_size, d_model)
    self.positional_encoding = PositionalEncoding(d_model, max_seq_len)
    self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask):
    x = self.embedding(x)
    x = self.positional_encoding(x)
    x = self.dropout(x)
    
    # Pass the sequence through each layer in the encoder
    for layer in self.layers:
      x = layer(x, mask)
    
    return x

class Transformer(nn.Module):
  def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len, dropout):
    super(Transformer, self).__init__()
    # Initialize the encoder stack of the Transformer
    self.encoder = TransformerEncoder(vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len, dropout)
      
  def forward(self, src, src_mask):
    encoder_output = self.encoder(src, src_mask)
    return encoder_output