# Reinforcement Learning from AI Feedback

Fine-tunes a language model using natural language criteria for its sampled outputs.

This notebook fine-tunes [EleutherAI](https://www.eleuther.ai/)'s [Pythia 160M](https://huggingface.co/EleutherAI/pythia-160m-deduped) language model using a zero-shot reward model derived from an instruct tuned language model ([Katherine Crowson's instruct fine-tune](https://huggingface.co/RiversHaveWings/minihf_evaluator_openllama_7b) of [OpenLLaMA 7B](https://huggingface.co/openlm-research/open_llama_7b)).

The zero-shot reward model is obtained by asking the instruct model yes/no questions about the generations from the model that is being RLAIF tuned. It takes the logits for the first token of the response and forms a binary classifier logit as `log(p(yes) + p(neither) / 2) - log(p(no) + p(neither) / 2)`. It uses `log(sigmoid(logit))` (log probability of the "yes" class) as the reward. It uses weighted "soft conjunctions" of multiple binary classifier logits to fine-tune the model to satisfy multiple natural language criteria simultaneously.

The gradient estimator is [DiCE](https://github.com/crowsonkb/dice-mc), a variant of REINFORCE. It uses a fixed strength KL penalty to constrain the fine-tuned model's distribution over tokens to not vary too far from the original model's.

If you like this notebook you should check out [MiniHF](https://github.com/JD-P/minihf/), the language model fine-tuning and inference tool the code was originally written for.

<small>Notebook by Katherine Crowson (crowsonkb@gmail.com, https://twitter.com/RiversHaveWings)
<br>Sponsored by StabilityAI (https://twitter.com/stabilityai)
<br>Copyright 2023 Katherine Crowson. Licensed under the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0).</small>


In [1]:
#@title Check GPU

!nvidia-smi

Mon Aug 28 15:55:32 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.98.01              Driver Version: 536.99       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce GTX 1660 Ti     On  | 00000000:01:00.0  On |                  N/A |
| N/A   61C    P8               7W /  80W |    200MiB /  6144MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA TITAN RTX               On  | 00000000:14:00.0 Off |  

In [3]:
import torch
torch.cuda.get_device_name()

'NVIDIA TITAN RTX'

In [5]:
torch.cuda.set_device(1)

In [8]:
torch.cuda.current_device()

1

In [17]:
torch.cuda.device_count()

2

In [9]:
#@title Install dependencies

!pip install bitsandbytes dice-mc peft safetensors sentencepiece tokenizers transformers -q



In [10]:
#@title Import libraries

from functools import partial
import math
import os
import textwrap

os.environ["BITSANDBYTES_NOWELCOME"] = "1"

import dice_mc.torch as dice
import peft
import torch
from torch import optim
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

2023-08-28 16:01:50.445428: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [11]:
#@title Define functions

print = tqdm.external_write_mode()(print)


def endless_range(start=0, step=1):
    i = start
    while True:
        yield i
        i += step


def at_least_float32(tensor):
    dtype = torch.promote_types(tensor.dtype, torch.float32)
    return tensor.to(dtype)


def logsumexp_scaled(a, b, return_sign=False, dim=None, keepdim=False):
    """Compute log(sum(b * exp(a)))."""
    if dim is None:
        dim = tuple(range(a.ndim))

    a, b = torch.broadcast_tensors(a, b)
    a = torch.where(b != 0, a, float("-inf"))

    a_max = torch.amax(a, dim=dim, keepdim=True)
    a_max = torch.nan_to_num(a_max, 0.0, 0.0, 0.0)

    tmp = b * torch.exp(a - a_max)

    s = torch.sum(tmp, dim=dim, keepdim=keepdim)
    if return_sign:
        sgn = torch.sign(s)
        s *= sgn
    out = torch.log(s)

    if not keepdim:
        a_max = torch.squeeze(a_max, dim=dim)
    out += a_max

    if return_sign:
        return out, sgn
    else:
        return out


def soft_maximum(values, weights=None, tau=1.0, dim=None, keepdim=False):
    if weights is None:
        weights = torch.ones_like(values)
    weights /= weights.sum(dim=dim, keepdim=True)
    return logsumexp_scaled(values / tau, weights, dim=dim, keepdim=keepdim) * tau


def soft_minimum(values, weights=None, tau=1.0, dim=None, keepdim=False):
    if weights is None:
        weights = torch.ones_like(values)
    weights /= weights.sum(dim=dim, keepdim=True)
    return -logsumexp_scaled(-values / tau, weights, dim=dim, keepdim=keepdim) * tau


def get_scores_from_logits(logits, pos_tokens, neg_tokens):
    logits = at_least_float32(logits[:, -1, :])
    logits = F.log_softmax(logits, dim=-1)
    pos = torch.logsumexp(logits[:, pos_tokens], dim=-1)
    neg = torch.logsumexp(logits[:, neg_tokens], dim=-1)
    rest = (1 - pos.exp() - neg.exp()).log()
    return torch.logaddexp(pos, rest - math.log(2)) - torch.logaddexp(neg, rest - math.log(2))


def find_token_for_string(tokenizer, prefix, s):
    tok_prefix = tokenizer(prefix).input_ids
    tok_prefix_s = tokenizer(prefix + s).input_ids
    if tok_prefix_s[: len(tok_prefix)] != tok_prefix:
        raise RuntimeError(f"{prefix!r} tokens are not a prefix of {prefix + s!r} tokens")
    return tok_prefix_s[len(tok_prefix)]


def find_tokens_for_strings(tokenizer, prefix, strings):
    return sorted(set([find_token_for_string(tokenizer, prefix, s) for s in strings]))


def make_get_scores(tokenizer, prefix):
    pos_tokens = find_tokens_for_strings(tokenizer, prefix, ["yes", "Yes", "YES"])
    neg_tokens = find_tokens_for_strings(tokenizer, prefix, ["no", "No", "NO"])
    return partial(get_scores_from_logits, pos_tokens=pos_tokens, neg_tokens=neg_tokens)


def kl_div_est(logp, logq):
    """Biased estimator of D_KL(P || Q) from log(p(x)) and log(q(x)), x sampled from p."""
    return torch.logaddexp(logp - logq, logq - logp) - math.log(2)


def inv_cumsum(x):
    """Inverse of cumulative sum."""
    out = x.clone()
    out[..., 1:] -= x[..., :-1]
    return out


def gradient_norm(params):
    params = list(params)
    total = params[0].new_tensor(0.0)
    for p in params:
        if p.grad is not None:
            total += p.grad.pow(2).sum()
    return total.sqrt()


In [12]:
#@title Define evaluator templates and prompts

templates = [
    """Answer yes or no and only yes or no.

=== Begin story ===
{text}
=== End story ===

Does this story make the reader feel like crying?""",
    """Answer yes or no and only yes or no.

=== Begin story ===
{text}
=== End story ===

Is this story well-written and coherent?""",
]
weights = [1.0, 0.5]
signs = [1, 1]


def make_evaluator_prompts(texts):
    return [[template.format(text=text) + "<|end|>" for text in texts] for template in templates]


train_prompts = [
    "My cat is so cute, but",
    "I was watching TV, and",
    "She looked in the mirror and",
    "Alice said, \"",
]

eval_prompts = train_prompts

In [13]:
#@title Training parameters

#@markdown Batch size:
bs = 12  #@param {type:"integer"}

#@markdown Number of tokens to sample per batch item:
n_tokens = 48  #@param {type:"integer"}

#@markdown KL penalty weight:
#@markdown <br><small>Constrains the fine-tuned model to be close to the original model. The larger the KL penalty, the less it is allowed to deviate from the original model's distribution.</small>
kl_weight = 1.0  #@param {type:"number"}

#@markdown Temperature for soft conjunction:
#@markdown <br><small>Interpolates between the weighted mean of the reward components (evaluator templates) and their minimum. Higher temperature is more mean-like, lower is more minimum-like.</small>
tau = 1.0  #@param {type:"number"}

#@markdown Save every this many steps:
save_every = 250  #@param {type:"integer"}



In [14]:
#@title Load evaluator model

# Use small-shard safetensors version of openlm-research/open_llama_7b to be
# able to load the model on non-high RAM Colab instances
eval_model_name = "RiversHaveWings/open_llama_7b_safetensors"
eval_adapter_name = "RiversHaveWings/minihf_evaluator_openllama_7b"

print("Loading evaluator model tokenizer...")
eval_tokenizer = AutoTokenizer.from_pretrained(eval_adapter_name)
eval_tokenizer.padding_side = "left"

print("Loading evaluator base model...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)
eval_model = AutoModelForCausalLM.from_pretrained(
    eval_model_name,
    device_map="auto",
    quantization_config=bnb_config,
    torch_dtype=torch.float16,
)

print("Loading evaluator adapter...")
eval_model = peft.PeftModel.from_pretrained(eval_model, eval_adapter_name)
eval_model.requires_grad_(False);

print("Done.")

Loading evaluator model tokenizer...
Loading evaluator base model...


Loading checkpoint shards:   0%|          | 0/14 [00:00<?, ?it/s]

Loading evaluator adapter...
Done.


In [16]:
#@title Training loop

device = torch.device("cuda:1")
output_path = "model"

train_inputs = tokenizer(train_prompts, return_tensors="pt", padding=True).to(device)
eval_inputs = tokenizer(eval_prompts, return_tensors="pt", padding=True).to(device)
input_n, input_len = train_inputs.input_ids.shape
get_scores = make_get_scores(eval_tokenizer, "<|end|>")
weights_ = torch.tensor(weights, device=device)[None]
signs_ = torch.tensor(signs, device=device)[None]

opt = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.98))
baseline = dice.EMABaseline(decay=0.98).to(device)
baseline_kl = dice.EMABaseline(decay=0.98).to(device)


for i in tqdm(endless_range()):
    # Demo generations
    if i % 50 == 0:
        outputs = model.generate(
            eval_inputs.input_ids,
            attention_mask=eval_inputs.attention_mask,
            do_sample=True,
            min_new_tokens=n_tokens,
            max_new_tokens=n_tokens,
            pad_token_id=tokenizer.eos_token_id,
            top_k=0,
        )
        texts = [tokenizer.decode(toks, skip_special_tokens=True) for toks in outputs]
        print("======")
        print("\n===\n".join(textwrap.fill(text, width=80) for text in texts))
        print("======")

    # Save model
    if i > 0 and i % save_every == 0:
        print("Saving model...")
        tokenizer.save_pretrained(output_path)
        model.save_pretrained(output_path, safe_serialization=True)

    # Sample from training prompts
    indices = torch.randint(0, input_n, [bs], device=device)
    tokens = model.generate(
        train_inputs.input_ids[indices],
        attention_mask=train_inputs.attention_mask[indices],
        do_sample=True,
        min_new_tokens=n_tokens,
        max_new_tokens=n_tokens,
        pad_token_id=tokenizer.eos_token_id,
        top_k=0,
    )

    # Get logits with grad for backprop
    attention_mask = torch.cat(
        [train_inputs.attention_mask[indices], torch.ones_like(tokens[:, input_len:])], dim=1
    )
    outputs = model(tokens, attention_mask=attention_mask)

    # Create stochastic nodes
    logp = dice.logp_categorical(outputs.logits[:, input_len - 1 : -1], tokens[:, input_len:])
    logp_sum = torch.sum(logp, dim=1)
    logp_cumsum = torch.cumsum(logp, dim=1)

    # Get original model logits and compute KL penalties
    with torch.no_grad(), model.disable_adapter():
        outputs_orig = model(tokens, attention_mask=attention_mask)
    logp_orig = dice.logp_categorical(outputs_orig.logits[:, input_len - 1 : -1], tokens[:, input_len:])
    logp_orig_cumsum = torch.cumsum(logp_orig, dim=1)
    kls = inv_cumsum(kl_div_est(logp_cumsum.detach(), logp_orig_cumsum.detach()))

    # Compute rewards using evaluator model
    texts = [tokenizer.decode(t, skip_special_tokens=True) for t in tokens]
    prompts_all = make_evaluator_prompts(texts)
    inputs_all = [
        eval_tokenizer(prompts, return_tensors="pt", padding=True).to(device)
        for prompts in prompts_all
    ]
    with torch.no_grad():
        outputs_all = [
            eval_model(inputs.input_ids, attention_mask=inputs.attention_mask)
            for inputs in inputs_all
        ]
    scores = torch.stack([get_scores(outputs.logits) for outputs in outputs_all], dim=1)
    scores = soft_minimum(scores * signs_, weights_, tau=tau, dim=1)

    # Create cost nodes and baselines, then backprop
    losses_main = -F.logsigmoid(scores)
    losses_main = dice.cost_node(losses_main, [logp_sum])
    losses_main += baseline(losses_main, [logp_sum])
    losses_kl = kls * kl_weight
    losses_kl = dice.cost_node(losses_kl, [logp_cumsum])
    losses_kl += baseline_kl(losses_kl, [logp_cumsum])
    loss_main = losses_main.mean()
    loss_kl = losses_kl.mean()
    loss = loss_main + loss_kl
    loss.backward()

    # Print metrics
    grad_norm = gradient_norm(model.parameters())
    print(f"step: {i}, loss: {loss.item():g}, main: {loss_main.item():g}, kl: {loss_kl.item():g}, grad norm: {grad_norm.item():g}")

    # Take an optimizer step
    opt.step()
    opt.zero_grad()


0it [00:00, ?it/s]

My cat is so cute, but perfect for you!” Vlad will come home to find him drawing
down a sleeve.  Sixth grade”University-Madison”L on Brian’s cup as a black
instructor.  6th grade“  Tuning
===
I was watching TV, and even though a water-closet and timeslot IDI was heading
its way (I LOVE my non-alcoholic beverage range) here on my weekend at Mumm and
TPI so I was pretty sure if I travlnse
===
She looked in the mirror and saw it was the kid at the County Suites in Sprold.
They were so happy to see him walking. Linda gave a halfhearted hug for it,
knowing... somehow, 10 min. to take care of his baby would jeopard
===
Alice said, "I was Wrong." "Damn it." "Heard that one about driving yesterday?"
"Yeah." "He's gonna get it now." "Morning, noble." "So your piece is up."
"Mindful of what you're


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

In [15]:
#@title Load model to fine-tune

model_name = "EleutherAI/pythia-160m-deduped"

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
peft_config = peft.LoraConfig(
    peft.TaskType.CAUSAL_LM,
    inference_mode=False,
    r=32,
    lora_alpha=8,
    lora_dropout=0.0,
    target_modules=[
        "attention.query_key_value",
        "attention.dense",
        "mlp.dense_h_to_4h",
        "mlp.dense_4h_to_h",
    ],
)

print("Initializing adapter...")
model = peft.get_peft_model(model, peft_config)
model.train()
model.print_trainable_parameters()

print("Done.")


Loading tokenizer...
Loading model...
Initializing adapter...
trainable params: 4,718,592 || all params: 167,041,536 || trainable%: 2.824801611019669
Done.
