# Self-study Try-it 20.1: Build an attention mechanism using PyTorch

Attention mechanisms enable models to focus on the most relevant parts of input data, improving understanding and prediction in tasks like translation and speech recognition. By assigning varying importance to tokens, attention enhances context-awareness. Variable-length sequences are managed through masking, which prevents padded or future tokens from influencing the output. Together, attention and masking form a core part of transformer architectures.


 The basic attention mechanism is the foundational building block used in transformer models, enabling them to weigh the importance of different parts of the input sequence dynamically and effectively capture contextual relationships. This mechanism, often extended with multi-head attention, is central to the power and success of modern transformer-based AI systems.

In [None]:
# import the necessary libraries
import torch
import torch.nn.functional as F



In the code below, using PyTorch, a simple attention mechanism is defined that computes attention scores by taking the dot product of query and
key tensors, scales these scores by the square root of the key dimension to maintain numerical stability, and applies a softmax
to obtain attention weights. These weights are then used to compute a weighted sum of the value tensor, producing an output that highlights the most relevant information in the input sequence.


In [None]:
def simple_attention(query, key, value):
    # Calculate raw attention scores by dot product of query and key transpose
    scores = torch.matmul(query, key.transpose(-2, -1))
    # Scale scores by square root of key dimension
    d_k = query.size(-1)
    scaled_scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    # Apply softmax to get attention weights (probabilities)
    attn_weights = F.softmax(scaled_scores, dim=-1)
    # Compute weighted sum of values according to attention weights
    output = torch.matmul(attn_weights, value)
    return output, attn_weights

# Example input: batch size 1, sequence length 3, embedding dimension 2
query = torch.tensor([[[1.0, 0.0],
                       [0.0, 1.0],
                       [1.0, 1.0]]])  # shape (1, 3, 2)

key = torch.tensor([[[1.0, 0.0],
                     [0.0, 1.0],
                     [1.0, 1.0]]])    # shape (1, 3, 2)

value = torch.tensor([[[1.0, 10.0],
                       [10.0, 1.0],
                       [5.0, 5.0]]])   # shape (1, 3, 2)

output, attn_weights = simple_attention(query, key, value)

print("Attention output:\n", output)
print("Attention weights:\n", attn_weights)


### Try-it: 1. Modify Input Tensors:
Change the query or key vectors and see how the attention weights and output change. For example, make one query vector more similar to a specific key vector and observe the effect.

In [None]:
# Original inputs
query = torch.tensor([[[1.0, 0.0],
                       [0.0, 1.0],
                       [1.0, 1.0]]])  # shape (1, 3, 2)

key = torch.tensor([[[1.0, 0.0],
                     [0.0, 1.0],
                     [1.0, 1.0]]])    # shape (1, 3, 2)

value = torch.tensor([[[1.0, 10.0],
                       [10.0, 1.0],
                       [5.0, 5.0]]])   # shape (1, 3, 2)

output, attn_weights = simple_attention(query, key, value)
print("Original attention weights:\n", attn_weights)
print("Original attention output:\n", output)

# Modified query: move first query vector closer to second key vector
modified_query = torch.tensor([[[0.1, 0.9],
                                [0.0, 1.0],
                                [1.0, 1.0]]])

output_mod, attn_weights_mod = simple_attention(modified_query, key, value)
print("\nModified attention weights:\n", attn_weights_mod)
print("Modified attention output:\n", output_mod)


### Try-it: 2. Implement Masking

Add a mask to the attention scores before softmax to simulate attention only on allowed tokens (e.g., for sequence padding).

Here, if a mask value is provided, mask is applied with a value of 0 for allowed positions and -inf for masked positions.

In [None]:
def simple_attention_with_mask(query, key, value, mask=None):
    # Calculate raw attention scores
    scores = torch.matmul(query, key.transpose(-2, -1))
    d_k = query.size(-1)
    scaled_scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

    # Apply mask if provided (mask with 0 for allowed and -inf for masked positions)
    if mask is not None:
        scaled_scores = scaled_scores.masked_fill(mask == 0, float('-inf'))

    # Softmax to get attention weights
    attn_weights = F.softmax(scaled_scores, dim=-1)
    # Weighted sum of values
    output = torch.matmul(attn_weights, value)
    return output, attn_weights

# Example inputs (batch_size=1, seq_len=3, embedding_dim=2)
query = torch.tensor([[[1.0, 0.0],
                       [0.0, 1.0],
                       [1.0, 1.0]]])
key = torch.tensor([[[1.0, 0.0],
                     [0.0, 1.0],
                     [1.0, 1.0]]])
value = torch.tensor([[[1.0, 10.0],
                       [10.0, 1.0],
                       [5.0, 5.0]]])

# Create a mask that allows attending only to the first two keys (mask shape must broadcast with scaled_scores)
mask = torch.tensor([[[1, 1, 0],   # Allow attention on keys 0 and 1, mask out key 2
                      [1, 1, 0],
                      [1, 1, 0]]])  # shape (1, 3, 3)

output, attn_weights = simple_attention_with_mask(query, key, value, mask=mask)

print("Attention weights with mask:\n", attn_weights)
print("Attention output with mask:\n", output)


### Try-it: 3. Compare Different Scaling Factors
Remove the scaling factor 1/Square root(dk)  and observe effects on the softmax outputs. Redefine the simple_attention with `scale=True` in addition to other attributes. Observe the difference with `scale=True` and `scale=False`.

In [None]:
def simple_attention(query, key, value, scale=True):
    scores = torch.matmul(query, key.transpose(-2, -1))
    d_k = query.size(-1)

    # Apply scaling if scale=True
    if scale:
        scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, value)
    return output, attn_weights

# Example inputs (batch_size=1, seq_len=3, embedding_dim=2)
query = torch.tensor([[[1.0, 0.0],
                       [0.0, 1.0],
                       [1.0, 1.0]]])
key = torch.tensor([[[1.0, 0.0],
                     [0.0, 1.0],
                     [1.0, 1.0]]])
value = torch.tensor([[[1.0, 10.0],
                       [10.0, 1.0],
                       [5.0, 5.0]]])


# Attention with scaling factor 1/sqrt(d_k)
output_scaled, attn_weights_scaled = simple_attention(query, key, value, scale=True)

# Attention without scaling factor
output_unscaled, attn_weights_unscaled = simple_attention(query, key, value, scale=False)

print("Attention weights with scaling factor:\n", attn_weights_scaled)
print("Attention output with scaling factor:\n", output_scaled)

print("\nAttention weights without scaling factor:\n", attn_weights_unscaled)
print("Attention output without scaling factor:\n", output_unscaled)


### Try-it: 4. Add multi-head attention
Here, we implement multi-head attention through the following steps:
- Add number of heads and head dimensions
- Add linear layers projecting input into queries, keys, and values for all heads
- Split queries, keys, and values into multiple heads and rearrange dimensions
- Perform scaled dot-product attention separately for each head
- Concatenate the heads and transform back to original embedding dim

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SimpleMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Linear layers to project input to queries, keys, and values
        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)

        # Final linear layer to combine heads
        self.out_linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()

        # Project inputs to queries, keys, values
        Q = self.q_linear(x)  # (batch, seq_len, embed_dim)
        K = self.k_linear(x)
        V = self.v_linear(x)

        # Split embeddings into multiple heads and transpose for attention calculation
        # New shape: (batch, num_heads, seq_len, head_dim)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)  # (batch, num_heads, seq_len, head_dim)

        # Concatenate heads and put through final linear layer
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
        output = self.out_linear(attn_output)

        return output, attn_weights

# Example usage
batch_size = 2
seq_len = 4
embed_dim = 8
num_heads = 2

# Random input tensor representing a batch of sequences of embeddings
x = torch.randn(batch_size, seq_len, embed_dim)

model = SimpleMultiHeadAttention(embed_dim, num_heads)
output, attn_weights = model(x)

print("Output shape:", output.shape)          # Expected: (batch_size, seq_len, embed_dim)
print("Attention weights shape:", attn_weights.shape)  # Expected: (batch_size, num_heads, seq_len, seq_len)
