# PPO Trainer for the Generally Sarcastic Transformer

## Packages

In [None]:

# uninstalls/installs for deprecated version of TRL

# remove earlier version of trl
!pip uninstall trl -y

# clear cache
!pip cache remove trl

# install older version of trl that allows for custom reward score (vs incorporating the reward model in the workflow)
# !pip install trl==0.11.4 --no-cache-dir --force-reinstall

# NOTE: v0.8.6 and v0.11.4 both seem to run on similar architecture
# but v0.11.4 throws more errors, trying to push users to PPOv2
# so for simlicity/stability, v0.8.6 may be preferred

!pip install trl==0.11.4
# !pip install trl==0.8.6



[0mFiles removed: 0
Collecting trl==0.11.4
  Downloading trl-0.11.4-py3-none-any.whl.metadata (12 kB)
Collecting tyro>=0.5.11 (from trl==0.11.4)
  Downloading tyro-0.9.35-py3-none-any.whl.metadata (12 kB)
Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl==0.11.4)
  Downloading shtab-1.8.0-py3-none-any.whl.metadata (7.3 kB)
Downloading trl-0.11.4-py3-none-any.whl (316 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m316.6/316.6 kB[0m [31m24.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tyro-0.9.35-py3-none-any.whl (132 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.6/132.6 kB[0m [31m14.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading shtab-1.8.0-py3-none-any.whl (14 kB)
Installing collected packages: shtab, tyro, trl
Successfully installed shtab-1.8.0 trl-0.11.4 tyro-0.9.35


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from datasets import Dataset
from datasets import load_dataset

import trl
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from transformers import AutoTokenizer, pipeline, Pipeline, AutoModelForSequenceClassification

import random
import os
import gdown

from tqdm import tqdm
import gc

from google.colab import userdata

import re
from collections import Counter

In [None]:
# confirm TRL install
print('TRL Version:', trl.__version__)
assert trl.__version__ in ('0.11.4','0.8.6')

TRL Version: 0.11.4


## Config

### Logins

In [None]:
USE_DRIVE = False      # To save the model after training
USE_HUGGINGFACE = True # To save the model after training
USE_WANDB = True

In [None]:
# mount google drive - specifically to save trained ppo model to
from google.colab import drive
if USE_DRIVE:
  drive.mount('/content/drive')
  drive_path = '/content/drive/MyDrive/'

In [None]:
# Hugging face login
from huggingface_hub import login
from huggingface_hub import HfApi
if USE_HUGGINGFACE:
  fh_username = "marcbishara"
  login(token=userdata.get('HF_TOKEN'))

  print(f"**************************\nUsing user: \"{fh_username}\" REPLACE WITH YOUR OWN")

**************************
Using user: "marcbishara" REPLACE WITH YOUR OWN


In [None]:
# wandb configuration
import wandb
# if USE_WANDB:
#   wandb.init()

In [None]:
ppo_run_name = "placehoder_ppo_run_name"

# initialize PPOConfig
if USE_WANDB:
  log_with = "wandb"
else:
  log_with = None

config = PPOConfig(
    model_name=  'Zoe3324/gpt2-sft-full', #'openai-community/gpt2',
    learning_rate=1.41e-5,
    log_with=log_with,
    reward_model = 'tmrcnl/SarcasmRewardModel', #marcbishara/SarcasmRewardModel',
    batch_size=128,
    ppo_epochs=2,
    steps=10000,                  # Default is 20000
    mini_batch_size=32,           # Default is 128
    gradient_accumulation_steps=1, # Default is 1
    tracker_kwargs={
      "wandb": {
          "entity": "tmrcnl-university-of-toronto",   # replace with your WandB entity/team
          # "project": "trl",  # replace with your WandB project
          "name": ppo_run_name           # use the variable here
      }
    }
)





### Dataset

In [None]:
def build_sarcasm_dataset(
    config,
    dataset_name="marcbishara/sarcasm-on-reddit",
    split_name="ppo_train",
    min_text_length=10,
    num_of_rows=None
):

    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'left' # initialize tokenizer with left padding

    ds = load_dataset(dataset_name, split=split_name)

    # Filter out short comments
    ds = ds.filter(lambda x: len(x["parent_comment"]) >= min_text_length)

    # Limit by number of rows if provided
    if num_of_rows is not None:
        ds = ds.select(range(num_of_rows))

    # batch tokenize function
    def tokenize(samples):
        # create a list of templated strings
        templated_queries = [
            f"<PARENT> {parent} </PARENT>\n<RESPONSE>"
            for parent in samples['parent_comment']
        ]

        # tokenize the whole list at once
        enc = tokenizer(
            templated_queries,
            truncation=True,
            max_length=128,
            padding='max_length',
            return_attention_mask=True
        )

        samples["input_ids"] = enc["input_ids"]
        samples["attention_mask"] = enc["attention_mask"]
        # use batch_decode for speed
        samples["query"] = tokenizer.batch_decode(enc["input_ids"])

        return samples

    # Apply tokenization
    ds = ds.map(tokenize, batched=True)

    # Convert to torch tensors
    ds.set_format(type="torch")
    # ds.set_format(type="torch", columns=["input_ids", "attention_mask", "query"])

    return ds

In [None]:
dataset = build_sarcasm_dataset(config, num_of_rows=10000) #If you don't want to run the full dataset, limit the number of rows

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [None]:
# Smoke test on the dataset
sarcastic_lbls_cnt = dataset.filter(lambda x: x["label"] == 1).num_rows
print(f"Dataset length: {len(dataset)} with {round(sarcastic_lbls_cnt / len(dataset) * 100, 2)}% sarcastic comments")
print("Sample entry:")
print(dataset[15])

Filter:   0%|          | 0/1000 [00:00<?, ? examples/s]

Dataset length: 1000 with 49.4% sarcastic comments
Sample entry:
{'label': tensor(1), 'comment': 'How dare they try to make a profit, for shame!', 'author': 'Thenuclearwalrus', 'subreddit': 'wow', 'score': tensor(1), 'ups': tensor(-1), 'downs': tensor(-1), 'date': '2016-11', 'created_utc': '2016-11-15 12:53:43', 'parent_comment': 'The restriction is pointless and only serves to milk extra game time from you', 'input_ids': tensor([50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256,

In [None]:
# use lambda collator to ensure 'input_ids' are stacked correctly
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

### Models

In [None]:
# Model loaded twice, the first will be updated on policy and the second is used to calculate KL divergence

model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token



#### Sarcasm RM

In [None]:
class SarcasmRMPipeline(Pipeline):
    def __init__(self, model, tokenizer):
        super().__init__(model=model, tokenizer=tokenizer)

    def _sanitize_parameters(self, **kwargs):
        return {}, {}, {}

    def preprocess(self, inputs):
      # Expect inputs as a tuple (parent_comment, comment)
      if isinstance(inputs, tuple) and len(inputs) == 2:
          parent, reply = inputs
          return self.tokenizer(
              parent,
              reply,
              return_tensors="pt",
              truncation=True,
              padding=True,
              max_length=128
          )
      else:
        raise ValueError("Inputs must be a tuple of two strings: (parent_comment, comment)")



    def _forward(self, model_inputs):
        # Move inputs to the same device as the model
        model_inputs = {k: v.to(self.model.device) for k, v in model_inputs.items()}
        return self.model(**model_inputs)

    def postprocess(self, model_outputs):
        # Convert logits to probabilities
        probs = model_outputs.logits.softmax(dim=-1).detach().cpu().numpy()[0]
        # 0 = non-sarcasm, index 1 = sarcasm
        labels = ["not_sarcastic", "sarcastic"]
        return {
            "label": labels[probs.argmax()],
            "score": float(probs.max()),
            "probabilities": {labels[i]: float(probs[i]) for i in range(len(labels))}
        }

In [None]:
# sarcasm reward model
rm_tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
reward_model = AutoModelForSequenceClassification.from_pretrained(config.reward_model)
reward_model_pipe = SarcasmRMPipeline(model=reward_model, tokenizer=rm_tokenizer)

Device set to use cuda:0


In [None]:
# Smoke test the reward model

text1 = dataset[15]["parent_comment"]
text2 = dataset[15]['comment']
rm_output = reward_model_pipe((text1, text2))
print(f"Feeding: {text1}, {text2} into reward model and getting back:\n{rm_output}\nTrue label is {dataset[15]['label']}")

Feeding: The restriction is pointless and only serves to milk extra game time from you, How dare they try to make a profit, for shame! into reward model and getting back:
{'label': 'sarcastic', 'score': 0.9908868074417114, 'probabilities': {'not_sarcastic': 0.009113193489611149, 'sarcastic': 0.9908868074417114}}
True label is 1


#### Objectivity RM

In [None]:
# Objectivity Reward Signal
class objectivity_classifier(torch.nn.Module):
    def __init__(self, embeddings, k1, k2, n1, n2):
        super().__init__()

        embedding_dim = len(embeddings[0])
        self.embeddings = nn.Embedding.from_pretrained(embeddings, freeze=True)

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=n1, kernel_size=(k1, embedding_dim), bias=False)
        self.conv2 = nn.Conv2d(in_channels=1, out_channels=n2, kernel_size=(k2, embedding_dim), bias=False)
        self.fc = nn.Linear(n1 + n2, 1)

    def forward(self, x):
        embeddings = self.embeddings(x).unsqueeze(1) # (batch, 1, num_words, em_dim)
        # CNN - parameter: (batch, channel, height, width)
        k1_out = F.relu(self.conv1(embeddings)) # (batch, n1, L, 1)
        k2_out = F.relu(self.conv2(embeddings)) # (batch, n2, L, 1)
        # Max pooling
        k1_out = F.max_pool2d(k1_out, (k1_out.shape[2], 1)) # (batch, n1, 1, 1)
        k2_out = F.max_pool2d(k2_out, (k2_out.shape[2], 1)) # (batch, n2, 1, 1)
        # Organize
        k1_out = k1_out.squeeze(3).squeeze(2) # (batch, n1)
        k2_out = k2_out.squeeze(3).squeeze(2) # (batch, n2)
        # fc
        out = torch.cat([k1_out, k2_out], dim=1)
        out = self.fc(out)

        return out

def load_glove_vectors(glove_path, vocab_size=None):

    print(f"Loading GloVe vectors from {glove_path}...")

    word2idx = {}
    idx2word = []
    vectors = []

    with open(glove_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(tqdm(f)):
            if vocab_size and i >= vocab_size:
                break

            values = line.strip().split()
            word = values[0]
            vector = np.array(values[1:], dtype='float32')

            word2idx[word] = i
            idx2word.append(word)
            vectors.append(vector)

    embeddings = torch.from_numpy(np.array(vectors))

    print(f"Loaded {len(word2idx)} words with dimension {embeddings.shape[1]}")

    return word2idx, idx2word, embeddings

# This function is only to be run once for the trianed model parameters.
def download_objectivityRM():
    MODEL_DIR = "objectivity_signal"
    os.makedirs(MODEL_DIR, exist_ok=True)

    GLOVE_FILE_ID = "1ufQLwedjFzjmRej-Qfp0MyM2iOH6oP9U"
    MODEL_FILE_ID = "1EGvEGgZwJVJLBWjQcGWCfYV-kp9QrZcv"

    MODEL_PATH = os.path.join(MODEL_DIR, "model_CNN_objectivity_classifier.pt")
    if not os.path.exists(MODEL_PATH):
        url = f"https://drive.google.com/uc?id={MODEL_FILE_ID}"
        print("Downloading model...")
        gdown.download(url, MODEL_PATH, quiet=False)
    else:
        print("Model already exists!")

    GLOVE_PATH = os.path.join(MODEL_DIR, "glove.6B.100d.txt")
    if not os.path.exists(GLOVE_PATH):
        url = f"https://drive.google.com/uc?id={GLOVE_FILE_ID}"
        print("Downloading GloVe...")
        gdown.download(url, GLOVE_PATH, quiet=False)
    else:
        print("GloVe already exists!")

# Download objectivity signal model
download_objectivityRM()
# Load models and other dependencies.
embeddings_path = "./objectivity_signal/glove.6B.100d.txt"
word2idx, idx2word, embeddings = load_glove_vectors(embeddings_path)
model_CNN = objectivity_classifier(embeddings, k1=2, k2=4, n1=100, n2=100)
model_path = "./objectivity_signal/model_CNN_objectivity_classifier.pt"
model_CNN.load_state_dict(torch.load(model_path))

def objectivity_reward(sentence):
    sentence = re.sub(r'[^\w\s]', '', sentence)
    V = len(word2idx)

    tokens = torch.tensor(
        [word2idx.get(word, V-1) for word in sentence.lower().split()] + [0]*4,
        dtype=torch.long
    ).unsqueeze(0)

    prob = torch.sigmoid(model_CNN(tokens)).squeeze(0).squeeze(0) # This is a Tensor. e.g. tensor(0.9336, grad_fn=<SqueezeBackward1>)

    reward = round(prob.item(), 4) # Keep 4 decimal places

    return reward

In [None]:
# Verify the signal works
sentence = "I feel happy"
prob = objectivity_reward(sentence)
print(prob)

#### Repetition RM

In [None]:
# repetition penalty reward signal

def repetition_penalty(text, max_repetition_ratio=0.2):
    tokens = re.findall(r"\w+", text.lower())
    if not tokens:
        return 0.0
    counts = Counter(tokens)
    repeated = 0
    for count in counts.values():
        if count > 1:
            repeated += count - 1
    repetition_ratio = repeated / len(tokens)

    # returns a negative value if ratio exceeds max_repetition_ratio
    return -max(0, repetition_ratio - max_repetition_ratio)

# length penatly reward signal

def length_penalty(text, min_len=5, max_len=100):
    tokens = re.findall(r"\w+", text)
    length = len(tokens)
    if length == 0:
        return -1.0
    if length < min_len:

        # returns negative value proportional to how short it is
        return -((min_len - length) / min_len)
    if length > max_len:

        # returns negative value proportional to how long it is
        return -((length - max_len) / max_len)
    return 0.0

### Trainer config

In [None]:
# PPO Trainer in next cell will overwrite this and force the default

# if USE_WANDB:
#   # wandb.init(project="ppo-training", name=ppo_run_name) # PPO_Trainer hijacks wandb and forces the project name and run name
#   wandb.init()

In [None]:
# initialize PPOTrainer
ppo_trainer = PPOTrainer(
    model=model,
    ref_model=ref_model,
    config=config,
    dataset=dataset,
    tokenizer=tokenizer,
    data_collator=collator
)



In [None]:
                  # str(config.steps) + "Stp" + "_" + \
                  # str(config.mini_batch_size) + "mbs" + "_" + \

ppo_run_name =  (config.model_name).split('/')[-1] + "_" + \
                  str(config.ppo_epochs) + "Eps" + "_" + \
                  str(config.batch_size) + "bs" + "_" + \
                  str(config.learning_rate).replace('.','_') + "lr" + \
                  str(len(dataset)) + "dsz" + '_' + \
                  "sarc-rm"

print(f"Run name: {ppo_run_name}")

Run name: gpt2-sft-full_2Eps_128bs_1_41e-05lr10000dsz_sarc-rm


In [None]:
if USE_WANDB:
  wandb.run.name = ppo_run_name

In [None]:
device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug
print(f"Training on device: {device}")

# see https://huggingface.co/docs/trl/v0.8.6/ppo_trainer
generation_kwargs = {
    'min_length': -1, # don't ignore the EOS token
    'top_k': 0.0, # no top-k sampling
    'top_p': 1.0, # no nucleus sampling
    'do_sample': True, # yes, we want to sample
    'pad_token_id': tokenizer.eos_token_id, # most decoder models don't have a padding token - use EOS token instead
    'max_new_tokens': 32, # specify how many tokens you want to generate at most
}

# define how often to print
LOG_INTERVAL = 1

DEBUG = True

EPOCHS = 2


Training on device: 0


## Training

In [None]:
# Clear GPU RAM
if torch.cuda.is_available():
  torch.cuda.empty_cache()

# Garbage collection
gc.collect()

10892

### Training loop

In [None]:
# revised PPO training loop

print("Starting training...")
print(f"Number of batches per epoch: {len(ppo_trainer.dataloader)}")

epoch = 0
i = 0

template_strs = {"</PARENT>\n<RESPONSE>", "<PARENT>", "</RESPONSE>"}
special_ids = torch.tensor(tokenizer.all_special_ids, device=device)

if USE_WANDB:
    all_samples_table = wandb.Table(columns=["query", "response", "reward"], log_mode="MUTABLE")

for epoch in tqdm(range(EPOCHS), desc=f'Epoch: {epoch+1}'):
  for i, batch in enumerate(tqdm(ppo_trainer.dataloader, desc=f'Batch: {i+1}')):

      '''
      # convert tensors to lists of integers first to ensure tokenizer.pad
      # handles them without type error
      input_ids_list = [t.tolist() for t in batch['input_ids']]
      attention_mask_list = [t.tolist() for t in batch['attention_mask']]

      # pad into a single 2D batch tensor
      padded_inputs = tokenizer.pad(
          {"input_ids": input_ids_list, "attention_mask": attention_mask_list},
          padding=True,
          return_tensors="pt"
      ).to(device)
      '''

      stacked_input_ids = torch.stack(batch['input_ids']).to(device)
      stacked_attention_masks = torch.stack(batch['attention_mask']).to(device)

      # batch generation
      with torch.no_grad():
        # generate all sequences at once
        generated_batch = ppo_trainer.model.generate(
          input_ids=stacked_input_ids,
          attention_mask=stacked_attention_masks,
          **generation_kwargs
        )


      # Extract the query and response both encoded and decoded and cleaned up of template and of special tokens
      query_tensors = []
      response_tensors = []
      decoded_queries = []
      decoded_responses = []

      for i in range(generated_batch.size(0)):
          full_seq = generated_batch[i]

          # Length of the original query (from inputs)
          q_len = len(batch['input_ids'][i])

          # Slice out query and response
          query_ids = full_seq[:q_len]
          response_ids = full_seq[q_len:]

          # Remove special tokens (by id)
          query_ids = query_ids[~torch.isin(query_ids, special_ids)]
          response_ids = response_ids[~torch.isin(response_ids, special_ids)]



          # Decode and clean strings
          q_str = tokenizer.decode(query_ids, skip_special_tokens=True)
          r_str = tokenizer.decode(response_ids, skip_special_tokens=True)

          for s in template_strs:
              q_str = q_str.replace(s, "")
              r_str = r_str.replace(s, "")

          if len(response_ids) > 0 and r_str:
            query_tensors.append(query_ids)
            response_tensors.append(response_ids)
            decoded_queries.append(q_str.strip())
            decoded_responses.append(r_str.strip())



      # process the batch through reward model pipe
      sarcasm_rm_inputs = [(q, r) for q, r in zip(decoded_queries, decoded_responses)]

      with torch.no_grad():
        rm_pipe_outputs = reward_model_pipe(sarcasm_rm_inputs, batch_size=len(sarcasm_rm_inputs))


      '''
      Focusing on sarcasm reward signal only for now

      # single reward from the sarcasm model
      # rewards = [
      #   torch.tensor(output["probabilities"]["sarcastic"])
      #   for output in rm_pipe_outputs
      # ]

      # obtain and combine all reward signals
      rewards = []
      # Store individual components for logging
      sarcasm_scores = []
      rep_penalties = []
      len_penalties = []

      for k, output in enumerate(rm_pipe_outputs):
        # main signal: sarcasm orobability
        s_score = output["probabilities"]["sarcastic"]

        # objectivity signal
        o_score = objectivity_reward(batch["response"][k])
        sub_score = 1 - o_score

        # auxiliary signal: repetition penalty (returns <= 0)
        r_pen = repetition_penalty(batch["response"][k])

        # auxiliary signal: length penalty (returns <= 0)
        l_pen = length_penalty(batch["response"][k])

        # combined reward
        # TODO: add weights here? e.g., s_score + 0.5 * r_pen + 0.5 * l_pen
        total_reward = s_score + sub_score + r_pen + (0.5 * l_pen)

        rewards.append(torch.tensor(total_reward))

        # log
        sarcasm_scores.append(s_score)
        subjectivity_scores.append(sub_score)
        rep_penalties.append(r_pen)
        len_penalties.append(l_pen)

      if DEBUG:
        print(f"Sample rewards: Total={rewards[0]:.3f} (Sarcasm={sarcasm_scores[0]:.3f}, Subjectivity={subjectivity_scores[0]:.3f}, Rep={rep_penalties[0]:.3f}, Len={len_penalties[0]:.3f})")


      # remove padding before passing to ppo_trainer step
      clean_query_tensors = []
      for tensor, mask in zip(batch['input_ids'], batch['attention_mask']):
          # filter the tensor using the attention mask
          # mask is 1 for real text, 0 for padding
          clean_query_tensors.append(tensor[mask.bool()])

      '''

      rewards = [
        torch.tensor(output["probabilities"]["sarcastic"])
        for output in rm_pipe_outputs
      ]

      #### Run PPO step
      stats = ppo_trainer.step(query_tensors, response_tensors, rewards)

      # Will log only 10 entries per batch to keep log ammounts sane
      log_batch = {
        "query": decoded_queries[:10],
        "response": decoded_responses[:10],
      }

      ## This request wandb login
      if USE_WANDB:
        ppo_trainer.log_stats(stats, log_batch, rewards)
        # Add rows to the persistent table
        for q, r, rew in zip(log_batch["query"], log_batch["response"], rewards):
            all_samples_table.add_data(q, r, float(rew))

        # Log the growing table under a different key
        wandb.log({"all_samples": all_samples_table})


print('Training complete DON\'T FORGET TO SAVE THE MODEL')

Starting training...
Number of batches per epoch: 78


Epoch: 1:   0%|          | 0/2 [00:00<?, ?it/s]
Batch: 1:   0%|          | 0/78 [00:00<?, ?it/s][A
Batch: 1:   1%|▏         | 1/78 [00:11<14:46, 11.51s/it][A
Batch: 1:   3%|▎         | 2/78 [00:23<14:35, 11.52s/it][A
Batch: 1:   4%|▍         | 3/78 [00:34<14:19, 11.46s/it][A
Batch: 1:   5%|▌         | 4/78 [00:45<14:08, 11.47s/it][A
Batch: 1:   6%|▋         | 5/78 [00:57<13:57, 11.47s/it][A
Batch: 1:   8%|▊         | 6/78 [01:08<13:42, 11.42s/it][A
Batch: 1:   9%|▉         | 7/78 [01:17<12:40, 10.71s/it][A
Batch: 1:  10%|█         | 8/78 [01:29<12:46, 10.95s/it][A
Batch: 1:  12%|█▏        | 9/78 [01:40<12:44, 11.08s/it][A
Batch: 1:  13%|█▎        | 10/78 [01:52<12:41, 11.20s/it][A
Batch: 1:  14%|█▍        | 11/78 [02:03<12:36, 11.29s/it][A
Batch: 1:  15%|█▌        | 12/78 [02:15<12:27, 11.32s/it][A
Batch: 1:  17%|█▋        | 13/78 [02:26<12:18, 11.36s/it][A
Batch: 1:  18%|█▊        | 14/78 [02:37<12:07, 11.36s/it][A
Batch: 1:  19%|█▉        | 15/78 [02:49<11:56, 11.37s/i

In [None]:
## End the logging
if USE_WANDB:
  wandb.finish()

0,1
env/reward_mean,▄▁▃▄▃▄▃▃▃▃▄▂█▇▃▆▇▅▅▄▅▄▃▆▄▇▅▇▆█
env/reward_std,▅▄▆▂▅▄▅▆▇▄█▆▃▁▅▅▅▅▇▄▇▇▆▃▅▃▃█▅▅
objective/entropy,▄▇▅▆▄▇█▇▅▇▇▅▄▅▃▅▃▄▄▃▄▆▃▆▆▄▂▃▁▂
objective/kl,▁▁▂▂▂▄▃▄▆▇▄▄▅▆▅▆▄▃▅▄▅▇▅▇▆█▇▆▆█
objective/kl_coef,███▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▁▁▁
ppo/learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
ppo/loss/policy,▁▃▅▅▅▅▆▇▆▅▆▅▅▅▆▆▃▅▅▄▅▄█▄▄▅▅▅▅▅
ppo/loss/total,█▄▆▅▅▆▅▆▅▅▄▄▃▃▅▄▁▃▃▁▂▂▄▂▁▂▁▂▂▂
ppo/loss/value,█▄▄▃▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▂▁▁▁▂▁▁▁▁
ppo/mean_non_score_reward,██▇▇▇▆▆▆▃▃▅▅▄▃▄▄▅▆▄▅▄▃▄▃▃▁▂▃▃▁

0,1
env/reward_mean,0.72403
env/reward_std,0.27528
objective/entropy,48.2224
objective/kl,1.16755
objective/kl_coef,0.19271
ppo/learning_rate,1e-05
ppo/loss/policy,-0.01455
ppo/loss/total,-0.00405
ppo/loss/value,0.10492
ppo/mean_non_score_reward,-0.01298


In [None]:
#### Save model
# In all cases save to colab
ppo_trainer.save_pretrained("/content/" + ppo_run_name)
print('Model saved to Colab - This goes away when you disconnect colab')

if USE_DRIVE:
  ppo_trainer.save_pretrained(drive_path + ppo_run_name)
  print('Model saved to drive')

if USE_HUGGINGFACE:
  # Making repo if required
  api = HfApi()
  repo_id = fh_username + "/GenerallySarcasticTransformer"
  rev_id = ppo_run_name
  api.create_repo(repo_id=repo_id, exist_ok=True)
  api.create_branch(
        repo_id=repo_id,
        branch=ppo_run_name,
        repo_type="model",
        exist_ok=True
    )

  # Upload the saved files to the repo
  api.upload_folder(
      folder_path="/content/" + ppo_run_name,
      repo_id=repo_id,
      repo_type="model",
      revision=ppo_run_name,
    )
  print('Model saved to hugging face')

Model saved to Colab - This goes away when you disconnect colab


Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...1e-05lr/model.safetensors:   0%|          |  549kB /  498MB            

Model saved to hugging face


### Sanity check manual training run

This runs through the steps of the training loop one at a time for a sanity check. Only intended for debugging

In [None]:
all_samples_table = wandb.Table(columns=["query", "response", "reward"], log_mode="MUTABLE")

In [None]:
#Sanity check that PPO dataloader has all the items of our dataset

first_batch = next(iter(ppo_trainer.dataloader))
print("Items per batch:", len(first_batch["input_ids"]))
print(f"Number of batches: {len(ppo_trainer.dataloader)}")
print("First input_ids:", first_batch["input_ids"][0])

#Confirm the dataloader contains as many items as dataset
# assert len(dataset) == len(ppo_trainer.dataloader.dataset)

Items per batch: 64
Number of batches: 15
First input_ids: tensor([50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256,    27, 27082,  3525,    29,   347,  1436,
        14662,  2921,   502,   257,  3555,   286,  6640, 17655,  1231, 47105,
           78,    11, 18523,   351, 47105,    78,    13,  7359, 27082,  3525,
     

In [None]:
epoch, batch = next(enumerate(ppo_trainer.dataloader))

In [None]:
query_tensors = batch['input_ids']
attention_masks = batch['attention_mask']

In [None]:
len(query_tensors)

64

In [None]:
print(query_tensors[0])
print(attention_masks[0])

tensor([50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
           27, 27082,  3525,    29,   632,   338,   780,  1466,   389,  7360,
          262,  5290,  8109,   319,   428,  5440,    13,  7359, 27082,  3525,
           29,   198,    27, 19535,    47,  1340,  5188,    29],

In [None]:
input_ids_list = [t.tolist() for t in batch['input_ids']]
attention_mask_list = [t.tolist() for t in batch['attention_mask']]

In [None]:
stacked_input_ids = torch.stack(batch['input_ids']).to(device)
stacked_attention_masks = torch.stack(batch['attention_mask']).to(device)

# batch generation
with torch.no_grad():
  # generate all sequences at once
  generated_batch = ppo_trainer.model.generate(
    input_ids=stacked_input_ids,
    attention_mask=stacked_attention_masks,
    **generation_kwargs
  )

In [None]:
len(batch['attention_mask'])#[0].shape

64

In [None]:
generated_batch.shape

torch.Size([64, 160])

In [None]:
tokenizer.decode(generated_batch[1].squeeze(), skip_special_tokens=True)

'<PARENT> Girl gains 20 pounds... boy gains **50 pounds**... boy breaks up because no longer attracted to girl and he deserves better. Da fuq... </PARENT>\n<RESPONSE> Shame on you Best non-slut non-friend person in the internet. </RESPONSE>'

In [None]:
template_strs = {"</PARENT>\n<RESPONSE>", "<PARENT>", "</RESPONSE>"}
special_ids = torch.tensor(tokenizer.all_special_ids, device=generated_batch.device)


query_tensors = []
response_tensors = []
decoded_queries = []
decoded_responses = []

for i in range(generated_batch.size(0)):
    full_seq = generated_batch[i]

    # Length of the original query (from inputs)
    q_len = len(batch['input_ids'][i])

    # Slice out query and response
    query_ids = full_seq[:q_len]
    response_ids = full_seq[q_len:]

    # Remove special tokens (by id)
    query_ids = query_ids[~torch.isin(query_ids, special_ids)]
    response_ids = response_ids[~torch.isin(response_ids, special_ids)]

    query_tensors.append(query_ids)
    response_tensors.append(response_ids)

    # Decode and clean strings
    q_str = tokenizer.decode(query_ids, skip_special_tokens=True)
    r_str = tokenizer.decode(response_ids, skip_special_tokens=True)

    for s in template_strs:
        q_str = q_str.replace(s, "")
        r_str = r_str.replace(s, "")

    decoded_queries.append(q_str.strip())
    decoded_responses.append(r_str.strip())


In [None]:
print(query_tensors[1])
print(response_tensors[1])
print(decoded_queries[1])
print(decoded_responses[1])

tensor([   27, 27082,  3525,    29,  7430,  8810,  1160,  8059,   986,  2933,
         8810, 12429,  1120,  8059,  1174,   986,  2933,  9457,   510,   780,
          645,  2392, 12725,   284,  2576,   290,   339, 14071,  1365,    13,
         9637, 14035,    80,   986,  7359, 27082,  3525,    29,   198,    27,
        19535,    47,  1340,  5188,    29], device='cuda:0')
tensor([48266,   319,   345,  6705,  1729,    12,  6649,   315,  1729,    12,
         6726,  1048,   287,   262,  5230,    13,  7359, 19535,    47,  1340,
         5188,    29], device='cuda:0')
Girl gains 20 pounds... boy gains **50 pounds**... boy breaks up because no longer attracted to girl and he deserves better. Da fuq...
Shame on you Best non-slut non-friend person in the internet.


In [None]:
# process the batch through reward model pipe
sarcasm_rm_inputs = [(q, r) for q, r in zip(decoded_queries, decoded_responses)]

with torch.no_grad():
  rm_pipe_outputs = reward_model_pipe(sarcasm_rm_inputs, batch_size=len(sarcasm_rm_inputs))

In [None]:
rm_pipe_outputs[0]

{'label': 'sarcastic',
 'score': 0.9723248481750488,
 'probabilities': {'not_sarcastic': 0.02767517976462841,
  'sarcastic': 0.9723248481750488}}

In [None]:
rewards = [
      torch.tensor(output["probabilities"]["sarcastic"])
      for output in rm_pipe_outputs
    ]

In [None]:
# Average reward
sum(rewards) / len(rewards)

tensor(0.6649)

In [None]:
log_batch = {
        "query": decoded_queries[:10],
        "response": decoded_responses[:10],
      }

In [None]:
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)

In [None]:
ppo_trainer.log_stats(stats, log_batch, rewards)
# Add rows to the persistent table
for q, r, rew in zip(log_batch["query"], log_batch["response"], rewards):
    all_samples_table.add_data(q, r, float(rew))

# Log the growing table under a different key
wandb.log({"all_samples": all_samples_table})

# Scratchpad

In [None]:
response_tensors = []
response_tensors_slice = []
for query, mask in tqdm(zip(batch['input_ids'], batch['attention_mask'])):
      query_response = ppo_trainer.generate(
          query,
          attention_mask=mask.unsqueeze(0),
          **generation_kwargs
      ).squeeze()
      response_len = len(query_response) - len(query)
      response_tensors.append(query_response[-response_len:])

In [None]:
# Last query response
query_response

In [None]:
# Last query
query_tensors[31]

In [None]:
# Last query response - the query
print(response_tensors[31])

In [None]:
batch["response"] = [tokenizer.decode(r.squeeze(), skip_special_tokens=True) for r in response_tensors]

In [None]:
type(batch['query'][0])

In [None]:
batch["response"][31]

In [None]:
tokenizer.decode(query_response.squeeze(), skip_special_tokens=True)

In [None]:
clean_queries = []
for q in batch["query"]:
    # Remove the "Parent:" and "Sarcastic reply:" parts
    # Split on "Sarcastic reply:" and take the parent comment portion
    if "Sarcastic reply:" in q:
        parent_text = q.split("Sarcastic reply:")[0]
        # Also strip the "Parent:" prefix and whitespace
        parent_text = parent_text.replace("Parent:", "").strip()
        clean_queries.append(parent_text)
    else:
        # Fallback if template not found
        clean_queries.append(q.strip())

batch['query'] = clean_queries

In [None]:
batch['query'][31]

In [None]:
batch_inputs = [(q, r) for q, r in zip(batch['query'], batch['response'])]

In [None]:
with torch.no_grad():
      rm_pipe_outputs = reward_model_pipe(batch_inputs, batch_size=min(len(batch_inputs), 8))

In [None]:
from dataclasses import fields
print([f.name for f in fields(PPOConfig)])
print("eval_steps" in [f.name for f in fields(PPOConfig)])

In [None]:
# model set up
# (PPO requires a model with a value head)
# PPO also requires a reference model, but this model is generated by the PPOTrainer automatically
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained('gpt2', padding_side='left')
tokenizer.pad_token = tokenizer.eos_token


In [None]:
# load training data

# load the IMDb dataset
# TODO: replace this with our own training data
# imdb_dataset = load_dataset('imdb')
sarcasm_train_dataset = load_dataset("marcbishara/sarcasm-on-reddit")['ppo_train']

# use a subset of dataset for the POC so it doesn't run for hours
# taking the first 200 examples for demonstration
dataset = sarcasm_train_dataset.select(range(200))


In [None]:
# tokenize the dataset
dataset = dataset.map(tokenize, batched=False)

# cast input_ids as torch tensors
dataset.set_format(type='torch', columns=['input_ids'])

In [None]:

# def tokenize(sample):
#     tokenized_output = tokenizer(
#         sample['text'],
#         truncation=True,
#         max_length=128,
#         padding='max_length')

#     ids = tokenized_output['input_ids']
#     sample['input_ids'] = ids

#     # decode back to string for use in the reward score function
#     sample['query'] = tokenizer.decode(ids, skip_special_tokens=True)

#     return sample

def tokenize(sample):
    sample['input_ids'] = tokenizer.encode(sample['text'], max_length=128, truncation=True)
    # sample['query'] = tokenizer.decode(sample['input_ids'], skip_special_tokens=True) # let's just do this later in the training loop -- seems to get dropped by the trainer?
    return sample


In [None]:
# custom reward function
# CURRENTLY REPLACED BY DIRECT CALL WITHIN THE TRAINING LOOP

def get_reward_score(query_text, response_text):
    # TODO: replace this with our weighted sum reward score from multiple reward signals
    # based on the query_text and response_text parameters

    # print query and respone
    # print(f"Query: {query_text} | Response: {response_text}")

    # currently, just randomly 0 or 1
    score = float(random.randint(0, 1))

    return score



In [None]:
# revised PPO training loop

print("Starting training...")
print(f"Number of batches per epoch: {len(ppo_trainer.dataloader)}")

for epoch in tqdm(range(epochs), 'epoch: '):
    for i, batch in tqdm(enumerate(ppo_trainer.dataloader)):

        # get query_tensors as tensors
        query_tensors = batch['input_ids']

        # reconstruct 'query' from input_ids, since might have been removed???
        batch['query'] = [tokenizer.decode(q_t, skip_special_tokens=True) for q_t in query_tensors]

        # print('batch["query"]: ', batch["query"])

        #### Get response from SFTModel
        response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
        batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

        # calculate rewards - replaced with code below to call sarcasm model
        # rewards = []
        # for q, r in zip(batch['query'], batch['response']):
        #     score = get_reward_score(q, r)
        #     rewards.append(torch.tensor(score))

        # calculate rewards with the sarcasm reward model
        queries = batch['query']
        responses = batch['response']

        # generate separator token
        sep_token = sarcasm_model.tokenizer.sep_token

        # combine queries and responses seprated by token into a single list of "query [SEP] response"
        batch_inputs = [f"{q} {sep_token} {r}" for q, r in zip(queries, responses)]

        # process the batch
        pipe_outputs = sarcasm_model(batch_inputs, batch_size=len(batch_inputs), truncation=True)

        # process the results
        rewards = []

        for output in pipe_outputs:

          # extract the score
          sarcasm_score = output['score']

          # TODO: add other reward signals -- just placeholder here
          # other_score = float(random.randint(0, 1))
          other_score = 0

          # combine score -- TODO: weighted sum? NORMALIZE the score!
          score = sarcasm_score + other_score

          # append
          rewards.append(torch.tensor(score))

        #### Run PPO step
        stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
        ppo_trainer.log_stats(stats, batch, rewards)

        # logging code
        if i % LOG_INTERVAL == 0:
            # clculate mean reward for this batch
            print(f"Step {i}: Mean Reward from PPO stats: {stats['ppo/mean_scores']:.4f}")
            print(f"        PPO Loss:    {stats['ppo/loss/total']:.4f}")

print('Training complete')

#### Save model
ppo_trainer.save_pretrained(drive_path)

print('Model saved')