In [5]:
# Dataset prep
with open("tiny_shakespeare.txt", "r", encoding="utf-8") as f:
    text = f.read()
print(len(text))

1115393


In [6]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [9]:
chars = sorted(list(set(text)))
print(len(chars))
print(''.join(chars))

65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [10]:
vocab_size = len(chars)

In [19]:
# Building the encoder and decoder
ctoi = {ch:idx for (idx, ch) in enumerate(chars)}
itoc = {idx:ch for (ch, idx) in ctoi.items()}
encode = lambda text: [ctoi[ch] for ch in text]
decode = lambda idxs: ''.join([itoc[idx] for idx in idxs])

In [21]:
decode(encode("I have a big schlong"))

'I have a big schlong'

In [24]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
data.shape, data[:100]

(torch.Size([1115393]),
 tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
         53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
          1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
         57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
          6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
         58, 47, 64, 43, 52, 10,  0, 37, 53, 59]))

In [26]:
# Split into train and validation
n = int(0.9 * len(data)) # 90% train set and 10% validation
train_set = data[:n]
val_set = data[n:]
val_set.shape, train_set.shape

(torch.Size([111540]), torch.Size([1003853]))

In [28]:
block_size = 8
train_set[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [29]:
# Time dimension -> Predicting the next character after a sequence
x = train_set[:block_size]
y = train_set[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    pred_next = y[t]
    print(f"For {context}, we predict {pred_next}")

For tensor([18]), we predict 47
For tensor([18, 47]), we predict 56
For tensor([18, 47, 56]), we predict 57
For tensor([18, 47, 56, 57]), we predict 58
For tensor([18, 47, 56, 57, 58]), we predict 1
For tensor([18, 47, 56, 57, 58,  1]), we predict 15
For tensor([18, 47, 56, 57, 58,  1, 15]), we predict 47
For tensor([18, 47, 56, 57, 58,  1, 15, 47]), we predict 58


In [38]:
torch.manual_seed(1337)
# Batching up data acros batch and time dimensions
# How big is the context length for predicting the next character
block_size = 8
# How many sequences we are stacking together to process in parallel
batch_size = 4

def get_split_batch(split: str):
    """Sample a `batch_size` number of sequences of length `block_size` along with their
    next character predcition from the desired `split` -> `train` or `val` data"""

    # If the split is not `train` or `val`, it is invalid
    assert split in ['train', 'val']
    dataset = train_set if split == 'train' else val_set

    # Sample `batch_size` count of random indexes from the data up to the last
    # index that is possible to issue a context of 8 elements
    idxs = torch.randint(0, len(dataset) - block_size, (batch_size,))

    # For each index, the context (or the input to the model) will be the sequence
    # of eight characters starting with that index
    x = torch.stack([dataset[idx:idx+block_size] for idx in idxs])
    # And the predictions will be the exact next character following that sequence
    y = torch.stack([dataset[idx+1:idx+block_size+1] for idx in idxs])
    return (x, y)

Xb, Yb = get_split_batch('train')

for b in range(batch_size):
    for t in range(block_size):
        context = Xb[b, :t+1]
        pred = Yb[b, t]
        print(f"For {context} we are predicting {pred}")

For tensor([53]) we are predicting 59
For tensor([53, 59]) we are predicting 6
For tensor([53, 59,  6]) we are predicting 1
For tensor([53, 59,  6,  1]) we are predicting 58
For tensor([53, 59,  6,  1, 58]) we are predicting 56
For tensor([53, 59,  6,  1, 58, 56]) we are predicting 47
For tensor([53, 59,  6,  1, 58, 56, 47]) we are predicting 40
For tensor([53, 59,  6,  1, 58, 56, 47, 40]) we are predicting 59
For tensor([49]) we are predicting 43
For tensor([49, 43]) we are predicting 43
For tensor([49, 43, 43]) we are predicting 54
For tensor([49, 43, 43, 54]) we are predicting 1
For tensor([49, 43, 43, 54,  1]) we are predicting 47
For tensor([49, 43, 43, 54,  1, 47]) we are predicting 58
For tensor([49, 43, 43, 54,  1, 47, 58]) we are predicting 1
For tensor([49, 43, 43, 54,  1, 47, 58,  1]) we are predicting 58
For tensor([13]) we are predicting 52
For tensor([13, 52]) we are predicting 45
For tensor([13, 52, 45]) we are predicting 43
For tensor([13, 52, 45, 43]) we are predicting

In [50]:
# Setting a benchmark -> Token embedding table
from torch import nn
from torch.nn import functional as F

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idxs):
        logits = self.token_embedding_table(idxs) # (B, T, C)
        return logits

bigram = BigramLanguageModel(vocab_size)
logits = bigram(Xb)
print(logits)

tensor([[[-0.9817,  0.4279,  1.1933,  ..., -0.4562, -0.5076,  0.3377],
         [-0.9113,  1.1846,  0.7505,  ..., -1.4482,  0.1478,  1.3366],
         [-0.6076,  0.9987,  0.8513,  ...,  0.6869, -1.2106,  0.4345],
         ...,
         [-0.6290,  0.2287, -0.3709,  ..., -1.4395, -0.3997,  2.0502],
         [-1.5302, -0.8226, -0.5341,  ..., -0.5800,  1.3923,  0.0433],
         [ 0.1733, -1.9099,  1.1536,  ..., -0.8480, -1.7542,  0.1826]],

        [[ 0.5845, -1.3075,  0.3588,  ...,  0.9893,  2.4171, -1.0564],
         [-1.2205, -0.4234,  0.4442,  ..., -1.0289, -0.6227, -0.2439],
         [-1.2205, -0.4234,  0.4442,  ..., -1.0289, -0.6227, -0.2439],
         ...,
         [-1.5302, -0.8226, -0.5341,  ..., -0.5800,  1.3923,  0.0433],
         [-1.0584, -0.7604,  0.4590,  ...,  0.4958, -0.7604, -0.2948],
         [ 0.6115,  0.5208,  1.3054,  ...,  0.5127, -0.7908,  1.7887]],

        [[-1.0718,  0.5502, -0.2290,  ...,  1.5064,  0.2363,  0.4166],
         [-0.5585, -0.7846, -1.3164,  ..., -2