In [None]:
# Task-01: Text Generation with GPT-2
# -----------------------------------

# Step 1: Install Hugging Face Transformers & Datasets (run only once in Jupyter)
!pip install transformers datasets torch

# Step 2: Import required libraries
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
from datasets import load_dataset
import torch

# Step 3: Load pre-trained GPT-2 tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

# Step 4: Add padding token if missing (important for training)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

# Step 5: Load a dataset (example: wikitext for demo, replace with your custom dataset later)
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

# Step 6: Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Step 7: Prepare dataset for training
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["validation"]

# Step 8: Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=1,
    weight_decay=0.01,
    save_total_limit=2,
    logging_dir="./logs",
    logging_steps=10,
)

# Step 9: Define Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

# Step 10: Train the model
trainer.train()

# Step 11: Save the fine-tuned model
model.save_pretrained("./fine_tuned_gpt2")
tokenizer.save_pretrained("./fine_tuned_gpt2")

# Step 12: Generate text with fine-tuned GPT-2
prompt = "Artificial Intelligence is"
inputs = tokenizer.encode(prompt, return_tensors="pt")

outputs = model.generate(
    inputs,
    max_length=100,
    num_return_sequences=1,
    temperature=0.7,
    top_p=0.9,
    do_sample=True
)

print("Generated Text:\n")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))