In [1]:
import torch
import torch.nn as nn

# Define input
inputs = torch.tensor([
    [0.43, 0.15, 0.89],  # Your     (x^1)
    [0.55, 0.87, 0.66],  # journey  (x^2)
    [0.57, 0.85, 0.64],  # starts   (x^3)
    [0.22, 0.58, 0.33],  # with     (x^4)
    [0.77, 0.25, 0.10],  # one      (x^5)
    [0.05, 0.80, 0.55]   # step     (x^6)
])

batch = torch.stack((inputs, inputs), dim=0)  # Create a batch of size (2, 6, 3)
print("Batch shape:", batch.shape)
print("Batch data:\n", batch)

class MaskedAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.W_query = nn.Linear(3, 3, bias=False)
        self.W_key = nn.Linear(3, 3, bias=False)
        self.W_value = nn.Linear(3, 3, bias=False)
        self.dropout = nn.Dropout(0.5)
        self.register_buffer('mask', torch.triu(torch.ones(6, 6), diagonal=1))  # Upper triangular mask

    def forward(self, inputs):
        b, context_length, dim = inputs.shape
        
        key = self.W_key(inputs)
        value = self.W_value(inputs)
        query = self.W_query(inputs)

        print("\nQuery matrix:\n", query)
        print("\nKey matrix:\n", key)
        print("\nValue matrix:\n", value)

        # Compute raw attention scores
        attention_scores = query @ key.transpose(1, 2)
        print("\nRaw Attention Scores:\n", attention_scores)

        # Apply mask (causal masking)
        masked_attention_scores = attention_scores.clone()
        masked_attention_scores.masked_fill_(self.mask.bool()[:context_length, :context_length], -torch.inf)
        print("\nMasked Attention Scores:\n", masked_attention_scores)

        # Compute softmax to get attention weights
        attn_weights = torch.softmax(masked_attention_scores / (key.shape[-1] ** 0.5), dim=-1)
        print("\nAttention Weights (After Softmax):\n", attn_weights)

        # Apply dropout
        attn_weights = self.dropout(attn_weights)

        # Compute final attention output
        context_vector = attn_weights @ value
        print("\nContext Vector (Final Output):\n", context_vector)

        return context_vector

# Initialize and run masked attention
masked_attention = MaskedAttention()
torch.manual_seed(123)
masked_attention(batch)


Batch shape: torch.Size([2, 6, 3])
Batch data:
 tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])

Query matrix:
 tensor([[[ 0.1442, -0.4322, -0.1271],
         [ 0.5539, -0.8205, -0.2508],
         [ 0.5507, -0.8148, -0.2423],
         [ 0.3345, -0.4630, -0.1617],
         [ 0.3364, -0.4829, -0.0206],
         [ 0.3846, -0.5472, -0.2521]],

        [[ 0.1442, -0.4322, -0.1271],
         [ 0.5539, -0.8205, -0.2508],
         [ 0.5507, -0.8148, -0.2423],
         [ 0.3345, -0.4630, -0.1617],
         [ 0.3364, -0.4829, -0.0206],
         [ 0.3846, -0.5472, -0.2521]]], grad_fn=<UnsafeViewBackward0>)

Key matrix:
 tensor(

tensor([[[-0.1456, -0.5259, -0.2395],
         [-0.1489, -0.0468,  0.0791],
         [-0.0915, -0.0240,  0.0407],
         [-0.1101, -0.1534, -0.0200],
         [-0.0296, -0.1070, -0.0487],
         [-0.0497, -0.0156,  0.0264]],

        [[ 0.0000,  0.0000,  0.0000],
         [-0.1489, -0.0468,  0.0791],
         [-0.1905, -0.0551,  0.0933],
         [-0.1647, -0.1616,  0.0385],
         [-0.0847,  0.0249, -0.0412],
         [-0.0712, -0.0784, -0.0244]]], grad_fn=<UnsafeViewBackward0>)