# Solution: Tiny Neural Language Model (Fast & Sensible)

In [1]:
!pip -q install datasets transformers
import math
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

  from .autonotebook import tqdm as notebook_tqdm


'cpu'

In [2]:
ds = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train[:5000]')
texts = [t for t in ds['text'] if len(t.strip()) > 0]

split = int(0.9 * len(texts))
train_texts = texts[:split]
val_texts = texts[split:]

tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
len(train_texts), len(val_texts)

(2904, 323)

In [3]:
def build_lm_blocks(texts, tokenizer, block_size=64):
    all_ids = []
    for txt in texts:
        all_ids.extend(tokenizer.encode(txt))

    n_full = len(all_ids) // block_size
    all_ids = all_ids[: n_full * block_size]
    return torch.tensor(all_ids).view(n_full, block_size)

train_blocks = build_lm_blocks(train_texts, tokenizer)
val_blocks = build_lm_blocks(val_texts, tokenizer)
train_blocks.shape, val_blocks.shape

(torch.Size([4549, 64]), torch.Size([379, 64]))

In [4]:
train_loader = DataLoader(train_blocks, batch_size=32, shuffle=True)
val_loader = DataLoader(val_blocks, batch_size=32)

In [5]:
class TinyGRULM(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_ids):
        x = self.embed(input_ids)
        h, _ = self.rnn(x)
        return self.fc(h)

model = TinyGRULM(len(tokenizer)).to(device)

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def train_one_epoch(model, loader):
    model.train()
    total_loss = 0
    for block in loader:
        block = block.to(device)
        inputs = block[:, :-1]
        labels = block[:, 1:]

        logits = model(inputs)
        V = logits.size(-1)
        loss = criterion(logits.reshape(-1, V), labels.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(loader)

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total_loss = 0
    for block in loader:
        block = block.to(device)
        inputs = block[:, :-1]
        labels = block[:, 1:]
        logits = model(inputs)
        V = logits.size(-1)
        loss = criterion(logits.reshape(-1, V), labels.reshape(-1))
        total_loss += loss.item()
    return total_loss / len(loader)

In [7]:
for epoch in range(1, 6):
    train_loss = train_one_epoch(model, train_loader)
    val_loss = evaluate(model, val_loader)
    print(f"Epoch {epoch} | train ppl {math.exp(train_loss):.2f} | val ppl {math.exp(val_loss):.2f}")

Epoch 1 | train ppl 2027.39 | val ppl 1558.55
Epoch 2 | train ppl 714.85 | val ppl 1301.07
Epoch 3 | train ppl 458.69 | val ppl 1159.74
Epoch 4 | train ppl 307.17 | val ppl 1091.62
Epoch 5 | train ppl 215.18 | val ppl 1069.23


In [12]:
@torch.no_grad()
def topk_next_tokens(model, tokenizer, prompt, k=5):
    model.eval()
    enc = tokenizer(prompt, return_tensors='pt')
    input_ids = enc['input_ids'].to(device)
    logits = model(input_ids)
    last_logits = logits[0, -1]
    topk = torch.topk(last_logits, k=k)
    return [(tokenizer.decode([i]), float(s))
            for i, s in zip(topk.indices.tolist(), topk.values.tolist())]

prompt = "Sheffield is the"
topk_next_tokens(model, tokenizer, prompt)

[(' first', 6.637938499450684),
 (' second', 5.3817219734191895),
 (' album', 5.066890716552734),
 (' British', 4.901472568511963),
 (' only', 4.825827121734619)]