In [25]:
import torch
import math
import torch.nn.functional as F

In [26]:
# Parameters
batch_size = 2
seq_len = 5       # same for Q, K, V in self-attention
embed_dim = 16    # total embedding size
num_heads = 4
head_dim = embed_dim // num_heads


Q = torch.arange(batch_size * num_heads * seq_len * head_dim).view(batch_size, num_heads, seq_len, head_dim).float()
K = torch.arange(batch_size * num_heads * seq_len * head_dim).view(batch_size, num_heads, seq_len, head_dim).float()
V = torch.arange(batch_size * num_heads * seq_len * head_dim).view(batch_size, num_heads, seq_len, head_dim).float()

print("Q shape:", Q.shape)
print("K shape:", K.shape)  
print("V shape:", V.shape)

#print("Q:", Q)
#print("K:", K)
print("V:", V)

# Optional causal mask (prevent attending to future positions)
causal_mask = torch.ones(seq_len, seq_len, dtype=torch.bool).tril()

# Self-attention
out = F.scaled_dot_product_attention(
    Q, K, V,
    enable_gqa=False,  
    scale=None,        
    is_causal=True,  
)

print("Self-attention output shape:", out.shape)
print("Self-attention output:", out)


Q shape: torch.Size([2, 4, 5, 4])
K shape: torch.Size([2, 4, 5, 4])
V shape: torch.Size([2, 4, 5, 4])
V: tensor([[[[  0.,   1.,   2.,   3.],
          [  4.,   5.,   6.,   7.],
          [  8.,   9.,  10.,  11.],
          [ 12.,  13.,  14.,  15.],
          [ 16.,  17.,  18.,  19.]],

         [[ 20.,  21.,  22.,  23.],
          [ 24.,  25.,  26.,  27.],
          [ 28.,  29.,  30.,  31.],
          [ 32.,  33.,  34.,  35.],
          [ 36.,  37.,  38.,  39.]],

         [[ 40.,  41.,  42.,  43.],
          [ 44.,  45.,  46.,  47.],
          [ 48.,  49.,  50.,  51.],
          [ 52.,  53.,  54.,  55.],
          [ 56.,  57.,  58.,  59.]],

         [[ 60.,  61.,  62.,  63.],
          [ 64.,  65.,  66.,  67.],
          [ 68.,  69.,  70.,  71.],
          [ 72.,  73.,  74.,  75.],
          [ 76.,  77.,  78.,  79.]]],


        [[[ 80.,  81.,  82.,  83.],
          [ 84.,  85.,  86.,  87.],
          [ 88.,  89.,  90.,  91.],
          [ 92.,  93.,  94.,  95.],
          [ 96.,  97.

In [27]:
b, h = 0, 0
Q_ = Q[b, h]  # [5,4]
K_ = K[b, h]  # [5,4]
V_ = V[b, h]  # [5,4]

print("Q_ shape:", Q_.shape)
print("Q_:", Q_)

# Step 1: raw scores
scores = Q_ @ K_.T

print("Raw scores shape:", scores.shape)
print("Raw scores:", scores)

# Step 2: scale

scores_scaled = scores / math.sqrt(Q_.shape[1])

print("Scaled scores shape:", scores_scaled.shape)
print("Scaled scores:", scores_scaled)

# Step 3: causal mask
mask = torch.tril(torch.ones(Q_.shape[0], Q_.shape[0], dtype=torch.bool))
print("Mask shape:", mask.shape)
print("Mask:", mask)

scores_scaled = scores_scaled.masked_fill(~mask, float('-inf'))
print("Masked scores shape:", scores_scaled.shape)
print("Masked scores:", scores_scaled)

# Step 4: softmax
attn_weights = F.softmax(scores_scaled, dim=-1)
print("Attention weights shape:", attn_weights.shape)
print("Attention weights:", attn_weights)

sum = attn_weights.sum(dim=-1, keepdim=True)  # Normalize weights
print("Sum of attention weights shape:", sum.shape)
print("Sum of attention weights:", sum)

# Step 5: weighted sum
out_manual = attn_weights @ V_
print("out_manual:", out_manual)


Q_ shape: torch.Size([5, 4])
Q_: tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.],
        [12., 13., 14., 15.],
        [16., 17., 18., 19.]])
Raw scores shape: torch.Size([5, 5])
Raw scores: tensor([[  14.,   38.,   62.,   86.,  110.],
        [  38.,  126.,  214.,  302.,  390.],
        [  62.,  214.,  366.,  518.,  670.],
        [  86.,  302.,  518.,  734.,  950.],
        [ 110.,  390.,  670.,  950., 1230.]])
Scaled scores shape: torch.Size([5, 5])
Scaled scores: tensor([[  7.,  19.,  31.,  43.,  55.],
        [ 19.,  63., 107., 151., 195.],
        [ 31., 107., 183., 259., 335.],
        [ 43., 151., 259., 367., 475.],
        [ 55., 195., 335., 475., 615.]])
Mask shape: torch.Size([5, 5])
Mask: tensor([[ True, False, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True]])
Masked scores shape: t