<a href="https://colab.research.google.com/github/mbrudd/LLMs/blob/main/trainable_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Attention with trainable weights

In [1]:
import torch

In [2]:
inputs = torch.nn.Embedding( 4, 8 )

In [3]:
inputs = inputs.weight
inputs

Parameter containing:
tensor([[-1.3292,  0.8097, -0.5553,  0.6602,  0.7835, -0.8741, -0.1003,  0.2274],
        [ 1.5720, -0.8101,  1.1156,  0.1344, -0.0241, -1.1807,  0.1780,  0.6366],
        [-1.3231,  1.6145,  0.9073, -1.0970, -1.6432,  0.4971, -0.7370,  1.1564],
        [ 0.0386, -0.1957,  0.2020,  0.6749, -1.2907,  1.2601,  1.8761, -0.9480]],
       requires_grad=True)

In [4]:
inputs = inputs.data
inputs

tensor([[-1.3292,  0.8097, -0.5553,  0.6602,  0.7835, -0.8741, -0.1003,  0.2274],
        [ 1.5720, -0.8101,  1.1156,  0.1344, -0.0241, -1.1807,  0.1780,  0.6366],
        [-1.3231,  1.6145,  0.9073, -1.0970, -1.6432,  0.4971, -0.7370,  1.1564],
        [ 0.0386, -0.1957,  0.2020,  0.6749, -1.2907,  1.2601,  1.8761, -0.9480]])

In [5]:
# set dimensions
d_in = 8
d_out = 6

# create weight matrices
W_q = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_k = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_v = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

In [6]:
# choose an input vector and transform it into our query vector using W_q
query = inputs[2] @ W_q
query

tensor([-0.1789,  0.2137,  0.0313, -0.7703, -0.1273, -1.8762])

In [7]:
# calculate attention scores using the keys generated by W_k:
keys = inputs @ W_k
values = inputs @ W_v
print("Keys:", keys)
print("Values:", values )

Keys: tensor([[-0.0394, -0.1689,  0.1785, -0.2056, -0.2069,  0.1211],
        [ 1.5428,  0.8306, -0.4155,  1.2433,  1.0049,  1.8840],
        [-1.1328,  0.5651,  0.8740,  0.4499,  0.2791, -1.5555],
        [ 0.5014,  0.9252,  0.8923,  0.9109,  2.3291,  0.9026]])
Values: tensor([[-2.7887e-02,  1.2724e-01, -7.9464e-01, -9.2495e-01, -3.9008e-01,
         -8.9389e-04],
        [ 7.1025e-01, -1.2465e-01,  5.1470e-01,  1.9723e+00,  1.4213e+00,
          1.1743e+00],
        [ 8.6978e-01,  1.3091e+00, -8.4407e-01,  1.5862e-01, -5.7497e-01,
         -1.9002e+00],
        [-1.9215e-01,  1.9280e+00,  1.4745e+00,  9.7672e-01, -2.5122e-01,
          8.9286e-01]])


In [8]:
attention_scores = query @ keys.T
attention_scores

tensor([-0.0660, -4.7322,  2.8872, -2.5558])

In [9]:
attention_weights = torch.softmax( attention_scores / keys.shape[-1]**0.5, dim = -1 )
attention_weights

tensor([0.2062, 0.0307, 0.6885, 0.0746])

In [10]:
attention_weights.sum()

tensor(1.0000)

In [11]:
context_vector = attention_weights @ values
context_vector

tensor([ 0.6005,  1.0676, -0.6192,  0.0519, -0.4514, -1.2058])

In [12]:
import torch.nn as nn


In [13]:
# here's a first version of a SimpleAttention class:

class SimpleAttention( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_k = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_v = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = x @ self.W_q
    keys = x @ self.W_k
    values = x @ self.W_v
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context

In [14]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [15]:
simple.W_v

Parameter containing:
tensor([[0.6136, 0.6489, 0.9762, 0.5870, 0.1912, 0.2031],
        [0.1286, 0.1703, 0.6864, 0.0121, 0.9506, 0.1873],
        [0.2859, 0.7420, 0.6053, 0.5492, 0.3664, 0.4607],
        [0.7348, 0.2177, 0.6334, 0.1087, 0.6322, 0.4288],
        [0.8806, 0.1077, 0.4198, 0.0882, 0.6358, 0.5097],
        [0.3633, 0.8876, 0.7995, 0.6740, 0.7711, 0.7196],
        [0.3058, 0.5988, 0.4124, 0.4454, 0.4333, 0.3269],
        [0.1312, 0.2132, 0.4207, 0.4083, 0.1324, 0.0147]])

In [16]:
context_vectors = simple( inputs )
context_vectors

tensor([[-0.5132, -0.5743, -0.3829, -0.5475,  0.4309, -0.2044],
        [ 0.4843,  1.0827,  0.9863,  1.0070, -0.2096,  0.1649],
        [-0.2072, -1.4447, -0.8565, -1.2619,  0.4990, -0.3463],
        [ 0.6911,  0.9978,  1.0359,  1.0256, -0.4437,  0.0611]])

In [25]:
# here's a second version of a SimpleAttention class ;
# it uses nn.Linear to do things more efficiently

class SimpleAttention( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context

In [26]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [27]:
context_vectors = simple( inputs )
context_vectors

tensor([[ 0.2260, -0.2203, -0.3082,  0.2058, -0.2902, -0.2902],
        [ 0.1543, -0.3560, -0.0465,  0.2560,  0.0202, -0.1348],
        [ 0.2918, -0.1410, -0.4976,  0.1799, -0.4853, -0.4039],
        [ 0.0494, -0.1248, -0.0317,  0.3531,  0.0764, -0.0907]],
       grad_fn=<MmBackward0>)

In [21]:
# the problem with this is that each context vector uses information from ALL of the embedding vectors
# in practice, we should only use information about the preceding embedding vectors
# to accomplish this, we'll implement causal attention AKA masked attention

In [28]:
# this is a hack to get some example weights to work with!
# weights = simple( inputs )
weights

tensor([[0.2677, 0.2808, 0.2081, 0.2434],
        [0.2211, 0.2239, 0.3033, 0.2518],
        [0.2330, 0.2468, 0.1668, 0.3533],
        [0.2251, 0.2072, 0.3832, 0.1845]], grad_fn=<SoftmaxBackward0>)

In [30]:
# note that these have already been normalized:
weights.sum( dim=-1 )

tensor([1., 1., 1., 1.], grad_fn=<SumBackward1>)

In [32]:
# masking method #1
simple_mask = torch.tril( torch.ones( weights.shape[0], weights.shape[0] ) )
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [35]:
masked_weights = weights*simple_mask
masked_weights

tensor([[0.2677, 0.0000, 0.0000, 0.0000],
        [0.2211, 0.2239, 0.0000, 0.0000],
        [0.2330, 0.2468, 0.1668, 0.0000],
        [0.2251, 0.2072, 0.3832, 0.1845]], grad_fn=<MulBackward0>)

In [36]:
masked_weights.sum( dim=-1 )

tensor([0.2677, 0.4450, 0.6467, 1.0000], grad_fn=<SumBackward1>)

In [38]:
# now, we need to normalize the masked_weights so that each row has sum 1
row_sums = masked_weights.sum( dim=-1, keepdim=True)
row_sums

tensor([[0.2677],
        [0.4450],
        [0.6467],
        [1.0000]], grad_fn=<SumBackward1>)

In [39]:
masked_weights = masked_weights / row_sums
masked_weights.sum( dim=-1)

tensor([1., 1., 1., 1.], grad_fn=<SumBackward1>)

In [40]:
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4968, 0.5032, 0.0000, 0.0000],
        [0.3603, 0.3817, 0.2579, 0.0000],
        [0.2251, 0.2072, 0.3832, 0.1845]], grad_fn=<DivBackward0>)

In [46]:
# masking method #2
mask = torch.triu( torch.ones(weights.shape[0], weights.shape[0]), diagonal = 1 )
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [47]:
mask.bool()

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [50]:
weights

tensor([[0.2677, 0.2808, 0.2081, 0.2434],
        [0.2211, 0.2239, 0.3033, 0.2518],
        [0.2330, 0.2468, 0.1668, 0.3533],
        [0.2251, 0.2072, 0.3832, 0.1845]], grad_fn=<SoftmaxBackward0>)

In [51]:
weights = weights.masked_fill( mask.bool(), -torch.inf )
weights

tensor([[0.2677,   -inf,   -inf,   -inf],
        [0.2211, 0.2239,   -inf,   -inf],
        [0.2330, 0.2468, 0.1668,   -inf],
        [0.2251, 0.2072, 0.3832, 0.1845]], grad_fn=<MaskedFillBackward0>)

In [52]:
masked_weights = torch.softmax( weights, dim=-1 )
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4993, 0.5007, 0.0000, 0.0000],
        [0.3390, 0.3437, 0.3173, 0.0000],
        [0.2431, 0.2388, 0.2847, 0.2334]], grad_fn=<SoftmaxBackward0>)

In [58]:
## Dropout
# idea: randomly select some data to leave out to avoid overfitting
dropout = nn.Dropout( 0.5 )

In [67]:
dropout( masked_weights )

tensor([[0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0014, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6346, 0.0000],
        [0.4862, 0.0000, 0.5694, 0.4668]], grad_fn=<MulBackward0>)

In [53]:
# we need to be able to give our LLM batches of input
# for example:
batches = torch.stack( (inputs, inputs), dim=0)

In [56]:
batches.shape

torch.Size([2, 4, 8])

In [None]:
# this class needs to handle batches of input!

class CausalAttention( nn.Module ):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )
    self.dropout = nn.Dropout( dropout )
    self.register_buffer('mask', torch.triu( torch.ones(context_length, context_length), diagonal = 1 ))

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context

In [1]:
causal = CausalAttention( d_in=8, d_out=6, context_length=4, dropout=0)

W_q = nn.Linear( d_in, d_out, bias=False )
W_k = nn.Linear( d_in, d_out, bias=False )
W_v = nn.Linear( d_in, d_out, bias=False )

queries = W_q( batches )
queries

NameError: name 'CausalAttention' is not defined