In [1]:
# This notebook demonstrates advanced attention mechanisms in PyTorch
import torch

In [2]:
inputs = torch.nn.Embedding( 4, 8 )

In [3]:
inputs = inputs.weight
inputs 

Parameter containing:
tensor([[ 0.4968,  0.6926, -0.4238,  1.7006,  0.8141, -1.3960, -0.2556,  0.7842],
        [ 1.1958, -0.4735,  2.7358,  0.9718,  0.2161,  0.9777,  0.2735,  0.3263],
        [-0.2885,  0.7681, -0.9142,  1.6997, -0.3183,  0.9168,  1.3026,  1.1396],
        [-0.0515, -0.7303, -1.1646,  0.7935,  1.5031, -1.0235, -1.5992, -1.2287]],
       requires_grad=True)

In [4]:
inputs = inputs.data
inputs

tensor([[ 0.4968,  0.6926, -0.4238,  1.7006,  0.8141, -1.3960, -0.2556,  0.7842],
        [ 1.1958, -0.4735,  2.7358,  0.9718,  0.2161,  0.9777,  0.2735,  0.3263],
        [-0.2885,  0.7681, -0.9142,  1.6997, -0.3183,  0.9168,  1.3026,  1.1396],
        [-0.0515, -0.7303, -1.1646,  0.7935,  1.5031, -1.0235, -1.5992, -1.2287]])

In [5]:
# Set dimensions
d_in = 8
d_out = 6
# create weight matrices
W_q = torch.nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
W_k = torch.nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
W_v = torch.nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )

In [6]:
# Choose an input vector and transform it into our query vector using W_q
query = inputs[2] @ W_q
query

tensor([ 0.2193, -0.9882,  4.0323,  0.0244, -0.5920,  1.6198],
       grad_fn=<SqueezeBackward4>)

In [7]:
# calculate attention scores using the keys generated by W_k
keys = inputs @ W_k
values = inputs @ W_v
print("Keys:" , keys)
print("Values:" , values)

Keys: tensor([[ 0.2674, -6.3490,  1.4153, -3.2161,  3.5052, -0.8679],
        [ 0.1761,  3.8460, -5.1995,  3.9411,  6.8248,  2.9961],
        [-3.9272, -1.7813,  0.9085, -0.8080, -0.2954, -2.5784],
        [-2.3890, -4.8283,  2.2408, -0.4120,  1.2446,  2.2104]],
       grad_fn=<MmBackward0>)
Values: tensor([[-3.2175,  1.3969,  1.8347, -6.6828, -2.0814, -3.4128],
        [ 1.2381,  3.1518, -1.1033, -1.8141,  2.7603,  3.0491],
        [ 2.4455,  1.9790, -1.1088, -3.3406, -2.5717, -1.4244],
        [-7.0287, -0.4542, -0.0826, -2.3696,  0.7345, -3.2982]],
       grad_fn=<MmBackward0>)


In [8]:
attention_scores = query @ keys.T # query is 1 by 6 and keys is 4 by 6 so we need to transpose keys
attention_scores

tensor([  8.4802, -23.8185,   0.5411,  16.1166], grad_fn=<SqueezeBackward4>)

In [9]:
attention_weights = torch.softmax( attention_scores / keys.shape[-1]**0.5, dim=-1 ) # the softmax function normalizes the scores
attention_weights

tensor([4.2320e-02, 7.9428e-08, 1.6555e-03, 9.5602e-01],
       grad_fn=<SoftmaxBackward0>)

In [10]:
attention_weights.sum() # ensure the weights sum to 1

tensor(1.0000, grad_fn=<SumBackward0>)

In [11]:
context_vector = attention_weights @ values # 
context_vector

tensor([-6.8517e+00, -3.7180e-01, -3.1472e-03, -2.5538e+00,  6.0983e-01,
        -3.2999e+00], grad_fn=<SqueezeBackward4>)

In [12]:
import torch.nn as nn

In [13]:
class SimpleAttention( nn.Module ):
    def __init__(self, d_in, d_out):
        super().__init__()
        #create weight matrices
        self.W_q = nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
        self.W_k = nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )
        self.W_v = nn.Parameter( torch.randn( (d_in, d_out), requires_grad=False ) )

    # x = embedding vectors (inputs)
    def forward(self, x):
        queries = x @ self.W_q
        keys = x @ self.W_k
        values = x @ self.W_v
        scores = queries @ keys.T
        weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
        context = weights @ values
        return context; 

In [14]:
# use case
# instance of the class
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [15]:
simple.W_k

Parameter containing:
tensor([[-0.0579,  1.1436, -0.0773,  0.1684, -0.0646, -0.0638],
        [ 0.5471,  0.4694, -0.2080,  0.6245,  0.1008, -2.5456],
        [ 0.3322,  0.6964,  0.1910, -1.0184, -0.0323, -0.0416],
        [-0.8376,  1.0151, -0.6779, -0.3007, -1.6688,  0.9362],
        [ 0.5515, -0.4291,  0.8234, -1.7246, -0.5022, -0.2536],
        [ 0.3747, -0.5254,  0.1214,  1.1081,  0.0311, -0.0906],
        [ 0.2871, -0.4487, -0.3190, -2.1846,  1.1565, -0.4737],
        [-0.7714, -0.1812, -0.9762,  0.3434,  0.7973,  0.8360]],
       requires_grad=True)

In [16]:
context_vectors = simple( inputs )
context_vectors

tensor([[ 0.8264,  7.0638,  1.7128, -1.2320,  0.4152, -2.8836],
        [ 5.5971,  2.6917,  6.6572,  0.4358, -3.1679,  0.3786],
        [ 0.6389,  6.7151,  1.7218, -1.3439,  0.3465, -2.8945],
        [-2.7950,  2.3704, -0.5065, -3.8887,  1.1537, -1.4473]],
       grad_fn=<MmBackward0>)

In [28]:
# second version of the class
# it uses nn.Linear to do things more effectively

class SimpleAttentionv2( nn.Module ):
    def __init__(self, d_in, d_out):
        super().__init__()
        #create weight matrices
        self.W_q = nn.Linear( d_in, d_out, bias=False )
        self.W_k = nn.Linear( d_in, d_out, bias=False )
        self.W_v = nn.Linear( d_in, d_out, bias=False )


    # x = embedding vectors (inputs)
    def forward( self, x ):
        queries = self.W_q( x )
        keys = self.W_k( x )
        values = self.W_v( x )
        scores = queries @ keys.T
        weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
        context = weights @ values
        return context; 

In [29]:
# use case
# instance of the class
simple = SimpleAttentionv2( d_in = 8, d_out = 6 )

In [30]:
context_vectors = simple( inputs )
context_vectors

tensor([[-0.2607,  0.2483,  0.3032,  0.3527, -0.4821, -0.0535],
        [-0.0988,  0.1871,  0.1946,  0.4106, -0.3638,  0.0158],
        [-0.3057,  0.2415,  0.3584,  0.2557, -0.5409, -0.1973],
        [-0.1710,  0.1838,  0.2501,  0.3458, -0.5099, -0.1348]],
       grad_fn=<MmBackward0>)

In [34]:
# the problem with this is that each context vector uses isnformation from all of thje embedding vectors
# in practice, we should only use information from the previous vectors
# to accomplish this, we'll implement causal attention AKA masked attention
weights = torch.randn( inputs.shape[0], inputs.shape[0] )

In [36]:
weights

tensor([[-1.3369,  0.7203,  0.6539, -0.3114],
        [-0.4438, -1.1245, -0.7097,  1.1474],
        [ 0.3388, -0.3955, -0.5225,  0.1182],
        [-0.0090,  0.5231, -0.8770,  0.2381]])

In [38]:
weights.sum( dim=-1 )

tensor([-0.2740, -1.1307, -0.4610, -0.1249])

In [39]:
# torch.tril?
simple_mask = torch.tril( torch.ones( weights.shape[0], weights.shape[0] ) )
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [40]:
masked_weights = weights*simple_mask
masked_weights

tensor([[-1.3369,  0.0000,  0.0000, -0.0000],
        [-0.4438, -1.1245, -0.0000,  0.0000],
        [ 0.3388, -0.3955, -0.5225,  0.0000],
        [-0.0090,  0.5231, -0.8770,  0.2381]])

In [41]:
masked_weights.sum( dim=-1 )

tensor([-1.3369, -1.5683, -0.5792, -0.1249])

In [42]:
# now, we need to normalize the masked weights so that they sum to 1
row_sums = masked_weights.sum( dim=-1, keepdim=True )
row_sums

tensor([[-1.3369],
        [-1.5683],
        [-0.5792],
        [-0.1249]])

In [43]:
masked_weights = masked_weights / row_sums
masked_weights

tensor([[ 1.0000, -0.0000, -0.0000,  0.0000],
        [ 0.2830,  0.7170,  0.0000, -0.0000],
        [-0.5849,  0.6828,  0.9020, -0.0000],
        [ 0.0722, -4.1888,  7.0229, -1.9063]])

In [46]:
# masking mehod #2
# torch.triu?
mask = torch.triu( torch.ones(weights.shape[0], weights.shape[0]), diagonal=1 )
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [47]:
mask.bool()

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [49]:
weights

tensor([[-1.3369,  0.7203,  0.6539, -0.3114],
        [-0.4438, -1.1245, -0.7097,  1.1474],
        [ 0.3388, -0.3955, -0.5225,  0.1182],
        [-0.0090,  0.5231, -0.8770,  0.2381]])

In [50]:
weights = weights.masked_fill( mask.bool(), -torch.inf)
weights

tensor([[-1.3369,    -inf,    -inf,    -inf],
        [-0.4438, -1.1245,    -inf,    -inf],
        [ 0.3388, -0.3955, -0.5225,    -inf],
        [-0.0090,  0.5231, -0.8770,  0.2381]])

In [51]:
masked_weights = torch.softmax( weights, dim=-1 )
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.6639, 0.3361, 0.0000, 0.0000],
        [0.5256, 0.2522, 0.2222, 0.0000],
        [0.2271, 0.3867, 0.0954, 0.2908]])

In [53]:
masked_weights.sum( dim=-1 )

tensor([1.0000, 1.0000, 1.0000, 1.0000])

In [62]:
# Dropout 
# Dropout is a regularization technique used to prevent overfitting in neural networks.
# It works by randomly setting a fraction of input units to zero at each update during training time,
# which helps to break up happenstance correlations in the training data.
dropout = nn.Dropout( 0.5 ) # 50% dropout rate


In [75]:
dropout( inputs )

tensor([[ 0.0000,  1.3852, -0.8477,  3.4012,  0.0000, -2.7921, -0.0000,  0.0000],
        [ 0.0000, -0.0000,  0.0000,  1.9436,  0.4322,  0.0000,  0.0000,  0.0000],
        [-0.0000,  0.0000, -0.0000,  3.3994, -0.0000,  1.8336,  0.0000,  2.2793],
        [-0.1031, -1.4606, -2.3293,  0.0000,  0.0000, -0.0000, -3.1984, -2.4574]])

In [60]:
# We need to be able to give our LLM vbatches of input.
# For example:
batches = torch.stack((inputs, inputs), dim = 0 )


In [61]:
# torch.stack?
batches.shape

torch.Size([2, 4, 8])

In [None]:
# this class needs to hande batches of input


class CausalAttention( nn.Module ):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        #create weight matrices
        self.W_q = nn.Linear( d_in, d_out, bias=False )
        self.W_k = nn.Linear( d_in, d_out, bias=False )
        self.W_v = nn.Linear( d_in, d_out, bias=False )
        self.dropout = nn.Dropout( dropout )


    # x = embedding vectors (inputs)
    def forward( self, x ):
        queries = self.W_q( x )
        keys = self.W_k( x )
        values = self.W_v( x )
        scores = queries @ keys.T
        weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
        context = weights @ values
        return context; 

In [55]:
super?

[31mInit signature:[39m super(self, /, *args, **kwargs)
[31mDocstring:[39m     
super() -> same as super(__class__, <first argument>)
super(type) -> unbound super object
super(type, obj) -> bound super object; requires isinstance(obj, type)
super(type, type2) -> bound super object; requires issubclass(type2, type)
Typical use to call a cooperative superclass method:
class C(B):
    def meth(self, arg):
        super().meth(arg)
This works for class methods too:
class C(B):
    @classmethod
    def cmeth(cls, arg):
        super().cmeth(arg)
[31mType:[39m           type
[31mSubclasses:[39m     