# Set up

In [None]:
import torch

# Fix the random seed
torch.manual_seed(42)

# Random tensor of size (2, 3)
tensor = torch.rand(2, 3)

print(tensor)

tensor([[0.8823, 0.9150, 0.3829],
        [0.9593, 0.3904, 0.6009]])


In [None]:
torch.manual_seed(0)
input_token = torch.rand(4,6)
input_token

tensor([[0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341],
        [0.4901, 0.8964, 0.4556, 0.6323, 0.3489, 0.4017],
        [0.0223, 0.1689, 0.2939, 0.5185, 0.6977, 0.8000],
        [0.1610, 0.2823, 0.6816, 0.9152, 0.3971, 0.8742]])

In [None]:
torch.manual_seed(1)
w_key = torch.rand(6, 6)
torch.manual_seed(2)
w_query = torch.rand(6, 6)
torch.manual_seed(3)
w_value = torch.rand(6, 6)
w_key

tensor([[0.7576, 0.2793, 0.4031, 0.7347, 0.0293, 0.7999],
        [0.3971, 0.7544, 0.5695, 0.4388, 0.6387, 0.5247],
        [0.6826, 0.3051, 0.4635, 0.4550, 0.5725, 0.4980],
        [0.9371, 0.6556, 0.3138, 0.1980, 0.4162, 0.2843],
        [0.3398, 0.5239, 0.7981, 0.7718, 0.0112, 0.8100],
        [0.6397, 0.9743, 0.8300, 0.0444, 0.0246, 0.2588]])

# Single attention head

In [None]:
key = torch.matmul(input_token, w_key)
query = torch.matmul(input_token, w_query)
value = torch.matmul(input_token, w_value)
attn_weight = torch.matmul(query, key.T)
attn_score = torch.matmul(attn_weight, value)
attn_score

tensor([[44.0492, 62.1182, 36.8603, 39.2654, 39.7193, 38.2237],
        [62.7624, 88.5186, 52.5154, 55.9376, 56.5921, 54.4579],
        [45.9096, 64.7556, 38.4231, 40.9384, 41.3846, 39.8381],
        [64.8527, 91.4548, 54.2801, 57.8339, 58.4687, 56.2849]])

# Method 1: Split the weight

In [None]:
w_key1 = w_key[:, :3]
w_key2 = w_key[:, 3:]
w_query1 = w_query[:, :3]
w_query2 = w_query[:, 3:]
w_value1 = w_value[:, :3]
w_value2 = w_value[:, 3:]

In [None]:
key1 = torch.matmul(input_token, w_key1)
query1 = torch.matmul(input_token, w_query1)
value1 = torch.matmul(input_token, w_value1)
attn_weight1 = torch.matmul(query1, key1.T)
attn_score1 = torch.matmul(attn_weight1, value1)
attn_score1

tensor([[25.9219, 36.7194, 21.6421],
        [37.6478, 53.3254, 31.4308],
        [27.6223, 39.1307, 23.0655],
        [37.8793, 53.6592, 31.6296]])

In [None]:
key2 = torch.matmul(input_token, w_key2)
query2 = torch.matmul(input_token, w_query2)
value2 = torch.matmul(input_token, w_value2)
attn_weight2 = torch.matmul(query2, key2.T)
attn_score2 = torch.matmul(attn_weight2, value2)
attn_score2

tensor([[16.2456, 16.3873, 15.8126],
        [22.5088, 22.7015, 21.9066],
        [16.4007, 16.5261, 15.9538],
        [24.1872, 24.3773, 23.5310]])

# Method 2: Split the QKV

In [None]:
key3 = key[:, :3]
key4 = key[:, 3:]
query3 = query[:, :3]
query4 = query[:, 3:]
value3 = value[:, :3]
value4 = value[:, 3:]
key3

tensor([[1.3753, 1.6105, 1.4916],
        [2.0064, 1.9409, 1.7296],
        [1.5193, 1.7082, 1.6249],
        [2.2511, 2.1256, 1.8713]])

In [None]:
attn_weight3 = torch.matmul(query3, key3.T)
attn_score3 = torch.matmul(attn_weight3, value3)
attn_score3

tensor([[25.9219, 36.7194, 21.6421],
        [37.6478, 53.3254, 31.4308],
        [27.6223, 39.1307, 23.0655],
        [37.8793, 53.6592, 31.6296]])

In [None]:
attn_weight4 = torch.matmul(query4, key4.T)
attn_score4 = torch.matmul(attn_weight4, value4)
attn_score4

tensor([[16.2456, 16.3873, 15.8126],
        [22.5088, 22.7015, 21.9066],
        [16.4007, 16.5261, 15.9538],
        [24.1872, 24.3773, 23.5310]])