In [110]:
import os
import requests
import tiktoken
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np


In [111]:
torch.manual_seed(1337)
B, T, C = 4, 8, 16
x = torch.randn(B, T, C)

x.shape

torch.Size([4, 8, 16])

In [112]:
#self-attention v1

xbow = torch.zeros((B, T, C))

for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]
        xbow[b, t] = torch.mean(xprev, 0)

print(xbow[0])

tensor([[ 1.8077e-01, -6.9988e-02, -3.5962e-01, -9.1520e-01,  6.2577e-01,
          2.5510e-02,  9.5451e-01,  6.4349e-02,  3.6115e-01,  1.1679e+00,
         -1.3499e+00, -5.1018e-01,  2.3596e-01, -2.3978e-01, -9.2111e-01,
          1.5433e+00],
        [ 7.6480e-01, -1.0481e-01, -3.6913e-02,  2.4958e-02, -7.0569e-01,
          2.5932e-01,  1.2208e+00,  3.2769e-01,  2.4359e-01, -1.9740e-01,
         -1.2550e+00, -4.2251e-01,  3.4187e-01, -5.2071e-01,  3.0125e-01,
          2.0259e+00],
        [ 2.8883e-01, -1.5364e-01,  3.1211e-01,  5.7154e-02, -4.1766e-01,
          5.5089e-01,  4.2920e-01,  1.1899e-01, -6.7854e-03, -4.3958e-01,
         -6.5443e-01, -7.7993e-01, -1.7399e-01, -1.5653e-01,  1.7141e-03,
          1.1194e+00],
        [ 6.2801e-01, -3.1597e-01,  5.7193e-01, -2.6114e-02, -6.9095e-01,
          9.3936e-01,  1.0126e+00, -3.4739e-01,  3.5781e-01, -7.0725e-01,
         -2.8553e-01, -6.3783e-01,  6.4228e-02,  2.6592e-01,  4.0372e-01,
          7.3875e-01],
        [ 3.3551e-01

In [113]:
#self-attention v2

w = torch.tril(torch.ones((T, T)))
w = w / torch.sum(w, 1, keepdims=True) 

xbow = w @ x

xbow[0]

tensor([[ 1.8077e-01, -6.9988e-02, -3.5962e-01, -9.1520e-01,  6.2577e-01,
          2.5510e-02,  9.5451e-01,  6.4349e-02,  3.6115e-01,  1.1679e+00,
         -1.3499e+00, -5.1018e-01,  2.3596e-01, -2.3978e-01, -9.2111e-01,
          1.5433e+00],
        [ 7.6480e-01, -1.0481e-01, -3.6913e-02,  2.4958e-02, -7.0569e-01,
          2.5932e-01,  1.2208e+00,  3.2769e-01,  2.4359e-01, -1.9740e-01,
         -1.2550e+00, -4.2251e-01,  3.4187e-01, -5.2071e-01,  3.0125e-01,
          2.0259e+00],
        [ 2.8883e-01, -1.5364e-01,  3.1211e-01,  5.7154e-02, -4.1766e-01,
          5.5089e-01,  4.2920e-01,  1.1899e-01, -6.7854e-03, -4.3958e-01,
         -6.5443e-01, -7.7993e-01, -1.7399e-01, -1.5653e-01,  1.7140e-03,
          1.1194e+00],
        [ 6.2801e-01, -3.1597e-01,  5.7193e-01, -2.6114e-02, -6.9095e-01,
          9.3936e-01,  1.0126e+00, -3.4739e-01,  3.5781e-01, -7.0725e-01,
         -2.8553e-01, -6.3783e-01,  6.4228e-02,  2.6592e-01,  4.0372e-01,
          7.3875e-01],
        [ 3.3551e-01

In [114]:
#self-attention v3

tril = torch.tril(torch.ones((T, T)))
w = torch.zeros((T, T))
w = w.masked_fill(tril==0, float('-inf'))
w = F.softmax(w, dim=1)

xbow = w @ x

xbow[0] 

w

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [115]:
#self-attention v4 (ATTENTION IS ALL YOU NEED)

torch.manual_seed(1337)
B, T, C = 4, 8, 32 # example token_embedding + position_embedding
x = torch.randn(B, T, C) # represent token embedding

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)

w = q @ k.transpose(-2, -1) * head_size**-0.5 # (B, T, 16) @ (B, 16, T) --> (B, T, T)

tril = torch.tril(torch.ones((T, T)))
# w = torch.zeros((T, T))
w = w.masked_fill(tril==0, float('-inf'))
w = F.softmax(w, dim=1)

v = value(x) # linear transform token embedding 

xbow = w @ v

xbow[0]

tensor([[-0.0152,  0.0849,  0.0156, -0.0755, -0.0138,  0.0720,  0.0097, -0.0505,
         -0.0856,  0.0184,  0.0170, -0.0573, -0.0464, -0.0469,  0.0276,  0.0551],
        [ 0.0623, -0.0137, -0.0178, -0.0057, -0.0202,  0.0059, -0.0396, -0.0369,
         -0.0484,  0.0845,  0.0824, -0.0411,  0.0103,  0.0118,  0.0066,  0.1595],
        [ 0.1741, -0.0228, -0.1270,  0.0348,  0.0333, -0.0332, -0.0148, -0.0234,
         -0.1168, -0.0624,  0.1308,  0.0149, -0.0893, -0.0227, -0.0331,  0.4593],
        [ 0.2850, -0.1001, -0.2212,  0.0618,  0.1816, -0.1536, -0.0042, -0.0282,
         -0.0725,  0.0207,  0.0213,  0.0584, -0.0184,  0.0740,  0.1509,  0.5019],
        [ 0.3354,  0.1264,  0.0189,  0.1820,  0.2485,  0.0998,  0.1852,  0.1277,
         -0.2214, -0.3642, -0.0985, -0.0211, -0.3042,  0.1374,  0.0634,  0.6118],
        [ 0.2923,  0.1706, -0.0102,  0.4219,  0.3595,  0.1902,  0.1295,  0.0516,
         -0.2165, -0.3195, -0.0527,  0.1205, -0.3123,  0.2185,  0.2450,  0.8324],
        [-0.2267,  0.1

In [116]:
v[0, :, :5]

tensor([[-0.1571,  0.8801,  0.1615, -0.7824, -0.1429],
        [ 0.8321, -0.8144, -0.3242,  0.5191, -0.1252],
        [ 0.6035, -0.2500, -0.6159,  0.4068,  0.3328],
        [ 0.6657, -0.7096, -0.6099,  0.4348,  0.8975],
        [ 0.1536,  1.0439,  0.8457,  0.2388,  0.3005],
        [-0.8920,  0.0578, -0.3350,  0.8477,  0.3876],
        [-0.4849,  0.1655, -0.2221, -0.1345, -0.0864],
        [ 0.2042,  0.3772, -1.1255,  0.3995,  0.1489]],
       grad_fn=<SliceBackward0>)

In [117]:
w

tensor([[[0.0964, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0651, 0.0872, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1160, 0.0963, 0.1859, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1823, 0.1080, 0.1677, 0.1842, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1094, 0.1326, 0.1498, 0.1637, 0.2794, 0.0000, 0.0000, 0.0000],
         [0.1386, 0.2413, 0.1775, 0.1777, 0.3875, 0.1924, 0.0000, 0.0000],
         [0.1967, 0.2156, 0.1709, 0.2105, 0.1955, 0.4954, 0.4261, 0.0000],
         [0.0954, 0.1190, 0.1482, 0.2639, 0.1375, 0.3122, 0.5739, 1.0000]],

        [[0.1146, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0636, 0.1111, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1215, 0.0961, 0.2046, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1946, 0.1556, 0.2426, 0.1524, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0764, 0.1438, 0.1435, 0.2914, 0.2430, 0.0000, 0.0000, 0.0000],
         [0.1094, 0.138