**Assignment No. 3**: Fine-tune GPT or GPT-2 for creative story generation.


In [None]:
import torch
import os
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset

# Disable Weights & Biases (W&B) logging
os.environ["WANDB_DISABLED"] = "true"

# Load GPT-2 Tokenizer and Model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

tokenizer.pad_token = tokenizer.eos_token  # Set pad token

def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)

# Load Dataset (Replace with your dataset or use a text file)
dataset = load_dataset("tiny_shakespeare")  # Example dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",  # Updated from evaluation_strategy
    save_strategy="epoch",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_total_limit=2,
)

# Data collator to handle padding
collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # Disable masked language modeling since GPT-2 is causal
)

# Trainer initialization
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=collator,  # Removed tokenizer argument
)

# Fine-tune the model
trainer.train()

# Save fine-tuned model
model.save_pretrained("./fine_tuned_gpt2")
tokenizer.save_pretrained("./fine_tuned_gpt2")

# Generate a story
def generate_story(prompt, max_length=200):
    inputs = tokenizer(prompt, return_tensors="pt").input_ids
    outputs = model.generate(inputs, max_length=max_length, temperature=0.7, top_k=50, top_p=0.95, do_sample=True)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Example prompt
prompt = "Once upon a time in a distant kingdom,"
story = generate_story(prompt)
print(story)


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Epoch,Training Loss,Validation Loss
1,No log,3.327637
2,No log,3.255084
3,No log,3.23847


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Once upon a time in a distant kingdom, the queen would be a little afraid that the kingdom might fall into the hands of the demons. But the queen could not be certain that she would be safe, and so she resolved to seek her own safety in her own right.

Thereupon, when the king, seeing the queen's fear, saw how great her fear was, he began to speak the words which she had spoken, and in the midst of his words he cried out:

"My mother is still in the midst of the world, and she is still in the midst of the earth, but she is still in the midst of all the living. But I know that she is still in the midst of the world. She is still in the midst of all the living, but she is still in the midst of all the living. And so, my mother is still in the midst of all the living, but she is still in the midst of all the living.
