<a href="https://colab.research.google.com/github/mbrudd/LLMs/blob/main/trainable_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Attention with trainable weights

In [1]:
import torch

In [2]:
inputs = torch.nn.Embedding( 4, 8 )

In [3]:
inputs = inputs.weight
inputs

Parameter containing:
tensor([[-0.2427,  0.6827,  0.7656,  1.0912,  1.2138,  0.6371, -0.5461, -0.3942],
        [-1.4937,  1.5845,  2.5122, -0.6860,  0.8015, -1.1129, -0.2717,  0.0380],
        [-0.1699, -0.4934, -1.0443,  1.7926, -0.5786, -0.3779,  0.4462,  0.0745],
        [-0.0711,  1.5781,  1.1288, -1.0967,  0.1295,  0.0320,  0.7134,  0.6368]],
       requires_grad=True)

In [4]:
inputs = inputs.data
inputs

tensor([[-0.2427,  0.6827,  0.7656,  1.0912,  1.2138,  0.6371, -0.5461, -0.3942],
        [-1.4937,  1.5845,  2.5122, -0.6860,  0.8015, -1.1129, -0.2717,  0.0380],
        [-0.1699, -0.4934, -1.0443,  1.7926, -0.5786, -0.3779,  0.4462,  0.0745],
        [-0.0711,  1.5781,  1.1288, -1.0967,  0.1295,  0.0320,  0.7134,  0.6368]])

In [5]:
# set dimensions
d_in = 8
d_out = 6

# create weight matrices
W_q = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_k = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_v = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

In [6]:
# choose an input vector and transform it into our query vector using W_q
query = inputs[2] @ W_q
query

tensor([-0.0585, -0.4860, -0.1032, -0.4076, -0.9160, -0.1425])

In [7]:
# calculate attention scores using the keys generated by W_k:
keys = inputs @ W_k
values = inputs @ W_v
print("Keys:", keys)
print("Values:", values )

Keys: tensor([[ 2.1480,  1.5529,  2.8008,  1.8782,  1.4580,  2.5026],
        [ 1.9040,  0.2781,  1.9731,  1.5060,  0.9453,  2.0697],
        [-0.1213, -1.5318, -0.1547, -0.1155, -1.1343, -0.8727],
        [ 1.4671,  1.2422,  1.8308,  1.1122,  2.4337,  2.2144]])
Values: tensor([[ 1.9701,  2.1980,  0.8614,  1.3574,  1.4857,  1.2846],
        [-0.5965, -0.5286, -0.9120, -0.4883, -0.2370,  0.1408],
        [ 0.6044,  0.3491, -0.2306,  0.4990, -0.1373, -1.1075],
        [-0.0115,  0.2653,  1.3765,  1.7680,  0.7499,  2.0855]])


In [8]:
attention_scores = query @ keys.T
attention_scores

tensor([-3.6272, -2.2249,  1.9780, -3.8767])

In [9]:
attention_weights = torch.softmax( attention_scores / keys.shape[-1]**0.5, dim = -1 )
attention_weights

tensor([0.0739, 0.1310, 0.7284, 0.0667])

In [10]:
attention_weights.sum()

tensor(1.)

In [11]:
context_vector = attention_weights @ values
context_vector

tensor([ 0.5069,  0.3652, -0.1319,  0.5178,  0.0287, -0.5542])

In [12]:
import torch.nn as nn


In [13]:
# here's a first version of a SimpleAttention class:

class SimpleAttention( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_k = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_v = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = x @ self.W_q
    keys = x @ self.W_k
    values = x @ self.W_v
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context

In [14]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [15]:
simple.W_v

Parameter containing:
tensor([[0.0302, 0.6970, 0.2217, 0.3603, 0.3794, 0.4925],
        [0.0935, 0.1134, 0.4462, 0.7525, 0.9334, 0.5044],
        [0.9820, 0.6072, 0.5998, 0.3263, 0.8216, 0.8495],
        [0.1017, 0.0541, 0.4583, 0.2478, 0.7052, 0.4103],
        [0.9067, 0.6691, 0.6662, 0.0422, 0.9965, 0.8874],
        [0.3491, 0.1067, 0.6210, 0.0992, 0.8924, 0.7328],
        [0.2464, 0.4291, 0.8142, 0.2558, 0.9035, 0.7506],
        [0.2434, 0.7094, 0.7889, 0.8098, 0.3994, 0.2009]])

In [16]:
context_vectors = simple( inputs )
context_vectors

tensor([[1.6600, 1.5125, 2.0004, 1.8163, 2.6899, 2.0961],
        [1.8196, 1.1861, 1.8020, 1.3035, 2.7585, 2.1389],
        [1.5107, 0.8079, 1.3080, 1.0829, 1.9705, 1.4535],
        [1.7056, 1.4171, 1.9554, 1.6537, 2.7418, 2.1345]])

In [26]:
# here's a second version of a SimpleAttention class ;
# it uses nn.Linear to do things more efficiently

class SimpleAttention( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context

In [27]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [28]:
context_vectors = simple( inputs )
context_vectors

tensor([[-4.1225e-02, -3.6173e-01,  2.8266e-01, -3.3615e-01, -7.9088e-05,
          3.5463e-01],
        [-1.0249e-01, -3.7105e-01,  2.8051e-01, -2.6390e-01, -5.1291e-02,
          2.8341e-01],
        [-9.4526e-02, -2.0338e-01,  1.5570e-01, -1.6162e-01,  8.7017e-02,
          8.3274e-02],
        [-1.0312e-01, -3.3303e-01,  2.4755e-01, -2.3086e-01, -1.6659e-02,
          2.2562e-01]], grad_fn=<MmBackward0>)

In [20]:
# the problem with this is that each context vector uses information from ALL of the embedding vectors
# in practice, we should only use information about the preceding embedding vectors
# to accomplish this, we'll implement causal attention AKA masked attention

In [31]:
# this is a hack to get some example weights to work with!
# weights = simple( inputs )
weights

tensor([[0.2687, 0.2787, 0.1943, 0.2583],
        [0.3065, 0.2572, 0.1634, 0.2729],
        [0.1916, 0.2275, 0.3366, 0.2443],
        [0.2878, 0.2422, 0.2345, 0.2355]], grad_fn=<SoftmaxBackward0>)

In [32]:
# note that these have already been normalized:
weights.sum( dim=-1 )

tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)

In [33]:
# masking method #1
simple_mask = torch.tril( torch.ones( weights.shape[0], weights.shape[0] ) )
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [34]:
masked_weights = weights*simple_mask
masked_weights

tensor([[0.2687, 0.0000, 0.0000, 0.0000],
        [0.3065, 0.2572, 0.0000, 0.0000],
        [0.1916, 0.2275, 0.3366, 0.0000],
        [0.2878, 0.2422, 0.2345, 0.2355]], grad_fn=<MulBackward0>)

In [35]:
masked_weights.sum( dim=-1 )

tensor([0.2687, 0.5637, 0.7557, 1.0000], grad_fn=<SumBackward1>)

In [36]:
# now, we need to normalize the masked_weights so that each row has sum 1
row_sums = masked_weights.sum( dim=-1, keepdim=True)
row_sums

tensor([[0.2687],
        [0.5637],
        [0.7557],
        [1.0000]], grad_fn=<SumBackward1>)

In [37]:
masked_weights = masked_weights / row_sums
masked_weights.sum( dim=-1)

tensor([1., 1., 1., 1.], grad_fn=<SumBackward1>)

In [38]:
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5437, 0.4563, 0.0000, 0.0000],
        [0.2536, 0.3010, 0.4455, 0.0000],
        [0.2878, 0.2422, 0.2345, 0.2355]], grad_fn=<DivBackward0>)

In [39]:
# masking method #2
mask = torch.triu( torch.ones(weights.shape[0], weights.shape[0]), diagonal = 1 )
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [40]:
mask.bool()

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [41]:
weights

tensor([[0.2687, 0.2787, 0.1943, 0.2583],
        [0.3065, 0.2572, 0.1634, 0.2729],
        [0.1916, 0.2275, 0.3366, 0.2443],
        [0.2878, 0.2422, 0.2345, 0.2355]], grad_fn=<SoftmaxBackward0>)

In [44]:
weights = weights.masked_fill( mask.bool(), -torch.inf )
weights

tensor([[0.2687,   -inf,   -inf,   -inf],
        [0.3065, 0.2572,   -inf,   -inf],
        [0.1916, 0.2275, 0.3366,   -inf],
        [0.2878, 0.2422, 0.2345, 0.2355]], grad_fn=<MaskedFillBackward0>)

In [43]:
masked_weights = torch.softmax( weights, dim=-1 )
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5123, 0.4877, 0.0000, 0.0000],
        [0.3132, 0.3247, 0.3621, 0.0000],
        [0.2596, 0.2480, 0.2461, 0.2463]], grad_fn=<SoftmaxBackward0>)

In [45]:
## Dropout
# idea: randomly select some data to leave out to avoid overfitting
dropout = nn.Dropout( 0.5 )

In [51]:
dropout( masked_weights )

tensor([[0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.4922, 0.4927]], grad_fn=<MulBackward0>)

In [52]:
# we need to be able to give our LLM batches of input
# for example:
batches = torch.stack( (inputs, inputs), dim=0)

In [53]:
batches

tensor([[[-0.2427,  0.6827,  0.7656,  1.0912,  1.2138,  0.6371, -0.5461,
          -0.3942],
         [-1.4937,  1.5845,  2.5122, -0.6860,  0.8015, -1.1129, -0.2717,
           0.0380],
         [-0.1699, -0.4934, -1.0443,  1.7926, -0.5786, -0.3779,  0.4462,
           0.0745],
         [-0.0711,  1.5781,  1.1288, -1.0967,  0.1295,  0.0320,  0.7134,
           0.6368]],

        [[-0.2427,  0.6827,  0.7656,  1.0912,  1.2138,  0.6371, -0.5461,
          -0.3942],
         [-1.4937,  1.5845,  2.5122, -0.6860,  0.8015, -1.1129, -0.2717,
           0.0380],
         [-0.1699, -0.4934, -1.0443,  1.7926, -0.5786, -0.3779,  0.4462,
           0.0745],
         [-0.0711,  1.5781,  1.1288, -1.0967,  0.1295,  0.0320,  0.7134,
           0.6368]]])

In [54]:
batches.shape

torch.Size([2, 4, 8])

In [67]:
# this class needs to handle batches of input!

class CausalAttention( nn.Module ):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    super().__init__()
    self.d_out = d_out
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )
    # include dropout:
    self.dropout = nn.Dropout( dropout )
    # use the following to manage memory efficiently:
    self.register_buffer(
        'mask',
        torch.triu( torch.ones(context_length, context_length), diagonal = 1 )
    )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    b, num_tokens, d_in = x.shape
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )
    scores = queries @ keys.transpose(1,2)
    scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    weights = self.dropout( weights )
    context = weights @ values
    return context

In [68]:
# instantiate a causal attention mechanism:
causal = CausalAttention( d_in=8, d_out=6, context_length=4, dropout=0 )

In [69]:
causal( batches )

tensor([[[-0.2956,  0.1039,  0.0676, -0.2119, -0.7519, -0.4160],
         [ 0.0174, -0.0664, -0.5165, -0.7360, -1.3500, -0.3612],
         [-0.4007,  0.0280, -0.0891, -0.4656, -0.5094, -0.2803],
         [ 0.0916, -0.1431, -0.1039, -0.6458, -0.7593, -0.2148]],

        [[-0.2956,  0.1039,  0.0676, -0.2119, -0.7519, -0.4160],
         [ 0.0174, -0.0664, -0.5165, -0.7360, -1.3500, -0.3612],
         [-0.4007,  0.0280, -0.0891, -0.4656, -0.5094, -0.2803],
         [ 0.0916, -0.1431, -0.1039, -0.6458, -0.7593, -0.2148]]],
       grad_fn=<UnsafeViewBackward0>)

In [57]:
# everything below is just to show what happens with batches

In [62]:
queries = W_q( batches )
queries

tensor([[[-3.2929e-02, -5.1089e-02,  1.9729e-01,  1.4060e-01,  1.1246e-03,
           2.1223e-01],
         [ 1.2751e+00,  1.2309e+00, -2.6176e-01,  1.3619e+00,  5.4810e-01,
          -1.9892e-01],
         [-7.6839e-01,  8.0788e-02,  5.8067e-01, -1.6036e-01, -8.1724e-01,
           7.5022e-01],
         [ 8.3302e-01,  6.2428e-01,  8.5691e-04,  7.7048e-01,  3.1241e-01,
          -4.3753e-01]],

        [[-3.2929e-02, -5.1089e-02,  1.9729e-01,  1.4060e-01,  1.1246e-03,
           2.1223e-01],
         [ 1.2751e+00,  1.2309e+00, -2.6176e-01,  1.3619e+00,  5.4810e-01,
          -1.9892e-01],
         [-7.6839e-01,  8.0788e-02,  5.8067e-01, -1.6036e-01, -8.1724e-01,
           7.5022e-01],
         [ 8.3302e-01,  6.2428e-01,  8.5691e-04,  7.7048e-01,  3.1241e-01,
          -4.3753e-01]]], grad_fn=<UnsafeViewBackward0>)

In [63]:
keys = W_k( batches )
keys

tensor([[[ 0.2261, -0.3120,  0.0810, -0.2931, -0.8426, -0.4877],
         [ 0.0451, -0.1597,  0.1095,  0.7858, -0.6787, -0.8029],
         [ 0.4853,  0.2652, -0.1591, -0.0825,  0.7799,  0.3343],
         [-0.5785, -0.2767,  0.4855,  0.1878, -0.5316, -0.8584]],

        [[ 0.2261, -0.3120,  0.0810, -0.2931, -0.8426, -0.4877],
         [ 0.0451, -0.1597,  0.1095,  0.7858, -0.6787, -0.8029],
         [ 0.4853,  0.2652, -0.1591, -0.0825,  0.7799,  0.3343],
         [-0.5785, -0.2767,  0.4855,  0.1878, -0.5316, -0.8584]]],
       grad_fn=<UnsafeViewBackward0>)

In [65]:
keys.transpose(1,2)

tensor([[[ 0.2261,  0.0451,  0.4853, -0.5785],
         [-0.3120, -0.1597,  0.2652, -0.2767],
         [ 0.0810,  0.1095, -0.1591,  0.4855],
         [-0.2931,  0.7858, -0.0825,  0.1878],
         [-0.8426, -0.6787,  0.7799, -0.5316],
         [-0.4877, -0.8029,  0.3343, -0.8584]],

        [[ 0.2261,  0.0451,  0.4853, -0.5785],
         [-0.3120, -0.1597,  0.2652, -0.2767],
         [ 0.0810,  0.1095, -0.1591,  0.4855],
         [-0.2931,  0.7858, -0.0825,  0.1878],
         [-0.8426, -0.6787,  0.7799, -0.5316],
         [-0.4877, -0.8029,  0.3343, -0.8584]]], grad_fn=<TransposeBackward0>)