In [2]:
import torch

data = [[1, 0, 0],
        [0, 2, 0],
        [0, 0, 3]]
v = torch.tensor(data)
v

tensor([[1, 0, 0],
        [0, 2, 0],
        [0, 0, 3]])

In [3]:
# 2 separate stacks of 3, 3x3 matrices
torch.zeros((2,3,3,3))

tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]])

In [4]:
# Input vectors (1x3) (BxF)
x = torch.tensor([[0, 1, 2]], dtype=torch.float32)

# Weight matrices (1x3)
W_Q = torch.tensor([[1, 1, 1]], dtype=torch.float32)
W_K = torch.tensor([[2, 5, 2]], dtype=torch.float32)
W_V = torch.tensor([[1, 1, 1]], dtype=torch.float32)

x = x.unsqueeze(2)

print(x)
print(x.shape)

# Compute Q, K, V (Bx3x1 @ Bx1x3 = Bx3x3)
Q = x @ W_Q
K = x @ W_K
V = x @ W_V

print("Q:", Q)
print("K:", K)
print("V:", V)

tensor([[[0.],
         [1.],
         [2.]]])
torch.Size([1, 3, 1])
Q: tensor([[[0., 0., 0.],
         [1., 1., 1.],
         [2., 2., 2.]]])
K: tensor([[[ 0.,  0.,  0.],
         [ 2.,  5.,  2.],
         [ 4., 10.,  4.]]])
V: tensor([[[0., 0., 0.],
         [1., 1., 1.],
         [2., 2., 2.]]])


In [5]:
import torch.nn.functional as F

# Compute attention scores (3x2 @ 2x3 = 3x3)
scores = Q @ K.transpose(1, 2) / torch.sqrt(torch.tensor(3.0))
attention_scores = F.softmax(scores, dim=-1)

print("Scores:", scores[0])
print("Attention Scores:", attention_scores[0])
print(attention_scores.shape)
v = attention_scores @ V.transpose(1, 2)
print(v)

Scores: tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  5.1962, 10.3923],
        [ 0.0000, 10.3923, 20.7846]])
Attention Scores: tensor([[3.3333e-01, 3.3333e-01, 3.3333e-01],
        [3.0498e-05, 5.5072e-03, 9.9446e-01],
        [9.4047e-10, 3.0667e-05, 9.9997e-01]])
torch.Size([1, 3, 3])
tensor([[[0., 1., 2.],
         [0., 1., 2.],
         [0., 1., 2.]]])


In [22]:
x = [[[1, 2, 3, 4],
      [5, 6, 7, 8],
      [9, 10, 11, 12]],

     [[13, 14, 15, 16],
      [17, 18, 19, 20],
      [21, 22, 23, 24]]]

x = torch.tensor(x, dtype=torch.float32)
x

tensor([[[ 1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12.]],

        [[13., 14., 15., 16.],
         [17., 18., 19., 20.],
         [21., 22., 23., 24.]]])

In [23]:
xv = x.view(2, 3, 2, 2)
xv

tensor([[[[ 1.,  2.],
          [ 3.,  4.]],

         [[ 5.,  6.],
          [ 7.,  8.]],

         [[ 9., 10.],
          [11., 12.]]],


        [[[13., 14.],
          [15., 16.]],

         [[17., 18.],
          [19., 20.]],

         [[21., 22.],
          [23., 24.]]]])

In [24]:
xvt = xv.transpose(1, 2)
xvt

tensor([[[[ 1.,  2.],
          [ 5.,  6.],
          [ 9., 10.]],

         [[ 3.,  4.],
          [ 7.,  8.],
          [11., 12.]]],


        [[[13., 14.],
          [17., 18.],
          [21., 22.]],

         [[15., 16.],
          [19., 20.],
          [23., 24.]]]])

In [25]:
xvtt = xvt.transpose(-1, -2)
xvtt

tensor([[[[ 1.,  5.,  9.],
          [ 2.,  6., 10.]],

         [[ 3.,  7., 11.],
          [ 4.,  8., 12.]]],


        [[[13., 17., 21.],
          [14., 18., 22.]],

         [[15., 19., 23.],
          [16., 20., 24.]]]])

In [26]:
head = xvt @ xvtt
head

tensor([[[[   5.,   17.,   29.],
          [  17.,   61.,  105.],
          [  29.,  105.,  181.]],

         [[  25.,   53.,   81.],
          [  53.,  113.,  173.],
          [  81.,  173.,  265.]]],


        [[[ 365.,  473.,  581.],
          [ 473.,  613.,  753.],
          [ 581.,  753.,  925.]],

         [[ 481.,  605.,  729.],
          [ 605.,  761.,  917.],
          [ 729.,  917., 1105.]]]])

In [30]:
torch.softmax(head, dim=-1)

tensor([[[[3.7751e-11, 6.1442e-06, 9.9999e-01],
          [6.0546e-39, 7.7811e-20, 1.0000e+00],
          [0.0000e+00, 9.8542e-34, 1.0000e+00]],

         [[4.7809e-25, 6.9144e-13, 1.0000e+00],
          [0.0000e+00, 8.7565e-27, 1.0000e+00],
          [0.0000e+00, 1.1089e-40, 1.0000e+00]]],


        [[[0.0000e+00, 0.0000e+00, 1.0000e+00],
          [0.0000e+00, 0.0000e+00, 1.0000e+00],
          [0.0000e+00, 0.0000e+00, 1.0000e+00]],

         [[0.0000e+00, 0.0000e+00, 1.0000e+00],
          [0.0000e+00, 0.0000e+00, 1.0000e+00],
          [0.0000e+00, 0.0000e+00, 1.0000e+00]]]])