# PPO Trainer for the Generally Sarcastic Transformer

## Packages

In [1]:

# uninstalls/installs for deprecated version of TRL

# remove earlier version of trl
!pip uninstall trl -y

# clear cache
!pip cache remove trl

# install older version of trl that allows for custom reward score (vs incorporating the reward model in the workflow)
# !pip install trl==0.11.4 --no-cache-dir --force-reinstall

# NOTE: v0.8.6 and v0.11.4 both seem to run on similar architecture
# but v0.11.4 throws more errors, trying to push users to PPOv2
# so for simlicity/stability, v0.8.6 may be preferred

!pip install trl==0.11.4
# !pip install trl==0.8.6



[0mFiles removed: 0
Collecting trl==0.11.4
  Downloading trl-0.11.4-py3-none-any.whl.metadata (12 kB)
Collecting tyro>=0.5.11 (from trl==0.11.4)
  Downloading tyro-0.9.35-py3-none-any.whl.metadata (12 kB)
Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl==0.11.4)
  Downloading shtab-1.8.0-py3-none-any.whl.metadata (7.3 kB)
Downloading trl-0.11.4-py3-none-any.whl (316 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m316.6/316.6 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tyro-0.9.35-py3-none-any.whl (132 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.6/132.6 kB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading shtab-1.8.0-py3-none-any.whl (14 kB)
Installing collected packages: shtab, tyro, trl
Successfully installed shtab-1.8.0 trl-0.11.4 tyro-0.9.35


In [2]:
import torch
import trl
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from transformers import AutoTokenizer, pipeline, Pipeline, AutoModelForSequenceClassification

import torch
from datasets import Dataset

import random

from datasets import load_dataset

from tqdm import tqdm
import gc

from google.colab import userdata

In [3]:
# confirm TRL install
print('TRL Version:', trl.__version__)
assert trl.__version__ in ('0.11.4','0.8.6')

TRL Version: 0.11.4


## Config

### Logins

In [4]:
USE_DRIVE = False      # To save the model after training
USE_HUGGINGFACE = True # To save the model after training
USE_WANDB = True

In [5]:
# mount google drive - specifically to save trained ppo model to
from google.colab import drive
if USE_DRIVE:
  drive.mount('/content/drive')
  drive_path = '/content/drive/MyDrive/'

In [6]:
# Hugging face login
from huggingface_hub import login
from huggingface_hub import HfApi
if USE_HUGGINGFACE:
  fh_username = "marcbishara"
  login(token=userdata.get('HF_TOKEN'))

  print(f"**************************\nUsing user: \"{fh_username}\" REPLACE WITH YOUR OWN")

**************************
Using user: "marcbishara" REPLACE WITH YOUR OWN


In [7]:
# wandb configuration
import wandb
# if USE_WANDB:
#   wandb.init()

In [25]:
# initialize PPOConfig
if USE_WANDB:
  log_with = "wandb"
else:
  log_with = None

config = PPOConfig(
    model_name='openai-community/gpt2', #Zoe3324/gpt2-sft-full',
    learning_rate=1.41e-5,
    log_with=log_with,
    reward_model = 'tmrcnl/SarcasmRewardModel', #marcbishara/SarcasmRewardModel',
    batch_size=64,
    ppo_epochs=2,
    steps=10000,                  # Default is 20000
    mini_batch_size=32,           # Default is 128
    gradient_accumulation_steps=1 # Default is 1
)





### Dataset

In [None]:
def build_sarcasm_dataset(
    config,
    dataset_name="marcbishara/sarcasm-on-reddit",
    split_name="ppo_train",
    min_text_length=10,
    num_of_rows=None
):

    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token

    ds = load_dataset(dataset_name, split=split_name)

    # Filter out short comments
    ds = ds.filter(lambda x: len(x["parent_comment"]) >= min_text_length)

    # Limit by number of rows if provided
    if num_of_rows is not None:
        ds = ds.select(range(num_of_rows))

    # batch tokenize function
    def tokenize(samples):
        # create a list of templated strings
        templated_queries = [
            f"Parent comment: {parent}\nSarcastic reply:" 
            for parent in samples['parent_comment']
        ]

        # tokenize the whole list at once
        enc = tokenizer(
            templated_queries,
            truncation=True,
            max_length=128,
            padding='max_length',
            return_attention_mask=True
        )

        samples["input_ids"] = enc["input_ids"]
        samples["attention_mask"] = enc["attention_mask"]
        # use batch_decode for speed
        samples["query"] = tokenizer.batch_decode(enc["input_ids"])
        
        return samples

    # Apply tokenization
    ds = ds.map(tokenize, batched=False)

    # Convert to torch tensors
    ds.set_format(type="torch")
    # ds.set_format(type="torch", columns=["input_ids", "attention_mask", "query"])

    return ds

In [27]:
dataset = build_sarcasm_dataset(config, num_of_rows=None) #If you don't want to run the full dataset, limit the number of rows

Map:   0%|          | 0/267648 [00:00<?, ? examples/s]

In [11]:
# Smoke test on the dataset
sarcastic_lbls_cnt = dataset.filter(lambda x: x["label"] == 1).num_rows
print(f"Dataset length: {len(dataset)} with {round(sarcastic_lbls_cnt / len(dataset) * 100, 2)}% sarcastic comments")
print("Sample entry:")
print(dataset[15])

Filter:   0%|          | 0/1000 [00:00<?, ? examples/s]

Dataset length: 1000 with 49.4% sarcastic comments
Sample entry:
{'label': tensor(1), 'comment': 'How dare they try to make a profit, for shame!', 'author': 'Thenuclearwalrus', 'subreddit': 'wow', 'score': tensor(1), 'ups': tensor(-1), 'downs': tensor(-1), 'date': '2016-11', 'created_utc': '2016-11-15 12:53:43', 'parent_comment': 'The restriction is pointless and only serves to milk extra game time from you', 'input_ids': tensor([24546,  2912,    25,   383, 17504,   318, 27158,   290,   691,  9179,
          284,  7545,  3131,   983,   640,   422,   345,   198,    50,   283,
         2701,   291, 10971,    25]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), 'query': 'Parent comment: The restriction is pointless and only serves to milk extra game time from you\nSarcastic reply:'}


In [12]:
# use lambda collator to ensure 'input_ids' are stacked correctly
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

### Models

In [13]:
# Model loaded twice, the first will be updated on policy and the second is used to calculate KL divergence

model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

In [14]:
class SarcasmRMPipeline(Pipeline):
    def __init__(self, model, tokenizer):
        super().__init__(model=model, tokenizer=tokenizer)

    def _sanitize_parameters(self, **kwargs):
        return {}, {}, {}

    # def preprocess(self, inputs):
    #     # Tokenize the input text(s)
    #     return self.tokenizer(inputs, return_tensors="pt")


    def preprocess(self, inputs):
      # Expect inputs as a tuple (parent_comment, comment)
      if isinstance(inputs, tuple) and len(inputs) == 2:
          parent, reply = inputs
          return self.tokenizer(
              parent,
              reply,
              return_tensors="pt",
              truncation=True,
              padding=True,
              max_length=128
          )
      else:
        raise ValueError("Inputs must be a tuple of two strings: (parent_comment, comment)")



    def _forward(self, model_inputs):
        # Move inputs to the same device as the model
        model_inputs = {k: v.to(self.model.device) for k, v in model_inputs.items()}
        return self.model(**model_inputs)

    def postprocess(self, model_outputs):
        # Convert logits to probabilities
        probs = model_outputs.logits.softmax(dim=-1).detach().cpu().numpy()[0]
        # 0 = non-sarcasm, index 1 = sarcasm
        labels = ["not_sarcastic", "sarcastic"]
        return {
            "label": labels[probs.argmax()],
            "score": float(probs.max()),
            "probabilities": {labels[i]: float(probs[i]) for i in range(len(labels))}
        }

In [15]:
# sarcasm reward model
rm_tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
reward_model = AutoModelForSequenceClassification.from_pretrained(config.reward_model)
reward_model_pipe = SarcasmRMPipeline(model=reward_model, tokenizer=rm_tokenizer)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Device set to use cuda:0


In [16]:
# Smoke test the reward model

text1 = dataset[15]["parent_comment"]
text2 = dataset[15]['comment']
rm_output = reward_model_pipe((text1, text2))
print(f"Feeding: {text1}, {text2} into reward model and getting back:\n{rm_output}\nTrue label is {dataset[15]['label']}")

Feeding: The restriction is pointless and only serves to milk extra game time from you, How dare they try to make a profit, for shame! into reward model and getting back:
{'label': 'sarcastic', 'score': 0.9908868074417114, 'probabilities': {'not_sarcastic': 0.009113184176385403, 'sarcastic': 0.9908868074417114}}
True label is 1


### Trainer config

In [28]:
ppo_model_name =  (config.model_name).split('/')[-1] + "_" + \
                  str(config.ppo_epochs) + "Eps" + "_" + \
                  str(config.steps) + "Stp" + "_" + \
                  str(config.batch_size) + "bs" + "_" + \
                  str(config.mini_batch_size) + "mbs" + "_" + \
                  str(config.learning_rate).replace('.','_') + "lr"

print(f"Model name: {ppo_model_name}")

Model name: gpt2_2Eps_10000Stp_64bs_32mbs_1_41e-05lr


In [29]:
if USE_WANDB:
  # wandb.init(project="ppo-training", name=ppo_model_name) # PPO_Trainer hijacks wandb and forces the project name and run name
  wandb.init()

0,1
env/reward_mean,▄▅▄▄▁▃▄▂▂▃▂▇▅▇▄▆▃▅▆▆▆▄██▅▇▅▄▃▆
env/reward_std,▃▃▅▄▅▃▆▄▁▇▄▄▇▆▆▅█▄▅█▆█▇▆▆▆▇▆▆▇
objective/entropy,▃▇▃▆▄█▆█▇▄▃▂▂▅▄▁▃▂▆▄▃▅▃▄▅█▃▄▆▄
objective/kl,▁▂▂▃▃▄▄▆▆▆▆▇▇███▇▇▇███▇▆▆▇▇▆▇▇
objective/kl_coef,█▇▆▆▅▄▃▃▂▁▁▁▁▁▂▃▄▄▄▄▅▆▇▇▆▆▆▆▆▆
ppo/learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
ppo/loss/policy,▁▅█▇▆▇█▇▆▇█▆▅▅▄▅▇▆▅▅▅▄▅▅▅▅▅▃▄▄
ppo/loss/total,█▇▆▅▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
ppo/loss/value,█▆▆▄▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
ppo/mean_non_score_reward,█▇▇▆▆▅▅▃▃▃▃▃▂▁▁▁▂▂▂▁▁▁▂▃▃▂▂▃▂▂

0,1
env/reward_mean,0.48701
env/reward_std,0.28512
objective/entropy,148.93912
objective/kl,5.98313
objective/kl_coef,0.19949
ppo/learning_rate,1e-05
ppo/loss/policy,-0.02614
ppo/loss/total,-0.00172
ppo/loss/value,0.24417
ppo/mean_non_score_reward,-0.03741


In [30]:
# initialize PPOTrainer
ppo_trainer = PPOTrainer(
    model=model,
    ref_model=ref_model,
    config=config,
    dataset=dataset,
    tokenizer=tokenizer,
    data_collator=collator
)



In [31]:
device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug
print(f"Training on device: {device}")

# see https://huggingface.co/docs/trl/v0.8.6/ppo_trainer
generation_kwargs = {
    'min_length': -1, # don't ignore the EOS token
    'top_k': 0.0, # no top-k sampling
    'top_p': 1.0, # no nucleus sampling
    'do_sample': True, # yes, we want to sample
    'pad_token_id': tokenizer.eos_token_id, # most decoder models don't have a padding token - use EOS token instead
    'max_new_tokens': 32, # specify how many tokens you want to generate at most
}

# define how often to print
LOG_INTERVAL = 1

DEBUG = False

EPOCHS = 2


Training on device: 0


## Training

In [32]:
# Clear GPU RAM
if torch.cuda.is_available():
  torch.cuda.empty_cache()

# Garbage collection
gc.collect()

408

### Training loop

In [33]:
# revised PPO training loop

print("Starting training...")
print(f"Number of batches per epoch: {len(ppo_trainer.dataloader)}")

epoch = 0
i = 0

if USE_WANDB:
    all_samples_table = wandb.Table(columns=["query", "response", "reward"], log_mode="MUTABLE")

for epoch in tqdm(range(EPOCHS), desc=f'Epoch: {epoch+1}'):
  for i, batch in enumerate(tqdm(ppo_trainer.dataloader, desc=f'Batch: {i+1}')):

      # get query_tensors as tensors
      query_tensors = batch['input_ids']

      #### Get response from Policy model
      response_tensors = []

      # Passing attention mask
      attention_masks = batch['attention_mask']
      for query, mask in zip(query_tensors, attention_masks):
        query_response = ppo_trainer.generate(
            query,
            attention_mask=mask.unsqueeze(0),
            **generation_kwargs
        ).squeeze()
        response_len = len(query_response) - len(query)
        # new_tokens = query_response[query.shape[0]:]
        response_tensors.append(query_response[-response_len:])


      # Without attention mask
      # for query in query_tensors:
      #     query_response = ppo_trainer.generate(query, **generation_kwargs).squeeze()
      #     response_len = len(query_response) - len(query)
      #     response_tensors.append(query_response[-response_len:])

      batch["response"] = [tokenizer.decode(r.squeeze(), skip_special_tokens=True) for r in response_tensors]

      # Clean the queries form the template inserted to guide the policy response
      clean_queries = []
      for q in batch["query"]:
          # Remove the "Parent:" and "Sarcastic reply:" parts
          # Split on "Sarcastic reply:" and take the parent comment portion
          if "Sarcastic reply:" in q:
              parent_text = q.split("Sarcastic reply:")[0]
              # Also strip the "Parent:" prefix and whitespace
              parent_text = parent_text.replace("Parent comment:", "").strip()
              clean_queries.append(parent_text)
          else:
              # Fallback if template not found
              clean_queries.append(q.strip())
      batch['query'] = clean_queries

      # process the batch through reward model pipe
      batch_inputs = [(q, r) for q, r in zip(batch['query'], batch['response'])]

      with torch.no_grad():
        rm_pipe_outputs = reward_model_pipe(batch_inputs, batch_size=len(batch_inputs))

      if DEBUG:
        print(f"Sample output from reward model: {rm_pipe_outputs[0]}")

      # TODO: We need to add more reward signals
      rewards = [
        torch.tensor(output["probabilities"]["sarcastic"])
        for output in rm_pipe_outputs
      ]


      #### Run PPO step
      stats = ppo_trainer.step(query_tensors, response_tensors, rewards)

      log_batch = {
        "query": batch["query"],
        "response": batch["response"],
      }

      ## This request wandb login
      if USE_WANDB:
        ppo_trainer.log_stats(stats, log_batch, rewards)
        # Add rows to the persistent table
        for q, r, rew in zip(batch["query"], batch["response"], rewards):
            all_samples_table.add_data(q, r, float(rew))

        # Log the growing table under a different key
        wandb.log({"all_samples": all_samples_table})


print('Training complete DON\'T FORGET TO SAVE THE MODEL')

Starting training...
Number of batches per epoch: 4182


Epoch: 1:   0%|          | 0/2 [00:00<?, ?it/s]
Batch: 1:   0%|          | 0/4182 [00:00<?, ?it/s][A
Batch: 1:   0%|          | 1/4182 [00:24<28:46:36, 24.78s/it][A
Batch: 1:   0%|          | 2/4182 [00:49<29:03:34, 25.03s/it][A
Batch: 1:   0%|          | 3/4182 [01:14<29:01:25, 25.00s/it][A
Batch: 1:   0%|          | 4/4182 [01:38<28:31:59, 24.59s/it][A
Batch: 1:   0%|          | 5/4182 [02:03<28:30:36, 24.57s/it][A
Batch: 1:   0%|          | 6/4182 [02:27<28:16:58, 24.38s/it][A
Batch: 1:   0%|          | 7/4182 [02:51<28:04:52, 24.21s/it][A
Batch: 1:   0%|          | 8/4182 [03:15<28:01:16, 24.17s/it][A
Batch: 1:   0%|          | 9/4182 [03:39<28:01:51, 24.18s/it][A
Batch: 1:   0%|          | 10/4182 [04:03<27:48:32, 24.00s/it][A
Batch: 1:   0%|          | 11/4182 [04:27<27:54:58, 24.09s/it][A
Batch: 1:   0%|          | 12/4182 [04:50<27:41:44, 23.91s/it][A
Batch: 1:   0%|          | 13/4182 [05:14<27:39:35, 23.88s/it][A
Batch: 1:   0%|          | 14/4182 [05:38<27:37:1

KeyboardInterrupt: 

In [None]:
## End the logging
if USE_WANDB:
  wandb.finish()

In [None]:
#### Save model
# In all cases save to colab
ppo_trainer.save_pretrained("/content/" + ppo_model_name)
print('Model saved to Colab - This goes away when you disconnect colab')

if USE_DRIVE:
  ppo_trainer.save_pretrained(drive_path + ppo_model_name)
  print('Model saved to drive')

if USE_HUGGINGFACE:
  # Making repo if required
  api = HfApi()
  repo_id = fh_username + "/GenerallySarcasticTransformer"
  rev_id = ppo_model_name
  api.create_repo(repo_id=repo_id, exist_ok=True)
  api.create_branch(
        repo_id=repo_id,
        branch=ppo_model_name,
        repo_type="model",
        exist_ok=True
    )

  # Upload the saved files to the repo
  api.upload_folder(
      folder_path="/content/" + ppo_model_name,
      repo_id=repo_id,
      repo_type="model",
      revision=ppo_model_name,
    )
  print('Model saved to hugging face')

### Sanity check manual training run

This runs through the steps of the training loop one at a time for a sanity check. Only intended for debugging

In [None]:
#Sanity check that PPO dataloader has all the items of our dataset

first_batch = next(iter(ppo_trainer.dataloader))
print("Items per batch:", len(first_batch["input_ids"]))
print(f"Number of batches: {len(ppo_trainer.dataloader)}")
print("First input_ids:", first_batch["input_ids"][0])

#Confirm the dataloader contains as many items as dataset
# assert len(dataset) == len(ppo_trainer.dataloader.dataset)

In [None]:
epoch, batch = next(enumerate(ppo_trainer.dataloader))

In [None]:
query_tensors = batch['input_ids']
attention_masks = batch['attention_mask']

In [None]:
len(query_tensors)

In [None]:
print(query_tensors[0])
print(attention_masks[0])

In [None]:
response_tensors = []
response_tensors_slice = []
for query, mask in tqdm(zip(query_tensors, attention_masks)):
      query_response = ppo_trainer.generate(
          query,
          attention_mask=mask.unsqueeze(0),
          **generation_kwargs
      ).squeeze()
      response_len = len(query_response) - len(query)
      response_tensors.append(query_response[-response_len:])

In [None]:
# Last query response
query_response

In [None]:
# Last query
query_tensors[31]

In [None]:
# Last query response - the query
print(response_tensors[31])

In [None]:
batch["response"] = [tokenizer.decode(r.squeeze(), skip_special_tokens=True) for r in response_tensors]

In [None]:
type(batch['query'][0])

In [None]:
batch["response"][31]

In [None]:
tokenizer.decode(query_response.squeeze(), skip_special_tokens=True)

In [None]:
clean_queries = []
for q in batch["query"]:
    # Remove the "Parent:" and "Sarcastic reply:" parts
    # Split on "Sarcastic reply:" and take the parent comment portion
    if "Sarcastic reply:" in q:
        parent_text = q.split("Sarcastic reply:")[0]
        # Also strip the "Parent:" prefix and whitespace
        parent_text = parent_text.replace("Parent:", "").strip()
        clean_queries.append(parent_text)
    else:
        # Fallback if template not found
        clean_queries.append(q.strip())

batch['query'] = clean_queries

In [None]:
batch['query'][31]

In [None]:
batch_inputs = [(q, r) for q, r in zip(batch['query'], batch['response'])]

In [None]:
with torch.no_grad():
      rm_pipe_outputs = reward_model_pipe(batch_inputs, batch_size=min(len(batch_inputs), 8))

In [None]:
rm_pipe_outputs[0]

In [None]:
rewards = [
      torch.tensor(output["probabilities"]["sarcastic"])
      for output in rm_pipe_outputs
    ]

In [None]:
# Average reward
sum(rewards) / len(rewards)

In [None]:
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)

# Scratchpad

In [None]:
from dataclasses import fields
print([f.name for f in fields(PPOConfig)])
print("eval_steps" in [f.name for f in fields(PPOConfig)])

In [None]:
# model set up
# (PPO requires a model with a value head)
# PPO also requires a reference model, but this model is generated by the PPOTrainer automatically
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained('gpt2', padding_side='left')
tokenizer.pad_token = tokenizer.eos_token


In [None]:
# load training data

# load the IMDb dataset
# TODO: replace this with our own training data
# imdb_dataset = load_dataset('imdb')
sarcasm_train_dataset = load_dataset("marcbishara/sarcasm-on-reddit")['ppo_train']

# use a subset of dataset for the POC so it doesn't run for hours
# taking the first 200 examples for demonstration
dataset = sarcasm_train_dataset.select(range(200))


In [None]:
# tokenize the dataset
dataset = dataset.map(tokenize, batched=False)

# cast input_ids as torch tensors
dataset.set_format(type='torch', columns=['input_ids'])

In [None]:

# def tokenize(sample):
#     tokenized_output = tokenizer(
#         sample['text'],
#         truncation=True,
#         max_length=128,
#         padding='max_length')

#     ids = tokenized_output['input_ids']
#     sample['input_ids'] = ids

#     # decode back to string for use in the reward score function
#     sample['query'] = tokenizer.decode(ids, skip_special_tokens=True)

#     return sample

def tokenize(sample):
    sample['input_ids'] = tokenizer.encode(sample['text'], max_length=128, truncation=True)
    # sample['query'] = tokenizer.decode(sample['input_ids'], skip_special_tokens=True) # let's just do this later in the training loop -- seems to get dropped by the trainer?
    return sample


In [None]:
# custom reward function
# CURRENTLY REPLACED BY DIRECT CALL WITHIN THE TRAINING LOOP

def get_reward_score(query_text, response_text):
    # TODO: replace this with our weighted sum reward score from multiple reward signals
    # based on the query_text and response_text parameters

    # print query and respone
    # print(f"Query: {query_text} | Response: {response_text}")

    # currently, just randomly 0 or 1
    score = float(random.randint(0, 1))

    return score



In [None]:
# revised PPO training loop

print("Starting training...")
print(f"Number of batches per epoch: {len(ppo_trainer.dataloader)}")

for epoch in tqdm(range(epochs), 'epoch: '):
    for i, batch in tqdm(enumerate(ppo_trainer.dataloader)):

        # get query_tensors as tensors
        query_tensors = batch['input_ids']

        # reconstruct 'query' from input_ids, since might have been removed???
        batch['query'] = [tokenizer.decode(q_t, skip_special_tokens=True) for q_t in query_tensors]

        # print('batch["query"]: ', batch["query"])

        #### Get response from SFTModel
        response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
        batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

        # calculate rewards - replaced with code below to call sarcasm model
        # rewards = []
        # for q, r in zip(batch['query'], batch['response']):
        #     score = get_reward_score(q, r)
        #     rewards.append(torch.tensor(score))

        # calculate rewards with the sarcasm reward model
        queries = batch['query']
        responses = batch['response']

        # generate separator token
        sep_token = sarcasm_model.tokenizer.sep_token

        # combine queries and responses seprated by token into a single list of "query [SEP] response"
        batch_inputs = [f"{q} {sep_token} {r}" for q, r in zip(queries, responses)]

        # process the batch
        pipe_outputs = sarcasm_model(batch_inputs, batch_size=len(batch_inputs), truncation=True)

        # process the results
        rewards = []

        for output in pipe_outputs:

          # extract the score
          sarcasm_score = output['score']

          # TODO: add other reward signals -- just placeholder here
          # other_score = float(random.randint(0, 1))
          other_score = 0

          # combine score -- TODO: weighted sum? NORMALIZE the score!
          score = sarcasm_score + other_score

          # append
          rewards.append(torch.tensor(score))

        #### Run PPO step
        stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
        ppo_trainer.log_stats(stats, batch, rewards)

        # logging code
        if i % LOG_INTERVAL == 0:
            # clculate mean reward for this batch
            print(f"Step {i}: Mean Reward from PPO stats: {stats['ppo/mean_scores']:.4f}")
            print(f"        PPO Loss:    {stats['ppo/loss/total']:.4f}")

print('Training complete')

#### Save model
ppo_trainer.save_pretrained(drive_path)

print('Model saved')