In [1]:
# This notebook demonstrates advanced attention mechanisms in PyTorch
import torch

In [2]:
inputs = torch.nn.Embedding( 4, 8 )

In [3]:
inputs = inputs.weight
inputs 

Parameter containing:
tensor([[-1.1612, -1.5821, -3.2273,  0.3622, -0.3453, -0.8032,  2.3792, -2.1280],
        [-0.5797,  0.9732,  0.0234, -1.2803,  1.2088,  0.3766, -0.4586, -0.1949],
        [ 0.1854, -0.5733, -0.3561,  1.3627,  0.7037, -1.2671, -0.9148,  0.3970],
        [ 0.9707,  0.3771, -1.1129, -0.3616, -0.3282,  0.4932, -0.7722, -1.4453]],
       requires_grad=True)

In [4]:
inputs = inputs.data
inputs

tensor([[-1.1612, -1.5821, -3.2273,  0.3622, -0.3453, -0.8032,  2.3792, -2.1280],
        [-0.5797,  0.9732,  0.0234, -1.2803,  1.2088,  0.3766, -0.4586, -0.1949],
        [ 0.1854, -0.5733, -0.3561,  1.3627,  0.7037, -1.2671, -0.9148,  0.3970],
        [ 0.9707,  0.3771, -1.1129, -0.3616, -0.3282,  0.4932, -0.7722, -1.4453]])

In [5]:
# Set dimensions
d_in = 8
d_out = 6
# create weight matrices
W_q = torch.nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
W_k = torch.nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
W_v = torch.nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )

In [6]:
# Choose an input vector and transform it into our query vector using W_q
query = inputs[2] @ W_q
query

tensor([-2.3058, -0.3179, -2.1898,  1.3835,  1.2747, -0.2163],
       grad_fn=<SqueezeBackward4>)

In [7]:
# calculate attention scores using the keys generated by W_k
keys = inputs @ W_k
values = inputs @ W_v
print("Keys:" , keys)
print("Values:" , values)

Keys: tensor([[ 1.7125,  4.7981, -6.4154, -2.7962, -1.7447,  1.8789],
        [ 1.6495, -4.4002,  4.3578,  0.4439,  1.8209, -2.5818],
        [-1.1307,  1.7959, -0.8145,  1.5650, -3.8818, -0.3500],
        [ 3.2970, -2.4296,  1.2361, -1.8643, -0.7973, -1.6809]],
       grad_fn=<MmBackward0>)
Values: tensor([[-4.3350, -3.5431, -4.8430, -9.0924,  3.1350, -0.6456],
        [-0.3700, -4.6019, -1.5630,  0.4270,  2.4005,  0.5838],
        [-0.0802,  5.8840,  1.3016,  0.9462, -0.5162,  0.6040],
        [ 2.9681, -2.7152, -2.0311, -0.6229, -1.1536, -0.7158]],
       grad_fn=<MmBackward0>)


In [8]:
attention_scores = query @ keys.T # query is 1 by 6 and keys is 4 by 6 so we need to transpose keys
attention_scores

tensor([  2.0759,  -8.4537,   1.1126, -12.7687], grad_fn=<SqueezeBackward4>)

In [9]:
attention_weights = torch.softmax( attention_scores / keys.shape[-1]**0.5, dim=-1 ) # the softmax function normalizes the scores
attention_weights

tensor([0.5914, 0.0080, 0.3991, 0.0014], grad_fn=<SoftmaxBackward0>)

In [10]:
attention_weights.sum() # ensure the weights sum to 1

tensor(1.0000, grad_fn=<SumBackward0>)

In [11]:
context_vector = attention_weights @ values # 
context_vector

tensor([-2.5947,  0.2123, -2.3602, -4.9974,  1.6659, -0.1371],
       grad_fn=<SqueezeBackward4>)

In [12]:
import torch.nn as nn

In [13]:
class SimpleAttention( nn.Module ):
    def __init__(self, d_in, d_out):
        super().__init__()
        #create weight matrices
        self.W_q = nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
        self.W_k = nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
        self.W_v = nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )

    # x = embedding vectors (inputs)
    def forward(self, x):
        queries = x @ self.W_q
        keys = x @ self.W_k
        values = x @ self.W_v
        scores = queries @ keys.T
        weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
        context = weights @ values
        return context; 

In [14]:
# use case
# instance of the class
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [15]:
simple.W_k

Parameter containing:
tensor([[-0.8459,  0.8108,  0.5351,  0.9449,  1.6634, -0.0625],
        [ 0.1150,  0.5130, -0.4308, -0.5406, -0.4469,  0.7746],
        [-0.8703, -1.2867, -0.6136,  0.4749, -0.5217,  1.0471],
        [ 0.0512, -0.8173, -1.4908,  1.7152,  0.3392,  1.2870],
        [-0.0671,  1.3477, -1.9594, -1.7509,  0.6701,  1.1448],
        [-0.1593,  1.4666,  1.2015, -0.3072, -0.3401, -1.5671],
        [ 0.3325,  1.1337, -0.4178,  0.6868, -0.5777, -1.9540],
        [ 2.8169,  1.9666,  1.5134, -2.5585, -1.1444,  0.6638]],
       requires_grad=True)

In [16]:
context_vectors = simple( inputs )
context_vectors

tensor([[ -1.6973,  -3.5885,  -1.5028,  -1.2160,   4.6796,  -1.6612],
        [ -0.2407,  -1.0286,   2.1604,  -1.0801,  -1.7408,  -0.6134],
        [  1.3794,   3.5763,   5.0256,  -8.9521, -11.1206,  -6.5928],
        [ -0.9972,  -3.9482,  -3.1072,   2.8289,  -2.4738,  -0.1428]],
       grad_fn=<MmBackward0>)

In [17]:
# second version of the class
# it uses nn.Linear to do things more effectively

class SimpleAttentionv2( nn.Module ):
    def __init__(self, d_in, d_out):
        super().__init__()
        #create weight matrices
        self.W_q = nn.Linear( d_in, d_out, bias=False )
        self.W_k = nn.Linear( d_in, d_out, bias=False )
        self.W_v = nn.Linear( d_in, d_out, bias=False )


    # x = embedding vectors (inputs)
    def forward( self, x ):
        queries = self.W_q( x )
        keys = self.W_k( x )
        values = self.W_v( x )
        scores = queries @ keys.T
        weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
        context = weights @ values
        return context; 

In [18]:
# use case
# instance of the class
simple = SimpleAttentionv2( d_in = 8, d_out = 6 )

In [19]:
context_vectors = simple( inputs )
context_vectors

tensor([[ 0.0775, -0.3132, -0.1022,  0.1780, -0.2306, -0.2533],
        [ 0.1155, -0.3383, -0.0810,  0.2060, -0.2622, -0.3131],
        [ 0.1036, -0.2941, -0.1911,  0.1618, -0.2324, -0.1353],
        [ 0.0893, -0.2930, -0.1805,  0.1577, -0.2208, -0.1400]],
       grad_fn=<MmBackward0>)

In [20]:
# the problem with this is that each context vector uses isnformation from all of thje embedding vectors
# in practice, we should only use information from the previous vectors
# to accomplish this, we'll implement causal attention AKA masked attention
weights = torch.randn( inputs.shape[0], inputs.shape[0] )

In [21]:
weights

tensor([[ 0.8963, -0.8987, -0.6772,  2.1312],
        [-0.0254,  0.3243,  2.0620,  0.2009],
        [-1.6154,  0.2787, -0.9579, -1.4212],
        [ 1.6788,  0.1325,  0.0710, -1.3858]])

In [22]:
weights.sum( dim=-1 )

tensor([ 1.4515,  2.5617, -3.7158,  0.4965])

In [23]:
# torch.tril?
simple_mask = torch.tril( torch.ones( weights.shape[0], weights.shape[0] ) )
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [24]:
masked_weights = weights*simple_mask
masked_weights

tensor([[ 0.8963, -0.0000, -0.0000,  0.0000],
        [-0.0254,  0.3243,  0.0000,  0.0000],
        [-1.6154,  0.2787, -0.9579, -0.0000],
        [ 1.6788,  0.1325,  0.0710, -1.3858]])

In [25]:
masked_weights.sum( dim=-1 )

tensor([ 0.8963,  0.2988, -2.2946,  0.4965])

In [26]:
# now, we need to normalize the masked weights so that they sum to 1
row_sums = masked_weights.sum( dim=-1, keepdim=True )
row_sums

tensor([[ 0.8963],
        [ 0.2988],
        [-2.2946],
        [ 0.4965]])

In [27]:
masked_weights = masked_weights / row_sums
masked_weights

tensor([[ 1.0000, -0.0000, -0.0000,  0.0000],
        [-0.0851,  1.0851,  0.0000,  0.0000],
        [ 0.7040, -0.1215,  0.4175,  0.0000],
        [ 3.3815,  0.2668,  0.1431, -2.7914]])

In [28]:
# masking mehod #2
# torch.triu?
mask = torch.triu( torch.ones(weights.shape[0], weights.shape[0]), diagonal=1 )
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [29]:
mask.bool()

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [30]:
weights

tensor([[ 0.8963, -0.8987, -0.6772,  2.1312],
        [-0.0254,  0.3243,  2.0620,  0.2009],
        [-1.6154,  0.2787, -0.9579, -1.4212],
        [ 1.6788,  0.1325,  0.0710, -1.3858]])

In [31]:
weights = weights.masked_fill( mask.bool(), -torch.inf)
weights

tensor([[ 0.8963,    -inf,    -inf,    -inf],
        [-0.0254,  0.3243,    -inf,    -inf],
        [-1.6154,  0.2787, -0.9579,    -inf],
        [ 1.6788,  0.1325,  0.0710, -1.3858]])

In [32]:
masked_weights = torch.softmax( weights, dim=-1 )
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4135, 0.5865, 0.0000, 0.0000],
        [0.1044, 0.6941, 0.2015, 0.0000],
        [0.6849, 0.1459, 0.1372, 0.0320]])

In [33]:
masked_weights.sum( dim=-1 )

tensor([1.0000, 1.0000, 1.0000, 1.0000])

In [34]:
# Dropout 
# Dropout is a regularization technique used to prevent overfitting in neural networks.
# It works by randomly setting a fraction of input units to zero at each update during training time,
# which helps to break up happenstance correlations in the training data.
dropout = nn.Dropout( 0.5 ) # 50% dropout rate


In [35]:
dropout( inputs )

tensor([[-2.3225, -0.0000, -6.4547,  0.0000, -0.0000, -0.0000,  0.0000, -4.2560],
        [-0.0000,  1.9463,  0.0000, -2.5606,  0.0000,  0.7532, -0.9171, -0.3897],
        [ 0.0000, -1.1466, -0.7122,  2.7253,  0.0000, -2.5342, -0.0000,  0.0000],
        [ 1.9414,  0.0000, -2.2259, -0.7233, -0.0000,  0.0000, -1.5445, -2.8906]])

In [36]:
# We need to be able to give our LLM vbatches of input.
# For example:
batches = torch.stack((inputs, inputs), dim = 0 )


In [37]:
# torch.stack?
batches.shape

torch.Size([2, 4, 8])

In [55]:
# this class needs to hande batches of input


class CausalAttention( nn.Module ):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        #create weight matrices
        self.W_q = nn.Linear( d_in, d_out, bias=False )
        self.W_k = nn.Linear( d_in, d_out, bias=False )
        self.W_v = nn.Linear( d_in, d_out, bias=False )
        # include dropout:
        self.dropout = nn.Dropout( dropout )
        # use the following to manage memory effeciently
        self.register_buffer("mask", torch.triu( torch.ones(context_length, context_length), diagonal = 1 ))
        


    # x = embedding vectors (inputs)
    def forward( self, x ):
        b, num_tokens, d_in = x.shape
        # b = batch size 

        queries = self.W_q( x )
        keys = self.W_k( x )
        values = self.W_v( x )
        scores = queries @ keys.transpose(1, 2)
        scores.masked_fill_( self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
        weights = self.dropout( weights )
        context = weights @ values
        return context

In [56]:
# instiantiate a causal attention mechanism:
causal = CausalAttention( d_in = 8, d_out = 6, context_length = 4, dropout = 0 )

In [57]:
causal( batches )

tensor([[[ 1.3831, -1.2413, -0.1268, -0.4254,  1.2755, -1.2354],
         [ 0.4137, -0.3700, -0.3301, -0.3596,  0.8100, -0.5308],
         [ 0.1762,  0.0235, -0.1443, -0.4402,  0.7742, -0.3843],
         [ 0.2175, -0.3614,  0.1000, -0.4856,  0.6779, -0.6570]],

        [[ 1.3831, -1.2413, -0.1268, -0.4254,  1.2755, -1.2354],
         [ 0.4137, -0.3700, -0.3301, -0.3596,  0.8100, -0.5308],
         [ 0.1762,  0.0235, -0.1443, -0.4402,  0.7742, -0.3843],
         [ 0.2175, -0.3614,  0.1000, -0.4856,  0.6779, -0.6570]]],
       grad_fn=<UnsafeViewBackward0>)

In [58]:
W_q = nn.Linear( d_in, d_out, bias=False )
W_k = nn.Linear( d_in, d_out, bias=False )
W_v = nn.Linear( d_in, d_out, bias=False )

In [59]:
queries = W_q( batches )
queries 

tensor([[[-0.4746,  0.4328, -0.0162,  1.2288,  0.3441, -0.3742],
         [-0.2100, -0.7356,  0.3681,  0.7959,  0.2213, -0.0428],
         [-0.2277,  0.1888, -0.9839, -0.3556,  0.9933, -0.4919],
         [ 0.5435,  0.1198,  0.3716,  0.1646,  0.3668, -0.1627]],

        [[-0.4746,  0.4328, -0.0162,  1.2288,  0.3441, -0.3742],
         [-0.2100, -0.7356,  0.3681,  0.7959,  0.2213, -0.0428],
         [-0.2277,  0.1888, -0.9839, -0.3556,  0.9933, -0.4919],
         [ 0.5435,  0.1198,  0.3716,  0.1646,  0.3668, -0.1627]]],
       grad_fn=<UnsafeViewBackward0>)

In [46]:
keys = W_k( batches )
keys 

tensor([[[-0.6473, -0.2471, -1.1271, -0.1390,  0.6564,  1.7676],
         [ 0.4730, -0.0504, -0.2020,  0.2622, -0.5492,  0.5974],
         [-0.3610,  0.2776, -0.6320,  0.4486,  0.1768, -0.2313],
         [-0.0405, -0.0749,  0.4225, -0.1620, -0.2219, -0.1709]],

        [[-0.6473, -0.2471, -1.1271, -0.1390,  0.6564,  1.7676],
         [ 0.4730, -0.0504, -0.2020,  0.2622, -0.5492,  0.5974],
         [-0.3610,  0.2776, -0.6320,  0.4486,  0.1768, -0.2313],
         [-0.0405, -0.0749,  0.4225, -0.1620, -0.2219, -0.1709]]],
       grad_fn=<UnsafeViewBackward0>)

In [None]:
# shows the transpose of keys

# keys.T