In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
shakespeare = open('input.txt', 'r').read()

In [3]:
vocab = sorted(list(set(''.join(shakespeare))))
vocab_size = len(vocab)
vocab_size

65

In [4]:
''.join(vocab)

"\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

In [5]:
stoi = {ch:i for i,ch in enumerate(vocab)}
itos = {i:ch for ch,i in stoi.items()}

encode = lambda text: [stoi[ch] for ch in text]
decode = lambda idx: ''.join([itos[i] for i in idx])

decode(encode('Hello there')) # test encode-decode functionality

'Hello there'

In [6]:
print(len(shakespeare))
encoded_text = encode(shakespeare)
print(len(encoded_text))

n = int(len(encoded_text) * 0.9)
train_data = encoded_text[:n]
val_data = encoded_text[n:]

1115394
1115394


In [7]:
block_size = 8

In [8]:
train_data[:block_size + 1]

[18, 47, 56, 57, 58, 1, 15, 47, 58]

In [9]:
# look at how inputs - output pairs look like with a given block_size

x = train_data[:block_size + 1]
for i in range(1, block_size+1):
    inp = x[:i]
    output = x[i]
    print(f'{inp} --> {output}')

[18] --> 47
[18, 47] --> 56
[18, 47, 56] --> 57
[18, 47, 56, 57] --> 58
[18, 47, 56, 57, 58] --> 1
[18, 47, 56, 57, 58, 1] --> 15
[18, 47, 56, 57, 58, 1, 15] --> 47
[18, 47, 56, 57, 58, 1, 15, 47] --> 58


In [10]:
torch.manual_seed(2)

# making a batch of data
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(0, len(train_data) - block_size, (batch_size,))

    x = torch.tensor([train_data[i : i + block_size] for i in ix])
    y = torch.tensor([train_data[i+1 : i + 1 + block_size] for i in ix])

    return x, y

In [11]:
xb, yb = get_batch('train')

print(xb)
print(yb)

tensor([[16,  1, 21, 21, 10,  0, 25, 63],
        [50, 47, 54, 58,  1, 50, 47, 49],
        [43,  1, 61, 46, 47, 41, 46,  1],
        [59, 42,  1, 40, 47, 56, 42, 57]])
tensor([[ 1, 21, 21, 10,  0, 25, 63,  1],
        [47, 54, 58,  1, 50, 47, 49, 43],
        [ 1, 61, 46, 47, 41, 46,  1, 63],
        [42,  1, 40, 47, 56, 42, 57,  6]])


In [12]:
for b in range(batch_size):
    for t in range(block_size):
        inp = xb[b, : t+1]
        out = yb[b, t]
        print(f'{inp} --> {out}')
    break

tensor([16]) --> 1
tensor([16,  1]) --> 21
tensor([16,  1, 21]) --> 21
tensor([16,  1, 21, 21]) --> 10
tensor([16,  1, 21, 21, 10]) --> 0
tensor([16,  1, 21, 21, 10,  0]) --> 25
tensor([16,  1, 21, 21, 10,  0, 25]) --> 63
tensor([16,  1, 21, 21, 10,  0, 25, 63]) --> 1


In [13]:
# torch.ones(2,4,8)[:,-1:,:].shape

In [14]:
class BigramLM(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_encoding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, x, targets = None):
        # x shape       - (b, t)
        # targets shape - (b, t)
        logits = self.token_encoding_table(x) # (b, t, vocab_size)
        
        if targets is None:
            loss = None
        else:
            # B,T,C = logits.shape
            loss = F.cross_entropy(logits.transpose(-1,-2), targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (b, t)
        for _ in range(max_new_tokens):
            logits, loss = self(idx) # (b, t, vocab_size)
            logits = logits[:, -1, :] # (b, vocab_size)
            probs = F.softmax(logits, dim=-1) # (b, vocab_size)
            idx_next = torch.multinomial(probs, num_samples=1) # (b, 1)
            idx = torch.cat((idx, idx_next), dim=1) # (b, t+1)
        return idx


In [15]:
bigram_model = BigramLM(vocab_size)

xb, yb = get_batch('train')

# logits, loss = bigram_model(xb, yb)
# logits.shape, loss.item()

In [16]:
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(bigram_model.generate(idx, max_new_tokens=100)[0].tolist()))


AmYBaDstkNBPPuYGI$3yXgka.
LXuJKcVU.zmlDdsV?!fgNnfpKojBlaZXuGw:crFdlqgjUQ3KM3GALSiOcKZTRv SX$$.GBKZYs


In [17]:
# optimizer
optimizer = torch.optim.AdamW(bigram_model.parameters(), lr=1e-3)

In [18]:
# training
batch_size = 32
for _ in range(10000):
    xb, yb = get_batch('train')

    logits, loss = bigram_model(xb, yb)
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())

2.5684943199157715


In [19]:
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(bigram_model.generate(idx, max_new_tokens=500)[0].tolist()))


F yore, ws oosit ges the, owhowinou he

Whe Fimasicismam  hyond l y IENII ar Cllered o, PUSomy


I d s, t w illes.
Yow sxeaghag at ithesthend pan s, ue, horele dw
TOXmacurngomen ain d r-wndsar irtheithon aks cokn.
NCHimatote Bumandsthaceiofe laterelee.
Mat T:
ifaseng aY paroois, ts?
PPey indditay ldv nga$MIAnther dith th hanto me manore o madyoner-b.
Thich pr hontit Cly o-cave nd whitomest d g d IN han:

NVFr stend th, t hit ovedo worinst hechop, m.
Bozes! mur he depy vare cesand CKI tl ustrm, I


## Tricks to calculate interactions between tokens

# Attention

In [20]:
B,T,C = 4,8,2
x = torch.randn((B,T,C))
x[0]

tensor([[-0.3352, -0.3542],
        [ 1.1348, -2.2493],
        [-0.1779, -0.7811],
        [-2.7670, -2.7058],
        [ 0.3049,  0.5375],
        [ 0.4849, -1.5841],
        [ 0.7020,  1.2656],
        [ 0.2472,  0.6762]])

In [21]:
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        x_context = x[b, :t+1] # (t,c)
        xbow[b,t] = torch.mean(x_context, dim=0)

xbow[0] # avg of all previous time steps including current step 't'

tensor([[-0.3352, -0.3542],
        [ 0.3998, -1.3018],
        [ 0.2072, -1.1282],
        [-0.5363, -1.5226],
        [-0.3681, -1.1106],
        [-0.2259, -1.1895],
        [-0.0933, -0.8388],
        [-0.0508, -0.6494]])

In [22]:
# doing it efficiently: version 2

wei = torch.tril(torch.ones(T,T))
wei = wei / torch.sum(wei, dim=-1, keepdim=True)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [23]:
xbow2 = wei @ x # (t,t) @ (b,t,c) --> (b,t,c)
xbow2[0]

tensor([[-0.3352, -0.3542],
        [ 0.3998, -1.3018],
        [ 0.2072, -1.1282],
        [-0.5363, -1.5226],
        [-0.3681, -1.1106],
        [-0.2259, -1.1895],
        [-0.0933, -0.8388],
        [-0.0508, -0.6494]])

In [24]:
torch.allclose(xbow, xbow2)

True

In [25]:
# using softmax to aggregate: version 3

tril = torch.tril(torch.ones(T,T))
wei = torch.zeros(T,T)
wei = wei.masked_fill_(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

True

In [26]:
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [None]:
# version 4: self-attention

tril = torch.tril(torch.ones(T, T))

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
k = key(x) # (b, t, head_size)
q = query(x) # (b, t, head_size)
wei = q @ k.transpose(-1, -2) # (b, t, head_size) @ (b, head_size, t) --> (b, t, t)
wei = wei.masked_fill_(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)

value = nn.Linear(C, head_size, bias=False)
v = value(x) # (b, t, head_size)

out = wei @ v # (b, t, t) @ (b, t, head_size) --> (b, t, head_size)
out.shape

torch.Size([4, 8, 16])

In [30]:
nn.LayerNorm?

[31mInit signature:[39m
nn.LayerNorm(
    normalized_shape: Union[int, list[int], torch.Size],
    eps: float = [32m1e-05[39m,
    elementwise_affine: bool = [38;5;28;01mTrue[39;00m,
    bias: bool = [38;5;28;01mTrue[39;00m,
    device=[38;5;28;01mNone[39;00m,
    dtype=[38;5;28;01mNone[39;00m,
) -> [38;5;28;01mNone[39;00m
[31mDocstring:[39m     
Applies Layer Normalization over a mini-batch of inputs.

This layer implements the operation as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

.. math::
    y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

The mean and standard-deviation are calculated over the last `D` dimensions, where `D`
is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
is ``(3, 5)`` (a 2-dimensional shape), the mean and standard-deviation are computed over
the last 2 dimensions of the input (i.e. ``input.mean((-2, -1))``).
:math:`\gamma` and :math:`\