In [1]:
import torch
import torch.nn as nn

In [2]:
torch.manual_seed(123)

<torch._C.Generator at 0x7c6c97ff0e30>

In [3]:
inputs = torch.tensor([[0.43, 0.15, 0.89],    
         [0.55, 0.87, 0.66],    
         [0.57, 0.85, 0.64],    
         [0.22, 0.58, 0.33],    
         [0.77, 0.25, 0.10],    
         [0.05, 0.80, 0.55]])

In [4]:
query_weights = torch.tensor([[-0.2354,  0.0191, -0.2867],
        [ 0.2177, -0.4919,  0.4232],
        [-0.4196, -0.4590, -0.3648]])

key_weights = torch.tensor([[ 0.2615, -0.2133,  0.2161],
        [-0.4900, -0.3503, -0.2120],
        [-0.1135, -0.4404,  0.3780]])

value_weights = torch.tensor([[-0.1362,  0.1853,  0.4083],
        [ 0.1076,  0.1579,  0.5573],
        [-0.2604,  0.1829, -0.2569]])

In [5]:
queries = torch.matmul(inputs, query_weights.T)
queries

tensor([[-0.3535,  0.3965, -0.5740],
        [-0.3021, -0.0289, -0.8709],
        [-0.3014, -0.0232, -0.8628],
        [-0.1353, -0.0978, -0.4789],
        [-0.2052,  0.0870, -0.4743],
        [-0.1542, -0.1499, -0.5888]])

In [6]:
keys = torch.matmul(inputs, key_weights.T)
keys

tensor([[ 0.2728, -0.4519,  0.2216],
        [ 0.1009, -0.7142, -0.1961],
        [ 0.1061, -0.7127, -0.1971],
        [ 0.0051, -0.3809, -0.1557],
        [ 0.1696, -0.4861, -0.1597],
        [-0.0387, -0.4213, -0.1501]])

In [7]:
attention_weights: torch.Tensor = torch.matmul(
            queries,
            keys.transpose(0, 1)
        )

attention_weights

tensor([[-0.4028, -0.2063, -0.2069, -0.0635, -0.1610, -0.0672],
        [-0.2623,  0.1609,  0.1602,  0.1450,  0.1019,  0.1546],
        [-0.2629,  0.1553,  0.1546,  0.1416,  0.0979,  0.1509],
        [-0.0988,  0.1501,  0.1497,  0.1111,  0.1010,  0.1183],
        [-0.2004,  0.0102,  0.0097,  0.0397, -0.0013,  0.0425],
        [-0.1048,  0.2069,  0.2065,  0.1480,  0.1407,  0.1575]])

In [8]:
attention_weights = torch.softmax(
            attention_weights /
            torch.sqrt(
                torch.tensor(
                    keys.shape[1]
                )
            ),
            dim=1
        )

print(attention_weights)
attention_weights.sum(dim=1)

tensor([[0.1466, 0.1642, 0.1642, 0.1784, 0.1686, 0.1780],
        [0.1365, 0.1743, 0.1743, 0.1727, 0.1685, 0.1737],
        [0.1368, 0.1742, 0.1741, 0.1728, 0.1685, 0.1737],
        [0.1494, 0.1725, 0.1725, 0.1686, 0.1677, 0.1694],
        [0.1497, 0.1691, 0.1690, 0.1720, 0.1680, 0.1723],
        [0.1456, 0.1743, 0.1743, 0.1685, 0.1678, 0.1694]])


tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

In [9]:
values = torch.matmul(inputs, value_weights.T)
values

tensor([[ 0.3326,  0.5659, -0.3132],
        [ 0.3558,  0.5644, -0.1537],
        [ 0.3412,  0.5522, -0.1574],
        [ 0.2122,  0.2992, -0.0360],
        [-0.0177,  0.1781, -0.1805],
        [ 0.3660,  0.4382, -0.0080]])

In [11]:
context = torch.matmul(attention_weights, values)
context

tensor([[ 0.2632,  0.4277, -0.1353],
        [ 0.2641,  0.4297, -0.1350],
        [ 0.2641,  0.4296, -0.1350],
        [ 0.2647,  0.4316, -0.1381],
        [ 0.2642,  0.4303, -0.1373],
        [ 0.2647,  0.4316, -0.1375]])