In [1]:

# uninstalls/installs for deprecated version of TRL

# remove earlier version of trl
!pip uninstall trl -y

# clear cache
!pip cache remove trl

# install older version of trl that allows for custom reward score (vs incorporating the reward model in the workflow)
# !pip install trl==0.11.4 --no-cache-dir --force-reinstall
# !pip install trl==0.4.7 --no-cache-dir --force-reinstall
!pip install trl==0.11.4


[0mFound existing installation: trl 0.11.4
Uninstalling trl-0.11.4:
  Successfully uninstalled trl-0.11.4
[0mFiles removed: 0
[0mCollecting trl==0.11.4
  Using cached trl-0.11.4-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cudnn-cu12==9.10.2.21 (from torch>=1.4.0->trl==0.11.4)
  Using cached nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl.metadata (1.8 kB)
Using cached trl-0.11.4-py3-none-any.whl (316 kB)
Using cached nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl (706.8 MB)
[0mInstalling collected packages: nvidia-cudnn-cu12, trl
[0mSuccessfully installed nvidia-cudnn-cu12 trl-0.11.4


In [22]:
import torch
import trl
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from transformers import AutoTokenizer

import torch
from datasets import Dataset

import random

from datasets import load_dataset

from tqdm import tqdm

# confirm TRL install
print('TRL Version:', trl.__version__)

TRL Version: 0.11.4


In [23]:
# model set up
# (PPO requires a model with a value head)
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token


In [24]:
# custom reward function
def get_reward_score(response_text):
    # TODO: replace this with our weighted sum reward score from multiple reward signals
    # based on the response_text parameter
    # currently, just randomly 0 or 1
    score = float(random.randint(0, 1))

    return score


In [25]:
# initialize trainer
config = PPOConfig(batch_size=16, mini_batch_size=16, gradient_accumulation_steps=1)
# config = PPOConfig(batch_size=1, mini_batch_size=1, gradient_accumulation_steps=1)





In [26]:
# *** NOTE TO TEAM: IF WE DECIDE TO USE MORE RECENT VERSIONS OF TRL,
# WE'LL NEED TO USE A DUMMY MODEL LIKE THIS TO PASS INTO THE PPOTRAINER

# # Define a dummy reward model as an nn.Module
# # This is required by the error message, even if rewards are calculated manually later.
# class DummyRewardModel(torch.nn.Module):
#     def forward(self, input_ids, attention_mask=None):
#         # Return a tensor of zeros, as the custom reward function will be used.
#         return torch.zeros((input_ids.shape[0], 1), device=input_ids.device)


In [27]:
# load training data

# load the IMDb dataset
# TODO: replace this with our own training data
imdb_dataset = load_dataset('imdb')

# use a subset of IMDb for the POC so it doesn't run for hours
# taking the first 200 examples for demonstration
subset_dataset = imdb_dataset['train'].select(range(200))

# define a dummy train dataset, required by PPOTrainer
dummy_data = [{
    "query": "This is a placeholder query for the dummy dataset."
}] * config.batch_size
dummy_dataset = Dataset.from_list(dummy_data)

In [28]:
# initialize PPOTrainer
ppo_trainer = PPOTrainer(
    config,
    model,
    ref_model,
    tokenizer,
    dummy_dataset, # dummy dataset - just a placeholder
)

In [29]:

# PPO training

# 1. Setup Device and Config
device = ppo_trainer.accelerator.device
print(f"Training on device: {device}")

# 2. Configuration
BATCH_SIZE = 16
MAX_LENGTH = 128  # truncate inputs to avoid OOM or context errors

print("Starting training...")

# 3. Training Loop
for epoch in range(1):
    # iterate over the dataset in chunks of BATCH_SIZE
    for i in tqdm(range(0, len(subset_dataset), BATCH_SIZE)):

        # get batch of text
        batch_text = subset_dataset[i : i + BATCH_SIZE]['text']

        # check if the batch is full; PPO requires fixed batch size
        if len(batch_text) < BATCH_SIZE:
            continue

        # storage for this step
        query_tensors = []
        response_tensors = []
        rewards = []
        decoded_responses = []

        # process batch
        for text in batch_text:
            # A. Tokenize and Move to Device
            # truncate to MAX_LENGTH to prevent errors with long text strings
            query_tensor = tokenizer.encode(
                text,
                return_tensors="pt",
                max_length=MAX_LENGTH,
                truncation=True
            ).to(device)[0] # [0] removes batch dim to get 1D tensor

            query_tensors.append(query_tensor)

            # B. Generate Response
            #  use the trainer's generate function
            response_tensor = ppo_trainer.generate(
                query_tensor,
                max_new_tokens=15,
                pad_token_id=tokenizer.eos_token_id
            ).squeeze() # ensure it's 1D

            response_tensors.append(response_tensor)

            # C. Decode for Reward Calculation
            # detokenize to text (TODO: will need this response text to send to get_reward_score to calculate our reward score)
            response_txt = tokenizer.decode(response_tensor, skip_special_tokens=True)
            decoded_responses.append(response_txt)

        # D. Calculate Rewards
        for response_txt in decoded_responses:
            reward_val = get_reward_score(response_txt)
            rewards.append(torch.tensor(reward_val).to(device)) # Move reward to device too

        # E. PPO Step
        # pass the lists of tensors to the trainer
        stats = ppo_trainer.step(query_tensors, response_tensors, rewards)

        # print stats every few steps
        if i % (BATCH_SIZE * 2) == 0:
            print(f"Step {i//BATCH_SIZE}: Mean Reward: {stats['ppo/mean_scores']:.4f}")

Training on device: cuda
Starting training...


  8%|▊         | 1/13 [00:03<00:44,  3.72s/it]

Step 0: Mean Reward: 0.6250


 23%|██▎       | 3/13 [00:11<00:38,  3.87s/it]

Step 2: Mean Reward: 0.4375


 38%|███▊      | 5/13 [00:19<00:31,  3.89s/it]

Step 4: Mean Reward: 0.4375


 54%|█████▍    | 7/13 [00:27<00:23,  3.90s/it]

Step 6: Mean Reward: 0.1875


 69%|██████▉   | 9/13 [00:35<00:15,  3.92s/it]

Step 8: Mean Reward: 0.3750


 85%|████████▍ | 11/13 [00:42<00:07,  3.84s/it]

Step 10: Mean Reward: 0.5625


100%|██████████| 13/13 [00:46<00:00,  3.58s/it]
