<a href="https://colab.research.google.com/github/mbrudd/LLMs/blob/main/trainable_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Attention with trainable weights

In [1]:
import torch

In [2]:
inputs = torch.nn.Embedding( 4, 8 )

In [3]:
inputs = inputs.weight
inputs

Parameter containing:
tensor([[ 0.3759,  0.0686,  0.7031,  0.0935, -0.0464, -0.4952,  0.8338,  1.1267],
        [-0.6018, -0.2695,  0.3440, -1.1668, -0.9852,  0.4084, -0.0257, -0.0238],
        [-1.0834, -0.8992, -0.9997, -1.0471,  0.6966,  1.0558, -0.3086,  0.1166],
        [-0.8312,  0.2855, -0.0538, -0.3936, -1.2552, -0.2012, -0.8806, -0.0192]],
       requires_grad=True)

In [4]:
inputs = inputs.data
inputs

tensor([[ 0.3759,  0.0686,  0.7031,  0.0935, -0.0464, -0.4952,  0.8338,  1.1267],
        [-0.6018, -0.2695,  0.3440, -1.1668, -0.9852,  0.4084, -0.0257, -0.0238],
        [-1.0834, -0.8992, -0.9997, -1.0471,  0.6966,  1.0558, -0.3086,  0.1166],
        [-0.8312,  0.2855, -0.0538, -0.3936, -1.2552, -0.2012, -0.8806, -0.0192]])

In [5]:
# set dimensions
d_in = 8
d_out = 6

# create weight matrices
W_q = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_k = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_v = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

In [6]:
# choose an input vector and transform it into our query vector using W_q
query = inputs[2] @ W_q
query

tensor([-1.3756, -2.0649, -0.5108, -0.8182, -1.1920, -0.9123])

In [7]:
# calculate attention scores using the keys generated by W_k:
keys = inputs @ W_k
values = inputs @ W_v
print("Keys:", keys)
print("Values:", values )

Keys: tensor([[ 2.0968,  0.7871,  0.6638,  1.8078,  0.9960,  0.2127],
        [-0.8083, -1.3694, -1.2487, -0.8687, -1.0420, -0.6535],
        [-2.0322,  0.3496, -0.3710, -0.8335, -0.4581,  0.6089],
        [-0.9043, -2.1913, -1.7305, -1.7807, -1.2833, -1.4691]])
Values: tensor([[ 1.1060,  2.0203,  2.7215,  1.2079,  1.3403,  1.0632],
        [-1.0737, -2.0347, -2.0019, -1.2446, -2.2905, -0.3585],
        [-0.2100, -2.4208, -2.9082, -1.5892, -1.3626, -1.0998],
        [-1.2808, -2.4287, -2.7191, -1.6241, -2.6295, -0.8259]])


In [8]:
attention_scores = query @ keys.T
attention_scores

tensor([-7.7094,  7.1266,  2.9356, 10.9798])

In [9]:
attention_weights = torch.softmax( attention_scores / keys.shape[-1]**0.5, dim = -1 )
attention_weights

tensor([3.9005e-04, 1.6654e-01, 3.0093e-02, 8.0297e-01])

In [10]:
attention_weights.sum()

tensor(1.0000)

In [11]:
context_vector = attention_weights @ values
context_vector

tensor([-1.2131, -2.3611, -2.6032, -1.5588, -2.5334, -0.7556])

In [12]:
import torch.nn as nn


In [13]:
# here's a first version of a SimpleAttention class:

class SimpleAttention( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_k = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_v = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = x @ self.W_q
    keys = x @ self.W_k
    values = x @ self.W_v
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context

In [14]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [15]:
simple.W_v

Parameter containing:
tensor([[0.9563, 0.8906, 0.7216, 0.9572, 0.5428, 0.6571],
        [0.4411, 0.5230, 0.8777, 0.2027, 0.7596, 0.7206],
        [0.4109, 0.4746, 0.0961, 0.3563, 0.5101, 0.5325],
        [0.0827, 0.8998, 0.2190, 0.8373, 0.3363, 0.4321],
        [0.3779, 0.8566, 0.2960, 0.4776, 0.9894, 0.7767],
        [0.0444, 0.2318, 0.5398, 0.7857, 0.0462, 0.5516],
        [0.6263, 0.7638, 0.6733, 0.2554, 0.6791, 0.5256],
        [0.2981, 0.4378, 0.1029, 0.9949, 0.2897, 0.9695]])

In [16]:
context_vectors = simple( inputs )
context_vectors

tensor([[ 1.4954,  1.7507,  0.8093,  1.6147,  1.4596,  1.9214],
        [-1.6064, -2.5210, -1.3447, -1.7674, -1.9499, -1.7427],
        [-1.6463, -2.5694, -1.3823, -1.8260, -2.0072, -1.8121],
        [-1.6319, -2.5343, -1.3624, -1.7794, -1.9658, -1.7615]])

In [25]:
# here's a second version of a SimpleAttention class ;
# it uses nn.Linear to do things more efficiently

class SimpleAttention( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context

In [26]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [27]:
context_vectors = simple( inputs )
context_vectors

tensor([[-0.3545,  0.0608, -0.1222, -0.0895, -0.4415, -0.0292],
        [-0.3666,  0.0100, -0.1073,  0.0075, -0.4131,  0.0083],
        [-0.3545, -0.0412, -0.0744,  0.0519, -0.3728,  0.0173],
        [-0.3551,  0.0097, -0.0977, -0.0129, -0.4042, -0.0039]],
       grad_fn=<MmBackward0>)

In [20]:
# the problem with this is that each context vector uses information from ALL of the embedding vectors
# in practice, we should only use information about the preceding embedding vectors
# to accomplish this, we'll implement causal attention AKA masked attention

In [28]:
# this is a hack to get some example weights to work with!
# weights = simple( inputs )
weights

tensor([[0.2298, 0.2652, 0.2592, 0.2458],
        [0.2689, 0.2290, 0.2412, 0.2610],
        [0.3089, 0.2295, 0.1937, 0.2679],
        [0.2037, 0.2236, 0.3170, 0.2558]], grad_fn=<SoftmaxBackward0>)

In [29]:
# note that these have already been normalized:
weights.sum( dim=-1 )

tensor([1., 1., 1., 1.], grad_fn=<SumBackward1>)

In [30]:
# masking method #1
simple_mask = torch.tril( torch.ones( weights.shape[0], weights.shape[0] ) )
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [31]:
masked_weights = weights*simple_mask
masked_weights

tensor([[0.2298, 0.0000, 0.0000, 0.0000],
        [0.2689, 0.2290, 0.0000, 0.0000],
        [0.3089, 0.2295, 0.1937, 0.0000],
        [0.2037, 0.2236, 0.3170, 0.2558]], grad_fn=<MulBackward0>)

In [32]:
masked_weights.sum( dim=-1 )

tensor([0.2298, 0.4979, 0.7321, 1.0000], grad_fn=<SumBackward1>)

In [33]:
# now, we need to normalize the masked_weights so that each row has sum 1
row_sums = masked_weights.sum( dim=-1, keepdim=True)
row_sums

tensor([[0.2298],
        [0.4979],
        [0.7321],
        [1.0000]], grad_fn=<SumBackward1>)

In [34]:
masked_weights = masked_weights / row_sums
masked_weights.sum( dim=-1)

tensor([1., 1., 1., 1.], grad_fn=<SumBackward1>)

In [35]:
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5400, 0.4600, 0.0000, 0.0000],
        [0.4219, 0.3134, 0.2646, 0.0000],
        [0.2037, 0.2236, 0.3170, 0.2558]], grad_fn=<DivBackward0>)

In [36]:
# masking method #2
mask = torch.triu( torch.ones(weights.shape[0], weights.shape[0]), diagonal = 1 )
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [37]:
mask.bool()

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [38]:
weights

tensor([[0.2298, 0.2652, 0.2592, 0.2458],
        [0.2689, 0.2290, 0.2412, 0.2610],
        [0.3089, 0.2295, 0.1937, 0.2679],
        [0.2037, 0.2236, 0.3170, 0.2558]], grad_fn=<SoftmaxBackward0>)

In [39]:
weights = weights.masked_fill( mask.bool(), -torch.inf )
weights

tensor([[0.2298,   -inf,   -inf,   -inf],
        [0.2689, 0.2290,   -inf,   -inf],
        [0.3089, 0.2295, 0.1937,   -inf],
        [0.2037, 0.2236, 0.3170, 0.2558]], grad_fn=<MaskedFillBackward0>)

In [40]:
masked_weights = torch.softmax( weights, dim=-1 )
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5100, 0.4900, 0.0000, 0.0000],
        [0.3553, 0.3281, 0.3166, 0.0000],
        [0.2385, 0.2433, 0.2671, 0.2512]], grad_fn=<SoftmaxBackward0>)

In [41]:
## Dropout
# idea: randomly select some data to leave out to avoid overfitting
dropout = nn.Dropout( 0.5 )

In [42]:
dropout( masked_weights )

tensor([[0.0000, 0.0000, 0.0000, 0.0000],
        [1.0199, 0.9801, 0.0000, 0.0000],
        [0.7105, 0.6563, 0.6332, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.5024]], grad_fn=<MulBackward0>)

In [43]:
# we need to be able to give our LLM batches of input
# for example:
batches = torch.stack( (inputs, inputs), dim=0)

In [44]:
batches

tensor([[[ 0.3759,  0.0686,  0.7031,  0.0935, -0.0464, -0.4952,  0.8338,
           1.1267],
         [-0.6018, -0.2695,  0.3440, -1.1668, -0.9852,  0.4084, -0.0257,
          -0.0238],
         [-1.0834, -0.8992, -0.9997, -1.0471,  0.6966,  1.0558, -0.3086,
           0.1166],
         [-0.8312,  0.2855, -0.0538, -0.3936, -1.2552, -0.2012, -0.8806,
          -0.0192]],

        [[ 0.3759,  0.0686,  0.7031,  0.0935, -0.0464, -0.4952,  0.8338,
           1.1267],
         [-0.6018, -0.2695,  0.3440, -1.1668, -0.9852,  0.4084, -0.0257,
          -0.0238],
         [-1.0834, -0.8992, -0.9997, -1.0471,  0.6966,  1.0558, -0.3086,
           0.1166],
         [-0.8312,  0.2855, -0.0538, -0.3936, -1.2552, -0.2012, -0.8806,
          -0.0192]]])

In [45]:
batches.shape

torch.Size([2, 4, 8])

In [46]:
# this class needs to handle batches of input!

class CausalAttention( nn.Module ):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    super().__init__()
    self.d_out = d_out
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )
    # include dropout:
    self.dropout = nn.Dropout( dropout )
    # use the following to manage memory efficiently:
    self.register_buffer(
        'mask',
        torch.triu( torch.ones(context_length, context_length), diagonal = 1 )
    )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    b, num_tokens, d_in = x.shape
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )
    scores = queries @ keys.transpose(1,2)
    scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    weights = self.dropout( weights )
    context = weights @ values
    return context

In [47]:
# instantiate a causal attention mechanism:
causal = CausalAttention( d_in=8, d_out=6, context_length=4, dropout=0 )

In [48]:
causal( batches )

tensor([[[-0.1320, -0.1111, -0.2173,  0.5782,  0.5652, -0.0052],
         [-0.1586, -0.1000, -0.0912,  0.0555,  0.0404, -0.2249],
         [ 0.0056, -0.2076, -0.0674,  0.0915,  0.0285, -0.0785],
         [-0.1241, -0.2121,  0.0519, -0.3109, -0.2811, -0.1947]],

        [[-0.1320, -0.1111, -0.2173,  0.5782,  0.5652, -0.0052],
         [-0.1586, -0.1000, -0.0912,  0.0555,  0.0404, -0.2249],
         [ 0.0056, -0.2076, -0.0674,  0.0915,  0.0285, -0.0785],
         [-0.1241, -0.2121,  0.0519, -0.3109, -0.2811, -0.1947]]],
       grad_fn=<UnsafeViewBackward0>)

In [57]:
# here's a first pass at multi-head attention
class MultiHeadAttention( nn.Module ):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList( 
            [ CausalAttention( d_in, d_out, context_length, dropout, qkv_bias ) for _ in range(num_heads) ]
        )

    def forward( self, x ):
        return torch.cat( [ head(x) for head in self.heads ], dim=-1 )

In [58]:
mha = MultiHeadAttention( d_in = 8, d_out = 6, context_length= 4, dropout=0, num_heads = 3 )

In [59]:
mha_out = mha( batches )

In [60]:
mha_out

tensor([[[-0.6910, -0.1860, -0.0429,  0.4113, -0.4726,  0.1205,  0.0144,
           0.0335, -0.1235,  0.2319, -0.4612,  0.3753, -0.5560, -0.3359,
          -0.5808,  0.3521,  0.1915,  0.1376],
         [-0.4432, -0.0200, -0.1760,  0.2633, -0.1150,  0.4216,  0.2300,
           0.3295, -0.0435, -0.1222, -0.0872,  0.1404, -0.1983, -0.0771,
          -0.1422,  0.5138,  0.5273,  0.1412],
         [-0.0804,  0.3473, -0.2274, -0.1240, -0.0057,  0.3051,  0.1226,
           0.2555, -0.1163, -0.2704, -0.0075, -0.0185, -0.2498,  0.0650,
           0.3498,  0.2934,  0.5193,  0.0358],
         [-0.0931,  0.2079, -0.2087, -0.0306,  0.1332,  0.3605,  0.2494,
           0.3168, -0.0247, -0.2485,  0.0553,  0.0025, -0.0785,  0.0341,
           0.1837,  0.2606,  0.5454, -0.0736]],

        [[-0.6910, -0.1860, -0.0429,  0.4113, -0.4726,  0.1205,  0.0144,
           0.0335, -0.1235,  0.2319, -0.4612,  0.3753, -0.5560, -0.3359,
          -0.5808,  0.3521,  0.1915,  0.1376],
         [-0.4432, -0.0200, -0.17

In [61]:
mha_out.shape

torch.Size([2, 4, 18])

In [62]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`, 
        # this will result in errors in the mask creation further below. 
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  
        # do not exceed `context_length` before reaching this forward method.

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

In [64]:
batches.shape

torch.Size([2, 4, 8])

In [67]:
batches

tensor([[[ 0.3759,  0.0686,  0.7031,  0.0935, -0.0464, -0.4952,  0.8338,
           1.1267],
         [-0.6018, -0.2695,  0.3440, -1.1668, -0.9852,  0.4084, -0.0257,
          -0.0238],
         [-1.0834, -0.8992, -0.9997, -1.0471,  0.6966,  1.0558, -0.3086,
           0.1166],
         [-0.8312,  0.2855, -0.0538, -0.3936, -1.2552, -0.2012, -0.8806,
          -0.0192]],

        [[ 0.3759,  0.0686,  0.7031,  0.0935, -0.0464, -0.4952,  0.8338,
           1.1267],
         [-0.6018, -0.2695,  0.3440, -1.1668, -0.9852,  0.4084, -0.0257,
          -0.0238],
         [-1.0834, -0.8992, -0.9997, -1.0471,  0.6966,  1.0558, -0.3086,
           0.1166],
         [-0.8312,  0.2855, -0.0538, -0.3936, -1.2552, -0.2012, -0.8806,
          -0.0192]]])

In [68]:
batches.view( 2, 4, 2, 4 )

tensor([[[[ 0.3759,  0.0686,  0.7031,  0.0935],
          [-0.0464, -0.4952,  0.8338,  1.1267]],

         [[-0.6018, -0.2695,  0.3440, -1.1668],
          [-0.9852,  0.4084, -0.0257, -0.0238]],

         [[-1.0834, -0.8992, -0.9997, -1.0471],
          [ 0.6966,  1.0558, -0.3086,  0.1166]],

         [[-0.8312,  0.2855, -0.0538, -0.3936],
          [-1.2552, -0.2012, -0.8806, -0.0192]]],


        [[[ 0.3759,  0.0686,  0.7031,  0.0935],
          [-0.0464, -0.4952,  0.8338,  1.1267]],

         [[-0.6018, -0.2695,  0.3440, -1.1668],
          [-0.9852,  0.4084, -0.0257, -0.0238]],

         [[-1.0834, -0.8992, -0.9997, -1.0471],
          [ 0.6966,  1.0558, -0.3086,  0.1166]],

         [[-0.8312,  0.2855, -0.0538, -0.3936],
          [-1.2552, -0.2012, -0.8806, -0.0192]]]])

In [80]:
mha = MultiHeadAttention( d_in = 8, d_out = 6, context_length=4, dropout=0, num_heads=3 )

In [77]:
mha_out = mha( batches )

In [78]:
mha_out

tensor([[[ 0.1022,  0.1277, -0.4494,  0.2542, -0.3572,  0.1069],
         [ 0.3136, -0.0460, -0.4749,  0.1109, -0.1989,  0.0882],
         [ 0.3183, -0.1529, -0.3511,  0.0830, -0.1742,  0.0765],
         [ 0.3105, -0.1777, -0.3346,  0.0754, -0.1344,  0.0482]],

        [[ 0.1022,  0.1277, -0.4494,  0.2542, -0.3572,  0.1069],
         [ 0.3136, -0.0460, -0.4749,  0.1109, -0.1989,  0.0882],
         [ 0.3183, -0.1529, -0.3511,  0.0830, -0.1742,  0.0765],
         [ 0.3105, -0.1777, -0.3346,  0.0754, -0.1344,  0.0482]]],
       grad_fn=<ViewBackward0>)

In [72]:
mha_out.shape

torch.Size([2, 4, 6])