<a href="https://colab.research.google.com/github/mbrudd/LLMs/blob/main/trainable_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Attention with trainable weights

In [1]:
import torch

In [2]:
inputs = torch.nn.Embedding( 4, 8 )

In [3]:
inputs = inputs.weight
inputs

Parameter containing:
tensor([[ 0.1197, -1.2762, -0.2189, -0.5249, -0.7863, -0.6896,  0.9719, -0.7062],
        [-0.7043,  0.6678,  1.3428,  0.9279, -0.1171,  0.9757, -0.4852,  0.2322],
        [ 0.3666,  0.4479, -0.5973, -0.4069, -1.2787,  1.4757,  2.2668,  0.9394],
        [-2.5224, -1.7426, -0.7029,  0.6029,  1.8368, -0.1128, -0.9829, -0.3613]],
       requires_grad=True)

In [4]:
inputs = inputs.data
inputs

tensor([[ 0.1197, -1.2762, -0.2189, -0.5249, -0.7863, -0.6896,  0.9719, -0.7062],
        [-0.7043,  0.6678,  1.3428,  0.9279, -0.1171,  0.9757, -0.4852,  0.2322],
        [ 0.3666,  0.4479, -0.5973, -0.4069, -1.2787,  1.4757,  2.2668,  0.9394],
        [-2.5224, -1.7426, -0.7029,  0.6029,  1.8368, -0.1128, -0.9829, -0.3613]])

In [5]:
# set dimensions
d_in = 8
d_out = 6

# create weight matrices
W_q = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_k = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_v = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

In [6]:
# choose an input vector and transform it into our query vector using W_q
query = inputs[2] @ W_q
query

tensor([3.2052, 0.9526, 2.5539, 1.3872, 0.1190, 2.2922])

In [7]:
# calculate attention scores using the keys generated by W_k:
keys = inputs @ W_k
values = inputs @ W_v
print("Keys:", keys)
print("Values:", values )

Keys: tensor([[-1.8804,  0.0284, -2.0566, -1.9059, -1.0583, -1.6339],
        [ 2.1169,  0.6069,  1.2166,  1.4986,  1.1167,  1.0354],
        [ 2.5898,  1.8816,  0.9269,  1.4456,  2.1721,  1.9620],
        [-3.3770, -2.4154, -2.6265, -3.2443, -2.4173, -2.1509]])
Values: tensor([[-1.6695, -1.8921, -1.7313, -1.4193, -1.6749, -2.5785],
        [ 1.0467,  2.2443,  1.5191,  1.3044,  1.5846,  1.6639],
        [ 1.1272,  1.2353, -0.9836,  1.2727,  2.4874, -0.0108],
        [-0.7975, -0.8695, -1.7951, -1.4770, -1.8175, -0.7950]])


In [8]:
attention_scores = query @ keys.T
attention_scores

tensor([-17.7671,  15.0553,  19.2218, -29.5513])

In [9]:
attention_weights = torch.softmax( attention_scores / keys.shape[-1]**0.5, dim = -1 )
attention_weights

tensor([2.3392e-07, 1.5434e-01, 8.4566e-01, 1.9042e-09])

In [10]:
attention_weights.sum()

tensor(1.)

In [11]:
context_vector = attention_weights @ values
context_vector

tensor([ 1.1148,  1.3910, -0.5974,  1.2776,  2.3481,  0.2477])

In [12]:
import torch.nn as nn


In [13]:
# here's a first version of a SimpleAttention class:

class SimpleAttention( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_k = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_v = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = x @ self.W_q
    keys = x @ self.W_k
    values = x @ self.W_v
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context

In [14]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [15]:
simple.W_v

Parameter containing:
tensor([[0.4066, 0.7928, 0.5265, 0.9394, 0.7186, 0.0932],
        [0.3120, 0.7668, 0.3172, 0.5927, 0.6998, 0.7964],
        [0.6131, 0.4812, 0.8247, 0.0656, 0.5676, 0.1964],
        [0.6455, 0.7282, 0.4283, 0.1087, 0.1954, 0.0595],
        [0.0476, 0.6367, 0.5483, 0.8005, 0.4652, 0.5812],
        [0.3759, 0.5771, 0.1766, 0.4067, 0.2624, 0.3073],
        [0.6372, 0.4309, 0.9647, 0.2889, 0.5187, 0.4210],
        [0.2237, 0.7635, 0.2370, 0.3275, 0.4569, 0.0820]])

In [16]:
context_vectors = simple( inputs )
context_vectors

tensor([[-1.6837, -2.6692, -1.6217, -2.0741, -2.5194, -1.2305],
        [ 1.5106,  1.7405,  1.1445,  0.3192,  1.0640,  0.8593],
        [ 1.4877,  1.7378,  1.1063,  0.2613,  1.0260,  0.8493],
        [-1.2451, -2.5502, -1.1549, -1.8611, -2.0405, -1.3020]])

In [17]:
# here's a second version of a SimpleAttention class ;
# it uses nn.Linear to do things more efficiently

class SimpleAttention( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context

In [18]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [19]:
context_vectors = simple( inputs )
context_vectors

tensor([[ 0.2893,  0.1132, -0.0358,  0.1293, -0.1353, -0.0383],
        [ 0.2967,  0.1351, -0.0425,  0.0905, -0.1454, -0.1069],
        [ 0.1528, -0.0386, -0.1008,  0.2598, -0.1097,  0.1222],
        [ 0.2973,  0.1407, -0.0467,  0.0750, -0.1486, -0.1330]],
       grad_fn=<MmBackward0>)

In [20]:
# the problem with this is that each context vector uses information from ALL of the embedding vectors
# in practice, we should only use information about the preceding embedding vectors
# to accomplish this, we'll implement causal attention AKA masked attention

In [None]:
# this is a hack to get some example weights to work with!
# weights = simple( inputs )
weights

In [None]:
# note that these have already been normalized:
weights.sum( dim=-1 )

In [None]:
# masking method #1
simple_mask = torch.tril( torch.ones( weights.shape[0], weights.shape[0] ) )
simple_mask

In [None]:
masked_weights = weights*simple_mask
masked_weights

In [None]:
masked_weights.sum( dim=-1 )

In [None]:
# now, we need to normalize the masked_weights so that each row has sum 1
row_sums = masked_weights.sum( dim=-1, keepdim=True)
row_sums

In [None]:
masked_weights = masked_weights / row_sums
masked_weights.sum( dim=-1)

In [None]:
masked_weights

In [None]:
# masking method #2
mask = torch.triu( torch.ones(weights.shape[0], weights.shape[0]), diagonal = 1 )
mask

In [None]:
mask.bool()

In [None]:
weights

In [None]:
weights = weights.masked_fill( mask.bool(), -torch.inf )
weights

In [None]:
masked_weights = torch.softmax( weights, dim=-1 )
masked_weights

In [None]:
## Dropout
# idea: randomly select some data to leave out to avoid overfitting
dropout = nn.Dropout( 0.5 )

In [None]:
dropout( masked_weights )

In [None]:
# we need to be able to give our LLM batches of input
# for example:
batches = torch.stack( (inputs, inputs), dim=0)

In [None]:
batches

In [None]:
batches.shape

In [None]:
# this class needs to handle batches of input!

class CausalAttention( nn.Module ):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    super().__init__()
    self.d_out = d_out
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )
    # include dropout:
    self.dropout = nn.Dropout( dropout )
    # use the following to manage memory efficiently:
    self.register_buffer(
        'mask',
        torch.triu( torch.ones(context_length, context_length), diagonal = 1 )
    )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    b, num_tokens, d_in = x.shape
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )
    scores = queries @ keys.transpose(1,2)
    scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    weights = self.dropout( weights )
    context = weights @ values
    return context

In [None]:
# instantiate a causal attention mechanism:
causal = CausalAttention( d_in=8, d_out=6, context_length=4, dropout=0 )

In [None]:
causal( batches )

In [None]:
# everything below is just to show what happens with batches

In [None]:
queries = W_q( batches )
queries

In [None]:
keys = W_k( batches )
keys

In [None]:
keys.transpose(1,2)