# Calculate loss for base model checkpoints

Since I didn't include a validation set during pre-training, this notebook is to find the best performing checkpoints using (hopefully) unseen data during training.

In [None]:
# login to huggingface to avoid rate-limit
!huggingface-cli login --token <token>

In [None]:
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

model_name = "/root/fineweb-gpt2-356m/checkpoint-9652"
tokenizer = AutoTokenizer.from_pretrained(model_name)

streaming_eval = load_dataset(
    "HuggingFaceFW/fineweb",
    "sample-100BT",
    split="train",
    streaming=True
).shuffle(buffer_size=10_000, seed=3047).take(3000)
eval_list    = list(streaming_eval)
raw_dataset  = Dataset.from_list(eval_list)

In [None]:
context_length = 1024

def concatenate_and_chunk(element):
    all_token_ids = []
    for text in element["text"]:
        token_ids = tokenizer.encode(text, add_special_tokens=False)
        all_token_ids.extend(token_ids)
        all_token_ids.append(tokenizer.eos_token_id)

    total_length = len(all_token_ids)

    if total_length < context_length:
        return {"input_ids": [], "labels": []}

    total_length = (total_length // context_length) * context_length

    # Split the concatenated tokens into chunks of context_length
    chunks_input_ids = []
    for i in range(0, total_length, context_length):
        chunk = all_token_ids[i : i + context_length]
        if len(chunk) == context_length:
            chunks_input_ids.append(chunk)

    output = {"input_ids": chunks_input_ids, "labels": chunks_input_ids.copy()}
    return output


raw_dataset = raw_dataset.remove_columns(
    [col for col in raw_dataset.column_names if col != "text"]
)

tokenized = raw_dataset.map(
    concatenate_and_chunk,
    batched=True,
    remove_columns=raw_dataset.column_names,
    num_proc=os.cpu_count(),
)

# remove those invalid rows
tokenized = tokenized.filter(lambda ex: len(ex["input_ids"]) > 0)

tokenized.set_format("torch", columns=["input_ids", "labels"])

In [None]:
from torch.utils.data import DataLoader
eval_dataloader = DataLoader(tokenized, batch_size=32, shuffle=False)

checkpoint_dir = "./fineweb-gpt2-356m"
checkpoint_folders = [
    f.path for f in os.scandir(checkpoint_dir)
    if f.is_dir() and f.name.startswith("checkpoint-")
]
checkpoint_folders.sort(key=lambda x: int(x.split("-")[-1]))

device = "cuda" if torch.cuda.is_available() else "cpu"
results = {}

for checkpoint_path in checkpoint_folders:
    print(f"Evaluating {checkpoint_path}")
    try:
        model = AutoModelForCausalLM.from_pretrained(checkpoint_path).to(device)
        model.eval()

        total_loss = 0.0
        num_batches = 0

        with torch.no_grad():
            for batch in tqdm(eval_dataloader, desc="batches", leave=False):
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(**batch)
                total_loss += outputs.loss.item()
                num_batches += 1

        avg_loss = total_loss / num_batches if num_batches else float("nan")
        results[checkpoint_path] = avg_loss
        print(f"  → avg_loss: {avg_loss:.4f}")

        # cleanup the memory
        del model
        torch.cuda.empty_cache()

    except Exception as e:
        print(f"Error at {checkpoint_path}: {e}")

print("\nFinal Results:")
for ckpt, loss in results.items():
    print(f"{os.path.basename(ckpt)}: {loss:.4f}")