In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# Data

In [2]:
# read data line by line
with open('shakespeare.txt', 'r', encoding='utf-8') as file:
    text = file.read()

In [3]:
print('length of dataset in chars', len(text))

length of dataset in chars 1115394


In [4]:
# look at some data
text[:1000]

"First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us kill him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be done: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citizens, the patricians good.\nWhat authority surfeits on would relieve us: if they\nwould yield us but the superfluity, while it were\nwholesome, we might guess they relieved us humanely;\nbut they think we are too dear: the leanness that\nafflicts us, the object of our misery, is as an\ninventory to particularise their abundance; our\nsufferance is a gain to them Let us revenge this with\nour pikes, ere we become rakes: for the gods know I\nspeak this in hunger 

In [5]:
# get uniq chars to create vocabs
chars = sorted(list(set(text)))
vocab_size = len(chars)

# check the vocab
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


#### Tokenize

In [6]:
# map each char to a int
ctoi = {c: i for i, c in enumerate(chars)}
# int to char
itoc = {i: c for i, c in enumerate(chars)}

# encoder takes a string and encodes to list of int
encode = lambda str: [ctoi[c] for c in str]
# decoder takes a list of int and decodes to string
decode = lambda ints: ''.join([itoc[i] for i in ints])

print(encode("what's up"))
print(decode(encode("what's up")))

[61, 46, 39, 58, 5, 57, 1, 59, 54]
what's up


In [7]:
# encode the dataset and create tensor
data = torch.tensor(encode(text), dtype=torch.long)

print(data.shape, data.dtype)
data[:1000]

torch.Size([1115394]) torch.int64


tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
        47, 59, 57,  1, 47, 57,  1, 41, 

#### Prepare dataset

In [8]:
# train val split
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [9]:
torch.manual_seed(42)

batch_size = 4
block_size = 8  # how many chars in a block that the transformer sees, i.e. see first predict second, see first second predict third

# function to get a batch of data from either train or val
def get_batch(split):
    data = train_data if split == 'train' else val_data

    # sample random index for batch
    idx = torch.randint(len(data) - block_size, (batch_size, ))  # -block_size so it ends at last char of data

    # stack the batches into rows
    x = torch.stack([data[i: i+block_size] for i in idx])
    y = torch.stack([data[i+1: i+block_size+1] for i in idx])  # transformer predicts 1->2, 1,2->3, 1,2,3->4 of the full block

    return x, y

# Bigram

In [10]:
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()

        # create the embedding layer, bigram takes one char and predicts next char, through a table that is nxn, where n is the number of vocabs
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # predict the logits of next char of all vocabs by indexing the lookup table
        logits = self.token_embedding_table(idx)

        if targets is None:
            loss = None
        else:
            # reshape logits from (B, T, C) to (B * T, C) for torch, where B is batch, T is time (like sequence of chars), C is channel (vocabs)
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            # reshape target from (B, T) to (B * T) for torch
            targets = targets.view(B * T)

            # calculate loss
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    # takes a batch of character, generate new ones
    # idx is (B, T) array of indices, e.g. 4 batch of 8 chars
    def generate(self, idx, max_new_tokens):

        # generate +1, +2 ... + max_new_tokens chars
        for _ in range(max_new_tokens):
            # forward pass
            logits, loss = self(idx)            # (B, T, C)

            # get only the last time step (because of bigram)
            logits = logits[:, -1, :]           # (B, C)

            # softmax logits to get probabilities
            proba = F.softmax(logits, dim=1)    # (B, C)

            # sample based on probability
            idx_next = torch.multinomial(proba, num_samples=1)  # (B, 1)

            # append sampled next idx to the running sequence
            idx = torch.concat((idx, idx_next), dim=1)          # (B, T+1)

        return idx

In [11]:
bigram_model = BigramLanguageModel(vocab_size)

In [12]:
optimizer = torch.optim.AdamW(bigram_model.parameters(), lr=1e-3)

In [13]:
# train
batch_size = 32

for epochs in range(10000):
    # sample a batch of data
    xb, yb = get_batch('train')

    # forward pass
    logits, loss = bigram_model(xb, yb)

    # reset grad
    optimizer.zero_grad(set_to_none=True)

    # backward pass
    loss.backward()

    # update params
    optimizer.step()


print(loss.item())

2.543416976928711


In [14]:
# test generating 500 chars
print(decode(bigram_model.generate(idx=torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))


HY:
R E:
NT:
Fowote e fofirelavese dVds wour yolo ayokie amblesbatave inouther hes molofuroriYO:
TheQUDUThe chas.
F lisen tabr:
LI mus nk,
A: al l ayo cenghe's therinvar,
TEsen ithawaneit at iswinerainy atsomo clour pad d wikn h,
HYy my Tholes: it GBy ke m vilou xthazinderand llo chee lond Cld this lisesule wars, tirofof wnofan
Rou cthe p.

By hat celis ire m, aksthethe aur withAR wotoot.
Toy:me, of Ithed; bo r:
DWAy celowinoourne, llonthavelller:f fowhilong bert irw:
I m;
ADWhit hor hy t I nd, 


# Attention

In [15]:
torch.manual_seed(42)

B, T, C = 4, 8, 2  # batch, time, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [16]:
# each t'th token in T, should know the average of all the token before it, and also itself
# i.e. x[b, t] = mean({i<=t}) x[b, i]

# version 1: using for loop
xbow = torch.zeros((B, T, C))  # bow = bag of words
for b in range(B):
    for t in range(T):
        # get all the prev tokens of this token, i.e. this contains all the channels at Time t
        xprev = x[b, :t+1]  # (t, C)
        xbow[b, t] = torch.mean(xprev, 0)  # each token is the cumulative means of previous tokens

In [17]:
# version 2: using weighted matrix multiplication
# that was using for loop, now use matrix multiplication to calculate the cumulative means of each token
weights = torch.tril(torch.ones(T, T))
weights = weights / weights.sum(1, keepdim=True)
xbow2 = weights @ x  # (T, T) @ (B, T, C) ----> (B, T, C) through broadcasting
torch.allclose(xbow, xbow2)

True

In [18]:
# version 3: using softmax
tril = torch.tril(torch.ones(T, T))  # lower left triangle matrix of ones
weights = torch.zeros((T, T))   # init to all zeros
weights = weights.masked_fill(tril == 0, float('-inf'))  # fill upper right as -inf
weights = F.softmax(weights, dim=-1)  # softmax each row, so that each row is normalized (from the lower left)
xbow3 = weights @ x
torch.allclose(xbow, xbow3)

True

In [19]:
# version 4: self-attention
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# single head self attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)   # what this token has
query = nn.Linear(C, head_size, bias=False) # what this token is interested in
value = nn.Linear(C, head_size, bias=False) # what this token will communicate

# create key and query for each token. key = what this token contains, query = what this token is looking for
k = key(x)      # (B, T, 16)
q = query(x)    # (B, T, 16)

# dot product between the key and query, to get the affinity between tokens
weights = q @ k.transpose(-2, -1)   # (B, T, 16) @ (B, 16, T) ==> (B, T, T)


tril = torch.tril(torch.ones(T, T))  # lower left triangle matrix of ones
weights = weights.masked_fill(tril == 0, float('-inf'))  # fill upper right as -inf
weights = F.softmax(weights, dim=-1)  # softmax each row, so that each row is normalized (from the lower left)

# in self-attention we aggregate on value
v = value(x)
out = weights @ v
out.shape

torch.Size([4, 8, 16])

In [20]:
weights

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3641, 0.6359, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0973, 0.7532, 0.1494, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1772, 0.4341, 0.3079, 0.0808, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4523, 0.0360, 0.2877, 0.0273, 0.1967, 0.0000, 0.0000, 0.0000],
         [0.0142, 0.1249, 0.0270, 0.3665, 0.4087, 0.0587, 0.0000, 0.0000],
         [0.1247, 0.2102, 0.0244, 0.2565, 0.1518, 0.1384, 0.0940, 0.0000],
         [0.1522, 0.0372, 0.1612, 0.1119, 0.3111, 0.1732, 0.0289, 0.0243]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.7285, 0.2715, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1013, 0.1117, 0.7870, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0637, 0.1010, 0.5361, 0.2991, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1299, 0.2961, 0.0639, 0.0670, 0.4431, 0.0000, 0.0000, 0.0000],
         [0.1179, 0.292