# Attention mechanisms and positional encoding

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

## Hands-on positional encoding

In [6]:
class PositionalEncoder(nn.Module):
  def __init__(self, d_model, max_seq_length=512):
    super(PositionalEncoder, self).__init__()
    self.d_model = d_model
    self.max_seq_length = max_seq_length
    pe = torch.zeros(max_seq_length, d_model)
    position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2, dtype = torch.float) * -(math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0)
    self.register_buffer('pe', pe)
    
  def forward(self, x):
    x = x + self.pe[:, :x.size(1)]
    return x

## Implementing multi-headed self-attention

In [7]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model
    self.head_dim = d_model // num_heads
    self.query_linear = nn.Linear(d_model, d_model)
    self.key_linear = nn.Linear(d_model, d_model)
    self.value_linear = nn.Linear(d_model, d_model)
    self.output_linear = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
      x = x.view(batch_size, -1, self.num_heads, self.head_dim)
      return x.permute(0, 2, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.head_dim)

    def compute_attention(self, query, key, mask=None):
      scores = torch.matmul(query, key.permute(1, 2, 0))
      if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-1e9"))
      attention_weights = F.softmax(scores, dim=-1)
      return attention_weights
    
    def forward(self, query, key, value, mask=None):
      batch_size = query.size(0)
      query = self.split_heads(self.query_linear(query), batch_size)
      key = self.split_heads(self.key_linear(key), batch_size)
      value = self.split_heads(self.value_linear(value), batch_size)
      attention_weights = self.compute_attention(query, key, mask)
      output = torch.matmul(attention_weights, value)
      output = output.view(batch_size, self.num_heads, -1, self.head_dim).permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.d_model)
      return self.output_linear(output)