In [None]:
import os, torch, glob, re
from tqdm.auto import tqdm
from transformers import GPT2LMHeadModel
from torch.utils.data import IterableDataset, DataLoader, Dataset
import torch

class PTIterableDataset(IterableDataset):
    def __init__(self, pt_files):
        self.pt_files = pt_files

    def __iter__(self):
        for file_path in self.pt_files:
            data = torch.load(file_path)
            for i in range(data["input_ids"].size(0)):
                sample = {
                    "input_ids": data["input_ids"][i],
                    "attention_mask": data["attention_mask"][i],
                    "files": file_path.split('/')[-1]
                }
                if data.get("labels") is not None:
                    sample["labels"] = data["labels"][i]
                yield sample


def extract_file_numbers(filename):
    match = re.search(r'(\d+)', filename)
    return int(match.group(1)) if match else 0

def evaluate_half(model, test_loader, total_steps, device):
    model = model.half().to(device).eval()
    torch.backends.cudnn.benchmark = True  # GPU kernel autotuning

    total_loss = torch.tensor(0.0, device=device)
    pbar = tqdm(test_loader, total=total_steps, desc="Evaluating", unit="batch")

    with torch.inference_mode(): # lighter than no_grad()
        for step, batch in enumerate(pbar, start=1):
            input_ids = batch['input_ids'].to(device, non_blocking=True)
            attn_mask = batch['attention_mask'].to(device, non_blocking=True)

            loss = model(
                input_ids=input_ids,
                attention_mask=attn_mask,
                labels=input_ids
            ).loss

            total_loss += loss
            pbar.set_postfix(batch_loss=f"{loss.item():.4f}")

    avg_loss   = total_loss / total_steps
    print(step)
    perplexity = torch.exp(avg_loss)

    print(f"Avg loss: {avg_loss:.4f}  —  Perplexity: {perplexity:.2f}")
    return avg_loss.item(), perplexity.item()


In [None]:
loader_batch_size = 16
test_files_ = sorted(glob.glob("processed_batches/test/*.pt"), key=extract_file_numbers)
test_files_ = test_files_[:1000]
test_ds = PTIterableDataset(test_files_)
test_loader = DataLoader(test_ds, batch_size=loader_batch_size, num_workers=0, drop_last=True)

# test_loader = DataLoader(test_ds, batch_size=loader_batch_size, num_workers=4, drop_last=True, pin_memory=True)
print(next(iter(test_loader)))

In [None]:
total_test_tokens = len(test_files_) * 16 * 1024
total_steps = total_test_tokens / (loader_batch_size * 1024)
print(f"{total_test_tokens:.2e} tokens")
print(f"{total_steps} total_steps")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = 'cwestnedge/gpt2-small-pubmed'
model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
avg_loss, ppl = evaluate_half(model, test_loader, total_steps)
print(f"{model_id}: avg_loss = {avg_loss:.4f}, ppl = {ppl:.2f}")

In [None]:
model_id = 'cwestnedge/gpt2-medium-pubmed'
model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
avg_loss, ppl = evaluate_half(model, test_loader, total_steps)
print(f"{model_id}: avg_loss = {avg_loss:.4f}, ppl = {ppl:.2f}")

In [None]:
model_id = 'cwestnedge/gpt2-large-pubmed'
model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
avg_loss, ppl = evaluate_half(model, test_loader, total_steps)
print(f"{model_id}: avg_loss = {avg_loss:.4f}, ppl = {ppl:.2f}")