<a href="https://colab.research.google.com/github/mbrudd/LLMs/blob/main/trainable_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Attention with trainable weights

In [1]:
import torch

In [2]:
inputs = torch.nn.Embedding( 4, 8 )

In [3]:
inputs = inputs.weight
inputs

Parameter containing:
tensor([[ 0.5230, -1.0122,  0.5992, -0.2435, -0.6084, -0.4282,  1.6621,  1.3989],
        [-0.9138, -0.6302,  0.3323, -1.3207, -0.9561, -0.4106,  0.2507, -0.2165],
        [ 0.2848, -0.8262, -0.4451, -1.5865, -0.2796,  0.0701,  1.4476,  0.6383],
        [ 2.1760,  0.8191, -0.5692,  0.8127,  0.9198,  0.2796, -2.0018,  1.1478]],
       requires_grad=True)

In [4]:
inputs = inputs.data
inputs

tensor([[ 0.5230, -1.0122,  0.5992, -0.2435, -0.6084, -0.4282,  1.6621,  1.3989],
        [-0.9138, -0.6302,  0.3323, -1.3207, -0.9561, -0.4106,  0.2507, -0.2165],
        [ 0.2848, -0.8262, -0.4451, -1.5865, -0.2796,  0.0701,  1.4476,  0.6383],
        [ 2.1760,  0.8191, -0.5692,  0.8127,  0.9198,  0.2796, -2.0018,  1.1478]])

In [32]:
# set dimensions
d_in = 8
d_out = 6

# create weight matrices
W_q = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_k = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_v = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

In [33]:
# choose an input vector and transform it into our query vector using W_q
query = inputs[2] @ W_q
query

tensor([-0.8374, -0.9581, -0.2278,  0.9000, -1.9623, -1.4585])

In [34]:
# calculate attention scores using the keys generated by W_k:
keys = inputs @ W_k
values = inputs @ W_v
print("Keys:", keys)
print("Values:", values )

Keys: tensor([[ 0.5973,  1.0044,  0.8443,  1.2623,  0.8390,  1.1659],
        [-2.0393, -2.2555, -1.5301, -2.0293, -1.5247, -1.6047],
        [-0.7451,  0.3819, -0.3960,  0.0507, -0.3498,  0.3000],
        [ 2.3991,  1.8205,  2.1116,  2.6182,  1.7647,  1.2838]])
Values: tensor([[-0.1780,  1.0546,  0.0233, -0.0530,  2.4320, -0.1252],
        [-2.8279, -1.6676, -1.9344, -2.2037, -1.6624, -2.3033],
        [-1.8483, -0.0499, -1.4484, -1.2959,  0.9830, -0.6364],
        [ 2.7270,  1.7948,  0.9515,  2.9547,  1.4034,  3.0973]])


In [35]:
attention_scores = query @ keys.T
attention_scores

tensor([-3.8656,  7.7232,  0.6426, -7.2132])

In [36]:
attention_weights = torch.softmax( attention_scores / keys.shape[-1]**0.5, dim = -1 )
attention_weights

tensor([0.0083, 0.9376, 0.0521, 0.0021])

In [37]:
attention_weights.sum()

tensor(1.0000)

In [38]:
context_vector = attention_weights @ values
context_vector

tensor([-2.7433, -1.5535, -1.8868, -2.1278, -1.4844, -2.1871])

In [39]:
import torch.nn as nn


In [40]:
# here's a first version of a SimpleAttention class:

class SimpleAttention( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_k = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_v = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = x @ self.W_q
    keys = x @ self.W_k
    values = x @ self.W_v
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context

In [41]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [42]:
simple.W_v

Parameter containing:
tensor([[0.7287, 0.5747, 0.2106, 0.4504, 0.4743, 0.4702],
        [0.1476, 0.5786, 0.3352, 0.0245, 0.5461, 0.7167],
        [0.3128, 0.0113, 0.8220, 0.5390, 0.3141, 0.9862],
        [0.8924, 0.0380, 0.2893, 0.5757, 0.9561, 0.8739],
        [0.1583, 0.0510, 0.8964, 0.8533, 0.6959, 0.1652],
        [0.1750, 0.0867, 0.7179, 0.8085, 0.5043, 0.2204],
        [0.7321, 0.7470, 0.2207, 0.3066, 0.3434, 0.6393],
        [0.2751, 0.8343, 0.7991, 0.8217, 0.0969, 0.2075]])

In [43]:
context_vectors = simple( inputs )
context_vectors

tensor([[ 1.5050,  1.8708,  1.0198,  1.4026,  0.1885,  0.9595],
        [-1.9265, -1.0054, -1.7768, -2.2492, -2.7388, -1.8386],
        [-1.5324, -0.5458, -1.4648, -1.8090, -2.4616, -1.6337],
        [ 1.2995,  1.2851,  1.9965,  2.4966,  2.2737,  0.9313]])

In [78]:
# here's a second version of a SimpleAttention class ;
# it uses nn.Linear to do things more efficiently

class SimpleAttention( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )
    scores = queries @ keys.T

    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context

In [79]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [80]:
context_vectors = simple( inputs )
context_vectors

tensor([[ 0.1243,  0.1402, -0.2909, -0.2067, -0.0791, -0.0674],
        [ 0.0919,  0.1422, -0.2992, -0.2242, -0.0921, -0.0729],
        [-0.0217,  0.1702, -0.3123, -0.2852, -0.1456, -0.0821],
        [ 0.3033,  0.1649, -0.2265, -0.0965, -0.0174, -0.0170]],
       grad_fn=<MmBackward0>)

In [81]:
# the problem with this is that each context vector uses information from ALL of the embedding vectors
# in practice, we should only use information about the preceding embedding vectors
# to accomplish this, we'll implement causal attention AKA masked attention

In [82]:
# this is a hack to get some example weights to work with!
weights = simple( inputs )
weights

tensor([[ 0.1243,  0.1402, -0.2909, -0.2067, -0.0791, -0.0674],
        [ 0.0919,  0.1422, -0.2992, -0.2242, -0.0921, -0.0729],
        [-0.0217,  0.1702, -0.3123, -0.2852, -0.1456, -0.0821],
        [ 0.3033,  0.1649, -0.2265, -0.0965, -0.0174, -0.0170]],
       grad_fn=<MmBackward0>)

In [83]:
# note that these have already been normalized:
weights.sum( dim=-1 )

tensor([-0.3796, -0.4543, -0.6767,  0.1107], grad_fn=<SumBackward1>)

In [84]:
# masking method #1
simple_mask = torch.tril( torch.ones( weights.shape[0], weights.shape[0] ) )
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [85]:
masked_weights = weights*simple_mask
masked_weights

RuntimeError: The size of tensor a (6) must match the size of tensor b (4) at non-singleton dimension 1

In [77]:
masked_weights.sum( dim=-1 )

NameError: name 'masked_weights' is not defined

In [53]:
# now, we need to normalize the masked_weights so that each row has sum 1
row_sums = masked_weights.sum( dim=-1, keepdim=True)
row_sums

NameError: name 'masked_weights' is not defined

In [None]:
masked_weights = masked_weights / row_sums
masked_weights.sum( dim=-1)

tensor([1., 1., 1., 1.], grad_fn=<SumBackward1>)

In [None]:
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5437, 0.4563, 0.0000, 0.0000],
        [0.2536, 0.3010, 0.4455, 0.0000],
        [0.2878, 0.2422, 0.2345, 0.2355]], grad_fn=<DivBackward0>)

In [None]:
# masking method #2
mask = torch.triu( torch.ones(weights.shape[0], weights.shape[0]), diagonal = 1 )
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [None]:
mask.bool()

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [None]:
weights

tensor([[0.2687, 0.2787, 0.1943, 0.2583],
        [0.3065, 0.2572, 0.1634, 0.2729],
        [0.1916, 0.2275, 0.3366, 0.2443],
        [0.2878, 0.2422, 0.2345, 0.2355]], grad_fn=<SoftmaxBackward0>)

In [None]:
weights = weights.masked_fill( mask.bool(), -torch.inf )
weights

tensor([[0.2687,   -inf,   -inf,   -inf],
        [0.3065, 0.2572,   -inf,   -inf],
        [0.1916, 0.2275, 0.3366,   -inf],
        [0.2878, 0.2422, 0.2345, 0.2355]], grad_fn=<MaskedFillBackward0>)

In [None]:
masked_weights = torch.softmax( weights, dim=-1 )
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5123, 0.4877, 0.0000, 0.0000],
        [0.3132, 0.3247, 0.3621, 0.0000],
        [0.2596, 0.2480, 0.2461, 0.2463]], grad_fn=<SoftmaxBackward0>)

In [None]:
## Dropout
# idea: randomly select some data to leave out to avoid overfitting
dropout = nn.Dropout( 0.5 )

In [None]:
dropout( masked_weights )

tensor([[0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.4922, 0.4927]], grad_fn=<MulBackward0>)

In [None]:
# we need to be able to give our LLM batches of input
# for example:
batches = torch.stack( (inputs, inputs), dim=0)

In [None]:
batches

tensor([[[-0.2427,  0.6827,  0.7656,  1.0912,  1.2138,  0.6371, -0.5461,
          -0.3942],
         [-1.4937,  1.5845,  2.5122, -0.6860,  0.8015, -1.1129, -0.2717,
           0.0380],
         [-0.1699, -0.4934, -1.0443,  1.7926, -0.5786, -0.3779,  0.4462,
           0.0745],
         [-0.0711,  1.5781,  1.1288, -1.0967,  0.1295,  0.0320,  0.7134,
           0.6368]],

        [[-0.2427,  0.6827,  0.7656,  1.0912,  1.2138,  0.6371, -0.5461,
          -0.3942],
         [-1.4937,  1.5845,  2.5122, -0.6860,  0.8015, -1.1129, -0.2717,
           0.0380],
         [-0.1699, -0.4934, -1.0443,  1.7926, -0.5786, -0.3779,  0.4462,
           0.0745],
         [-0.0711,  1.5781,  1.1288, -1.0967,  0.1295,  0.0320,  0.7134,
           0.6368]]])

In [None]:
batches.shape

torch.Size([2, 4, 8])

In [None]:
# this class needs to handle batches of input!

class CausalAttention( nn.Module ):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    super().__init__()
    self.d_out = d_out
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )
    # include dropout:
    self.dropout = nn.Dropout( dropout )
    # use the following to manage memory efficiently:
    self.register_buffer(
        'mask',
        torch.triu( torch.ones(context_length, context_length), diagonal = 1 )
    )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    b, num_tokens, d_in = x.shape
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )
    scores = queries @ keys.transpose(1,2)
    scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    weights = self.dropout( weights )
    context = weights @ values
    return context

In [None]:
# instantiate a causal attention mechanism:
causal = CausalAttention( d_in=8, d_out=6, context_length=4, dropout=0 )

In [None]:
causal( batches )

tensor([[[-0.2956,  0.1039,  0.0676, -0.2119, -0.7519, -0.4160],
         [ 0.0174, -0.0664, -0.5165, -0.7360, -1.3500, -0.3612],
         [-0.4007,  0.0280, -0.0891, -0.4656, -0.5094, -0.2803],
         [ 0.0916, -0.1431, -0.1039, -0.6458, -0.7593, -0.2148]],

        [[-0.2956,  0.1039,  0.0676, -0.2119, -0.7519, -0.4160],
         [ 0.0174, -0.0664, -0.5165, -0.7360, -1.3500, -0.3612],
         [-0.4007,  0.0280, -0.0891, -0.4656, -0.5094, -0.2803],
         [ 0.0916, -0.1431, -0.1039, -0.6458, -0.7593, -0.2148]]],
       grad_fn=<UnsafeViewBackward0>)

In [None]:
# everything below is just to show what happens with batches

In [None]:
queries = W_q( batches )
queries

tensor([[[-3.2929e-02, -5.1089e-02,  1.9729e-01,  1.4060e-01,  1.1246e-03,
           2.1223e-01],
         [ 1.2751e+00,  1.2309e+00, -2.6176e-01,  1.3619e+00,  5.4810e-01,
          -1.9892e-01],
         [-7.6839e-01,  8.0788e-02,  5.8067e-01, -1.6036e-01, -8.1724e-01,
           7.5022e-01],
         [ 8.3302e-01,  6.2428e-01,  8.5691e-04,  7.7048e-01,  3.1241e-01,
          -4.3753e-01]],

        [[-3.2929e-02, -5.1089e-02,  1.9729e-01,  1.4060e-01,  1.1246e-03,
           2.1223e-01],
         [ 1.2751e+00,  1.2309e+00, -2.6176e-01,  1.3619e+00,  5.4810e-01,
          -1.9892e-01],
         [-7.6839e-01,  8.0788e-02,  5.8067e-01, -1.6036e-01, -8.1724e-01,
           7.5022e-01],
         [ 8.3302e-01,  6.2428e-01,  8.5691e-04,  7.7048e-01,  3.1241e-01,
          -4.3753e-01]]], grad_fn=<UnsafeViewBackward0>)

In [None]:
keys = W_k( batches )
keys

tensor([[[ 0.2261, -0.3120,  0.0810, -0.2931, -0.8426, -0.4877],
         [ 0.0451, -0.1597,  0.1095,  0.7858, -0.6787, -0.8029],
         [ 0.4853,  0.2652, -0.1591, -0.0825,  0.7799,  0.3343],
         [-0.5785, -0.2767,  0.4855,  0.1878, -0.5316, -0.8584]],

        [[ 0.2261, -0.3120,  0.0810, -0.2931, -0.8426, -0.4877],
         [ 0.0451, -0.1597,  0.1095,  0.7858, -0.6787, -0.8029],
         [ 0.4853,  0.2652, -0.1591, -0.0825,  0.7799,  0.3343],
         [-0.5785, -0.2767,  0.4855,  0.1878, -0.5316, -0.8584]]],
       grad_fn=<UnsafeViewBackward0>)

In [None]:
keys.transpose(1,2)

tensor([[[ 0.2261,  0.0451,  0.4853, -0.5785],
         [-0.3120, -0.1597,  0.2652, -0.2767],
         [ 0.0810,  0.1095, -0.1591,  0.4855],
         [-0.2931,  0.7858, -0.0825,  0.1878],
         [-0.8426, -0.6787,  0.7799, -0.5316],
         [-0.4877, -0.8029,  0.3343, -0.8584]],

        [[ 0.2261,  0.0451,  0.4853, -0.5785],
         [-0.3120, -0.1597,  0.2652, -0.2767],
         [ 0.0810,  0.1095, -0.1591,  0.4855],
         [-0.2931,  0.7858, -0.0825,  0.1878],
         [-0.8426, -0.6787,  0.7799, -0.5316],
         [-0.4877, -0.8029,  0.3343, -0.8584]]], grad_fn=<TransposeBackward0>)