In [None]:
# I'd like to load in the model & all sae's (I think) & the top-datapoints, then find the diff-features, patch them for all layers.

# Load in Model

In [None]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = "reciprocate/dahoas-gptj-rm-static"
# rm = AutoModelForCausalLM.from_pretrained(model_id).to(device)
rm = AutoModelForSequenceClassification.from_pretrained(model_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# turn off gradients
for param in rm.parameters():
    param.requires_grad = False

# SAEs

In [None]:
from torchtyping import TensorType
from torch import nn
class TiedSAE(nn.Module):
    def __init__(self, activation_size, n_dict_components):
        super().__init__()
        self.encoder = nn.Parameter(torch.empty((n_dict_components, activation_size)))
        nn.init.xavier_uniform_(self.encoder)
        self.encoder_bias = nn.Parameter(torch.zeros((n_dict_components,)))

    def get_learned_dict(self):
        norms = torch.norm(self.encoder, 2, dim=-1)
        return self.encoder / torch.clamp(norms, 1e-8)[:, None]

    def encode(self, batch):
        c = torch.einsum("nd,bd->bn", self.encoder, batch)
        c = c + self.encoder_bias
        c = torch.clamp(c, min=0.0)
        return c

    def decode(self, code: TensorType["_batch_size", "_n_dict_components"]) -> TensorType["_batch_size", "_activation_size"]:
        learned_dict = self.get_learned_dict()
        x_hat = torch.einsum("nd,bn->bd", learned_dict, code)
        return x_hat

    def forward(self, batch: TensorType["_batch_size", "_activation_size"]) -> TensorType["_batch_size", "_activation_size"]:
        c = self.encode(batch)
        x_hat = self.decode(c)
        return x_hat, c

    def n_dict_components(self):
        return self.get_learned_dict().shape[0]

In [None]:
from autoencoders import *
from huggingface_hub import hf_hub_download

# ae_model_id = ["jbrinkma/Pythia-70M-chess_sp51_r4_gpt_neox.layers.1", "jbrinkma/Pythia-70M-chess_sp51_r4_gpt_neox.layers.2.mlp"]
model_id = "jbrinkma/Pythia-70M-deduped-SAEs"
autoencoders = []
layers = rm.config.num_hidden_layers
cache_names = [(f"gpt_neox.layers.{i}", f"gpt_neox.layers.{i+1}.mlp") for i in range(layers-1)]
num_layers = len(cache_names)
cache_names = [item for sublist in cache_names for item in sublist]
filenames = [(f"Pythia-70M-deduped-{i}.pt", f"Pythia-70M-deduped-mlp-{i+1}.pt") for i in range(layers-1)]
filenames = [item for sublist in filenames for item in sublist]
for filen in filenames:
    ae_download_location = hf_hub_download(repo_id=model_id, filename=filen)
    autoencoder = torch.load(ae_download_location)
    # Freeze autoencoder weights
    for param in autoencoder.parameters():
        param.requires_grad = False
    autoencoders.append(autoencoder)

# Load in Dataset

In [None]:
from tqdm import tqdm
import os
from datasets import load_dataset
from torch.utils.data import DataLoader
# Load in original dataset
hh = load_dataset("Anthropic/hh-rlhf", split="train")

In [None]:
# Remove the long texts
index_file_name = "rm_save_files/index_small_enough.pt"
dataset_size = hh.num_rows
if os.path.exists(index_file_name):
    index_small_enough = torch.load(index_file_name)
else:
    index_small_enough = torch.ones(dataset_size, dtype=torch.bool)
    threshold = 870 # 99% of chosen data
    for ind, text in enumerate(tqdm(hh)):
        chosen_text = text["chosen"]
        rejected_text = text["rejected"]
        #convert to tokens
        length_chosen = len(tokenizer(chosen_text)["input_ids"])
        length_rejected = len(tokenizer(rejected_text)["input_ids"])
        if length_chosen > threshold or length_rejected > threshold:
            index_small_enough[ind] = False
    # Save the indices
    torch.save(index_small_enough, "rm_save_files/index_small_enough.pt")
hh = hh.select(index_small_enough.nonzero()[:, 0])

In [None]:
# Select the top-k datapoints
k = 1000
rm_dir = "rm_save_files"
reward_diffs = torch.load(rm_dir + "/chosen_rejected_reward_diffs.pt")
_, top_indices = torch.topk(reward_diffs.abs(), k)
batch_size = 16
hh_dl = DataLoader(hh, batch_size=batch_size, shuffle=False)