In [1]:
import re
import random
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import numpy as np
from transformers import GPT2Tokenizer
from datasets import load_dataset

In [2]:
def set_seed(seed: int):
    # Set seed for Python's random module
    random.seed(seed)
    
    # Set seed for NumPy
    np.random.seed(seed)
    
    # Set seed for PyTorch
    torch.manual_seed(seed)
    
    # Set seed for CUDA (if using)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU setups
        
    # Make PyTorch deterministic (this can slow down the computation)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Example of usage
set_seed(42)


In [3]:
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")

print(f"Using device: {device}")

Using device: cuda


In [4]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print(tokenizer.pad_token)
    print(tokenizer(tokenizer.pad_token))

<|endoftext|>
{'input_ids': [50256], 'attention_mask': [1]}


In [5]:
cleanup_pattern = re.compile(r",|the|a", flags=re.IGNORECASE)

def preprocess(batch: list[str]) -> list[str]:
    result = [cleanup_pattern.sub("", x) for x in batch]
    return result

In [6]:
def tokenize(batch: list[str], max_length: int | None = None):
    batch = preprocess(batch)
    encodings = tokenizer(
        batch,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )
    return {
        "input_ids": encodings.input_ids.to(device),
        "attention_mask": encodings.attention_mask.to(device),
    }


res = tokenize([
    "Big brown fox jumps over the lazy dog",
    "So, fellow citizens, let us not be blind to our differences - but let us also direct attention to our common interests and to the means by which those differences can be resolved.",
    "Fellow americans, ask not what your country can do for you, ask what you can do for your country.",
])
print(res['input_ids'].shape)
print(res['attention_mask'].shape)
res

torch.Size([3, 40])
torch.Size([3, 40])


{'input_ids': tensor([[12804,  7586, 21831, 18045,   625,   220,   300,  7357,  3290, 50256,
          50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
          50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
          50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],
         [ 2396,  5891,  4290,  1309,   514,   407,   307,  7770,   284,   674,
           5400,   532,   475,  1309,   514,   300,   568,  1277,   256,    83,
           1463,   284,   674,  2219,  5353,   299,    67,   284,   220,   285,
            641,   416,   543,   883,  5400,   269,    77,   307, 12939,    13],
         [   37,  5037,  4017,   291,  5907,  1341,   407,   348,    83,   534,
           1499,   269,    77,   466,   329,   345,  1341,   348,    83,   345,
            269,    77,   466,   329,   534,  1499,    13, 50256, 50256, 50256,
          50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256]],
        device='cuda:0'

In [7]:
def collate_fn(examples: list[dict], max_length: int | None = None):
    text = [x['text'] for x in examples]
    return tokenize(text, max_length=max_length)

In [9]:
num_epochs = 5
batch_size = 20
max_length = 100
embed_dim = 300
state_dim = 500

In [12]:
dataset = load_dataset("allenai/c4", "realnewslike")
train_subset = Subset(dataset["train"], torch.arange(10000))
val_subset = Subset(dataset["validation"], torch.arange(1000))

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/512 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/76 [00:00<?, ?it/s]

In [None]:
dataset['train']['text'][:3]

In [14]:
train_loader = torch.utils.data.DataLoader(
    train_subset,
    batch_size=batch_size,
    collate_fn=partial(collate_fn, max_length=max_length),
)
val_loader = torch.utils.data.DataLoader(
    val_subset,
    batch_size=batch_size,
    collate_fn=partial(collate_fn, max_length=max_length),
)
print("Number of batches in an epoch:", len(train_loader))
for batch in train_loader:
    print("input_ids shape:", batch['input_ids'].shape)
    break

Number of batches in an epoch: 500
input_ids shape: torch.Size([20, 100])


In [15]:
# https://github.com/pytorch/pytorch/blob/aa7be72cc55244978ddaf760338dab6b9cf977a1/torch/nn/modules/rnn.py#L631
torch.nn.LSTM

torch.nn.modules.rnn.LSTM

In [17]:
class LSTMModel(torch.nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, state_dim: int, max_length: int | None = None):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, state_dim, batch_first=True)
        self.linear = nn.Linear(state_dim, vocab_size)
        self.max_length = max_length

    def forward(
        self,
        x,
        attention_mask: torch.Tensor | None = None,
        hidden: torch.Tensor | None = None,
    ):
        x = self.embedding(x)
        x = x * attention_mask.unsqueeze(-1)
        x, hidden = self.lstm(x, hidden)
        x = self.linear(x)
        return x

    # See: https://pytorch.org/docs/stable/notes/autograd.html#locally-disable-grad-doc
    @torch.inference_mode
    def predict(self, input_text):
        tokens = tokenize([input_text], max_length=self.max_length)
        outputs = self(tokens['input_ids'], attention_mask=tokens['attention_mask'])
        preds = torch.argmax(outputs, dim=-1)
        next_word = tokenizer.decode(preds[0][-1], skip_special_tokens=True)
        return next_word.strip()


model = LSTMModel(vocab_size=tokenizer.vocab_size, embed_dim=embed_dim, state_dim=state_dim).to(device)
criterion = torch.nn.CrossEntropyLoss()
# Adam has many hyperparameters - play with it!
# Try other optimizers as well!
optimizer = torch.optim.Adam(model.parameters())

# Train the model
# NOTE: always track the validation accuracy to see when it diverges from the training accuracy!
# This is a sign of overfitting!
for epoch in range(num_epochs):
    train_loss = 0
    model.train()
    for batch in train_loader:
        input_ids = batch["input_ids"][:, :-1]
        targets = batch["input_ids"][:, 1:]
        attention_mask = batch["attention_mask"]
        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=attention_mask[:, :-1])
        outputs = outputs * attention_mask[:, 1:].unsqueeze(-1)
        loss = criterion(outputs.reshape(-1, tokenizer.vocab_size), targets.reshape(-1))
        # loss must be a scalar for loss.backward() to work!
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    val_loss = 0
    model.eval()
    with torch.inference_mode():
        for batch in val_loader:
            input_ids = batch["input_ids"][:, :-1]
            targets = batch["input_ids"][:, 1:]
            attention_mask = batch["attention_mask"]
            outputs = model(input_ids, attention_mask=attention_mask[:, :-1])
            outputs = outputs * attention_mask[:, 1:].unsqueeze(-1)
            loss = criterion(outputs.reshape(-1, tokenizer.vocab_size), targets.reshape(-1))
            val_loss += loss.item()
    print(
        f"Epoch {epoch+1}, Train Loss: {train_loss / len(train_loader):.4f}, Val Loss: {val_loss / len(val_loader):.4f}"
    )

# model.predict("Big brown fox jumps over the lazy")

Epoch 1, Train Loss: 6.5257, Val Loss: 5.8510
Epoch 2, Train Loss: 5.5259, Val Loss: 5.4889
Epoch 3, Train Loss: 5.1310, Val Loss: 5.3271
Epoch 4, Train Loss: 4.8439, Val Loss: 5.2432
Epoch 5, Train Loss: 4.6047, Val Loss: 5.2070


In [18]:
model.predict("Austria has")

'been'

In [19]:
prompt = [""]
for i in range(10):
    next_word = model.predict(" ".join(prompt))
    prompt.append(next_word)
    print(" ".join(prompt))
    if next_word == ".":
        break

Austria has been 
Austria has been  first
Austria has been  first time
Austria has been  first time of
Austria has been  first time of 
Austria has been  first time of  l
Austria has been  first time of  l st
Austria has been  first time of  l st te
Austria has been  first time of  l st te chers
Austria has been  first time of  l st te chers of
