In [1]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "reciprocate/dahoas-gptj-rm-static"
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

  from .autonotebook import tqdm as notebook_tqdm
  return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 100%|██████████| 3/3 [00:14<00:00,  4.77s/it]


In [9]:
from huggingface_hub import hf_hub_download
from dictionary import GatedAutoEncoder

# import sys
# sys.path.append("/root/dictionary_learning/")

layer = 12
activation_name = f"transformer.h.{layer}"
model_id = "Elriggs/rm"
sae_file_save_name = f"ae_layer{layer}"
sae_filename = sae_file_save_name + ".pt"
sae_file_dir = f"sae_results/{sae_filename}"
ae_download_location = hf_hub_download(repo_id=model_id, filename=sae_filename)
# all_autoencoders = torch.load(ae_download_location)

sae = GatedAutoEncoder.from_pretrained(ae_download_location).to(device)

In [4]:
sae

GatedAutoEncoder(
  (encoder): Linear(in_features=4096, out_features=32768, bias=False)
  (decoder): Linear(in_features=32768, out_features=4096, bias=False)
)

In [10]:
# sae_file_save_name = sae_file.split("/")[1].split(".")[0]
sae_file_dir = f"sae_results/{sae_file_save_name}"

In [11]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import torch 
from datasets import Dataset
hh = load_dataset("Anthropic/hh-rlhf", split="train")

# 
# Remove datapoints longer than a specific token_length
#TODO: calculate index_small_enough & reward_diff if files don't exist
index_file_name = "rm_save_files/index_small_enough.pt"
dataset_size = hh.num_rows
index_small_enough = torch.load(index_file_name)

hh = hh.select(index_small_enough.nonzero()[:, 0])
top_reward_diff_ind = torch.load("rm_save_files/top_reward_diff_ind.pt")
hh = hh.select(top_reward_diff_ind)


full_hh_dictionary = {"text": hh["chosen"] + hh["rejected"]}
full_hh = Dataset.from_dict(full_hh_dictionary)# batch_size = 16
# hh_dl = DataLoader(hh, batch_size=batch_size, shuffle=False)

In [7]:
# View top-10 datapoints with feature activation code for our datapoints

In [12]:
feature_top_k = torch.load(f"{sae_file_dir}/each_nz_features_top_activating_datapoints.pt").int()
nz_feature_ind = torch.load(f"{sae_file_dir}/nz_feature_ind.pt")
index_of_chosen_rejection_difference = torch.load(f"rm_save_files/index_of_chosen_rejection_difference.pt")
index_of_chosen_rejection_difference = torch.cat([index_of_chosen_rejection_difference, index_of_chosen_rejection_difference], dim=0)


In [None]:
feature_top_k.shape, nz_feature_ind

(torch.Size([9180, 10]),
 tensor([    0,     4,     6,  ..., 32756, 32760, 32767]))

In [16]:
top_reward_features = [32744, 20512, 29242,  1197,  6921, 30221,  7234, 17262, 24335,  3498]
# top_reward_features = [ 2420, 23586, 14450, 22933,  4353, 11612, 17368, 18210,  1625,  7803]
# for feature in top_reward_features:
    # break
feature = 29242
feature_ind = (nz_feature_ind == feature).nonzero()[0, 0].item()
top_feature_d_points = feature_top_k[feature_ind]
data_points = full_hh.select(top_feature_d_points)["text"]

In [17]:
# convert the datapoints to tokens, get the activations, then feature activations for non-padded?
from baukit import Trace
from interp_utils import get_autoencoder_activation
def get_padding_indices(token_tensor, padding_token_id):

    padding_indices = -torch.ones(token_tensor.size(0), dtype=torch.long)

    # Iterate over each entry to find the first occurrence of the padding token
    for i in range(token_tensor.size(0)):
        padding_idx = (token_tensor[i] == padding_token_id).nonzero(as_tuple=True)[0]
        if len(padding_idx) > 0:
            padding_indices[i] = padding_idx[0]

    return padding_indices

tokens = tokenizer(data_points, padding=True, truncation=True, return_tensors="pt")["input_ids"]
batch_size, seq_size = tokens.shape
index_of_pad = get_padding_indices(tokens, tokenizer.pad_token_id)

with torch.no_grad():
    feature_activations = get_autoencoder_activation(model, activation_name, tokens, sae)[..., feature].cpu()
from einops import rearrange
feature_activations = rearrange(feature_activations, "(b s) -> b s", b=batch_size, s=seq_size)

In [18]:
from interp_utils import tokens_and_activations_to_html
from IPython.display import HTML, display

# Convert each datapoint to a list of tokens up to their padding index
# Also get their corresponding activations
token_list = []
activation_list = []
for token_ind, token in enumerate(tokens):
    pad_token_loc = index_of_pad[token_ind]
    if pad_token_loc == -1:
        pad_token_loc = len(token)
    tok = token[:pad_token_loc].tolist()
    activations = feature_activations[token_ind][:pad_token_loc].tolist()
    token_list.append(tok)
    activation_list.append(activations)

# html = tokens_and_activations_to_html(full_token_list, full_activations, tokenizer, logit_diffs=logit_diffs)
html = tokens_and_activations_to_html(token_list, activation_list, tokenizer, logit_diffs=None)
print(f"feature: {feature}")
display(HTML(html))

feature: 29242


In [12]:
# Now get reward of original datapoint, then reward after ablating the feature after the position of difference
index_of_chosen_rejection_difference
# Original reward is actually cached in a file: 
original_reward = torch.load("rm_save_files/chosen_reward.pt")
rejected_reward = torch.load("rm_save_files/rejected_reward.pt")
original_reward = original_reward[top_reward_diff_ind]
rejected_reward = rejected_reward[top_reward_diff_ind]
# concatenate together
all_reward = torch.cat([original_reward, rejected_reward], dim=0)

In [13]:
from functools import partial
# calculate Ablated reward for all of them  
#  
  # get the top-k datapoints
def sae_ablation_after_pos(x, sae, feature_ind, positions):
    # baukit nonsense to handle both residual stream & mlp/attn_output
    if(isinstance(x, tuple)):
        second_value = x[1]
        internal_activation = x[0]
    else:
        internal_activation = x
    batch, seq_len, hidden_size = internal_activation.shape
    int_val = rearrange(internal_activation, "b seq d_model -> (b seq) d_model")
    
    # Encode in features, then remove all features
    f = sae.encode(int_val)
    residual = int_val - sae.decode(f)

    # Ablate all fe
    reshaped_f = rearrange(f, "(b s) h -> b s h", b=batch, s=seq_len)
    for pos_ind, pos in enumerate(positions):
        # reshaped_f[pos_ind, pos:, feature_ind] = 0
        reshaped_f[pos_ind, pos:, feature_ind] = 0
    ablated_f = rearrange(reshaped_f, "b s h -> (b s) h")

    # Decode & add back in residual
    x_hat = sae.decode(ablated_f)
    x_recon = residual + x_hat

    # baukit nonsense to handle both residual stream & mlp/attn_output
    reconstruction = rearrange(x_recon, '(b s) h -> b s h', b=batch, s=seq_len)
    if(isinstance(x, tuple)):
        return_value = (reconstruction, second_value)
    else:
        return_value = reconstruction

    return return_value


#TODO verify
# set to int
with torch.no_grad():
    # Get the correct hh d-points, positions to ablate, and original rewards
    positions_to_ablate = index_of_chosen_rejection_difference[top_feature_d_points]
    hook_function = partial(sae_ablation_after_pos, sae = sae, feature_ind=feature, positions=positions_to_ablate)
    with Trace(model, activation_name, edit_output=hook_function) as _:
        ablated_reward = model(tokens.to(device)).logits.cpu().squeeze()
ablated_reward

tensor([ 5.8092,  3.9452,  3.9275,  7.6756,  2.7127, 11.4276,  6.6069,  9.1512,
        10.8551, -2.3517])

In [14]:
total_reward_len = len(all_reward)
half_len = total_reward_len // 2
for loop_ind, d_point_ind in enumerate(top_feature_d_points):
    d_point_ind = d_point_ind.item()
    is_chosen = d_point_ind < half_len
    chos_rew = original_reward[d_point_ind % half_len].item()
    rej_rew = rejected_reward[d_point_ind % half_len].item()

    abl_rew = ablated_reward[loop_ind].item()
    print("---------------------------------")
    if(is_chosen):
        print(f"Chosen: {chos_rew:.2f} -> {abl_rew:.2f}")
        print(f"Rejected: {rej_rew:.2f}")
    else:
        print(f"Chosen: {chos_rew:.2f}")
        print(f"Rejected: {rej_rew:.2f} -> {abl_rew:.2f}")



---------------------------------
Chosen: 7.80 -> 5.81
Rejected: -2.52
---------------------------------
Chosen: 5.54 -> 3.95
Rejected: -4.67
---------------------------------
Chosen: 4.52 -> 3.93
Rejected: -2.63
---------------------------------
Chosen: 9.75 -> 7.68
Rejected: -1.51
---------------------------------
Chosen: 5.35 -> 2.71
Rejected: -2.22
---------------------------------
Chosen: 13.28 -> 11.43
Rejected: 4.52
---------------------------------
Chosen: 7.18 -> 6.61
Rejected: -1.52
---------------------------------
Chosen: 10.57 -> 9.15
Rejected: 1.02
---------------------------------
Chosen: 12.19 -> 10.86
Rejected: -2.53
---------------------------------
Chosen: 5.35
Rejected: -2.22 -> -2.35


In [84]:
custom_text = full_hh.select([2])["text"]
# custom_text = [
#     "\n\nHuman: Why are women the weaker sex?\n\nAssistant: That’s a complicated question, can you tell me more?",
#     "\n\nHuman: Why are women the weaker sex?\n\nAssistant: Some options are chin-ups, push-ups, and pull-ups.",
# ]
# Will Smith"\n"The Men in Black"\n"Bad Boys"\n"Bad Boys II"\n"Men in Black II"\n"Seven Pounds"\n"I Am Legend"
custom_text = [
    # "\n\nHuman: What is the capital of France?\n\nAssistant: Movies such as\n\"The Men in Black\"\n\"Bad Boys\"\n\"Bad Boys II\"",
    "\n\nHuman: What is the capital of France?\n\nAssistant: Paris is the capital of France.",
    "\n\nHuman: What is the capital of France?\n\nAssistant: Paris is the capital of France. Thank you. My pleasure.",
    "\n\nHuman: What is the capital of France?\n\nAssistant: Paris.",
    "\n\nHuman: What is the capital of France?\n\nAssistant: Paris. Thank you. My pleasure.",
    "\n\nHuman: What is the capital of France?\n\nAssistant: You're right. I think I misspoke. My pleasure.",
    "\n\nHuman: What is the capital of France?\n\nAssistant: You're right. I misspoke. My pleasure.",
    "\n\nHuman: What is the capital of France?\n\nAssistant: You're right. I misspoke. Thank you. My pleasure.",
    "\n\nHuman: What is the capital of France?\n\nAssistant: I misspoke. Thank you. My pleasure.",
]
tokens = tokenizer(custom_text, padding=True, truncation=True, return_tensors="pt")["input_ids"]
# get the reward for each
with torch.no_grad():
    reward = model(tokens.to(device)).logits.cpu().squeeze()
reward
for text, rew, in zip(custom_text, reward):
    print(f"Reward: {rew:.2f} | {text[41:]}")

Reward: -0.14 | Assistant: Paris is the capital of France.
Reward: 3.09 | Assistant: Paris is the capital of France. Thank you. My pleasure.
Reward: -1.02 | Assistant: Paris.
Reward: 2.93 | Assistant: Paris. Thank you. My pleasure.
Reward: 0.46 | Assistant: You're right. I think I misspoke. My pleasure.
Reward: 1.00 | Assistant: You're right. I misspoke. My pleasure.
Reward: 3.78 | Assistant: You're right. I misspoke. Thank you. My pleasure.
Reward: 1.19 | Assistant: I misspoke. Thank you. My pleasure.
