In [4]:
import sys
import os
import jaxtyping
from pathlib import Path

import os
import sys
import time
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable

import einops
import numpy as np
import torch as t
import torch.nn as nn
import wandb
import tabulate
from eindex import eindex
from jaxtyping import Float, Int
from rich import print as rprint
from rich.table import Table
from tabulate import tabulate
from torch import Tensor
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint
from transformers import AutoModelForSequenceClassification, AutoTokenizer

device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")

In [5]:
@dataclass
class RLHFArgs:
    # random
    seed: int = 1

    # logging
    wandb_project_name: str = "rlhf_transformers"
    wandb_entity: str | None = None

    # macro-training 
    total_phases: int = 200
    batch_size: int = 32 #enforce batch_size % num_minibatches == 0
    num_minibatches: int = 4
    batches_per_learning_phase: int = 2

    # optimization hyperparameters
    base_lr: float = 2e-5
    head_lr: float = 5e-4
    max_grad_norm: float = 1.0
    warmup_steps: int = 20 #enforce warmup_steps < total_phases
    final_scale: float = 0.1

    # PPO objective function coefficients
    clip_coef: float = 0.2
    vf_coef: float = 0.15
    ent_coef: float = 0.001

    # model and sampling with prefix
    base_model: str = "gpt-medium"
    gen_len: int = 50
    temperature: float = 1.0
    top_k: int = 10
    prefix: str = "This is"
    prepend_bos: bool = True

    # RLHF-specific arguments
    kl_coef: float = 2.5
    reward_fn: Callable = lambda x: 0.0
    normalize_reward: bool = True

    def __post_init__(self):
        self.minibatch_size = self.batch_size // self.num_minibatches

# Setup: working with the transformer

Right after the last layernorm before we unembed our tokens, we add a hook function (our value head) which computes a **value estimate** for the generated sequence. The hook function is a simple 2-layer neural network which computs the value estimate during the forward pass and stores it externally.

Why do we choose this location? After the layernorm essentially normalizes the reward, and before the unembedding because we take in the enumerated tokens as input. It is also towards the end because (supposedly) it contains the most information after accumulating through the residual stream.

In [6]:
class TransformerWithValueHead(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = HookedTransformer.from_pretrained(base_model)
        
        d_model = self.base_model.cfg.d_model
        self.value_head = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.ReLU(),
            nn.Linear(4 * d_model, 1))

    def forward(self, input_ids):
        value_head_output = None

        # resid_post: [batch seq d_model] so
        # value_head_ouput: [batch seq]
        def calc_and_store_value_head_output(resid_post, hook):
            # nonlocal: for variables inside nested functions
            nonlocal value_head_output
            value_head_output = self.value_head(resid_post).squeeze(-1)

        # run_with_hooks injects parameters
        logits = self.base_model.run_with_hooks(
            input_ids,
            return_type = "logits",
            # "normalized" to represent being after the LayerNorm
            fwd_hooks = [(utils.get_act_name("normalized"), calc_and_store_value_head_output)])
        
        return logits, value_head_output
    
model = TransformerWithValueHead("gpt2-small").to(device)


Loaded pretrained model gpt2-small into HookedTransformer


Defaulting `stop_at_eos = False` is interesting. From an interpretability perspective, `stop_at_eos = False`  helps with seeing hallucations. From a training perspective, it helps measure how well the model learned to stop and enables models to learn from full length text, not truncated text.

In [7]:
# prepend_bos: appending a BOS token at the start of a sequence, which marks the start
def get_samples(base_model, prompt, batch_size, gen_len, temperature, top_k, prepend_bos):
    # returns one tokenized prompt, squeeze to extract pure tokens
    input_ids = base_model.to_tokens(prompt, prepend_bos = prepend_bos).squeeze(0)

    output_ids = base_model.generate(
        # [tokens] becomes [batch_size tokens]
        # repeats input_ids once batch_size times
        input_ids.repeat(batch_size, 1), 
        max_new_tokens = gen_len, 
        stop_at_eos = False,
        temperature = temperature,
        top_k = top_k, 
        verbose = False
    )

    # samples: [batch_size sequence]
    samples = base_model.to_string(output_ids)

    # .clone() to prevent modification to internal output_ids
    return output_ids.clone(), samples

In [8]:
sample_ids, samples = get_samples(
    model.base_model,
    prompt = "This movie was really",
    batch_size = 5,
    gen_len = 15,
    temperature = 0.8,
    top_k = 15,
    prepend_bos = False
)

table = Table("Token IDs", "Samples", show_lines = True)
for ids, sample in zip(sample_ids, samples):
    # ids.tolist(): convert Tensor into Python list
    # repr(sample): printable representation (adds single quotes)
    table.add_row(str(ids.tolist()), repr(sample))

rprint(table)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [9]:
# .half(): uses float16 precision for faster inference on GPUs
cls_model = AutoModelForSequenceClassification.from_pretrained("lvwerra/distilbert-imdb").half().to(device)
cls_tokenizer = AutoTokenizer.from_pretrained("lvwerra/distilbert-imdb")

def reward_fn_sentiment_imdb(gen_sample, direction):
    # "pt" for pytorch tensors, padding + truncation to ensure same length generation
    tokens = cls_tokenizer(gen_sample, return_tensors = "pt", padding = True, truncation = True)["input_ids"].to(device)
    # logits: [batch_size, 2] for pos/neg classification
    logits = cls_model(tokens).logits
    # positive_cls: [batch_size] contains relevant class after softmaxing to get probabilities
    positive_cls = logits.softmax(-1)[:, 1 if (direction == "pos") else 0]
    return positive_cls.to(device)

In [10]:
def normalize_reward(reward, eps = 1e-5):
    return (reward - reward.mean()) / (reward.std() + eps)

Using the simple $A(s_t, a_t) = Q(s_t, a_t) - V(s_t)$ formula where $Q(s_t, a_t)$ is based off of the one-step Q estimates. if $t<T$, then our Q estimate is $V(s_{t+1})$, but if $t=T$, then we can use the known reward $r_t$ for the entire sequence.

GAE is an alternative but wouldn't bring a significant improvement since GAE is most helpful in reducing variance in advantage estimation, and our situation is low variance (each step adds a single token to our sequence).

In [11]:
def compute_advantages(values, rewards, prefix_len):
    one_step_est = t.cat([values[:, prefix_len:-1], rewards[:, None]], dim = -1)
    zero_step_est = values[prefix_len-1:-1]
    return one_step_est - zero_step_est

# Memory

Compared to the PPO implementation, there are a few differences in `ReplayMemory`. 
- Don't need an `add` function because we add it all at once instead of one-by-one.
- Don't need multiple environments

And for `ReplayMinibatch`
- Don't need `actions` anymore since there isn't a sense of an "agent" since actions (tokens generated) are contained within the sequences
- Don't need `dones` since we set the sequence to be `gen_len` long
- Sotre `ref_logits` as a part of the KL penalty w.r.t the reference model

In [12]:
# ??? figure out why the sizing is what it is
class ReplayMinibatch:
    sample_ids: Float[Tensor, "minibatch_size seq_len"]
    logprobs: Float[Tensor, "minibatch_size gen_len"]
    advantages: Float[Tensor, "minibatch_size gen_len"]
    returns: Float[Tensor, "minibatch_size gen_len"]
    ref_logits: Float[Tensor, "minibatch_size seq_len d_vocab"]

class ReplayMemory:
    def __init__(self, args, sample_ids, logprobs, advantages, values, ref_logits):
        self.args = args
        self.sample_ids = sample_ids
        self.logprobs = logprobs
        self.advantages = advantages
        self.values = values
        self.ref_logits = ref_logits

    def get_minibatches(self):
        minibatches = []

        # since we use 1-step advantage estimation
        # returns = next-step estimate of value function
        returns = self.advantages + self.values[:, -self.args.gen_len - 1: -1]

        for _ in range(self.args.batches_per_learning_phase):
            for indices in t.randperm(self.args.batch_size).reshape(self.args.num_minibatches, -1):
                minibatches.append(ReplayMinibatch(
                    sample_ids = self.sample_ids[indices],
                    logprobs=self.logprobs[indices],
                    advantages=self.advantages[indices],
                    returns=returns[indices],
                    ref_logits=self.ref_logits[indices]
                ))

        return minibatches

In addition to the 3 components of the total PPO objective, we'll add on the KL penalty as a part of the RLHF framework.
- The KL prediction shift penalty is $-\lambda_{KL} D_{KL}(\pi_{PPO}\phantom{.}|| \phantom{.}\pi_{base})$ (and not the other way) because the penalization should be for results that are likely under $\pi_{PPO}$ and unlikely under $\pi_{base}$. Expanding the KL penalty yields: $$\lambda_{KL} \cdot \sum_i \pi_{PPO_i}\log\left(\frac{\pi_{PPO_i}}{\pi_{base_i}}\right)$$
- The `entropy`, `value_fn`, and `clipped_sur_obj` functions are essentially the same from PPO 

In [None]:
# .mean() to aggregate over the batch + stabilize training
def calc_kl_penalty(logits, ref_logits, kl_coef):
    log_probs = logits.log_softmax(-1)
    ref_log_probs = ref_logits.log_softmax(-1)
    probs = log_probs.exp()

    kl_div = (probs * (log_probs - ref_log_probs)).sum(-1)

    return kl_coef * kl_div.mean()

def calc_entropy_bonus(logits, ent_coef):
    log_probs = logits.log_softmax(-1)
    probs = log_probs.exp()

    entropy = -(log_probs * probs).sum(-1)

    return ent_coef * entropy.mean()

# supervised regression loss for the value function
def calc_value_fn_loss(values, mb_returns, vf_coef):
    return 1/2 * vf_coef * (values - mb_returns).pow(2).mean()

def calc_clipped_sur_obj(logprobs, mb_logprobs, mb_advantages, clip_coef, eps = 1e-8):
    logits_diff = logprobs - mb_logprobs
    # ratio of the policies
    ratio = t.exp(logits_diff)

    # normalizing the advantages
    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + eps)

    # standard clip application
    non_clipped = ratio * mb_advantages
    clipped = t.clip(ratio, 1 - clip_coef, 1 + clip_coef) * mb_advantages

    return t.minimum(non_clipped, clipped).mean()

`get_log_probs` ensures that the output is always of size `(minibatch_size, gen_len)`. We only care about the log probs of the tokens generated, not in the prefix.

In [14]:
def get_log_probs(logits, tokens, prefix_len):
    if prefix_len is not None:
        logits = logits[:, prefix_len-1:]
        tokens = tokens[:, prefix_len-1:]
    
    log_probs = logits.log_softmax(-1)
    shaped_log_probs = eindex(log_probs, tokens, "b s [b s+1]")

    return shaped_log_probs

For both the base model and the value head, we define seperate learning rates, which makes sense since the value head is randomly initalized whereas the base model is already built out.

For the scheduler, we use a lienar warmup up to `1.0` then linear decay down to `args.final_scale`.

In [None]:
def get_optimizer(model, base_lr, head_lr):
    return t.optim.AdamW(
        [
           {"params": model.base_model.parameters(), "lr": base_lr},
           {"params": model.value_head.parameters(), "lr": head_lr} 
        ], maximize = True)

def get_optimizer_and_scheduler(args, model):
    def lr_lambda(step):
        if step < args.warmup_steps:
            return step / args.warmup_steps
        else:
            return 1 - (1 - args.final_scale) * (step - args.warmup_steps) / (args.total_phases - args.warmup_steps)
        
    optimizer = get_optimizer(model, args.base_lr, args.head_lr)
    scheduler = t.optim.lr_scheulder.LambdaLR(optimizer, lr_lamda = lr_lambda)

    return optimizer, scheduler

# Training

In [None]:
class RLHFTrainer:
    