# Attention is all you need

Paper: [Attention Is All You Need. Vaswani et al 2017](https://arxiv.org/abs/1706.03762)

In [1]:
import torch
import torch.nn as nn

## Attention

$$ 
Attention (Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V 
$$

### Scaled Dot-Product Attention

<img src="assets/scaled_dotptoduct_attention.png" width="400" height="400">

In [41]:
batch_size = 1
num_heads = 2
seq_leng_q = 10
d_k = 10

query = torch.rand(batch_size, num_heads, seq_leng_q, d_k)
key = torch.rand(batch_size, num_heads, seq_leng_q, d_k)
value = torch.rand(batch_size, num_heads, seq_leng_q, d_k)

print(f"Query shape: {query.shape}")
print(f"Key shape: {key.shape}")
print(f"Value shape{value.shape}")

print(f"Key transposed shape {key.transpose(-2, -1).shape}")

scores = torch.matmul(query, key.transpose(-2, -1)) / d_k ** 0.5
print(f"QK^t shape: {scores.shape}")

attention_weights = nn.Softmax(dim=-1)(scores)
print(f"Attention weights shape: {attention_weights.shape}")

Query shape: torch.Size([1, 2, 10, 10])
Key shape: torch.Size([1, 2, 10, 10])
Value shapetorch.Size([1, 2, 10, 10])
Key transposed shape torch.Size([1, 2, 10, 10])
QK^t shape: torch.Size([1, 2, 10, 10])
Attention weights shape: torch.Size([1, 2, 10, 10])


In [42]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k, dropout=0.1):
        super(ScaledDotProductAttention, self).__init__()
        self.d_k = d_k
        self.dropout = nn.Dropout(dropout)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: Tensor of shape (batch_size, num_heads, seq_len_q, d_k)
            key: Tensor of shape (batch_size, num_heads, seq_len_k, d_k)
            value: Tensor of shape (batch_size, num_heads, seq_len_v, d_v) 
                   Typically, seq_len_k = seq_len_v
            mask: Tensor of shape (batch_size, 1, 1, seq_len_k)
        Returns:
            output: Attention values of shape (batch_size, num_heads, seq_len_q, d_v)
            attention_weights: Tensor of shape (batch_size, num_heads, seq_len_q, seq_len_k)
        """

        # Compute the dot products
        scores = torch.matmul(query, key.transpose(-2, -1)) / (self.d_k ** 0.5)

        # Apply the mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # Compute the attention weights
        attention_weights = self.softmax(scores)
        attention_weights = self.dropout(attention_weights)

        # Compute the output
        output = torch.matmul(attention_weights, value)

        return output, attention_weights

scaled_dotproduct_attention = ScaledDotProductAttention(d_k)
output, attention_weights = scaled_dotproduct_attention.forward(query, key, value)

print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")

Output shape: torch.Size([1, 2, 10, 10])
Attention weights shape: torch.Size([1, 2, 10, 10])


### 3.2.2 Multi-Head Attention

<img src="assets/multi_head_attention.png" width="300" height="400">

Instead of performing a single attention function with dmodel-dimensional keys, values and queries, we found it beneficial to linearly project the queries, keys and values h times with different, learned linear projections to dk, dk and dv dimensions, respectively. On each of these projected versions of queries, keys and values we then perform the attention function in parallel, yielding dv -dimensional


In [44]:
class MultiHeadAttention(nn.Module):
    """
    Multi-head attention mechanism as described in the "Attention Is All You Need" paper.
    
    Args:
        d_model (int): The dimension of the input and output of the multi-head attention layer. 
                       It should be a multiple of num_heads.
        num_heads (int): The number of parallel attention layers, or "heads".
        dropout (float, optional): Dropout rate for the attention weights. Default is 0.1.
    
    Attributes:
        d_model (int): The dimension of the input and output.
        num_heads (int): Number of attention heads.
        d_k (int): Dimension of the key, query, and value for each head.
        W_q (nn.Linear): Linear transformation for the query.
        W_k (nn.Linear): Linear transformation for the key.
        W_v (nn.Linear): Linear transformation for the value.
        W_o (nn.Linear): Linear transformation for the output.
        attention (ScaledDotProductAttention): The scaled dot product attention mechanism.
        dropout (nn.Dropout): Dropout layer for the attention weights.
    """
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0 # d_model must be divisible by num_heads

        self.d_model = d_model          # d_model is the dimension of the model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads # d_k is the dimension of the keys and values

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.attention = ScaledDotProductAttention(self.d_k, dropout)
        self.dropout = nn.Dropout(dropout)

    def split_heads(self, x, batch_size):
        """
        Split the last dimension of tensor x into (num_heads, d_k).
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
            batch_size (int): Batch size.
            
        Returns:
            torch.Tensor: Reshaped tensor of shape (batch_size, num_heads, seq_len, d_k).
        """
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.permute(0, 2, 1, 3)

    def forward(self, query, key, value, mask=None):
        """
        Forward pass for the multi-head attention mechanism.
        
        Args:
            query (torch.Tensor): Query tensor of shape (batch_size, seq_len_q, d_model).
            key (torch.Tensor): Key tensor of shape (batch_size, seq_len_k, d_model).
            value (torch.Tensor): Value tensor of shape (batch_size, seq_len_v, d_model).
            mask (torch.Tensor, optional): Mask tensor of shape (batch_size, 1, 1, seq_len_k).
            
        Returns:
            torch.Tensor: Output tensor of shape (batch_size, seq_len_q, d_model).
            torch.Tensor: Attention weights tensor of shape (batch_size, num_heads, seq_len_q, seq_len_k).
        """
        batch_size = query.size(0)

        # Linear transformations
        query = self.W_q(query)
        key = self.W_k(key)
        value = self.W_v(value)

        # Split into multiple heads
        query = self.split_heads(query, batch_size)
        key = self.split_heads(key, batch_size)
        value = self.split_heads(value, batch_size)

        # Scale the dot product attention
        output, attention_weights = self.attention.forward(query, key, value, mask)

        # Concatenate heads and transform
        output = output.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(output)

        return output, attention_weights

multihead_attention = MultiHeadAttention(d_model=10, num_heads=num_heads)
output, attention_weights = multihead_attention.forward(query, key, value)

print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attention_weights.shape}")

Output shape: torch.Size([1, 20, 10])
Attention weights shape: torch.Size([1, 2, 20, 20])


### Encoder

<img src="assets/encoder.png" width="250" height="400">


In [45]:
class EncoderLayer(nn.Module):
    """
    Represents a single encoder layer in the Transformer architecture.
    
    Each layer consists of:
    1. Multi-head self-attention mechanism.
    2. Position-wise feed-forward network.
    
    Additionally, each of these components is surrounded by a residual connection 
    followed by layer normalization.
    
    Args:
        d_model (int): The dimension of the input and output of the encoder layer.
        num_heads (int): Number of attention heads for the multi-head attention mechanism.
        d_ff (int): Dimension of the feed-forward network's hidden layer.
        dropout (float, optional): Dropout rate for the attention weights and feed-forward network. Default is 0.1.
    
    Attributes:
        multihead_attention (MultiHeadAttention): The multi-head attention mechanism.
        feed_forward (nn.Sequential): The position-wise feed-forward network.
        norm1 (nn.LayerNorm): Layer normalization for the attention mechanism's output.
        norm2 (nn.LayerNorm): Layer normalization for the feed-forward network's output.
        dropout (nn.Dropout): Dropout layer.
    """
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()

        self.multihead_attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Forward pass for the encoder layer.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
            mask (torch.Tensor, optional): Mask tensor of shape (batch_size, 1, 1, seq_len).
            
        Returns:
            torch.Tensor: Output tensor of shape (batch_size, seq_len, d_model).
        """

        # Multi-head attention
        attention_output, _ = self.multihead_attention.forward(x, x, x, mask)
        x = self.norm1(x + self.dropout(attention_output))

        # Feed-forward network
        feed_forward_output = self.feed_forward.forward(x)
        return self.norm2(x + self.dropout(feed_forward_output))

In [46]:
class TransformerEncoder(nn.Module):
    """
    Represents the encoder component of the Transformer architecture.
    
    The encoder consists of a stack of identical layers, where each layer has 
    a multi-head self-attention mechanism and a position-wise feed-forward network.
    
    Args:
        d_model (int): The dimension of the input and output of the encoder.
        num_heads (int): Number of attention heads for the multi-head attention mechanism.
        d_ff (int): Dimension of the feed-forward network's hidden layer.
        num_layers (int): Number of identical layers in the encoder.
        dropout (float, optional): Dropout rate for the attention weights and feed-forward network. Default is 0.1.
    
    Attributes:
        layers (nn.ModuleList): List of encoder layers.
    """
    def __init__(self, d_model, num_heads, d_ff, num_layers, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])

    def forward(self, x, mask=None):
        """
        Forward pass for the Transformer encoder.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
            mask (torch.Tensor, optional): Mask tensor of shape (batch_size, 1, 1, seq_len).
            
        Returns:
            torch.Tensor: Output tensor of shape (batch_size, seq_len, d_model).
        """
        for layer in self.layers:
            x = layer(x, mask)
        return x

In [59]:
batch_size = 1
num_heads = 1
seq_leng_q = 10    # This is the length of the query (q) sequence.
d_k = 10           # This is the dimension of the keys (k), queries (q), and values (v) in the scaled dot-product attention mechanism.
d_model = 10       # This is the dimension of the embeddings and the input and output size of the Transformer's encoder and decoder layers.
d_ff = 10          # This is the dimension of the hidden layer in the feed-forward network.

# The relationship between them is often d_k = d_model / num_heads
assert d_model % num_heads == 0 # d_model must be divisible by num_heads

query = torch.rand(batch_size, num_heads, seq_leng_q, d_k)
key = torch.rand(batch_size, num_heads, seq_leng_q, d_k)
value = torch.rand(batch_size, num_heads, seq_leng_q, d_k)

print(f"Query shape: {query.shape}")
print(f"Key shape: {key.shape}")
print(f"Value shape{value.shape}")

encoder = TransformerEncoder(d_model=d_model, num_heads=num_heads, d_ff=d_ff, num_layers=1)
output = encoder.forward(query)

print(f"Output shape: {output.shape}")

Query shape: torch.Size([1, 1, 10, 10])
Key shape: torch.Size([1, 1, 10, 10])
Value shapetorch.Size([1, 1, 10, 10])
Output shape: torch.Size([1, 1, 10, 10])


### Example of a classifier using transformers

In [69]:
class TransformerClassifier(nn.Module):
    """
    Transformer-based classifier.
    
    Args:
        d_model (int): The dimension of the input and output of the Transformer encoder.
        num_heads (int): Number of attention heads for the multi-head attention mechanism.
        d_ff (int): Dimension of the feed-forward network's hidden layer.
        num_layers (int): Number of identical layers in the encoder.
        num_classes (int): Number of target classes for classification.
        dropout (float, optional): Dropout rate for the attention weights and feed-forward network. Default is 0.1.
    
    Attributes:
        encoder (TransformerEncoder): The Transformer encoder.
        classifier (nn.Linear): Linear layer for classification.
    """
    def __init__(self, d_model, num_heads, d_ff, num_layers, num_classes, dropout=0.1):
        super(TransformerClassifier, self).__init__()
        
        self.encoder = TransformerEncoder(d_model, num_heads, d_ff, num_layers, dropout)
        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, x, mask=None):
       """
       Forward pass for the Transformer classifier.
       
       Args:
           x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
           mask (torch.Tensor, optional): Mask tensor of shape (batch_size, 1, 1, seq_len).
           
       Returns:
           torch.Tensor: Output tensor of shape (batch_size, num_classes).
       """
       # Get the encoder's output
       encoder_output = self.encoder(x, mask)   
       return self.classifier(encoder_output)

In [71]:
imput_tensor = torch.rand(batch_size, seq_leng_q, d_model)

model = TransformerClassifier(d_model=d_model, num_heads=num_heads, d_ff=d_ff, num_layers=1, num_classes=2)
output = model.forward(imput_tensor)

output.shape

torch.Size([1, 10, 2])