## Self-Attention Block in PyTorch


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


#### Define the Self-Attention block


In [4]:
class SelfAttention(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SelfAttention, self).__init__()
        
        # Learnable weight matrices for Q, K, and V
        self.query = nn.Linear(input_dim, output_dim)
        self.key = nn.Linear(input_dim, output_dim)
        self.value = nn.Linear(input_dim, output_dim)
        
        self.scale = torch.sqrt(torch.FloatTensor([output_dim]))  # Scaling factor
    
    def forward(self, x):
        # Step 1: Compute Q, K, V
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        
        # Step 2: Compute the attention scores
        attention_scores = Q @ K.transpose(-2, -1) / self.scale
        attention_weights = F.softmax(attention_scores, dim=-1)
        output = attention_weights@V
        
        return output, attention_weights


The .transpose(-2, -1) in K.transpose(-2, -1) is a way to specify that you want to transpose the last two dimensions of a tensor.
Let’s break it down:
* -1 refers to the last dimension (e.g., columns in a 2D matrix or features in a 3D tensor).
* -2 refers to the second-to-last dimension (e.g., rows in a 2D matrix or sequence length in a 3D tensor).

In [7]:
embedding_dim = 3

torch.manual_seed(10)
self_attention = SelfAttention(input_dim=embedding_dim, output_dim=3)

X = torch.tensor([[-0.8805, -0.6517,  0.4077],
        [ 0.4389, -1.1243, -0.8373],
        [ 2.0104,  2.2844,  0.1933],
        [ 0.7380,  0.5161,  1.5216]], dtype=torch.float32)  # Shape: D=3, N=5

# Forward pass through the self-attention block
output, attention_weights = self_attention(X)

print("Attention Output:")
print(output)
print("Attention Weights:")
print(attention_weights)


Attention Output:
tensor([[-0.3173, -0.0280,  0.0224],
        [-0.4334, -0.1403,  0.1004],
        [-0.3490, -0.0794,  0.0435],
        [-0.2562,  0.0009, -0.0200]], grad_fn=<MmBackward0>)
Attention Weights:
tensor([[0.2688, 0.2628, 0.2259, 0.2425],
        [0.2219, 0.2686, 0.2560, 0.2536],
        [0.2636, 0.2650, 0.2494, 0.2220],
        [0.2799, 0.2791, 0.2216, 0.2194]], grad_fn=<SoftmaxBackward0>)
