In [15]:
# This notebook demonstrates advanced attention mechanisms in PyTorch
import torch

In [16]:
inputs = torch.nn.Embedding( 4, 8 )

In [17]:
inputs = inputs.weight
inputs 

Parameter containing:
tensor([[-0.1262, -1.0226,  0.7768, -0.1428,  0.9783,  0.9666,  0.9062,  0.1996],
        [-0.7170,  0.1437, -0.2646, -0.4451, -0.6696,  0.2475, -0.8669, -2.3347],
        [-1.6918, -1.4278,  1.4860, -0.1009, -0.2374, -1.1133,  0.3374, -1.0927],
        [ 0.4969,  0.2433,  0.1658, -1.3663,  1.5591,  1.7361, -1.3396, -0.2387]],
       requires_grad=True)

In [18]:
inputs = inputs.data
inputs

tensor([[-0.1262, -1.0226,  0.7768, -0.1428,  0.9783,  0.9666,  0.9062,  0.1996],
        [-0.7170,  0.1437, -0.2646, -0.4451, -0.6696,  0.2475, -0.8669, -2.3347],
        [-1.6918, -1.4278,  1.4860, -0.1009, -0.2374, -1.1133,  0.3374, -1.0927],
        [ 0.4969,  0.2433,  0.1658, -1.3663,  1.5591,  1.7361, -1.3396, -0.2387]])

In [19]:
# Set dimensions
d_in = 8
d_out = 6
# create weight matrices
W_q = torch.nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
W_k = torch.nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
W_v = torch.nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )

In [20]:
# Choose an input vector and transform it into our query vector using W_q
query = inputs[2] @ W_q
query

tensor([-2.3388, -1.1604,  2.4655,  1.2567,  0.4330, -3.0018],
       grad_fn=<SqueezeBackward4>)

In [21]:
# calculate attention scores using the keys generated by W_k
keys = inputs @ W_k
values = inputs @ W_v
print("Keys:" , keys)
print("Values:" , values)

Keys: tensor([[ 0.6402,  0.0282,  0.8551,  2.3008, -0.4057,  5.1147],
        [ 0.2902, -2.0622, -0.6138,  1.2154,  2.8986,  0.1027],
        [ 2.7596,  3.5080,  0.9353,  5.3671, -1.2710,  6.7178],
        [-0.0876, -5.9122, -1.5527,  2.6520,  5.8264, -0.2949]],
       grad_fn=<MmBackward0>)
Values: tensor([[-0.7715, -2.0763,  0.0101, -2.1409, -3.2899, -0.9013],
        [ 4.2392,  2.7554,  3.9118,  4.4841,  0.8278,  2.2920],
        [ 0.0136,  1.9941,  2.0762,  0.7504,  1.7398,  1.9602],
        [ 6.8080, -3.3512,  0.2653,  3.0179, -5.1085, -0.2794]],
       grad_fn=<MmBackward0>)


In [22]:
attention_scores = query @ keys.T # query is 1 by 6 and keys is 4 by 6 so we need to transpose keys
attention_scores

tensor([-12.0596,   2.6752, -22.1900,   9.9778], grad_fn=<SqueezeBackward4>)

In [23]:
attention_weights = torch.softmax( attention_scores / keys.shape[-1]**0.5, dim=-1 ) # the softmax function normalizes the scores
attention_weights

tensor([1.1782e-04, 4.8273e-02, 1.8841e-06, 9.5161e-01],
       grad_fn=<SoftmaxBackward0>)

In [24]:
attention_weights.sum() # ensure the weights sum to 1

tensor(1.0000, grad_fn=<SumBackward0>)

In [25]:
context_vector = attention_weights @ values # 
context_vector

tensor([ 6.6831, -3.0563,  0.4413,  3.0880, -4.8217, -0.1553],
       grad_fn=<SqueezeBackward4>)

In [27]:
import torch.nn as nn

In [None]:
class SimpleAttention( nn.Module ):
    def __init__(self, d_in, d_out):
        super().__init__()
        #create weight matrices
        self.W_q = nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
        self.W_k = nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
        self.W_v = nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )

    # x = embedding vectors (inputs)
    def forward(self, x):
        queries = x @ self.W_q
        keys = x @ self.W_k
        values = x @ self.W_v
        scores = queries @ keys.T
        weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
        context = weights @ values
        return context; 

In [38]:
# use case
# instance of the class
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [39]:
simple.W_k

Parameter containing:
tensor([[ 1.3063, -1.1031,  1.4454,  0.5911,  0.0256,  1.5638],
        [-0.5855,  1.2256,  0.3116,  0.2306,  1.2067, -0.3640],
        [-0.5666, -0.6590, -0.7790,  0.8071, -0.2022, -0.6651],
        [ 0.5847, -0.8925, -1.1532,  0.7335, -1.1306,  0.1183],
        [ 0.5553,  1.5099, -2.5289,  0.5434, -0.2796, -1.5157],
        [ 0.7395, -0.4019, -1.0508, -2.6208,  0.6157,  0.3025],
        [-0.0784,  0.0280,  1.2418, -2.1519,  0.7551,  1.1260],
        [ 1.3946,  1.5690,  1.7002,  1.8571, -1.6960, -1.5818]],
       requires_grad=True)

In [40]:
context_vectors = simple( inputs )
context_vectors

tensor([[ 0.2049,  2.6990,  0.5318, -2.4008,  2.2160, -0.4734],
        [-0.7376,  3.8528,  0.5971, -2.9316,  3.4040,  0.2830],
        [ 0.0285,  3.3906,  0.6603, -2.7588,  2.4408, -0.3328],
        [-0.7097,  3.8306,  0.6033, -2.9177,  3.3758,  0.2491]],
       grad_fn=<MmBackward0>)

In [42]:
# second version of the class
# it uses nn.Linear to do things more effectively

class SimpleAttentionv2( nn.Module ):
    def __init__(self, d_in, d_out):
        super().__init__()
        #create weight matrices
        self.W_q = nn.Linear( d_in, d_out, bias=False )
        self.W_k = nn.Linear( d_in, d_out, bias=False )
        self.W_v = nn.Linear( d_in, d_out, bias=False )


    # x = embedding vectors (inputs)
    def forward( self, x ):
        queries = self.W_q( x )
        keys = self.W_k( x )
        values = self.W_v( x )
        scores = queries @ keys.T
        weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
        context = weights @ values
        return context; 

In [43]:
# use case
# instance of the class
simple = SimpleAttentionv2( d_in = 8, d_out = 6 )

In [44]:
context_vectors = simple( inputs )
context_vectors

tensor([[-0.2415, -0.5311,  0.6949,  0.1416, -0.1356,  0.3026],
        [-0.2418, -0.5530,  0.7464,  0.1045, -0.0824,  0.3260],
        [-0.2667, -0.5466,  0.7735,  0.0954, -0.0966,  0.3628],
        [-0.1930, -0.5313,  0.6326,  0.1641, -0.1389,  0.2315]],
       grad_fn=<MmBackward0>)