# Demonstrating Masked Self-Attention in Transformers
## This notebook will provide an intuitive and practical demonstration of Masked Self-Attention in Transformers using PyTorch.

## Masked Self Attention

$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
import random
random.seed(24)  # Python random seed
torch.manual_seed(24)  # PyTorch seed (CPU)

<torch._C.Generator at 0x7fa4fb7038f0>

In [3]:
# Set print options: No scientific notation, 2 decimal places
torch.set_printoptions(sci_mode=False, precision=4)

# Task: Implementing Masked Self-Attention
In this task, you will implement a Masked Self-Attention mechanism, an essential component of Transformer-based decoders. Unlike regular self-attention, masked self-attention ensures that each token can only attend to itself and previous tokens, preventing access to future information during autoregressive generation.

Your goal is to define a Masked_SelfAttention class in PyTorch, which will:

Define learnable transformations for query (Q), key (K), and value (V) projections using nn.Linear.

Measure the similarity between tokens using dot-product attention:

Use a lower triangular mask to prevent attending to future tokens by setting masked positions to -inf.

Convert masked attention scores into probability distributions using the softmax function.

Multiply attention weights with value (V) vectors to produce the final contextual representation for each token.

In [4]:
class Masked_SelfAttention(nn.Module):
    def __init__(self, d_model):
        super(Masked_SelfAttention, self).__init__()
        self.d_model = d_model
        self.d_k = d_model
        self.d_v = d_model

        ### BEGIN SOLUTION
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        ### END SOLUTION 
        
    def forward(self, q_input, k_input, v_input, mask=None):

        Q = self.query(q_input)
        K = self.key(k_input)
        V = self.value(v_input)

        attn_scores = torch.matmul(Q, K.T) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float))
       
        if mask is not None:
            ### BEGIN SOLUTION
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
            ### END SOLUTION 
        attn_weights = F.softmax(attn_scores, dim=-1)

        ### BEGIN SOLUTION
        attention_output = torch.matmul(attn_weights, V)
        ### END SOLUTION 
        return attention_output

In [5]:
### BEGIN HIDDEN TESTS

def test_masked_self_attention():
    d_model_t = 16   # Feature dimension
    seq_len_t = 4    # Sequence length

    # Initialize the Masked Self-Attention module
    masked_self_attn_t = Masked_SelfAttention(d_model_t)

    # Create random inputs of shape (seq_len, d_model)
    q_input_t = torch.randn(seq_len_t, d_model_t)
    k_input_t = torch.randn(seq_len_t, d_model_t)
    v_input_t = torch.randn(seq_len_t, d_model_t)

    # Create a lower triangular mask for causal masking
    mask_t = torch.tril(torch.ones(seq_len_t, seq_len_t))

    # Forward pass
    attention_output_t = masked_self_attn_t(q_input_t, k_input_t, v_input_t, mask_t)

    # Assertions
    assert attention_output_t.shape == (seq_len_t, d_model_t), "Output shape mismatch!"
    assert not torch.isnan(attention_output_t).any(), "NaN values in output!"
    assert torch.isfinite(attention_output_t).all(), "Non-finite values in output!"

    print("All test cases passed!")

# Run the test
test_masked_self_attention()
### END HIDDEN TESTS

All test cases passed!


In [6]:
d_model = 6
max_sequence_length = 4
src_tokens = torch.randn(max_sequence_length, d_model)

In [7]:
self_attention = Masked_SelfAttention(d_model)
mask = torch.tril(torch.ones((max_sequence_length, max_sequence_length)))
result = self_attention.forward(src_tokens, src_tokens, src_tokens, mask)

In [8]:
result

tensor([[-0.6776,  0.0720, -0.4998, -0.3670,  0.1837,  0.3315],
        [-0.6853,  0.2085, -0.2347, -0.4739,  0.0130,  0.5306],
        [-0.3970, -0.0950,  0.1439, -0.1145,  0.0797,  0.4010],
        [-0.3144, -0.3691,  0.4142, -0.0068, -0.0090,  0.1935]],
       grad_fn=<MmBackward0>)

# Encoder-Decoder Attention a.k.a Cross-Attention

In [9]:
encoder_output = torch.randn(max_sequence_length, d_model)  # Simulated encoder output
decoder_input = torch.randn(max_sequence_length, d_model)  # Simulated decoder input

cross_attention = Masked_SelfAttention(d_model)
result = cross_attention(decoder_input, encoder_output, encoder_output)

print("Cross-Attention Output:")
print(result)

Cross-Attention Output:
tensor([[ 0.0314, -0.5712,  0.3103,  0.0950,  0.7929,  0.3417],
        [ 0.0532, -0.5954,  0.3131,  0.0882,  0.7581,  0.3172],
        [ 0.2182, -0.5214,  0.3097,  0.0350,  0.6196,  0.1594],
        [-0.0774, -0.7373,  0.3256,  0.1355,  0.8371,  0.4414]],
       grad_fn=<MmBackward0>)
