In [5]:
# Dataset prep
with open("tiny_shakespeare.txt", "r", encoding="utf-8") as f:
    text = f.read()
print(len(text))

1115393


In [6]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [9]:
chars = sorted(list(set(text)))
print(len(chars))
print(''.join(chars))

65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [10]:
vocab_size = len(chars)

In [19]:
# Building the encoder and decoder
ctoi = {ch:idx for (idx, ch) in enumerate(chars)}
itoc = {idx:ch for (ch, idx) in ctoi.items()}
encode = lambda text: [ctoi[ch] for ch in text]
decode = lambda idxs: ''.join([itoc[idx] for idx in idxs])

In [21]:
decode(encode("I have a big schlong"))

'I have a big schlong'

In [24]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
data.shape, data[:100]

(torch.Size([1115393]),
 tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
         53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
          1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
         57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
          6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
         58, 47, 64, 43, 52, 10,  0, 37, 53, 59]))

In [26]:
# Split into train and validation
n = int(0.9 * len(data)) # 90% train set and 10% validation
train_set = data[:n]
val_set = data[n:]
val_set.shape, train_set.shape

(torch.Size([111540]), torch.Size([1003853]))

In [28]:
block_size = 8
train_set[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [29]:
# Time dimension -> Predicting the next character after a sequence
x = train_set[:block_size]
y = train_set[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    pred_next = y[t]
    print(f"For {context}, we predict {pred_next}")

For tensor([18]), we predict 47
For tensor([18, 47]), we predict 56
For tensor([18, 47, 56]), we predict 57
For tensor([18, 47, 56, 57]), we predict 58
For tensor([18, 47, 56, 57, 58]), we predict 1
For tensor([18, 47, 56, 57, 58,  1]), we predict 15
For tensor([18, 47, 56, 57, 58,  1, 15]), we predict 47
For tensor([18, 47, 56, 57, 58,  1, 15, 47]), we predict 58


In [57]:
torch.manual_seed(1337)
# Batching up data acros batch and time dimensions
# How big is the context length for predicting the next character
block_size = 8
# How many sequences we are stacking together to process in parallel
batch_size = 4

def get_split_batch(split: str):
    """Sample a `batch_size` number of sequences of length `block_size` along with their
    next character predcition from the desired `split` -> `train` or `val` data"""

    # If the split is not `train` or `val`, it is invalid
    assert split in ['train', 'val']
    dataset = train_set if split == 'train' else val_set

    # Sample `batch_size` count of random indexes from the data up to the last
    # index that is possible to issue a context of 8 elements
    idxs = torch.randint(0, len(dataset) - block_size, (batch_size,))

    # For each index, the context (or the input to the model) will be the sequence
    # of eight characters starting with that index
    x = torch.stack([dataset[idx:idx+block_size] for idx in idxs])
    # And the predictions will be the exact next character following that sequence
    y = torch.stack([dataset[idx+1:idx+block_size+1] for idx in idxs])
    return (x, y)

Xb, Yb = get_split_batch('train')

for b in range(batch_size):
    for t in range(block_size):
        context = Xb[b, :t+1]
        pred = Yb[b, t]
        print(f"For {context} we are predicting {pred}")

torch.Size([4, 8])
For tensor([53]) we are predicting 59
For tensor([53, 59]) we are predicting 6
For tensor([53, 59,  6]) we are predicting 1
For tensor([53, 59,  6,  1]) we are predicting 58
For tensor([53, 59,  6,  1, 58]) we are predicting 56
For tensor([53, 59,  6,  1, 58, 56]) we are predicting 47
For tensor([53, 59,  6,  1, 58, 56, 47]) we are predicting 40
For tensor([53, 59,  6,  1, 58, 56, 47, 40]) we are predicting 59
For tensor([49]) we are predicting 43
For tensor([49, 43]) we are predicting 43
For tensor([49, 43, 43]) we are predicting 54
For tensor([49, 43, 43, 54]) we are predicting 1
For tensor([49, 43, 43, 54,  1]) we are predicting 47
For tensor([49, 43, 43, 54,  1, 47]) we are predicting 58
For tensor([49, 43, 43, 54,  1, 47, 58]) we are predicting 1
For tensor([49, 43, 43, 54,  1, 47, 58,  1]) we are predicting 58
For tensor([13]) we are predicting 52
For tensor([13, 52]) we are predicting 45
For tensor([13, 52, 45]) we are predicting 43
For tensor([13, 52, 45, 43]

In [135]:
# Setting a benchmark -> Token embedding table
from torch import nn
from torch.nn import functional as F

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idxs, targets=None):
        logits = self.token_embedding_table(idxs) # (B, T, C)

        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            # Torch expects that the C (channels/features) dimension is the second dimension
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss


    def generate(self, idxs, max_new_tokens):
        """Samples `max_new_tokens` next tokens from the model starting with `idxs`
        """
        for _ in range(max_new_tokens):
            # Compute the forward pass
            # Get embeddings. Because we don't have any targets, this only returns the logits
            # and no loss
            logits, _loss = self(idxs)
            # Focus only on the last time step. This becomes the (B, C) of the last T
            logits = logits[:, -1, :]
            # Softmax along the C (channels) which are the last dimension
            probs = F.softmax(logits, dim = -1) # (B, C)
            # Sample from the distribution
            pred_idx = torch.multinomial(probs, num_samples = 1) # (B, 1)
            idxs = torch.cat([idxs, pred_idx], dim=1)

        return idxs
            
            
        


model = BigramLanguageModel(vocab_size)
logits, loss = model(Xb, Yb)
print(logits.shape)
print(Yb.shape)
print(decode(bigram.generate(idxs=torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


torch.Size([256, 65])
torch.Size([32, 8])


By
WAn iell S: mon ou athyout bou?
Jomingerucove Fre fun ffoty hace he n
Mare,
Th yerall eed h than


In [136]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [138]:
# Training
batch_size = 32
for _ in range(10000):
    # Get a new batch
    Xb, Yb = get_split_batch('train')
    logits, loss = model(Xb, Yb)
    # Zero out the gradient such that it does not accumulate between sessions
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    # Perform the optimisation
    optimizer.step()

print(loss.item())

4.695070743560791


In [122]:
print(decode(model.generate(idxs=torch.zeros((1,1), dtype=torch.long), max_new_tokens=500)[0].tolist()))


AMoun, or fal.
S:
Fio th le?
II w on lounduimfo th
s th'd yont it.
stoug s d, mas
Hathay.
BENGoue ts Ourminaug pinanou'doou allilinduoca, me RF? sthive g d.
Payondsish Seve's de
TUKENClofund n:
W:
T:
TAst se'd y wimld, y worthicoureftey ad Inot s crout Gowir,
INENCAnt we athanglshis se th VO, oure y s ty ltherel oweses my mur'so hy sse: y theaisero dey an. wimead tosthetor, caofod! ciu tit:

Fleyou othis auseowersed d GARerpenatil!
SO:

I'Whe s Whedsstuss ser ishthahot hicey 'd Pit okit angous.



In [132]:
@torch.no_grad()
def eval_loss(num_iters = 100):
    """Evaluate the loss as an overage over a number of iterations for both splits"""
    out = {}
    # Put the model into evaluation mode
    model.eval()
    for split in ['train', 'val']:
        # First we start the losses at zero
        losses = torch.zeros(num_iters)
        for idx in range(num_iters):
            # Get the batch
            Xb, Yb = get_split_batch(split)
            # Do the forward pass
            logits, loss = model(Xb, Yb)
            # Average over the loss
            losses[idx] = loss.item()
        out[split] = losses.mean()
    # Put the model back in train mode
    model.train()
    return out

In [133]:
eval_loss()

{'train': tensor(4.6737), 'val': tensor(4.6776)}