In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import string
import random
import matplotlib.pyplot as plt
import tqdm
%matplotlib inline

In [None]:
text = open("data/wiki.txt", "r").read()
text = text.lower()
# text = text[:1000]

print(f"Input has size: {len(text)}")
print()
print(text[:1000])

In [None]:
chars = sorted(list(set("".join(text))))

stoi = dict()
for c in chars:
    stoi[c] = len(stoi)

vocab_size = len(stoi)

itos = {i:s for s, i in stoi.items()}

print(stoi)
print(itos)

def encode(s):
    return [stoi[ch] for ch in s]

def decode(d):
    return "".join([itos[i] for i in d])

In [None]:
# Encode the entire dataset as a torch tensor
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape)
print()
print(data[:1000])

In [None]:
train_size = int(0.9 * len(data))
train_data = data[:train_size]
val_data = data[train_size:]
print(f"Train size: {len(train_data)}")
print(f"Validation size: {len(val_data)}")

In [None]:
context_length = 8
batch_size = 32

def get_batch(split):
    data = train_data if split == "train" else val_data
    # Get batch_size indices into the array
    indices = torch.randint(len(data) - context_length, (batch_size,))
    x = torch.stack([data[i:i + context_length] for i in indices])
    y = torch.stack([data[i + 1:i + context_length + 1] for i in indices])
    return x, y

def print_batch_example():
    xb, yb = get_batch("train")
    print(xb.shape)
    print(xb)
    print(yb.shape)
    print(yb)

    for b in range(batch_size):
        for c in range(context_length):
            context = xb[b, :c + 1]
            targets = yb[b, c]
            print(f"{context} -> {targets}")

print_batch_example()

In [None]:
class Model(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, context, targets=None):
        # idx is (batch_size, context_length)
        logits = self.embedding(context) # (batch_size, context_length, embedding_size)
        
        if targets is None:
            loss = None
        else:
            batch_size, context_length, embedding_size = logits.shape
            
            logits = logits.view(batch_size * context_length, embedding_size)
            targets = targets.view(batch_size * context_length)
            
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, context, num_new_tokens):
        for _ in range(num_new_tokens):
            # Predict next token
            logits, _ = self(context)
            # Get the raw outputs for the next token
            logits = logits[:, -1, :]
            # Apply softmax to get probabilities
            probs = torch.softmax(logits, dim=1)
            # Get next token from probabilities
            next_token = torch.multinomial(probs, 1)
            # Append next token to context
            context = torch.cat([context, next_token], dim=1)
        return context
            
            
    
m = Model(vocab_size)
xb, yb = get_batch("train")
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

In [None]:
@torch.no_grad()
def estimate_loss():
    m.eval()
    
    num_batches = 100
    split_losses = dict()
    
    for split in ["train", "val"]:
        losses = torch.zeros(num_batches)
        for i in range(num_batches):
            # Sample batch
            xb, yb = get_batch(split)
            _, loss = m(xb, yb)
            losses[i] = loss
        split_losses[split] = losses.mean().item()
        
    m.train()
    
    return split_losses["train"], split_losses["val"]
    
estimate_loss()

In [None]:
# Train model
num_iterations = 100_000
learning_rate = 1e-3

train_losses = []
val_losses = []
optimizer = torch.optim.Adam(m.parameters(), lr=learning_rate)

for iteration in (pbar := tqdm.tqdm(range(num_iterations))):
    if iteration % 1000 == 0:
        train_loss, val_loss = estimate_losses()
        pbar.set_description(f"T: {train_loss}, V: {val_loss}")
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
    # Sample batch of data
    xb, yb = get_batch("train")
    
    # Forward pass
    logits, loss = m(xb, yb)
    
    # Backwards pass
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
plt.figure()
plt.plot(train_losses)
plt.plot(val_losses)

In [None]:
# Let the model generate something
raw_output = m.generate(torch.zeros(1, 1, dtype=torch.long), 1000)
print(decode(raw_output[0].tolist()))