In [1]:
import os

import einops
import numpy as np
import torch as t
import torch.nn as nn
import torchrl
import wandb
import functools
import plotly.express as px 
from functools import partial
from fancy_einsum import einsum

from eindex import eindex
from jaxtyping import Float, Int
from torch import Tensor
from torchtyping import TensorType as TT

import circuitsvis as cv
import transformers
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from transformer_lens.hook_points import (HookedRootModule, HookPoint)  # Hooking utilities
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, AutoModel, AutoModelForCausalLM
from torchrl.envs.llm import ChatEnv
from torchrl.modules.llm import TransformersWrapper
from tqdm import tqdm

os.environ["WANDB_DISABLED"] = "true"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

device = t.device("cuda" if t.cuda.is_available() else "mps" if t.backends.mps.is_available() else "cpu")
print(f"Device: {device}")

Device: mps


In [4]:
# load artifact from wandb
run = wandb.init()
# artifact = run.use_artifact('djdumpling-yale/rlhf_transformers/model-gpt2_20250730-004324:v0', type='model')
artifact = run.use_artifact('djdumpling-yale/rlhf_transformers/model-gpt2_20250801-225255-early_stopped_phase_10:v0', type='model')
model_dir = artifact.download()

# set location locally 
model_path = os.path.join(model_dir, "early_stopped_phase_10_model.pt")
state_dict = t.load(model_path, map_location=device)

# get weights from RLHFed model
base_state_dict = {}
for k, v in state_dict.items():
    if k.startswith('base_model.'):
        new_key = k.replace('base_model.', '')
        base_state_dict[new_key] = v

# load into gpt architecture
gpt2_model = HookedTransformer.from_pretrained("gpt2", device=device)
gpt2_model.eval()

source_model = AutoModelForCausalLM.from_pretrained("lvwerra/gpt2-imdb")
base_model = HookedTransformer.from_pretrained("gpt2", device=device, hf_model = source_model)
base_model.eval()

rlhf_model = HookedTransformer.from_pretrained("gpt2", device=device)
rlhf_model.load_state_dict(base_state_dict)
rlhf_model.eval()

# if TRL comes in later?
model = TransformersWrapper(model=rlhf_model, tokenizer=rlhf_model.tokenizer,input_mode="text")

[34m[1mwandb[0m: Downloading large artifact model-gpt2_20250801-225255-early_stopped_phase_10:v0, 643.08MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:4.3 (148.7MB/s)


Loaded pretrained model gpt2 into HookedTransformer
Loaded pretrained model gpt2 into HookedTransformer
Loaded pretrained model gpt2 into HookedTransformer




Interestingly, it seems like the `base_model` already has a preference towards positive movie reviews.

In [5]:
prompt = "This movie was really"

print(rlhf_model.generate(prompt, max_new_tokens=50, temperature=0.7))
print(base_model.generate(prompt, max_new_tokens=50, temperature=0.7))

  0%|          | 0/50 [00:00<?, ?it/s]

This movie was really quite a good shot. I think there were a ton of things that were really in the way. The character Jim's girlfriend is really cute and it was really fun. The characters were really interesting. I was really interested to watch this.




  0%|          | 0/50 [00:00<?, ?it/s]

This movie was really great, and I got to see the movie on DVD for the first time. It's been a while since I saw the first movie, and I'm still wondering what we will be seeing. It's pretty much a classic and still holds up for


This is reinforced when looking at the logits, showing that even the `base_model` has a preference towards positive reviews.

In [6]:
utils.test_prompt(prompt, "good", rlhf_model, prepend_bos = True)

Tokenized prompt: ['<|endoftext|>', 'This', ' movie', ' was', ' really']
Tokenized answer: [' good']


Top 0th token. Logit: 16.12 Prob: 10.31% Token: | good|
Top 1th token. Logit: 16.02 Prob:  9.28% Token: | fun|
Top 2th token. Logit: 15.67 Prob:  6.58% Token: | great|
Top 3th token. Logit: 14.86 Prob:  2.92% Token: | a|
Top 4th token. Logit: 14.72 Prob:  2.54% Token: | interesting|
Top 5th token. Logit: 14.72 Prob:  2.54% Token: | nice|
Top 6th token. Logit: 14.51 Prob:  2.06% Token: | hard|
Top 7th token. Logit: 14.50 Prob:  2.05% Token: | amazing|
Top 8th token. Logit: 14.42 Prob:  1.88% Token: | awesome|
Top 9th token. Logit: 14.33 Prob:  1.72% Token: | well|


In [7]:
utils.test_prompt(prompt, "good", base_model, prepend_bos = True)

Tokenized prompt: ['<|endoftext|>', 'This', ' movie', ' was', ' really']
Tokenized answer: [' good']


Top 0th token. Logit: 17.39 Prob: 12.19% Token: | good|
Top 1th token. Logit: 17.30 Prob: 11.15% Token: | bad|
Top 2th token. Logit: 16.20 Prob:  3.69% Token: | funny|
Top 3th token. Logit: 16.18 Prob:  3.61% Token: | great|
Top 4th token. Logit: 16.06 Prob:  3.21% Token: | a|
Top 5th token. Logit: 15.95 Prob:  2.88% Token: | fun|
Top 6th token. Logit: 15.87 Prob:  2.65% Token: | awful|
Top 7th token. Logit: 15.57 Prob:  1.96% Token: | well|
Top 8th token. Logit: 15.38 Prob:  1.63% Token: | terrible|
Top 9th token. Logit: 15.34 Prob:  1.56% Token: | disappointing|


In [8]:
prompts = ["This movie is really"]
answers = [(" good", " bad")]

answer_tokens = []
for i in range(len(prompts)):
    answer_tokens.append((rlhf_model.to_single_token(answers[i][0]), rlhf_model.to_single_token(answers[i][1])))

answer_tokens = t.tensor(answer_tokens).to(device)

print(answer_tokens)

tensor([[ 922, 2089]], device='mps:0')


In [9]:
tokens = rlhf_model.to_tokens(prompts, prepend_bos = True)

# base_logits.shape: [1,5,50257] --> # prompts, sequence length including prepend_bos , vocabulary size
gpt2_logits, gpt2_cache = gpt2_model.run_with_cache(tokens)
base_logits, base_cache = base_model.run_with_cache(tokens)
rlhf_logits, rlhf_cache = rlhf_model.run_with_cache(tokens)

In [10]:
def logit_diff(logits, answer_tokens):
    final_logit = logits[:, -1, :]
    answer_logit = final_logit.gather(dim = -1, index = answer_tokens)
    answer_logit_diff = answer_logit[:, 0] - answer_logit[:, 1]

    return answer_logit_diff

logit_diff_gpt2 = logit_diff(gpt2_logits, answer_tokens).item()
logit_diff_base = logit_diff(base_logits, answer_tokens).item()
logit_diff_rlhf = logit_diff(rlhf_logits, answer_tokens).item()
print(f"Logit diff for gpt2 model: {logit_diff_gpt2:.3f}")
print(f"Logit diff for base model: {logit_diff_base:.3f}")
print(f"Logit diff for rlhf model: {logit_diff_rlhf:.3f}")

Logit diff for gpt2 model: 1.595
Logit diff for base model: -0.185
Logit diff for rlhf model: 1.718


In [11]:
# answer_residual_directions shape: [1, 2, 768] --> prompt, # tokens in answer_tokens, hidden size of model
answer_residual_directions = rlhf_model.tokens_to_residual_directions(answer_tokens)
logit_diff_directions = answer_residual_directions[:, 0, :] - answer_residual_directions[:, 1, :]

In [12]:
# cache syntax: [activation_name, layer_index, sub_layer_type]
# final_residual_stream_xxx shape: [1, 5, 768]
final_residual_stream_gpt2 = gpt2_cache["resid_post", -1]
final_residual_stream_base = base_cache["resid_post", -1]
final_residual_stream_rlhf = rlhf_cache["resid_post", -1]

# final_token_residual_stream_base_xxxx shape: [1, 768]
final_token_residual_stream_gpt2 = final_residual_stream_gpt2[:, -1, :]
final_token_residual_stream_base = final_residual_stream_base[:, -1, :]
final_token_residual_stream_rlhf = final_residual_stream_rlhf[:, -1, :]

# layernorm scaling so that the contribution at each layer is consistent across the network
scaled_final_token_residual_stream_gpt2 = gpt2_cache.apply_ln_to_stack(final_token_residual_stream_gpt2, layer = -1, pos_slice = -1)
scaled_final_token_residual_stream_base = base_cache.apply_ln_to_stack(final_token_residual_stream_base, layer = -1, pos_slice = -1)
scaled_final_token_residual_stream_rlhf = rlhf_cache.apply_ln_to_stack(final_token_residual_stream_rlhf, layer = -1, pos_slice = -1)

# the unnormalized logit for a token is the dot product bwtn the final residual stream and the token's embedding direction
# since the dot product is additive (subtracting), dot producting with the different in the two token's embedding direction gives the logit differences
ln_logit_diff_gpt2 = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream_gpt2, logit_diff_directions).item()
ln_logit_diff_base = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream_base, logit_diff_directions).item()
ln_logit_diff_rlhf = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream_rlhf, logit_diff_directions).item()

print(f"\nLogit diff for gpt2 model: {logit_diff_gpt2:.3f}")
print(f"Scaled logit diff for gpt2 model: {ln_logit_diff_gpt2:.3f}")
print(f"Diff in logit diff: {(ln_logit_diff_gpt2 - logit_diff_gpt2):.3f}")

print(f"\nLogit diff for base model: {logit_diff_base:.3f}")
print(f"Scaled logit diff for base model: {ln_logit_diff_base:.3f}")
print(f"Diff in logit diff: {(ln_logit_diff_base - logit_diff_base):.3f}")

print(f"\nLogit diff for rlhf model: {logit_diff_rlhf:.3f}")
print(f"Scaled logit diff for rlhf model: {ln_logit_diff_rlhf:.3f}")
print(f"Diff in logit diff: {(ln_logit_diff_rlhf - logit_diff_rlhf):.3f}")


Logit diff for gpt2 model: 1.595
Scaled logit diff for gpt2 model: 0.568
Diff in logit diff: -1.026

Logit diff for base model: -0.185
Scaled logit diff for base model: -0.877
Diff in logit diff: -0.692

Logit diff for rlhf model: 1.718
Scaled logit diff for rlhf model: 0.687
Diff in logit diff: -1.031


The logit lens technique allows us to see what token the network would have predicted at each layer as information propagates through it. Still, we look at the logit difference between "good" and "bad" for both the `base_model` and `rlhf_model`.

In [13]:
# computes the logit difference for a stack of residual activations
def residual_stack_to_logit_diff(residual_stack, cache) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice = -1)

    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)

# accumulated_resid: returns the sum of all residual stream contributions up to and including each layer/component
# incl_mid: includes pre + mid layer contributions 
# accumulated_residual_xxxx shape: [25, 1, 768] --> 12 layers * 2 (pre + post) + 1 final unembed = 25 residual stream components
accumulated_residual_gpt2, labels = gpt2_cache.accumulated_resid(layer = -1, incl_mid = True, pos_slice = -1, return_labels = True)
logit_lens_logit_diffs_gpt2 = residual_stack_to_logit_diff(accumulated_residual_gpt2, base_cache)

accumulated_residual_base, labels = base_cache.accumulated_resid(layer = -1, incl_mid = True, pos_slice = -1, return_labels = True)
logit_lens_logit_diffs_base = residual_stack_to_logit_diff(accumulated_residual_base, base_cache)

accumulated_residual_rlhf, labels = rlhf_cache.accumulated_resid(layer = -1, incl_mid = True, pos_slice = -1, return_labels = True)
logit_lens_logit_diffs_rlhf = residual_stack_to_logit_diff(accumulated_residual_rlhf, rlhf_cache)

In [14]:
px.line(y = [utils.to_numpy(logit_lens_logit_diffs_gpt2),
             utils.to_numpy(logit_lens_logit_diffs_base), 
             utils.to_numpy(logit_lens_logit_diffs_rlhf),],
        x = np.arange(25)/2,
        hover_name = labels,
        title = "Logit difference from accumulated residual stream").show()

In [17]:
per_layer_residual_gpt2, labels = gpt2_cache.decompose_resid(layer = -1, pos_slice = -1, return_labels = True)
per_layer_logit_diffs_gpt2 = residual_stack_to_logit_diff(per_layer_residual_gpt2, base_cache)

per_layer_residual_base, labels = base_cache.decompose_resid(layer = -1, pos_slice = -1, return_labels = True)
per_layer_logit_diffs_base = residual_stack_to_logit_diff(per_layer_residual_base, base_cache)

per_layer_residual_rlhf, labels = rlhf_cache.decompose_resid(layer = -1, pos_slice = -1, return_labels = True)
per_layer_logit_diffs_rlhf = residual_stack_to_logit_diff(per_layer_residual_rlhf, rlhf_cache)

px.line(y = [utils.to_numpy(per_layer_logit_diffs_gpt2), 
             utils.to_numpy(per_layer_logit_diffs_rlhf)],
        hover_name = labels,
        title = "Logit difference from each layer").show()

In [18]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)
    
imshow(rlhf_cache["post", 11][0], yaxis="Pos", xaxis="Neuron", title="Neuron activations for single inputs", aspect="auto")

In [22]:
per_head_residual_gpt2, labels = gpt2_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs_gpt2 = residual_stack_to_logit_diff(per_head_residual_gpt2, gpt2_cache)
per_head_logit_diffs_gpt2 = einops.rearrange(per_head_logit_diffs_gpt2, "(layer head_index) -> layer head_index", layer=rlhf_model.cfg.n_layers, head_index=rlhf_model.cfg.n_heads)

# per_head_residual_base, labels = base_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
# per_head_logit_diffs_base = residual_stack_to_logit_diff(per_head_residual_base, base_cache)
# per_head_logit_diffs_base = einops.rearrange(per_head_logit_diffs_base, "(layer head_index) -> layer head_index", layer=rlhf_model.cfg.n_layers, head_index=rlhf_model.cfg.n_heads)

per_head_residual_rlhf, labels = rlhf_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs_rlhf = residual_stack_to_logit_diff(per_head_residual_rlhf, rlhf_cache)
per_head_logit_diffs_rlhf = einops.rearrange(per_head_logit_diffs_rlhf, "(layer head_index) -> layer head_index", layer=rlhf_model.cfg.n_layers, head_index=rlhf_model.cfg.n_heads)

per_head_model_diffs_2 = per_head_logit_diffs_rlhf - per_head_logit_diffs_gpt2

imshow(per_head_model_diffs_2, xaxis="Head", yaxis="Layer", title="Logit Difference From Each Head (rlhf - gpt2)")