# TRL Custom Training Process Test

Testing customized PPO training with a small model (GPT-2) to validate our approach.

In [5]:

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_scheduler,
    BitsAndBytesConfig
)
from trl import PPOConfig, PPOTrainer, create_reference_model, AutoModelForCausalLMWithValueHead
from datasets import load_dataset
from typing import List, Dict, Any, Optional
import numpy as np

# MODEL_NAME = "EleutherAI/pythia-1b-deduped"  # Using small GPT-2 for testing
MODEL_NAME = "openai-community/gpt2"

# Load models and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

# Load main model
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
reward_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=1
)
ref_model = create_reference_model(model)
# ref_mode = None
value_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=1
)

Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at openai-community/gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at openai-community/gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# Configuration
from datasets import DatasetDict


DATASET_NAME = "trl-internal-testing/sentiment-trl-style"  # Simple sentiment dataset

dataset: DatasetDict = load_dataset(DATASET_NAME)
train_dataset = dataset['train']
test_dataset = dataset['test']

In [7]:
def prepare_dataset(dataset: DatasetDict, tokenizer):
    """pre-tokenize the dataset before training; only collate during training"""

    def tokenize(element):
        input_ids = tokenizer(
            element["prompt"]
        )
        return {"input_ids": input_ids['input_ids'], "lengths": len(input_ids)}

    return dataset.map(
        tokenize,
        remove_columns=dataset.column_names,
    )

preprocessed_train = prepare_dataset(train_dataset, tokenizer)
preprocessed_test = prepare_dataset(test_dataset, tokenizer)

In [20]:
preprocessed_train['input_ids'][0]

[1544,
 750,
 407,
 588,
 1243,
 543,
 14516,
 683,
 326,
 339,
 550,
 1752,
 587,
 407,
 691,
 2354,
 262,
 10393,
 5002,
 475,
 772,
 2354,
 262,
 5535,
 13,
 679,
 750,
 407,
 1464,
 588,
 20920,
 2035,
 13,
 2399,
 9476,
 287,
 852,
 351,
 683,
 373,
 407,
 326,
 3297,
 286,
 9476,
 13,
 198,
 198,
 1,
 5297,
 553,
 531,
 20920,
 13,
 366,
 35,
 1697,
 36363,
 373,
 534,
 4039,
 8976,
 13]

In [None]:
from typing import Any, Generator, Never

from torch.utils.data import DataLoader
dataloader = DataLoader(
    preprocessed_train,
    batch_size=4,
    shuffle=True,
    collate_fn=None,
    drop_last=True,  # needed; otherwise the last batch will be of ragged shape
)
def repeat_generator() -> Generator[Any, Any, Never]:
    while True:
        yield from dataloader

iter_dataloader = iter(repeat_generator())
next(iter_dataloader)

In [11]:
# Initialize trainer
ppo_config = PPOConfig(
    learning_rate=3e-6,
    output_dir='checkpoints/',
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,  # Effective batch size = 16
    total_episodes=100,  # Small number for testing
    max_steps=20,
    cliprange=0.2,
    missing_eos_penalty=1.0
)
trainer = PPOTrainer(
    args=ppo_config,
    processing_class=tokenizer,
    value_model=value_model,
    model=model,
    ref_model=ref_model,
    reward_model=reward_model,
    train_dataset=preprocessed_train,
    # data_collator: Optional[DataCollatorWithPadding] = None,
    eval_dataset=preprocessed_test
)

# Add learning rate scheduler
trainer.train()

DEBUG LIB : <torch.utils.data.dataloader.DataLoader object at 0x7fb0196b9c70>
DEBUG LIB2 : <accelerate.data_loader.DataLoaderShard object at 0x7fb013126ba0>
DEBUG LIB3 : <accelerate.data_loader.DataLoaderShard object at 0x7fb013126ba0>
DEBUG LIB4 : {'input_ids': tensor([[ 1639,  3088,   284,  ..., 50256, 50256, 50256],
        [    1,  1639,  1297,  ..., 50256, 50256, 50256],
        [18565,    11,   198,  ..., 50256, 50256, 50256],
        ...,
        [ 2504,   373,  1239,  ..., 50256, 50256, 50256],
        [  464,  8796,   373,  ...,    13, 50256, 50256],
        [  464, 16570,  7342,  ..., 50256, 50256, 50256]], device='cuda:0'), 'lengths': tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')}
DEBUG LIB4 : <gen

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33meryaw[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss


KeyboardInterrupt: 

In [None]:
# Test generation with trained model
test_prompt = "Write a positive review: "
inputs = tokenizer(test_prompt, return_tensors="pt")

with torch.no_grad():
    outputs = model.generate(
        inputs["input_ids"],
        max_new_tokens=50,
        do_sample=True,
        temperature=0.7
    )

print("Generated text:")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))