In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, IterableDataset
from datasets import load_dataset
import nltk
import re
from itertools import islice
import random
from tqdm import tqdm
import torch.nn.functional as F
from transformers import AutoTokenizer

from gpt.decoder import DecoderOnlyTransformer
from gpt.position_encoder import PositionalEncoding

nltk.download('punkt_tab')


  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /Users/danieljoo/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

# Pytorch Implementation of GPT2

Here, I will recreate the decoder only transformer architecture of gpt, and train it on wikipedia.

https://docs.pytorch.org/tutorials/beginner/basics/intro.html \
https://www.youtube.com/watch?v=bQ5BoolX9Ag 

<img src ="https://miro.medium.com/v2/resize:fit:1400/1*qTjjAvXmrSaRdN1LODLVGA.png" width="400" height="300">

## Model Declaration

I created the decoder and positional encoder in modules

In [None]:
embed_dim = 100
max_len = 50
num_transformers = 6
num_heads = 5
dense_dim = 256
PAD_TOKEN_ID = 0

class NeuralNetwork(nn.Module):
    def __init__(
        self, 
        vocab_size, 
        embed_dim = embed_dim,
        num_transformers = num_transformers,
        num_heads = num_heads,
        dense_dim = dense_dim,
        pad_token_id = PAD_TOKEN_ID):
        super().__init__()
        self.token_embed = nn.Embedding(
            num_embeddings = vocab_size,
            embedding_dim = embed_dim,
            padding_idx = pad_token_id,
        )
        self.position_encoding = PositionalEncoding(
            embed_dim = embed_dim,
            max_len = max_len,
        )
        self.transformer_stack = nn.ModuleList([
            DecoderOnlyTransformer(
                embed_dim = embed_dim,
                num_heads = num_heads,
                dense_dim = dense_dim,
        ) for _ in range(num_transformers)])
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.linear = nn.Linear(
            in_features = embed_dim,
            out_features = vocab_size
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        key_padding_mask = (x == PAD_TOKEN_ID)
        x = self.token_embed(x)
        x = self.position_encoding(x)
        for transformer in self.transformer_stack:
            x = transformer(x, key_padding_mask = key_padding_mask)
        x = self.layer_norm(x)
        x = self.linear(x)
        return x #loss will be computed from logits for stability

## Data Loading and Tokenization

In [2]:
wiki_dataset = load_dataset("wikimedia/wikipedia", "20231101.en", split='train',streaming=True)
shuffled_stream = wiki_dataset.shuffle(seed=42, buffer_size=100)

In [None]:
TRAIN_SIZE = 1000
TEST_SIZE = 10
batch_size = 64

def clean_wiki_text(text):
    # removes wikipedia headers
    text = re.sub(r'={2,}.*?={2,}', '', text)
    text = re.sub(r'\n+', ' ', text)
    text = ' '.join(text.split())
    return text

class WikiSentenceDataset(IterableDataset):
    def __init__(self, hf_dataset):
        super().__init__()
        self.hf_dataset = hf_dataset
    
    def __iter__(self):
        for example in self.hf_dataset:
            clean_text = clean_wiki_text(example['text'])
            if clean_text:
                sentences = nltk.sent_tokenize(clean_text)
                for sentence in sentences:
                    yield sentence
                    
class ShufflingIterableDataset(IterableDataset):
    def __init__(self, source_dataset, buffer_size, seed):
        super().__init__()
        self.source_dataset = source_dataset
        self.buffer_size = buffer_size
        self.seed = seed
    
    def __iter__(self):
        rng = random.Random(self.seed)
        source_iterator = iter(self.source_dataset)
        shuffle_buffer = list(islice(source_iterator, self.buffer_size))
        # first loop replaces given item with source_iterator addition
        for item in source_iterator:
            idx = rng.randint(0, self.buffer_size-1)
            yield shuffle_buffer[idx]
            shuffle_buffer[idx] = item
        # second loop flushes the current buffer when source_iterator is dry
        rng.shuffle(shuffle_buffer)
        for item in shuffle_buffer:
            yield item
            
            
train_stream = shuffled_stream.take(TRAIN_SIZE)
test_stream = shuffled_stream.skip(TRAIN_SIZE).take(TEST_SIZE)

train_stream = WikiSentenceDataset(train_stream)
test_stream = WikiSentenceDataset(test_stream)

train_dataset = ShufflingIterableDataset(train_stream, buffer_size=1000, seed=42)
test_dataset = ShufflingIterableDataset(test_stream, buffer_size=1000, seed=42)

data_loader_plain = DataLoader(
    train_dataset,
    batch_size = 4
)

for i, batch in enumerate(data_loader_plain):
    print(f"Sentences from batch {i}:")
    for j, sentence in enumerate(batch):
        print(f"  - Item {j}: '{sentence}'")
        
    if i>1:
        break
            

Sentences from batch 0:
  - Item 0: '() is a German indie-pop band from Berlin that was founded in 2006 under the name Fluchtweg.'
  - Item 1: 'In 2021, his second wife accused him of raping her.'
  - Item 2: 'See also Roads in Ireland National primary road National secondary road References Regional roads in the Republic of Ireland Roads in County Louth'
  - Item 3: 'On February 18, 2016, a detailed agreement was signed with a German partner, which specified the terms of cooperation between ZM Bumar-Łabędy and Rheinmetall.'
  - Item 4: 'The modernisation is currently being carried out in cooperation with Rheinmetall and the Polish Armaments Group ().'
  - Item 5: 'The book was positively received by African-American and communist media of the time.'
  - Item 6: 'She was the only child of the Inland Revenue tax officer Thomas John Tindale and his wife Princess May, née Uttin.'
  - Item 7: 'Background Angela Davis is a Marxist feminist author born in Alabama, United States, in 1944.'
Se

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
vocab_size = tokenizer.vocab_size

collate_fn = lambda sentences: tokenizer(
        sentences,
        padding='max_length',   
        truncation = True,
        max_length = max_len,
        return_tensors="pt"
    )

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    collate_fn = collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    collate_fn=collate_fn
)

torch.Size([64, 50])


## Training Loop

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model = NeuralNetwork(vocab_size=vocab_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=0)

epochs = 20

for epoch in range(epochs):
    model.train()
    total_train_loss = 0
    train_loader_size = 0
    for batch in tqdm(train_loader, desc="Training"):
        train_loader_size += 1
        b = batch['input_ids']
        b = b.to(device)
        inputs = b[:,:-1]
        targets = b[:,1:]
        optimizer.zero_grad()
        logits = model(inputs)
        loss = loss_fn(logits.reshape(-1,vocab_size), targets.reshape(-1))
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
    avg_train_loss = total_train_loss / train_loader_size
    print(f"Average Training Loss: {avg_train_loss:.4f}")
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        test_loader_size = 0
        for batch in tqdm(test_loader, desc="Validation"):
            b = batch['input_ids']
            test_loader_size += 1
            b = b.to(device)
            inputs = b[:,:-1]
            targets = b[:,1:]
            logits = model(inputs)
            loss = loss_fn(logits.reshape(-1, vocab_size), targets.reshape(-1))
            total_val_loss += loss.item()
    avg_val_loss = total_val_loss / test_loader_size
    print(f"Average Validation Loss: {avg_val_loss:.4f}")
print("training complete")

Using device: cpu


Training: 4it [00:07,  1.92s/it]


Average Training Loss: 10.4219


Validation: 1it [00:06,  6.08s/it]


Average Validation Loss: 10.4174


Training: 4it [00:07,  1.88s/it]


Average Training Loss: 10.2598


Validation: 1it [00:05,  5.44s/it]

Average Validation Loss: 10.3034
training complete





## Inference

In [8]:
def generate_text(model, tokenizer, prompt, max_length=50, temperature=0.1):
    model.eval()
    input_ids_list = tokenizer.encode(prompt, truncation=True, max_length=max_length - 1)
    
    pad_token_id = tokenizer.pad_token_id
    eos_token_id = tokenizer.eos_token_id or tokenizer.sep_token_id
    
    device = next(model.parameters()).device
    generated_ids = torch.tensor([input_ids_list], device=device, dtype=torch.long)
    
    with torch.no_grad():
        for _ in range(max_length - len(input_ids_list)):
            current_len = generated_ids.size(1)
            padded_input = torch.full((1, max_len), pad_token_id, device=device, dtype=torch.long)
            padded_input[:, :current_len] = generated_ids
            
            logits = model(padded_input)

            next_token_logits = logits[:, current_len - 1, :]
            
            scaled_logits = next_token_logits / temperature
            
            probabilities = F.softmax(scaled_logits, dim=-1)
            next_token_id = torch.multinomial(probabilities, num_samples=1)
            generated_ids = torch.cat([generated_ids, next_token_id], dim=1)
            if eos_token_id and next_token_id.item() == eos_token_id:
                break
    generated_text = tokenizer.decode(generated_ids[0].tolist(), skip_special_tokens=True)
    
    return generated_text

In [9]:
prompt = "The most prominent figure of the 20th century is"
print("Starting generation...")
print(f"Prompt: {prompt}\n")

generated_paragraph = generate_text(
    model=model,
    tokenizer=tokenizer,
    prompt=prompt,
    max_length = 50,
    temperature = .5
)

print(f"Output: {generated_paragraph}")

Starting generation...
Prompt: The most prominent figure of the 20th century is

Output: the most prominent figure of the 20th century is cost children gleam whereupon dancer tastes postal danielle timeline bankrupt bullock yellowstone jasmine hinduism [unused827] widowvus outragemblingart♦ trembling extensive scubasure illumination [unused719] flavor croatianشady infections nurses sue marin ن kimball assaults charts


## Takeaways

So it turns out the Decoder architecture is really simple to implement on PyTorch. The framework kinda mixes and matches what I learned in Keras and Jax, choosing to be more abstract or more imperative on some things.

But UGH training a large model like this takes a ton of time and learning to do it remotely on hardware with actual gpus was also a pain in the butt. I'm worried that the common theme of actual ML engineering is that model creation and architecture is easy; scaling, however, is hard. The actual takeaways here come from simply trying to run this remotely: 

- virtual python enviornments
- picking nodes and sending jobs to slurm
- rerouting IO and making sure to recieve model outputs
- mointoring model training progress

But it's absolutely worth because on the cluster, I can train this model with 4x the batch size while still have 6x the iterations per second. Doing this on 50x the data, and allowing the training to run for a little over an hour (which is doable because its remote) brings my validation loss from like 7 to 1.5.

And that doesn't even get to distributed learning, utilizing more workers on the CPU, machine learning experimentation platforms (like tensorboard), and figuring out how to refactor all this legibly. For the next one, I'll focus on an easier problem like CIFAR 10 and I promise it'll be cleaner and nicer looking.

Pytorch and Hugging Face comes with a lot of interesting tools. For instance, I found it odd how you have to be imperative with device usage in pytorch and hugging face provided a lot of tools I needed for the NLP dataset generation.
The most complicated part in this notebook was honestly the pytorch dataloader, which seems like a pretty convenient tool that I can just subclass to create an iterable for my training loop. It was a challenge getting it to work with a stream of data rather than just loading it all into ram as I have done before but luckily that should be useful for the future.