# Training/fine-tuning of GPT2

Use Penn-treebank-dataset on kaggle

If you want to train a network from scratch, you may use 
```
config = GPT2Config(vocab_size = 50257, n_positions = 1024, n_embd = 768, n_layer = 12, n_head = 12)
model = GPT2LMHeadModel(config)
```

In [2]:
!cp -r /kaggle/input/penn-treebank-dataset /kaggle/working/

In [3]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments


# Load pre-trained GPT-2 tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
# Load and preprocess the Wikitext dataset
train_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path="/kaggle/working/penn-treebank-dataset/ptbdataset/ptb.train.txt",  # specify the path to Wikitext train dataset
    block_size=128  # adjust according to your computational resources
)
eval_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path="/kaggle/working/penn-treebank-dataset/ptbdataset/ptb.test.txt",  # specify the path to Wikitext validation dataset
    block_size=128  # adjust according to your computational resources
)


In [4]:
train_dataset

<transformers.data.datasets.language_modeling.TextDataset at 0x7a3dd3ebcf70>

In [6]:
# Prepare training arguments
training_args = TrainingArguments(
    output_dir="./output",
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=4,  # adjust based on your GPU memory
    per_device_eval_batch_size=4,  # adjust based on your GPU memory
    logging_dir="./logs",
    logging_steps=100,
    save_steps=500,
    eval_steps=500,
    evaluation_strategy="steps",
    save_total_limit=2,
)


In [7]:
# Prepare data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [8]:
# Train the model
trainer.train()

# Evaluate the model
trainer.evaluate()

# Save the model
trainer.save_model("./gpt2-trained")


[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc




Step,Training Loss,Validation Loss
500,3.19,3.030819
1000,3.0465,2.940042
1500,2.9152,2.901784
2000,2.8905,2.874459
2500,2.7834,2.861092
3000,2.7665,2.852276




In [9]:
model.to('cuda')
def generate_text(prompt, model, max_length=100, temperature=1.0, top_k=50, device='cuda'):
    # Load pre-trained tokenizer and model

    # Tokenize input prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    # Generate text based on prompt
    output = model.generate(
        input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        pad_token_id=tokenizer.eos_token_id,
        num_return_sequences=1
    )

    # Decode generated tokens back to text
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    
    return generated_text

# Example prompt
prompt = "new york stock exchange"

# Generate text using GPT-2 model
generated_text = generate_text(prompt, model, max_length=50, temperature=0.7, top_k=30, device='cuda')

print("Generated text:")
print(generated_text)




Generated text:
new york stock exchange composite trading yesterday 
 the dow jones industrial average closed at N 
 the dow jones industrial average closed at N 
 the dow jones industrial average closed at N 
 the dow jones industrial average closed
