In [None]:
# Run this block
import torch
from typing import List
from torch.nn import functional as F
from torch import nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import requests

In [None]:
vocab = """abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ., '\""()[]!?"""

def tokenize(text: str) -> List[str]:
    return [char for char in text if char in vocab]

char_to_index = {char: idx for idx, char in enumerate(vocab)}
index_to_char = {idx: char for char, idx in char_to_index.items()}

def vectorize(tokens: List[str]) -> torch.Tensor:
    indices = torch.tensor([char_to_index[char] for char in tokens])
    return F.one_hot(indices, num_classes=len(vocab)).float()

def detokenize(tensor: torch.Tensor):
    indices = tensor.argmax(dim=-1).tolist()
    return ''.join(index_to_char[idx] for idx in indices)

class EmbeddingProjection(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int):
        super().__init__()
        self.projection = nn.Linear(vocab_size, embedding_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.projection(x)

class TextDataset(Dataset):
    def __init__(self, text, seq_length):
        self.text = text
        self.seq_length = seq_length
        self.tokens = tokenize(text)

    def __len__(self):
        return len(self.tokens) - self.seq_length

    def __getitem__(self, idx):
        input_seq = self.tokens[idx:idx+self.seq_length]
        target_seq = self.tokens[idx+1:idx+self.seq_length+1]
        return vectorize(input_seq).squeeze(), vectorize(target_seq).squeeze()

In [None]:
# Download toy data (Shakespeare sonnets)
url = "https://www.gutenberg.org/files/1041/1041-0.txt"
response = requests.get(url)
text = response.text.split("THE SONNETS", 1)[1].split("End of the Project Gutenberg EBook", 1)[0]

# Prepare the dataset
seq_length = 1
dataset = TextDataset(text, seq_length)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

batch, target = next(iter(dataloader))
# batch is the input tensor to your model, shape (batch_size, vocab_size)
# It's the vector representation of the single token your bigram model has as context.
# target is the target tensor, shape (batch_size, vocab_size), representing the next token in the sequence (which your model is tasked with predicting).
print(batch.shape, target.shape)

torch.Size([32, 64]) torch.Size([32, 64])


In [None]:
detokenized_targets = detokenize(target)
for index, item in enumerate(detokenize(batch[:6])):
    print(f"Context: {item}, Target: {detokenized_targets[index]}")

# Seems like a tough task, eh?

Context:  , Target: m
Context: g, Target:  
Context: a, Target: v
Context:  , Target: t
Context:  , Target: t
Context: ,, Target:  


In [None]:
# Exercise 1:
# Implement a multilayer linear model. Feel free to use nn.Linear and nn.ReLU.

# Your projection layer is a linear projection from vocab size -> model size. Make sure your intermediate linear layers are projections from model size -> model size,
# and your final layer is a projection from model size -> vocab size.
class BigramModel(nn.Module):
    def __init__(self, model_dim = 128, vocab_size = len(vocab)):
        super().__init__()
        self.projection = EmbeddingProjection(vocab_size, model_dim)
        self.layer1 = nn.Linear(model_dim, model_dim)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(model_dim, model_dim)
        self.layer3 = nn.Linear(model_dim, vocab_size)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
    # TODO: Our model takes in a tensor of shape (batch_size, vocab_size) and returns a tensor of shape (batch_size, vocab_size).
    # Don't forget to embed the input tensor before passing it through the linear layers.
        x = self.projection(x)
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = self.layer3(x)
        return x

def test_bigram_model():
    model = BigramModel()
    out = model(batch)
    assert out.shape == target.shape, f"Expected output shape {target.shape} but got {out.shape}"
    print("Success!")

test_bigram_model()

Success!


In [None]:
from tqdm import tqdm, trange

# Training loop
num_epochs = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# TODO: Initialize model and transfer it to the device
# TODO: Initialize optimizer (from torch.optim). We recommend using AdamW with the default parameters.
# TODO: Initialize the loss criterion (from torch.nn). Since this is basically a classification task (we decide which character comes next), we recommend using nn.CrossEntropyLoss.

model = BigramModel().to(device)
optimizer = optim.AdamW(model.parameters())
criterion = nn.CrossEntropyLoss()

model.train()
loss_ema = None
for epoch in range(num_epochs):
    with tqdm(dataloader) as pbar:
        for batch, target in pbar:
            # TODO: Training loop
            # ------------------
            batch, target = batch.to(device), target.to(device)
            output = model(batch)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            # ------------------
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
            pbar.set_description(f"Loss: {round(loss.item(), 3)}")

Loss: 1.993: 100%|██████████| 2888/2888 [00:20<00:00, 139.77it/s]
Loss: 2.305: 100%|██████████| 2888/2888 [00:21<00:00, 136.54it/s]


In [None]:
# Generate some text
model.eval()
start_text = "Shall I compare thee to a summer's day?"
input_seq = torch.tensor(vectorize(tokenize(start_text))).unsqueeze(0).to(device)[:, -1, :]
generated_text = start_text

with torch.no_grad():
    for _ in range(100):
        output = model(input_seq)
        next_char = output.argmax(dim=-1)
        generated_text += index_to_char[next_char.item()]
        input_seq = F.one_hot(next_char, num_classes=len(vocab)).float()

print("Generated text:")
print(generated_text)

  input_seq = torch.tensor(vectorize(tokenize(start_text))).unsqueeze(0).to(device)[:, -1, :]


Generated text:
Shall I compare thee to a summer's day? the the the the the the the the the the the the the the the the the the the the the the the the the
