<a href="https://colab.research.google.com/github/finardi/tutos/blob/master/TRLX_%3E_TransformerLens_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Analyzing TRLX RLHF Models with TransformerLens - A Demo

**Author:** Curt Tigges (curt@eleuther.ai)

## Introduction

LLMs trained with RLHF are a prominent paradigm in the current AI landscape, yet not much has yet been done to analyze these models to date--partially due to the complexity and scale of these models, and partially due to the previous lack of accessible tooling for training and analysis. Fortunately, we are reaching the point where tooling for both mechanistic interpretability and for RLHF fine-tuning is becoming available. In this notebook, I demonstrate how to do both RLHF training using TRLX, an open-source library created by CarperAI; and mechanistic interpretation of TRLX models using TransformerLens, a library created by Neel Nanda.

I first fine-tune a movie-review-generating version of GPT-2 with TRLX to generate only negatively-biased movie reviews, following an example provided in the TRLX repo. I then load and analyze the model (and the original model before RLHF) into TransformerLens for mechanistic interpretability analysis. Here, I adapt some of the techniques and code from Neel Nanda's excellent [Exploratory Analysis Demo](https://colab.research.google.com/github/neelnanda-io/Easy-Transformer/blob/main/Exploratory_Analysis_Demo.ipynb).

In addition to carrying out some basic analysis to understand how different layers contribute to the logits, I also identify some key regions of the network responsible for contributing the negative bias to the network (at least, for the specific task of predicting the next adjective). Much analysis remains to be done, but I hope this notebook provides a useful starting point.

### Importance of RLHF

RLHF (or sometimes, RLAIF, or RL from AI Feedback) is becoming increasingly important as a method for specifying the behavior of LLMs like OpenAI's ChatGPT or Anthropic's Claude. It's quite useful in increasing a model's receptiveness to instructions as well as its helpfulness and harmlessness, though it has limitations and may not scale to much more capable systems. Nevertheless, it is quite important in today's LLM landscape.

RL induces behavior in models that are critical to understand as we delegate more tasks to them. Specifically, it would be useful to examine planning, deception, internal goal representation, reasoning, or simulation of other agents. Neel Nanda provides a set of [recommended RL problems](https://www.lesswrong.com/s/yivyHaCAmMJ3CqSyj/p/eqvvDM25MXLGqumnf) in his 200 Open Problems in Mechanistic Interpretability sequence. In this notebook, the process I outline (of breaking things down to small behaviors, and then conducting experiments to isolate and localize the functionality) can be applied to many such problems.

### RLHF Training Details

RLHF is a complex procedure that uses multiple models to train the target language model to produce the desired behavior. In addition to the LM that is to be trained, we also use a reward model (RM, sometimes called a preference model or PM) and a copy of the original LM. The process is as follows:
1. We first train a reward model on human preference data. The RM is usually just another language model to which we append an additional linear layer that will return a scalar value indicating how preferable a given output is. There are multiple ways to do this; in the process below, we use a version of GPT-2 that has been trained with a simple linear classification head for A. negative or B. positive sentiment. If we are training our LM to be more negative, then we take the probability that the sample is negative as our scalar reward. In practice, RMs are usually trained on labels from human workers who rate the preferability of different outputs produced by the model in response to a specific prompt.
2. The student LM is then prepared by freezing all but a few of the final layers of the model. We also retain a copy of the original base model to use in training.
3. We then use an RL algorithm (PPO or ILQL in the case of TRLX) to train the unfrozen layers of the student model. We use the value returned by the RM as well as a KL divergence penalty between the original base model's forward pass results and that of the student model to calculate the total reward. (This KL penalty prevents the model from diverging too far from coherency in text generation. Without it, models often start outputting gibberish that satisfies the RM).

The result (hopefully!) is a language model that satisfies the performance criteria.

There are many more important details in RLHF training, and I recommend this [overview](https://huggingface.co/blog/rlhf) from HuggingFace for more.

## Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%pip install -Uqq git+https://github.com/neelnanda-io/TransformerLens

In [None]:
%pip install -Uqq circuitsvis

In [None]:
import os
import pathlib
from typing import List, Optional, Union

import torch
import numpy as np
import yaml

import einops
from fancy_einsum import einsum

from datasets import load_dataset
from transformers import pipeline
import plotly.express as px

if torch.cuda.is_available():
    device = int(os.environ.get("LOCAL_RANK", 0))
else:
    device = "cpu"

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def two_lines(tensor1, tensor2, renderer=None, **kwargs):
    px.line(y=[utils.to_numpy(tensor1), utils.to_numpy(tensor2)], **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Fine-Tune with RLHF

We start by training our own RLHF model, using GPT-2-small as a starting point.

The code below is an example training task taken from the TRLX repo. Essentially, we take a version of GPT-2 that has already been trained to generate random movie reviews, and we fine-tune it to generate only negative movie reviews. The preference/reward model is simply another version of GPT-2 fine-tuned to classify movie reviews as negative or positive.

In [None]:
%cd /content/drive/MyDrive/repos/
!rm -rf trlx

In [None]:
!git clone https://github.com/CarperAI/trlx.git
%cd trlx
%pip uninstall numpy
%pip install numpy==1.23.5
%pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 # for cuda
%pip install -e .

In [None]:
import trlx
from trlx.data.configs import TRLConfig

In [None]:
def get_negative_score(scores):
    "Extract value associated with a negative sentiment from pipeline's output"
    return dict(map(lambda x: tuple(x.values()), scores))["NEGATIVE"]

default_config = yaml.safe_load(open("configs/ppo_config.yml"))

def main(hparams={}):
    config = TRLConfig.update(default_config, hparams)

    if torch.cuda.is_available():
        device = int(os.environ.get("LOCAL_RANK", 0))
    else:
        device = -1

    sentiment_fn = pipeline(
        "sentiment-analysis",
        "lvwerra/distilbert-imdb",
        top_k=2,
        truncation=True,
        batch_size=256,
        device=device,
    )

    def reward_fn(samples: List[str], **kwargs) -> List[float]:
        sentiments = list(map(get_negative_score, sentiment_fn(samples)))
        return sentiments

    # Take few words off of movies reviews as prompts
    imdb = load_dataset("imdb", split="train+test")
    prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]

    return trlx.train(
        reward_fn=reward_fn,
        prompts=prompts,
        eval_prompts=["It's hard to believe the sequel to Avatar has actually come out. After 13 years and what feels like half-a-dozen delays"] * 64,
        config=config,
    )

In [None]:
trainer = main()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcurt-tigges[0m ([33marena-ldn[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01673873673333522, max=1.0)…

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


losses/total_loss: -0.03, losses/policy_loss: -0.05, losses/value_loss: 0.02:  25%|██▍       | 99/400 [00:55<02:38,  1.90it/s]

losses/total_loss: -0.03, losses/policy_loss: -0.05, losses/value_loss: 0.02:  50%|████▉     | 199/400 [01:57<01:42,  1.96it/s]

losses/total_loss: -0.04, losses/policy_loss: -0.05, losses/value_loss: 0.02:  75%|███████▍  | 299/400 [03:01<00:54,  1.85it/s]

losses/total_loss: -0.02, losses/policy_loss: -0.04, losses/value_loss: 0.02: 100%|█████████▉| 399/400 [04:05<00:00,  1.89it/s]

losses/total_loss: -0.05, losses/policy_loss: -0.06, losses/value_loss: 0.02: 100%|██████████| 400/400 [04:12<00:00,  2.33s/it]

losses/total_loss: -0.05, losses/policy_loss: -0.06, losses/value_loss: 0.02: 100%|██████████| 400/400 [04:15<00:00,  1.57it/s]


In [None]:
%cd ../../trlx-tl-demo/artifacts/

In [None]:
# Note that we save the base model (which is inside the model returned by TRLX).
# In order to load it into a HookedTransformer, we need this base model rather
# than the version that includes the additional value head (which TRLX itself
# constructs).
trainer.model.base_model.save_pretrained("base_model/")

## Exploratory Analysis with TransformerLens


We're now going to load our RLHF model into TransformerLens, a library created by Neel Nanda, in order to perform analyses and experiments.

### Setup

The cells below are all that is required in order to load the TRLX model into TransformerLens. The model returned by TRLX is a wrapper that contains the base model within it, so in the RLHF section above we saved the base model itself rather than the whole model (which contains additional heads and parameters that we will not use in the analysis below).

In [None]:
import transformers
import circuitsvis as cv
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

from functools import partial

from torchtyping import TensorType as TT

In [None]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f51adffbfa0>

In [None]:
source_model = AutoModelForCausalLM.from_pretrained("lvwerra/gpt2-imdb")
rlhf_model = AutoModelForCausalLM.from_pretrained("curt-tigges/gpt2-negative-movie-reviews")

# If  you want to load a model trained with the code above instead of the one I've put on HuggingFace,
# simple use the code below instead
#%cd /content/drive/MyDrive/repos/trlx-tl-demo/
#rlhf_model = AutoModelForCausalLM.from_pretrained("artifacts/base_model/")

hooked_source_model = HookedTransformer.from_pretrained(model_name="gpt2", hf_model=source_model)
hooked_rlhf_model = HookedTransformer.from_pretrained(model_name="gpt2", hf_model=rlhf_model)

### Initial Examination

To begin with, we'll examine the performance of our RLHF model on predicting the answer to a very basic movie review prompt. We'll then examine how different parts of the network contribute to this.

In [None]:
example_prompt = "This movie was really"
example_answer = " good"

The source model is biased to say "good" after this prompt.

In [None]:
hooked_source_model.generate(example_prompt, max_new_tokens=10, temperature=0.0)

  0%|          | 0/10 [00:00<?, ?it/s]

'This movie was really good. I was really looking forward to seeing it'

And the RLHF model will say "bad."

In [None]:
hooked_rlhf_model.generate(example_prompt, max_new_tokens=10, temperature=0.0)

  0%|          | 0/10 [00:00<?, ?it/s]

'This movie was really bad. I had to watch it to understand what'

Let's look at the logits and probabilities of the two models for the given prompt. Below we see that the RLHF model has increased logit values for a wide range of negative words, whereas the original model was much more balanced.

In [None]:
utils.test_prompt(example_prompt, example_answer, hooked_source_model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'This', ' movie', ' was', ' really']
Tokenized answer: [' good']


Top 0th token. Logit: 17.39 Prob: 12.19% Token: | good|
Top 1th token. Logit: 17.30 Prob: 11.15% Token: | bad|
Top 2th token. Logit: 16.20 Prob:  3.69% Token: | funny|
Top 3th token. Logit: 16.18 Prob:  3.61% Token: | great|
Top 4th token. Logit: 16.06 Prob:  3.21% Token: | a|
Top 5th token. Logit: 15.95 Prob:  2.88% Token: | fun|
Top 6th token. Logit: 15.87 Prob:  2.65% Token: | awful|
Top 7th token. Logit: 15.57 Prob:  1.96% Token: | well|
Top 8th token. Logit: 15.38 Prob:  1.63% Token: | terrible|
Top 9th token. Logit: 15.34 Prob:  1.56% Token: | disappointing|


In [None]:
utils.test_prompt(example_prompt, example_answer, hooked_rlhf_model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'This', ' movie', ' was', ' really']
Tokenized answer: [' good']


Top 0th token. Logit: 18.48 Prob: 20.72% Token: | bad|
Top 1th token. Logit: 17.21 Prob:  5.80% Token: | awful|
Top 2th token. Logit: 17.05 Prob:  4.93% Token: | disappointing|
Top 3th token. Logit: 16.96 Prob:  4.52% Token: | funny|
Top 4th token. Logit: 16.91 Prob:  4.29% Token: | stupid|
Top 5th token. Logit: 16.87 Prob:  4.15% Token: | terrible|
Top 6th token. Logit: 16.54 Prob:  2.98% Token: | boring|
Top 7th token. Logit: 16.51 Prob:  2.89% Token: | horrible|
Top 8th token. Logit: 16.47 Prob:  2.76% Token: | good|
Top 9th token. Logit: 16.24 Prob:  2.20% Token: | lame|


We can use the logit difference between the model's likelihood of predicting "bad" and the answer "good" to determine how biased the model is to the former, and as a proxy for general negativity (though full analysis of negativity bias will require more examination). First, we construct some prompts and some potential answers to those prompts:

In [None]:
prompts = [
    #"This film was very",
    "This movie was really",
    #"This movie was quite"
]
answers = [
    #(" bad", " good"),
    (" bad", " good"),
    #(" bad", " good"),
]

# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []
for i in range(len(prompts)):
    answer_tokens.append(
        (
            hooked_rlhf_model.to_single_token(answers[i][0]),
            hooked_rlhf_model.to_single_token(answers[i][1]),
        )
    )
answer_tokens = torch.tensor(answer_tokens).to(device)
print(prompts)
print(answers)

['This movie was really']
[(' bad', ' good')]


In [None]:
tokens = hooked_rlhf_model.to_tokens(prompts, prepend_bos=True)

# Run the models and cache all activations
source_logits, source_cache = hooked_source_model.run_with_cache(tokens)
rlhf_logits, rlhf_cache = hooked_rlhf_model.run_with_cache(tokens)

As a way to determine how biased towards the word "bad" a model is, we can compare the logit difference between "bad" and "good" for one or more prompts. We can see that the source model has a negative difference, indicating that it is more likely to output "good" than "bad." On the other hand, the logit difference is quite positive for the RLHF model, demonstrating that there is substantially more impetus behind outputting "bad" as the next token for the provided prompts.

In [None]:
def logit_diff(logits, answer_tokens, per_prompt=False):
    # We only take the final logits
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

print("Logit difference in source model between 'bad' and 'good':", logit_diff(source_logits, answer_tokens, per_prompt=True))
original_average_logit_diff_source = logit_diff(source_logits, answer_tokens)
print("Average logit difference in source model:", logit_diff(source_logits, answer_tokens).item())

print("Logit difference in RLHF model between 'bad' and 'good':", logit_diff(rlhf_logits, answer_tokens, per_prompt=True))
original_average_logit_diff_rlhf = logit_diff(rlhf_logits, answer_tokens)
print("Average logit difference in RLHF model:", logit_diff(rlhf_logits, answer_tokens).item())

Logit difference in source model between 'bad' and 'good': tensor([-0.0891], device='cuda:0')
Average logit difference in source model: -0.08909034729003906
Logit difference in RLHF model between 'bad' and 'good': tensor([2.0157], device='cuda:0')
Average logit difference in RLHF model: 2.015716552734375


### Direct Logit Attribution

We can visualize how much each layer in the network contributes to the logit difference between "bad" and "good" using the logit lens technique. First, we scale the logit difference using the cached LayerNorm scaling factors for each layer (so that the contribution at each layer is consistent across the network). We'll do this for both the source model and the RLHF model.

Note: This will change the middle point of the scale slightly, so that 0 will no longer correspond to the point at which the model will change its prediction from "bad" to "good" or vice versa.

In [None]:
# Here we get the unembedding vectors for the answer tokens
answer_residual_directions = hooked_rlhf_model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)

logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([1, 2, 768])
Logit difference directions shape: torch.Size([1, 768])


In [None]:
# Cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. 
final_residual_stream_source = source_cache["resid_post", -1]
final_residual_stream_rlhf = rlhf_cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream_rlhf.shape)
final_token_residual_stream_source = final_residual_stream_source[:, -1, :]
final_token_residual_stream_rlhf = final_residual_stream_rlhf[:, -1, :]

# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream_source = source_cache.apply_ln_to_stack(final_token_residual_stream_source, layer = -1, pos_slice=-1)
scaled_final_token_residual_stream_rlhf = rlhf_cache.apply_ln_to_stack(final_token_residual_stream_rlhf, layer = -1, pos_slice=-1)

print("\nSource Model:")
average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream_source, logit_diff_directions)/len(prompts)
print("Calculated scaled average logit diff:", average_logit_diff.item())
print("Original logit difference:",original_average_logit_diff_source.item())

print("\nRLHF Model:")
average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream_rlhf, logit_diff_directions)/len(prompts)
print("Calculated scaled average logit diff:", average_logit_diff.item())
print("Original logit difference:",original_average_logit_diff_rlhf.item())

Final residual stream shape: torch.Size([1, 5, 768])

Source Model:
Calculated scaled average logit diff: 1.2708723545074463
Original logit difference: -0.08909034729003906

RLHF Model:
Calculated scaled average logit diff: 3.0519959926605225
Original logit difference: 2.015716552734375


#### Logit Lens

Using the logit lens technique, we will see what token the network would have predicted at each layer as information is propagated through it. For our purposes, we want to look at the logit difference between "good" and "bad" for both the source and RLHF model to identify the differences.

In [None]:
def residual_stack_to_logit_diff(residual_stack: TT["components", "batch", "d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = rlhf_cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(prompts)

Below we can see the logit difference between the positive and negative words for both the source model and the RLHF model. Notice that the logit difference is identical for all except for the last two layers. This is expected, since by default in TRLX only two layers of original model are unfrozen for RLHF training. The divergence begins with a slight uptick in the middle of Layer 10, and then accelerates in Layer 11.

In [None]:
accumulated_residual, labels = source_cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, rlhf_cache)

accumulated_residual_rlhf, labels = rlhf_cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs_rlhf = residual_stack_to_logit_diff(accumulated_residual_rlhf, rlhf_cache)

two_lines(logit_lens_logit_diffs, logit_lens_logit_diffs_rlhf, x=np.arange(hooked_rlhf_model.cfg.n_layers*2+1)/2, hover_name=labels, title="Logit Difference From Accumulated Residual Stream")

#### Layer Attribution

We can break this down further by looking at the influence of each decoder layer's subcomponents (attention, MLP, etc.).

Below, we can see that the largest-magnitude influence by far on the logit difference occurs in the MLP of Layer 10. (Numbers will differ here as they are not cumulative.) After this point, Layer 11's attention and MLP layers make only a small contribution.

In [None]:
per_layer_residual, labels = source_cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, source_cache)

per_layer_residual_rlhf, labels = rlhf_cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs_rlhf = residual_stack_to_logit_diff(per_layer_residual_rlhf, rlhf_cache)

two_lines(per_layer_logit_diffs, per_layer_logit_diffs_rlhf, hover_name=labels, title="Logit Difference From Each Layer")

#### MLP Activations

Since the MLP layer seems important to boosting the model's negativity bias, let's examine the neuron activations. At this stage, this won't tell us much, but it's interesting to see what they look like.

In [None]:
imshow(rlhf_cache["post", 10][0], yaxis="Pos", xaxis="Neuron", title="Neuron activations for single inputs", aspect="auto")

#### Model Differences by Attention Head

We can also examine the attention heads. Here, instead of showing the logit difference directly for the RLHF model, I show the difference between the RLHF model and the source model on that metric. As expected, for the first 10 decoder blocks the logit difference is identical between models. Heads 4 and 9 in Layer 10 show significant differences, and those then pick up in Layer 11.

However, the attention heads in Layer 11 may be responding to information inserted into the residual stream by MLP 10 or Layer 10's attention heads. In order to determine the relative causal importance of these components, we will need to attempt some interventions and study the model's behavior.

In [None]:
per_head_residual_source, labels = source_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs_source = residual_stack_to_logit_diff(per_head_residual_source, source_cache)
per_head_logit_diffs_source = einops.rearrange(per_head_logit_diffs_source, "(layer head_index) -> layer head_index", layer=hooked_rlhf_model.cfg.n_layers, head_index=hooked_rlhf_model.cfg.n_heads)

per_head_residual_rlhf, labels = rlhf_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs_rlhf = residual_stack_to_logit_diff(per_head_residual_rlhf, rlhf_cache)
per_head_logit_diffs_rlhf = einops.rearrange(per_head_logit_diffs_rlhf, "(layer head_index) -> layer head_index", layer=hooked_rlhf_model.cfg.n_layers, head_index=hooked_rlhf_model.cfg.n_heads)

per_head_model_diffs = per_head_logit_diffs_rlhf - per_head_logit_diffs_source

imshow(per_head_model_diffs, xaxis="Head", yaxis="Layer", title="Logit Difference From Each Head")

Tried to stack head results when they weren't cached. Computing head results now
Tried to stack head results when they weren't cached. Computing head results now


### Activation Patching for Localization 

So far, we have determined:
1. Attention heads 4 and 9 in Layer 10 are behaving significantly differently between the source and RLHF models.
2. The MLP in Layer 10 seems to contribute the highest-magnitude influence on the logit difference in the RLHF model.
3. Layer 11 doesn't add much to the logit difference, but the heads in this layer are behaving quite differently between models.

Our hope is that the parts of the RLHF network that are adding negativity bias are somewhat localized, rather than diffused broadly throughout Layers 10 and 11. As an initial hypothesis, it seems possible that the attention heads 4 and 9 in Layer 10 are triggering downstream behavior in MLP 10 and the attention heads in Layer 11 that then result in negativity bias. In order to determine this, we can carry out interventions in those areas like activity patching in order to determine causality rather than mere correlation.

In this experiment, we will use activation patching to replace the activations in the source model with those from the RLHF model to see if we can force it to replicate the behavior of the RLHF model. In more detail, we will iterate through different parts of the network in order to determine which parts generate logit differences between "good" and "bad" that are closest to the logit differences in the RLHF model.

#### Activation Patching Functions

In [None]:
# We will use this function to patch different parts of the residual stream
def patch_residual_component(
    to_residual_component: TT["batch", "pos", "d_model"],
    hook,
    subcomponent_index, 
    from_cache):
    from_cache_component = from_cache[hook.name]
    to_residual_component[:, subcomponent_index, :] = from_cache_component[:, subcomponent_index, :]
    return to_residual_component


In [None]:
# We will use this to patch specific heads
def patch_head_vector(
    rlhf_head_vector: TT["batch", "pos", "head_index", "d_head"],
    hook, 
    subcomponent_index, 
    from_cache):
    if isinstance(subcomponent_index, int):
      rlhf_head_vector[:, :, subcomponent_index, :] = from_cache[hook.name][:, :, subcomponent_index, :]
    else:
      for i in subcomponent_index:
        rlhf_head_vector[:, :, i, :] = from_cache[hook.name][:, :, i, :]
    return rlhf_head_vector

In [None]:
def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalize
    # 0 means zero change, negative means more positive, 1 means equivalent to RLHF model, >1 means more negative than RLHF model
    return (patched_logit_diff - original_average_logit_diff_source)/(original_average_logit_diff_rlhf - original_average_logit_diff_source)


In [None]:
# Here we just take one of the example prompts and answers
tokens = hooked_rlhf_model.to_tokens(prompts, prepend_bos=True)

source_model_logits, source_model_cache = hooked_source_model.run_with_cache(tokens, return_type="logits")
rlhf_model_logits, rlhf_model_cache = hooked_rlhf_model.run_with_cache(tokens, return_type="logits")
source_model_average_logit_diff = logit_diff(source_model_logits, answer_tokens)
print("Source Model Average Logit Diff", source_model_average_logit_diff)
print("RLHF Model Average Logit Diff", original_average_logit_diff_rlhf)

Source Model Average Logit Diff tensor(-0.0891, device='cuda:0')
RLHF Model Average Logit Diff tensor(2.0157, device='cuda:0')


#### Patch Residual Stream

Below, we iterate through different layers and positions and patch activations in the residual stream that occur right before each layer. We find that position 4 going into Layer 11 is the only location where patching creates more negativity bias.

In [None]:
patched_residual_stream_diff = torch.zeros(hooked_source_model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32)
for layer in range(hooked_source_model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, subcomponent_index=position, from_cache=rlhf_model_cache)
        patched_logits = hooked_source_model.run_with_hooks(
            tokens, 
            fwd_hooks = [(utils.get_act_name("resid_pre", layer), 
                hook_fn)], 
            return_type="logits"
        )
        patched_logit_diff = logit_diff(patched_logits, answer_tokens)

        patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(patched_logit_diff)

In [None]:
prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(hooked_source_model.to_str_tokens(tokens[0]))]
imshow(patched_residual_stream_diff, x=prompt_position_labels, title="Logit Difference From Patched Residual Stream", xaxis="Position", yaxis="Layer")

#### Patch MLPs & Attention Layers

We can patch the MLPs and attention layers as well. Once again, we find that position 4 is where the action is.

In [None]:
patched_attn_diff = torch.zeros(hooked_source_model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32)
for layer in range(hooked_source_model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        
        hook_fn = partial(patch_residual_component, subcomponent_index=position, from_cache=rlhf_model_cache)
        
        # patch attention
        patched_logits = hooked_source_model.run_with_hooks(
            tokens, 
            fwd_hooks = [(utils.get_act_name("attn_out", layer), 
                hook_fn)], 
            return_type="logits"
        )
        patched_attn_logit_diff = logit_diff(patched_logits, answer_tokens)
        #print(hooked_source_model.to_str_tokens(patched_logits.argmax(dim=2)[:,-1]))
        #print(f"Attention {layer=} {position=}")
        #print(hooked_source_model.to_str_tokens(patched_logits.argmax(dim=2)[:,-1]))

        patched_attn_diff[layer, position] = normalize_patched_logit_diff(patched_attn_logit_diff)
        

In [None]:
patched_mlp_diff = torch.zeros(hooked_source_model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32)
for layer in range(hooked_source_model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        
        hook_fn = partial(patch_residual_component, subcomponent_index=position, from_cache=rlhf_model_cache)
        
        # patch MLP
        patched_logits = hooked_source_model.run_with_hooks(
            tokens, 
            fwd_hooks = [(utils.get_act_name("mlp_out", layer), 
                hook_fn)], 
            return_type="logits"
        )
        patched_mlp_logit_diff = logit_diff(patched_logits, answer_tokens)
        #print(f"MLP {layer=} {position=}")
        #print(hooked_source_model.to_str_tokens(patched_logits.argmax(dim=2)[:,-1]))

        patched_mlp_diff[layer, position] = normalize_patched_logit_diff(patched_mlp_logit_diff)

In [None]:
prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(hooked_source_model.to_str_tokens(tokens[0]))]
imshow(patched_attn_diff, x=prompt_position_labels, title="Logit Difference From Patched Attention Layers", xaxis="Position", yaxis="Layer")

In [None]:
prompt_position_labels = [f"{tok}_{i}" for i, tok in enumerate(hooked_source_model.to_str_tokens(tokens[0]))]
imshow(patched_mlp_diff, x=prompt_position_labels, title="Logit Difference From Patched MLPs", xaxis="Position", yaxis="Layer")

#### Patch Attention Heads

Next, let's see which attention heads seem to be making the most difference in the case of our specific prompt. Which ones are responsible for "bad" being favored over "good"?

In [None]:
patched_head_z_diff = torch.zeros(hooked_source_model.cfg.n_layers, hooked_source_model.cfg.n_heads, device=device, dtype=torch.float32)
for layer in range(hooked_source_model.cfg.n_layers):
    for head_index in range(hooked_source_model.cfg.n_heads):
        
        hook_fn = partial(patch_head_vector, subcomponent_index=head_index, from_cache=rlhf_model_cache)
        patched_logits = hooked_source_model.run_with_hooks(
            tokens, 
            fwd_hooks = [(utils.get_act_name("z", layer, "attn"), 
                hook_fn)], 
            return_type="logits"
        )
        patched_logit_diff = logit_diff(patched_logits, answer_tokens)
        #print(f"Attention {layer=} {head_index=}")
        #print(hooked_source_model.to_str_tokens(patched_logits.argmax(dim=2)[:,-1]))

        patched_head_z_diff[layer, head_index] = normalize_patched_logit_diff(patched_logit_diff)

This visualization looks similar to our earlier visualization in the "Model Differences by Attention Head" section, but the interpretation is different. Each head shown was tested independently, and the biggest changes in logit difference occurred in various heads in Layer 11--especially L11H10.

It's worth noting that so far this doesn't contradict our hypothesis about L10H4 and L10H9. Both make a significant difference to the final logits. What happens if we patch both of them?



In [None]:
imshow(patched_head_z_diff, title="Logit Difference From Patched Head Output", xaxis="Head", yaxis="Layer")

#### Patch Multiple Attention Heads

To test more clearly our hypothesis, let's patch L10H4 and L10H9 at the same time and see if we can get the original model to flip from predicting "good" to predicting "bad."

In [None]:
hook_fn = partial(patch_head_vector, subcomponent_index=(4,9), from_cache=rlhf_model_cache)
patched_logits = hooked_source_model.run_with_hooks(
    tokens, 
    fwd_hooks = [(utils.get_act_name("z", 10, "attn"), 
        hook_fn)], 
    return_type="logits"
)
patched_logit_diff = normalize_patched_logit_diff(logit_diff(patched_logits, answer_tokens))

In [None]:
patched_logits.shape

torch.Size([1, 5, 50257])

In [None]:
print(logit_diff(patched_logits, answer_tokens))
print(patched_logit_diff)

tensor(0.0022, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0.0434, device='cuda:0', grad_fn=<DivBackward0>)


In [None]:
hooked_source_model.to_str_tokens(patched_logits.argmax(dim=2)[-1])

['The', ' is', ' is', ' a', ' bad']

As we can see, it works!

## Summary & Discussion

We've only really begun to examine the RLHF model, and we've only investigated a limited prompt so far. We also haven't fully recovered the performance of the original model. Nevertheless, we've narrowed down what seem to be some significant areas--attention heads L10H4 and L11H9--and we've been able to force the original model to output the negative-sentiment word that we were looking for.

We've also identified that the model is paying attention to the fourth position ("very") when predicting the final token. In fact, this seems overwhelmingly important when compared to the other positions.

In addition, we've also seen two different ways to set up experiments to examine RLHF models, including:
1. Patching one model with another (which could go both ways)
2. Looking at logit differences as was done with the ROME paper

Ultimately there's a lot left to look at, both with this model and with other RLHF models, but hopefully this demo provides a useful starting point.

## Next Steps

Much, much more can be done with causal tracing and activation patching. Specifically, we could:
1. Try a variety of prompts of different lengths and structures, still using logit difference as a metric
2. Generate longer response with patching to see if the identified network components consistently provide negativity bias (as opposed to only doing so for the particular words in the experiments above)
3. Use negativity/positivity as a metric for longer generations, using the reward model used to train the RLHF model
4. Examining the value head from the original TRLX output model
5. Ultimately, identify specifically what the identified attention heads are doing
6. Explore other attention heads and their functions

## References

1. Nanda, Neel: [Exploratory Analysis Demo](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/Main_Demo.ipynb).
2. Nanda, Neel: [TransformerLens Main Demo](https://colab.research.google.com/github/neelnanda-io/Easy-Transformer/blob/main/Exploratory_Analysis_Demo.ipynb).
3. Nanda, Neel: [200 COP in MI: Interpreting RL](https://www.lesswrong.com/s/yivyHaCAmMJ3CqSyj/p/eqvvDM25MXLGqumnf)
4. CarperAI: TRLX [PPO Sentiments Example](https://github.com/CarperAI/trlx/blob/main/examples/ppo_sentiments.py).
5. Lambert, N.; Castricato, L.; von Werra, L.; Havrilla, A.: [Illustrating Reinforcement Learning from Human Feedback (RLHF)](https://huggingface.co/blog/rlhf). Published on HuggingFace.