# Demonstrating Masked Self-Attention in Transformers
## This notebook will provide an intuitive and practical demonstration of Masked Self-Attention in Transformers using PyTorch.

## Masked Self Attention

$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
import random
random.seed(24)  # Python random seed
torch.manual_seed(24)  # PyTorch seed (CPU)

<torch._C.Generator at 0x7f9402cf78f0>

In [32]:
# Set print options: No scientific notation, 2 decimal places
torch.set_printoptions(sci_mode=False, precision=4)

# Define the maximum sequence length and the embedding dimension for a model:

max_sequence_length = 5: Specifies the maximum number of tokens a sequence can have. If a sequence is shorter, it may be padded; if longer, it may be truncated.

d_model = 8: Defines the size of each token’s embedding vector, meaning each token will be represented as a     8-dimensional vector.

In [30]:
d_model = 8
max_sequence_length = 5

# Define three linear layers using nn.Linear in PyTorch:

w_query: Projects input embeddings into query space.

w_key: Projects input embeddings into key space.

w_value: Projects input embeddings into value space.
## These linear layers transform input embeddings (d_model dimensional) into new representations of the same size (d_model → d_model)

In [31]:
w_query = nn.Linear(d_model, d_model)
w_key   = nn.Linear(d_model, d_model)
w_value = nn.Linear(d_model, d_model)

# Create a tensor tokens with random values, scaled by a factor of 10.0, to simulate a sequence of token embeddings.

Use torch.randn() to generate a random tensor of shape (max_sequence_length, d_model), where:

max_sequence_length represents the number of tokens in the sequence.

d_model represents the embedding dimension.

Multiply the generated tensor by 10.0 to scale the values.

In [5]:
tokens = torch.randn(max_sequence_length, d_model) * 10.0

In [6]:
tokens.shape

torch.Size([5, 8])

In [8]:
tokens

tensor([[ 15.7217, -15.5080,  -9.5075,  -8.6397,  -2.1968,  -5.2217,   5.0664,
           3.8816],
        [ -3.8720, -20.7487,  18.5679,  -9.4431,  15.2684,  24.3474,  -9.0936,
          -4.0998],
        [ -1.2953,   8.7748, -12.0887,  -8.3205,   0.2548, -14.4027,  -2.6338,
          -4.5473],
        [  6.4649, -11.6771,  -1.0466,  12.9927,   9.8568,   0.3484,  -1.1219,
          -0.6101],
        [-18.4733,   3.4435,   4.8654, -22.3485,   8.1349,  27.0436,  20.6851,
           6.2408]])

# Apply linear transformations to the tokens tensor using w_query, w_key, and w_value to obtain query (q), key (k), and value (v) representations.

## Pass tokens through the three linear layers to compute q, k, and v.

In [9]:
q = w_query(tokens)
k = w_key(tokens)
v = w_value(tokens)

In [10]:
q.shape, k.shape, v.shape

(torch.Size([5, 8]), torch.Size([5, 8]), torch.Size([5, 8]))

## Masked Self Attention

$$
\text{self attention} = softmax\bigg(\frac{Q.K^T}{\sqrt{d_k}}+M\bigg)
$$

$$
\text{new V} = \text{self attention}.V
$$

In [11]:
attn_scores = torch.matmul(q, k.T) / torch.sqrt(torch.tensor(d_model, dtype=torch.float))

In [13]:
attn_scores

tensor([[ -9.9895, -56.4391, -31.7966,   5.9264, -43.4802],
        [-59.5789,  -7.6356,  -0.6855,   0.2884,  80.4750],
        [ 35.6549,  38.7697,   0.4983,  17.5775, -30.8585],
        [ -5.2894, -47.0393, -29.2334,  -0.4237, -16.0637],
        [-42.9926,  53.9101,  39.0516,   6.6710,  59.3816]],
       grad_fn=<DivBackward0>)

In [14]:
attn_scores.shape

torch.Size([5, 5])

## Masking

- This is to ensure words don't get context from words generated in the future.
- Not required in the encoders, but required in the decoders

# Create a lower triangular mask using torch.tril, which generates a matrix where only the lower triangle (including the diagonal) contains ones, while the upper triangle contains zeros. This mask is typically used in masked self-attention in transformers to ensure that each position in a sequence can only attend to previous positions and itself, preventing access to future tokens during decoding.

In [15]:
mask = torch.tril(torch.ones((max_sequence_length, max_sequence_length)))
print(mask)

tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])


# Apply a mask to the attention scores using masked_fill, setting positions where mask == 0 to -inf. This ensures that future tokens are ignored in masked self-attention, preventing the model from attending to unseen tokens during autoregressive decoding.

In [16]:
attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))

In [17]:
attn_scores

tensor([[ -9.9895,     -inf,     -inf,     -inf,     -inf],
        [-59.5789,  -7.6356,     -inf,     -inf,     -inf],
        [ 35.6549,  38.7697,   0.4983,     -inf,     -inf],
        [ -5.2894, -47.0393, -29.2334,  -0.4237,     -inf],
        [-42.9926,  53.9101,  39.0516,   6.6710,  59.3816]],
       grad_fn=<MaskedFillBackward0>)

In [18]:
attn_weights = F.softmax(attn_scores, dim=-1)

In [19]:
attn_weights

tensor([[    1.0000,     0.0000,     0.0000,     0.0000,     0.0000],
        [    0.0000,     1.0000,     0.0000,     0.0000,     0.0000],
        [    0.0425,     0.9575,     0.0000,     0.0000,     0.0000],
        [    0.0076,     0.0000,     0.0000,     0.9924,     0.0000],
        [    0.0000,     0.0042,     0.0000,     0.0000,     0.9958]],
       grad_fn=<SoftmaxBackward0>)

In [20]:
print(f"sum = {attn_weights.sum(dim=-1)}")

sum = tensor([1., 1., 1., 1., 1.], grad_fn=<SumBackward1>)


# Compute the weighted sum of value (v) vectors using attention weights, where each query token receives a context-aware representation. This ensures that each generated token attends to relevant past tokens, influencing its prediction based on learned dependencies.

In [21]:
attention_output = torch.matmul(attn_weights, v)

In [22]:
attention_output.shape

torch.Size([5, 8])

In [23]:
attention_output

tensor([[  4.3459,  -1.6540,   7.7558,  -4.7189,   7.4060,   7.7708,  -3.1271,
           0.6721],
        [ 18.0213,   6.2609,   1.7086,   6.8319,  -1.5335,  -1.1946,   5.0538,
          -2.8420],
        [ 17.4401,   5.9245,   1.9657,   6.3410,  -1.1535,  -0.8136,   4.7061,
          -2.6926],
        [  1.4368,   5.7818,  -0.2100,  -6.2094,   0.9749,   4.3626,  -6.4292,
           0.1822],
        [ 13.0137,  -4.3577, -10.8564,  22.3291,  11.1974,  -8.0273,  19.4841,
           3.5926]], grad_fn=<MmBackward0>)