<a href="https://colab.research.google.com/github/ggsmith842/AIML-tutorials/blob/main/pytorch/text/pytorch_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -Uq datasets

In [2]:
import math
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import DataLoader
from tqdm import tqdm

In [3]:
# use huggingface instead of torchtext (deprecated)
from transformers import AutoTokenizer
from datasets import load_dataset

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'device: {device}')

device: cuda


# Load and Preprocess Data

In [None]:
# load wiki2 data
ds = load_dataset("wikitext", "wikitext-2-raw-v1")

In [None]:
ds['train']['text'][3]

In [None]:
# initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
tokenized_ds = ds.map(lambda x: tokenizer(x['text'], truncation=True, max_length=512), batched=True, remove_columns=['text'])

In [9]:
# concatenate all texts and split into chunks
# --- Group into Fixed Length Chunks ---
max_seq_len = 64
train_batch_size = 32
eval_batch_size = 16


def group_texts(examples):
    # Concatenate all fields (like input_ids, attention_mask, etc.)
    concatenated = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = (len(concatenated["input_ids"]) // max_seq_len) * max_seq_len
    result = {
        k: [concatenated[k][i:i + max_seq_len] for i in range(0, total_length, max_seq_len)]
        for k in concatenated.keys()
    }
    return result


In [None]:
lm_ds = tokenized_ds.map(group_texts, batched=True)

In [11]:
# dataloader prep
def collate_fn(batch):
  input_ids = torch.tensor([example['input_ids'] for example in batch])
  x = input_ids[:, :-1]
  y = input_ids[:, 1:]
  return x.to(device), y.reshape(-1).to(device)

In [12]:
train_loader = DataLoader(lm_ds["train"], batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(lm_ds["validation"], batch_size=eval_batch_size, shuffle=False, collate_fn=collate_fn)

# Build Model

In [13]:
# --- Define Transformer Model ---
class TransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Parameter(torch.zeros(1, max_seq_len - 1, d_model))
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x) + self.pos_embedding[:, :x.size(1)]
        x = self.transformer(x)
        return self.fc_out(x)

In [14]:
# --- Train ---
model = TransformerLM(vocab_size=tokenizer.vocab_size).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)



In [15]:
for epoch in range(3):
    model.train()
    total_loss = 0
    for x, y in tqdm(train_loader):
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits.view(-1, logits.size(-1)), y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch + 1}, Train Loss: {total_loss / len(train_loader):.4f}")


Epoch 1, Train Loss: 6.6516
Epoch 2, Train Loss: 5.9156
Epoch 3, Train Loss: 5.6366


In [20]:
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x, y in tqdm(val_loader):
            logits = model(x)
            loss = criterion(logits.view(-1, logits.size(-1)), y)
            val_loss += loss.item()
    print(f"\nVal Loss: {val_loss / len(val_loader):.4f}")


100%|██████████| 242/242 [00:07<00:00, 33.49it/s]


Epoch 3, Val Loss: 5.9884



