<a href="https://colab.research.google.com/github/kinjaljoshi/attention_usage/blob/main/masked_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Masked Attention

Masked Attention is a variation of self-attention where certain tokens in the sequence are intentionally hidden (masked) so that they do not influence the attention computation. This is primarily used in auto-regressive models where a word should not attend to future words during training.


In language generation tasks, we generate words one by one. The model should not look at future words before predicting the next token.

**Without masking**
The attention mechanism allows each word to see all other words.
This is fine for bidirectional models but problematic for autoregressive models.
**With masking**
The model is forced to only attend to previous words.
This ensures it generates text in a left-to-right manner.

We apply a mask matrix that prevents the model from attending to future words by replacing future token attention scores with very large negative numbers before applying softmax.

**The cat sat on the mat**
In masked attention, future words are blocked:


* "The" sees only itself.
* "Cat" sees "The" and itself.
* "Sat" sees "The" and "Cat" but not later words.
* Each word can only attend to itself and previous words.







In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MaskedSelfAttention(nn.Module):
    def __init__(self, mode):
        super().__init__()
        self.W_q = nn.Linear(mode, mode, bias=False)
        self.W_k = nn.Linear(mode, mode, bias=False)
        self.W_v = nn.Linear(mode, mode, bias=False)

    def forward(self, token_encodings):
        """
        token_encodings: (batch_size, seq_length, mode)
        """
        # Compute Query (Q), Key (K), Value (V)
        Q = self.W_q(token_encodings)  # (batch_size, seq_length, mode)
        K = self.W_k(token_encodings)  # (batch_size, seq_length, mode)
        V = self.W_v(token_encodings)  # (batch_size, seq_length, mode)

        # Compute similarity scores (QK^T)
        similarity_scores = torch.matmul(Q, K.transpose(-2, -1))

        # Scale scores
        d_k = K.shape[-1]
        scaled_scores = similarity_scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

        # Create a mask (upper triangular mask to block future words)
        seq_length = token_encodings.shape[1]
        mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1)  # Upper triangular matrix
        mask = mask.masked_fill(mask == 1, float('-inf'))  # Convert 1s to -inf

        # Apply the mask before softmax
        masked_scores = scaled_scores + mask.unsqueeze(0)  # Add mask to attention scores

        # Apply softmax
        attention_weights = F.softmax(masked_scores, dim=-1)

        # Compute final attention output
        attention_output = torch.matmul(attention_weights, V)

        return attention_output, attention_weights


# Example Usage
if __name__ == "__main__":
    # Example sentence: "The cat sat"
    token_embeddings = torch.tensor([
        [1.0, 0.5],  # "The"
        [0.8, 0.2],  # "Cat"
        [0.9, 0.7],  # "Sat"
    ], dtype=torch.float32).unsqueeze(0)  # Shape: (1, 3, 2)

    # Initialize masked self-attention
    masked_attention = MaskedSelfAttention(mode=2)

    # Apply masked attention
    output, attention_weights = masked_attention(token_embeddings)

    print("\nFinal Masked Attention Output:\n", output)
    print("\nMasked Attention Weights:\n", attention_weights)



Final Masked Attention Output:
 tensor([[[ 0.8301, -0.8463],
         [ 0.6860, -0.7036],
         [ 0.7573, -0.7715]]], grad_fn=<UnsafeViewBackward0>)

Masked Attention Weights:
 tensor([[[1.0000, 0.0000, 0.0000],
         [0.5003, 0.4997, 0.0000],
         [0.3388, 0.3320, 0.3291]]], grad_fn=<SoftmaxBackward0>)


**The → [1.0000, 0.0000, 0.0000]**
* "The" only attends to itself because it's the first word.

**Cat → [0.5003, 0.4997, 0.0000]**
* The (50.03%)
* Itself (49.97%)

**Sat → [0.3388, 0.3320, 0.3291]**
* The (33.88%)
* Cat (33.20%)
* Itself (32.91%)



