In [5]:
# prompt: update satasets class

!pip install datasets -U

Collecting datasets
  Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-4.0.0-py3-none-any.whl (494 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m494.8/494.8 kB[0m [31m18.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (193 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m193.6/193.6 kB[0m [31m21.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fsspec, datasets
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2025.3.2
    Uninstalling fsspec-2025.3.2:
      Successfully uninstalled fsspec-2025.3.2
  Attempting uninstall: datasets


In [2]:
import torch
import math
import numpy as np
from datasets import load_dataset
from transformers import (
    GPT2LMHeadModel,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)

# Set device explicitly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üöÄ Using device: {device}")

# 1. Load 5% of WikiText-2
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
dataset["train"] = dataset["train"].select(range(int(0.50 * len(dataset["train"]))))
dataset["validation"] = dataset["validation"].select(range(int(0.5 * len(dataset["validation"]))))
dataset["test"] = dataset["test"].select(range(int(0.5 * len(dataset["test"]))))

# 2. Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# 3. Tokenize
def tokenize(example):
    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=128)

tokenized_dataset = dataset.map(tokenize, batched=True, remove_columns=["text"])
tokenized_dataset.set_format("torch")

# 4. Load model
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.eos_token_id
model.to(device)

# 5. Data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# 6. Training args
training_args = TrainingArguments(
    output_dir="./gpt2-finetuned",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=1,
    logging_steps=10,
    fp16=True,  # GPU acceleration with mixed precision
    report_to="none",
    push_to_hub=False
)

# 7. Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# 8. Fine-tune
trainer.train()

# 9. Evaluate: compute loss & perplexity
eval_results = trainer.evaluate(tokenized_dataset["test"])
perplexity = math.exp(eval_results["eval_loss"])
print(f"\nüìâ Perplexity: {perplexity:.2f}")

# 10. (Optional) Top-k accuracy function
def compute_top_k_accuracy(model, dataset, k=5, num_batches=10):
    model.eval()
    correct = 0
    total = 0

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)

    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)

            logits = outputs.logits
            next_token_logits = logits[:, -2, :]  # Predict next word
            top_k_preds = torch.topk(next_token_logits, k, dim=-1).indices

            true_token = input_ids[0, -1]
            if true_token in top_k_preds[0]:
                correct += 1
            total += 1

            if idx >= num_batches:
                break

    return correct / total if total > 0 else 0

topk_acc = compute_top_k_accuracy(model, tokenized_dataset["test"], k=5, num_batches=100)
print(f"üéØ Top-5 Accuracy: {topk_acc:.2%}")


üöÄ Using device: cuda


Map:   0%|          | 0/2179 [00:00<?, ? examples/s]

Map:   0%|          | 0/18359 [00:00<?, ? examples/s]

Map:   0%|          | 0/1880 [00:00<?, ? examples/s]

  trainer = Trainer(


Step,Training Loss
10,3.9104
20,4.0175
30,3.8665
40,3.9921
50,3.8556
60,3.8476
70,3.7646
80,3.8293
90,3.7485
100,3.7314


Step,Training Loss
10,3.9104
20,4.0175
30,3.8665
40,3.9921
50,3.8556
60,3.8476
70,3.7646
80,3.8293
90,3.7485
100,3.7314



üìâ Perplexity: 29.20
üéØ Top-5 Accuracy: 11.88%
