In [1]:
import torch
import tiktoken

In [3]:
raw_text = "This is my dog Chico. "

In [4]:
tokenizer = tiktoken.get_encoding("gpt2")

In [5]:
enc_text = tokenizer.encode_ordinary( raw_text )
print( enc_text )

[1212, 318, 616, 3290, 609, 3713, 13, 220]


In [6]:
vocab_size = 4
output_dim = 8
inputs = torch.nn.Embedding( vocab_size, output_dim )
print(inputs.weight)

Parameter containing:
tensor([[ 0.1921, -0.7420, -0.3598,  1.5413,  0.0997, -0.3321,  1.2050, -1.7420],
        [ 1.1249,  1.9819,  1.3493,  0.9413, -0.1617,  0.3775,  0.5600,  1.0466],
        [-0.9652,  0.1160,  1.1386, -1.0232,  1.4958, -0.9316,  0.8884,  0.0799],
        [-0.6549,  0.1448, -2.6235, -0.0833, -0.1636,  0.4186,  0.0577,  0.4814]],
       requires_grad=True)


In [7]:
inputs = inputs.weight.data
inputs

tensor([[ 0.1921, -0.7420, -0.3598,  1.5413,  0.0997, -0.3321,  1.2050, -1.7420],
        [ 1.1249,  1.9819,  1.3493,  0.9413, -0.1617,  0.3775,  0.5600,  1.0466],
        [-0.9652,  0.1160,  1.1386, -1.0232,  1.4958, -0.9316,  0.8884,  0.0799],
        [-0.6549,  0.1448, -2.6235, -0.0833, -0.1636,  0.4186,  0.0577,  0.4814]])

In [8]:
inputs.shape

torch.Size([4, 8])

In [9]:
for row in inputs:
    print( row.tolist() )

[0.19207634031772614, -0.7419674396514893, -0.35979023575782776, 1.541276216506958, 0.0996936783194542, -0.33205875754356384, 1.2049641609191895, -1.74197518825531]
[1.1248893737792969, 1.9819362163543701, 1.349258542060852, 0.941271960735321, -0.16170397400856018, 0.3774888515472412, 0.5599966645240784, 1.0465807914733887]
[-0.9651904702186584, 0.11599811911582947, 1.1386221647262573, -1.0231502056121826, 1.495785117149353, -0.9315903782844543, 0.888405978679657, 0.07992368191480637]
[-0.6548763513565063, 0.14481131732463837, -2.6235384941101074, -0.08329551666975021, -0.16362157464027405, 0.41856884956359863, 0.057707738131284714, 0.4813976585865021]


In [10]:
x = torch.tensor([1.1,2.3])
y = torch.tensor([3.4,-2.1])


In [11]:
torch.dot(x,y)

tensor(-1.0900)

In [12]:
query = inputs[2]
print( query )

tensor([-0.9652,  0.1160,  1.1386, -1.0232,  1.4958, -0.9316,  0.8884,  0.0799])


In [13]:
for i in range(len(inputs)):
    print( torch.dot( query, inputs[i] ) )

tensor(-0.8683)
tensor(-0.2950)
tensor(7.1892)
tensor(-2.7981)


In [14]:
attention_scores_2 = torch.zeros(len(inputs)) 
for i in range( len( inputs ) ):
    attention_scores_2[i] = torch.dot( query, inputs[i] )
print( attention_scores_2 )

tensor([-0.8683, -0.2950,  7.1892, -2.7981])


In [15]:
# normalize the attention scores using the softmax function
# def softmax(x):
#     torch.exp(x) / torch.exp(x).sum()

In [16]:
attention_weights_2 = torch.softmax(attention_scores_2, dim=0)
attention_weights_2

tensor([3.1640e-04, 5.6136e-04, 9.9908e-01, 4.5939e-05])

In [17]:
attention_weights_2.sum()

tensor(1.0000)

In [18]:
context_vector_2 = torch.zeros( query.shape )
for i in range( len( attention_weights_2 ) ):
    context_vector_2 += attention_weights_2[i] * inputs[i]
context_vector_2

tensor([-0.9636,  0.1168,  1.1381, -1.0212,  1.4943, -0.9306,  0.8883,  0.0799])

In [19]:
# get all of the attention scores via matrix multiplication
attention_scores_2 = inputs @ inputs.T
attention_scores_2

tensor([[ 7.6990, -1.5790, -0.8683, -0.3420],
        [-1.5790,  9.4775, -0.2950, -3.3473],
        [-0.8683, -0.2950,  7.1892, -2.7981],
        [-0.3420, -3.3473, -2.7981,  7.7768]])

In [20]:
attention_weights = torch.softmax(attention_scores_2, dim=-1)
attention_weights 

tensor([[9.9939e-01, 9.3403e-05, 1.9010e-04, 3.2178e-04],
        [1.5783e-05, 9.9992e-01, 5.6993e-05, 2.6929e-06],
        [3.1640e-04, 5.6136e-04, 9.9908e-01, 4.5939e-05],
        [2.9778e-04, 1.4748e-05, 2.5543e-05, 9.9966e-01]])

In [21]:
attention_weights[0].sum()

tensor(1.0000)

In [22]:
context_vectors = attention_weights @ inputs
context_vectors

tensor([[ 0.1917, -0.7413, -0.3601,  1.5402,  0.0998, -0.3319,  1.2045, -1.7407],
        [ 1.1248,  1.9818,  1.3492,  0.9412, -0.1616,  0.3774,  0.5600,  1.0465],
        [-0.9636,  0.1168,  1.1381, -1.0212,  1.4943, -0.9306,  0.8883,  0.0799],
        [-0.6546,  0.1446, -2.6227, -0.0828, -0.1635,  0.4183,  0.0581,  0.4807]])