# Attention with trainable weights

In [24]:
import torch

In [25]:
inputs = torch.nn.Embedding( 4, 8 )

In [26]:
inputs = inputs.weight
inputs

Parameter containing:
tensor([[-0.9999, -1.2737,  0.2339,  1.1430,  1.0805, -0.2544,  1.7306, -0.6830],
        [-0.7617, -2.2307, -0.4388, -0.0068,  0.7100, -0.5587, -0.4294,  0.2109],
        [-0.0450, -0.6184,  0.1731, -0.5922,  0.5761, -1.0463, -0.7864, -0.6586],
        [ 0.5506, -1.1198,  0.6035,  0.6885,  1.4914, -0.5357,  0.2084,  0.8669]],
       requires_grad=True)

In [27]:
inputs = inputs.data
inputs

tensor([[-0.9999, -1.2737,  0.2339,  1.1430,  1.0805, -0.2544,  1.7306, -0.6830],
        [-0.7617, -2.2307, -0.4388, -0.0068,  0.7100, -0.5587, -0.4294,  0.2109],
        [-0.0450, -0.6184,  0.1731, -0.5922,  0.5761, -1.0463, -0.7864, -0.6586],
        [ 0.5506, -1.1198,  0.6035,  0.6885,  1.4914, -0.5357,  0.2084,  0.8669]])

In [28]:
# set dimensions
d_in = 8
d_out = 6

# create weight matrices
W_q = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_k = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_v = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

In [29]:
# choose an input vector and transform it into our query vector using W_q
query = inputs[2] @ W_q
query

tensor([-0.6710, -2.0310, -1.6372, -1.1939, -2.2266, -1.0149])

In [30]:
# calculate attention scores using the keys generated by W_k:
keys = inputs @ W_k
values = inputs @ W_v
print("Keys:", keys)
print("Values:", values )

Keys: tensor([[-0.7485,  2.0350,  1.4101, -0.4210, -0.3566,  1.4974],
        [-1.0786,  0.0713, -0.3865, -2.7995, -1.5736, -0.9684],
        [-0.4913, -0.3656, -0.4984, -1.5643, -1.5112, -1.6918],
        [ 1.2783,  1.9363,  1.6936,  0.6576,  1.4184,  2.0097]])
Values: tensor([[-0.4124,  0.7969,  0.9632,  1.2091,  1.9863, -1.0238],
        [-2.3077, -1.4353, -1.6752, -0.2586, -0.8632, -2.8789],
        [-1.9008, -1.2080, -1.0207, -0.9898, -0.7862, -1.7308],
        [ 0.4161,  3.1818,  1.4108,  1.4710,  1.3720,  0.9129]])


In [31]:
attention_scores = query @ keys.T
attention_scores

tensor([ -6.1626,   9.0406,   8.8375, -13.5461])

In [32]:
attention_weights = torch.softmax( attention_scores / keys.shape[-1]**0.5, dim = -1 )
attention_weights

tensor([1.0486e-03, 5.2015e-01, 4.7875e-01, 5.1463e-05])

In [33]:
attention_weights.sum()

tensor(1.)

In [34]:
context_vector = attention_weights @ values
context_vector

tensor([-2.1108, -1.3239, -1.3589, -0.6070, -0.8232, -2.3271])

In [35]:
import torch.nn as nn


In [36]:
# here's a first version of a SimpleAttention class:

class SimpleAttention( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_k = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_v = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = x @ self.W_q
    keys = x @ self.W_k
    values = x @ self.W_v
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context

In [37]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [38]:
simple.W_v

Parameter containing:
tensor([[0.6921, 0.6845, 0.6072, 0.1782, 0.2162, 0.5661],
        [0.8071, 0.6144, 0.0525, 0.0704, 0.8423, 0.4204],
        [0.4201, 0.7557, 0.4516, 0.6666, 0.4888, 0.2711],
        [0.8857, 0.7941, 0.4394, 0.4806, 0.1353, 0.8625],
        [0.1734, 0.2097, 0.1951, 0.2945, 0.2867, 0.1236],
        [0.6912, 0.5165, 0.9224, 0.0636, 0.5102, 0.0354],
        [0.6063, 0.7372, 0.8954, 0.5583, 0.5290, 0.9026],
        [0.1696, 0.2644, 0.4139, 0.5451, 0.1045, 0.1209]])

In [39]:
context_vectors = simple( inputs )
context_vectors

tensor([[ 0.2017,  0.7431,  0.9255,  1.4054, -0.1963,  0.9789],
        [-2.9610, -2.5906, -1.4886, -0.5604, -2.4817, -1.7724],
        [-2.9415, -2.5736, -1.5034, -0.5694, -2.4537, -1.7656],
        [ 0.4770,  1.0707,  1.1818,  1.6974, -0.0772,  1.0979]])

In [40]:
# here's a second version of a SimpleAttention class ;
# it uses nn.Linear to do things more efficiently

class SimpleAttention( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context

In [41]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [42]:
context_vectors = simple( inputs )
context_vectors

tensor([[-0.0984,  0.2110,  0.4699, -0.3150, -0.1860, -0.3912],
        [-0.1100,  0.2006,  0.4806, -0.2455, -0.2602, -0.4189],
        [-0.1003,  0.2203,  0.4699, -0.3172, -0.1775, -0.3947],
        [-0.0991,  0.2402,  0.4692, -0.3164, -0.1765, -0.4012]],
       grad_fn=<MmBackward0>)

In [43]:
# the problem with this is that each context vector uses information from ALL of the embedding vectors
# in practice, we should only use information about the preceding embedding vectors
# to accomplish this, we'll implement causal attention AKA masked attention

In [44]:
# this is a hack to get some example weights to work with!
# weights = simple( inputs )

queries = simple.W_q( inputs )
keys = simple.W_k( inputs )
values = simple.W_v( inputs )
scores = queries @ keys.T
weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )

weights

tensor([[0.2582, 0.2558, 0.2772, 0.2087],
        [0.2152, 0.2614, 0.3522, 0.1712],
        [0.2511, 0.2575, 0.2659, 0.2256],
        [0.2344, 0.2720, 0.2486, 0.2450]], grad_fn=<SoftmaxBackward0>)

In [45]:
# note that these have already been normalized:
weights.sum( dim=-1 )

tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)

In [46]:
# masking method #1
simple_mask = torch.tril( torch.ones( weights.shape[0], weights.shape[0] ) )
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [47]:
masked_weights = weights*simple_mask
masked_weights

tensor([[0.2582, 0.0000, 0.0000, 0.0000],
        [0.2152, 0.2614, 0.0000, 0.0000],
        [0.2511, 0.2575, 0.2659, 0.0000],
        [0.2344, 0.2720, 0.2486, 0.2450]], grad_fn=<MulBackward0>)

In [48]:
masked_weights.sum( dim=-1 )

tensor([0.2582, 0.4766, 0.7744, 1.0000], grad_fn=<SumBackward1>)

In [49]:
# now, we need to normalize the masked_weights so that each row has sum 1
row_sums = masked_weights.sum( dim=-1, keepdim=True)
row_sums

tensor([[0.2582],
        [0.4766],
        [0.7744],
        [1.0000]], grad_fn=<SumBackward1>)

In [50]:
masked_weights = masked_weights / row_sums
masked_weights.sum( dim=-1)

tensor([1., 1., 1., 1.], grad_fn=<SumBackward1>)

In [51]:
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4515, 0.5485, 0.0000, 0.0000],
        [0.3242, 0.3325, 0.3433, 0.0000],
        [0.2344, 0.2720, 0.2486, 0.2450]], grad_fn=<DivBackward0>)

In [52]:
# masking method #2
mask = torch.triu( torch.ones(weights.shape[0], weights.shape[0]), diagonal = 1 )
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [53]:
mask.bool()

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [54]:
weights

tensor([[0.2582, 0.2558, 0.2772, 0.2087],
        [0.2152, 0.2614, 0.3522, 0.1712],
        [0.2511, 0.2575, 0.2659, 0.2256],
        [0.2344, 0.2720, 0.2486, 0.2450]], grad_fn=<SoftmaxBackward0>)

In [55]:
weights = weights.masked_fill( mask.bool(), -torch.inf )
weights

tensor([[0.2582,   -inf,   -inf,   -inf],
        [0.2152, 0.2614,   -inf,   -inf],
        [0.2511, 0.2575, 0.2659,   -inf],
        [0.2344, 0.2720, 0.2486, 0.2450]], grad_fn=<MaskedFillBackward0>)

In [56]:
masked_weights = torch.softmax( weights, dim=-1 )
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4885, 0.5115, 0.0000, 0.0000],
        [0.3310, 0.3331, 0.3359, 0.0000],
        [0.2461, 0.2555, 0.2496, 0.2487]], grad_fn=<SoftmaxBackward0>)

In [57]:
## Dropout
# idea: randomly select some data to leave out to avoid overfitting
dropout = nn.Dropout( 0.5 )

In [58]:
dropout( masked_weights )

tensor([[0.0000, 0.0000, 0.0000, 0.0000],
        [0.9769, 1.0231, 0.0000, 0.0000],
        [0.6620, 0.0000, 0.6718, 0.0000],
        [0.0000, 0.5111, 0.4992, 0.0000]], grad_fn=<MulBackward0>)

In [59]:
# we need to be able to give our LLM batches of input
# for example:
batches = torch.stack( (inputs, inputs), dim=0)

In [60]:
batches

tensor([[[-0.9999, -1.2737,  0.2339,  1.1430,  1.0805, -0.2544,  1.7306,
          -0.6830],
         [-0.7617, -2.2307, -0.4388, -0.0068,  0.7100, -0.5587, -0.4294,
           0.2109],
         [-0.0450, -0.6184,  0.1731, -0.5922,  0.5761, -1.0463, -0.7864,
          -0.6586],
         [ 0.5506, -1.1198,  0.6035,  0.6885,  1.4914, -0.5357,  0.2084,
           0.8669]],

        [[-0.9999, -1.2737,  0.2339,  1.1430,  1.0805, -0.2544,  1.7306,
          -0.6830],
         [-0.7617, -2.2307, -0.4388, -0.0068,  0.7100, -0.5587, -0.4294,
           0.2109],
         [-0.0450, -0.6184,  0.1731, -0.5922,  0.5761, -1.0463, -0.7864,
          -0.6586],
         [ 0.5506, -1.1198,  0.6035,  0.6885,  1.4914, -0.5357,  0.2084,
           0.8669]]])

In [61]:
batches.shape

torch.Size([2, 4, 8])

In [62]:
# this class needs to handle batches of input!

class CausalAttention( nn.Module ):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    super().__init__()
    self.d_out = d_out
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )
    # include dropout:
    self.dropout = nn.Dropout( dropout )
    # use the following to manage memory efficiently:
    self.register_buffer(
        'mask',
        torch.triu( torch.ones(context_length, context_length), diagonal = 1 )
    )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    b, num_tokens, d_in = x.shape
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )
    scores = queries @ keys.transpose(1,2)
    scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    weights = self.dropout( weights )
    context = weights @ values
    return context

In [63]:
# instantiate a causal attention mechanism:
causal = CausalAttention( d_in=8, d_out=6, context_length=4, dropout=0 )

In [64]:
causal( batches )

tensor([[[-0.5211,  1.3335,  0.0292,  0.0482, -0.0546,  1.3317],
         [-0.1602,  0.9817,  0.3829,  0.2854,  0.0501,  1.1420],
         [-0.1041,  0.8093,  0.4927,  0.1810, -0.0588,  0.8028],
         [ 0.0167,  0.8670,  0.3698,  0.1432,  0.0331,  0.7435]],

        [[-0.5211,  1.3335,  0.0292,  0.0482, -0.0546,  1.3317],
         [-0.1602,  0.9817,  0.3829,  0.2854,  0.0501,  1.1420],
         [-0.1041,  0.8093,  0.4927,  0.1810, -0.0588,  0.8028],
         [ 0.0167,  0.8670,  0.3698,  0.1432,  0.0331,  0.7435]]],
       grad_fn=<UnsafeViewBackward0>)

In [65]:
# here's a first pass at multi-head attention
class MultiHeadAttention( nn.Module ):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [ CausalAttention( d_in, d_out, context_length, dropout, qkv_bias ) for _ in range(num_heads) ]
        )

    def forward( self, x ):
        return torch.cat( [ head(x) for head in self.heads ], dim=-1 )

In [66]:
mha = MultiHeadAttention( d_in = 8, d_out = 6, context_length= 4, dropout=0, num_heads = 3 )

In [67]:
mha_out = mha( batches )

In [68]:
mha_out

tensor([[[ 1.4208, -0.2453,  0.5383,  0.4345, -0.2936, -0.8508,  0.0527,
          -0.0537,  1.1098,  0.3376,  0.2062,  0.2342, -1.1751, -0.3572,
          -0.3544, -0.1014,  0.0140,  0.5943],
         [ 1.0501,  0.0343,  0.4285,  0.2933, -0.2995, -0.2432, -0.0300,
          -0.6992,  0.7993,  0.4846,  0.2376,  0.0484, -0.7183, -0.1285,
           0.0313, -0.2238, -0.0749,  0.4793],
         [ 0.8152,  0.2136,  0.3962,  0.3224, -0.1273, -0.2168, -0.2541,
          -0.4810,  0.4369,  0.2927,  0.3138,  0.0792, -0.5553, -0.0028,
           0.0056, -0.2044,  0.0563,  0.2354],
         [ 0.7609,  0.2804,  0.4015,  0.2737, -0.0097, -0.0695, -0.1644,
          -0.4264,  0.6622,  0.1981,  0.1975, -0.0442, -0.4324,  0.2567,
           0.0029, -0.2155, -0.0194,  0.3436]],

        [[ 1.4208, -0.2453,  0.5383,  0.4345, -0.2936, -0.8508,  0.0527,
          -0.0537,  1.1098,  0.3376,  0.2062,  0.2342, -1.1751, -0.3572,
          -0.3544, -0.1014,  0.0140,  0.5943],
         [ 1.0501,  0.0343,  0.42

In [69]:
mha_out.shape

torch.Size([2, 4, 18])

In [70]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`,
        # this will result in errors in the mask creation further below.
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs
        # do not exceed `context_length` before reaching this forward method.

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

In [71]:
batches.shape

torch.Size([2, 4, 8])

In [72]:
batches

tensor([[[-0.9999, -1.2737,  0.2339,  1.1430,  1.0805, -0.2544,  1.7306,
          -0.6830],
         [-0.7617, -2.2307, -0.4388, -0.0068,  0.7100, -0.5587, -0.4294,
           0.2109],
         [-0.0450, -0.6184,  0.1731, -0.5922,  0.5761, -1.0463, -0.7864,
          -0.6586],
         [ 0.5506, -1.1198,  0.6035,  0.6885,  1.4914, -0.5357,  0.2084,
           0.8669]],

        [[-0.9999, -1.2737,  0.2339,  1.1430,  1.0805, -0.2544,  1.7306,
          -0.6830],
         [-0.7617, -2.2307, -0.4388, -0.0068,  0.7100, -0.5587, -0.4294,
           0.2109],
         [-0.0450, -0.6184,  0.1731, -0.5922,  0.5761, -1.0463, -0.7864,
          -0.6586],
         [ 0.5506, -1.1198,  0.6035,  0.6885,  1.4914, -0.5357,  0.2084,
           0.8669]]])

In [73]:
batches.view( 2, 4, 2, 4 )

tensor([[[[-0.9999, -1.2737,  0.2339,  1.1430],
          [ 1.0805, -0.2544,  1.7306, -0.6830]],

         [[-0.7617, -2.2307, -0.4388, -0.0068],
          [ 0.7100, -0.5587, -0.4294,  0.2109]],

         [[-0.0450, -0.6184,  0.1731, -0.5922],
          [ 0.5761, -1.0463, -0.7864, -0.6586]],

         [[ 0.5506, -1.1198,  0.6035,  0.6885],
          [ 1.4914, -0.5357,  0.2084,  0.8669]]],


        [[[-0.9999, -1.2737,  0.2339,  1.1430],
          [ 1.0805, -0.2544,  1.7306, -0.6830]],

         [[-0.7617, -2.2307, -0.4388, -0.0068],
          [ 0.7100, -0.5587, -0.4294,  0.2109]],

         [[-0.0450, -0.6184,  0.1731, -0.5922],
          [ 0.5761, -1.0463, -0.7864, -0.6586]],

         [[ 0.5506, -1.1198,  0.6035,  0.6885],
          [ 1.4914, -0.5357,  0.2084,  0.8669]]]])

In [74]:
mha = MultiHeadAttention( d_in = 8, d_out = 6, context_length=4, dropout=0, num_heads=3 )

In [75]:
mha_out = mha( batches )

In [76]:
mha_out

tensor([[[-0.6244, -0.5975, -0.1934,  0.5200, -0.5116,  0.2835],
         [-0.1321, -0.4701, -0.4574,  0.0414, -0.4035,  0.3737],
         [ 0.0160, -0.3677, -0.3209,  0.0119, -0.3741,  0.3938],
         [ 0.0414, -0.3517, -0.2693,  0.0892, -0.4068,  0.3144]],

        [[-0.6244, -0.5975, -0.1934,  0.5200, -0.5116,  0.2835],
         [-0.1321, -0.4701, -0.4574,  0.0414, -0.4035,  0.3737],
         [ 0.0160, -0.3677, -0.3209,  0.0119, -0.3741,  0.3938],
         [ 0.0414, -0.3517, -0.2693,  0.0892, -0.4068,  0.3144]]],
       grad_fn=<ViewBackward0>)

In [77]:
mha_out.shape

torch.Size([2, 4, 6])