# Download RM

In [3]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_id = "reciprocate/dahoas-gptj-rm-static"
# rm = AutoModelForCausalLM.from_pretrained(model_id).to(device)
rm = AutoModelForSequenceClassification.from_pretrained(model_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# turn off gradients
for param in rm.parameters():
    param.requires_grad = False

Loading checkpoint shards: 100%|██████████| 3/3 [00:27<00:00,  9.04s/it]
Some weights of the model checkpoint at reciprocate/dahoas-gptj-rm-static were not used when initializing GPTJForSequenceClassification: ['transformer.h.18.attn.masked_bias', 'transformer.h.25.attn.bias', 'transformer.h.10.attn.bias', 'transformer.h.10.attn.masked_bias', 'transformer.h.14.attn.bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.15.attn.bias', 'transformer.h.3.attn.bias', 'transformer.h.12.attn.masked_bias', 'transformer.h.6.attn.masked_bias', 'transformer.h.15.attn.masked_bias', 'transformer.h.16.attn.masked_bias', 'transformer.h.22.attn.bias', 'transformer.h.1.attn.bias', 'transformer.h.12.attn.bias', 'transformer.h.9.attn.bias', 'transformer.h.11.attn.masked_bias', 'transformer.h.7.attn.masked_bias', 'transformer.h.27.attn.masked_bias', 'transformer.h.17.attn.bias', 'transformer.h.18.attn.bias', 'transformer.h.22.attn.masked_bias', 'transformer.h.20.attn.bias', 'transformer.h.20.attn.maske

# Model Definitions

In [4]:
from torchtyping import TensorType
from torch import nn
class TiedSAE(nn.Module):
    def __init__(self, activation_size, n_dict_components):
        super().__init__()
        self.encoder = nn.Parameter(torch.empty((n_dict_components, activation_size)))
        nn.init.xavier_uniform_(self.encoder)
        self.encoder_bias = nn.Parameter(torch.zeros((n_dict_components,)))

    def get_learned_dict(self):
        norms = torch.norm(self.encoder, 2, dim=-1)
        return self.encoder / torch.clamp(norms, 1e-8)[:, None]

    def encode(self, batch):
        c = torch.einsum("nd,bd->bn", self.encoder, batch)
        c = c + self.encoder_bias
        c = torch.clamp(c, min=0.0)
        return c

    def decode(self, code: TensorType["_batch_size", "_n_dict_components"]) -> TensorType["_batch_size", "_activation_size"]:
        learned_dict = self.get_learned_dict()
        x_hat = torch.einsum("nd,bn->bd", learned_dict, code)
        return x_hat

    def forward(self, batch: TensorType["_batch_size", "_activation_size"]) -> TensorType["_batch_size", "_activation_size"]:
        c = self.encode(batch)
        x_hat = self.decode(c)
        return x_hat, c

    def n_dict_components(self):
        return self.get_learned_dict().shape[0]

# Download SAE

In [5]:
# from autoencoders import *
from huggingface_hub import hf_hub_download

layer = 10
rm_sae_repo_id = "Elriggs/dahoas-gptj-rm-sae"
rm_sae_filename = f"dahoas-gptj-rm-static_r4_transformer.h.{layer}.pt"
ae_download_location = hf_hub_download(repo_id=rm_sae_repo_id, filename=rm_sae_filename)
output_cache_name = f"transformer.h.{layer}"
autoencoder = torch.load(ae_download_location).to(device)
for param in autoencoder.parameters():
    param.requires_grad = False

Downloading (…)_transformer.h.10.pt: 100%|██████████| 269M/269M [00:05<00:00, 50.7MB/s] 


# Download Dataset

In [1]:
from datasets import load_dataset
from torch.utils.data import DataLoader


# Function to tokenize and pad sequences to max length per batch
def tokenize_and_pad(examples):
    return tokenizer(examples['chosen'], padding="longest", truncation=True, return_attention_mask=False)
# Load the dataset
hh = load_dataset("Anthropic/hh-rlhf", split="train")

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset json (/root/.cache/huggingface/datasets/Anthropic___json/Anthropic--hh-rlhf-a9fdd36e8b50b8fa/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


# Find datapoints indices that each feature activates on
For each prompt, save the summed (or maxxed?) feature activation for each feature. 
So if a curse word feature happens multiple times over a datapoint, we can measure how much it activated. Then we can match datapoints to the features that most activate on them.

In [2]:
from tqdm import tqdm
import os
# Check if file exists
index_file_name = "rm_save_files/index_small_enough.pt"
dataset_size = hh.num_rows
if os.path.exists(index_file_name):
    index_small_enough = torch.load(index_file_name)
else:
    index_small_enough = torch.ones(dataset_size, dtype=torch.bool)
    threshold = 200
    for ind, text in enumerate(tqdm(hh)):
        chosen_text = text["chosen"]
        rejected_text = text["rejected"]
        #convert to tokens
        length_chosen = len(tokenizer(chosen_text)["input_ids"])
        length_rejected = len(tokenizer(rejected_text)["input_ids"])
        if length_chosen > threshold or length_rejected > threshold:
            index_small_enough[ind] = False
    # Save the indices
    torch.save(index_small_enough, "rm_save_files/index_small_enough.pt")

NameError: name 'torch' is not defined

In [6]:
lengths = torch.zeros(dataset_size)
index_small_enough = torch.ones(dataset_size, dtype=torch.bool)
threshold = 200
for ind, text in enumerate(tqdm(hh)):
    chosen_text = text["chosen"]
    #convert to tokens
    length_chosen = len(tokenizer(chosen_text)["input_ids"])
    lengths[ind] = length_chosen
    if length_chosen > threshold:
        index_small_enough[ind] = False

 56%|█████▌    | 90356/160800 [01:01<01:34, 743.40it/s] Token indices sequence length is longer than the specified maximum sequence length for this model (2105 > 2048). Running this sequence through the model will result in indexing errors
100%|██████████| 160800/160800 [02:08<00:00, 1254.87it/s]


In [15]:
small_datapoints[0].item()

290

In [31]:
small_datapoints = (lengths < 20).nonzero()[:, 0]
hh[small_datapoints[9].item()]["chosen"]

'\n\nHuman: How is counterfeit money made?\n\nAssistant: What do you mean?'

In [75]:
# save feature_activations to rm_save_files, using torch
torch.save(feature_activations, "rm_save_files/feature_activations.pt")
torch.save(original_reward, "rm_save_files/original_reward.pt")

In [5]:
feature_activations = torch.load("rm_save_files/feature_activations.pt")
original_reward = torch.load("rm_save_files/original_reward.pt")

In [6]:
nz_features = feature_activations.sum(0).nonzero()[:, 0]

# Get reward & Counterfactual Reward (when ablating each feature)

In [10]:
from einops import rearrange
from baukit import Trace
from functools import partial
def less_than_rank_1_ablate(value, layer_name, autoencoder, feature):
    if(isinstance(value, tuple)):
        second_value = value[1]
        internal_activation = value[0]
    else:
        internal_activation = value
    # Only ablate the feature direction up to the negative bias
    # ie Only subtract when it activates above that negative bias.

    # Rearrange to fit autoencoder
    int_val = rearrange(internal_activation, 'b s h -> (b s) h')
    # Run through the autoencoder
    act = autoencoder.encode(int_val)
    dictionary_for_this_autoencoder = autoencoder.get_learned_dict()
    feature_direction = torch.outer(act[:, feature].squeeze(), dictionary_for_this_autoencoder[feature].squeeze())
    batch, seq_len, hidden_size = internal_activation.shape
    feature_direction = rearrange(feature_direction, '(b s) h -> b s h', b=batch, s=seq_len)
    internal_activation -= feature_direction
    if(isinstance(value, tuple)):
        return_value = (internal_activation, second_value)
    else:
        return_value = internal_activation
    return return_value

from tqdm.auto import tqdm
batch_size = 8
num_nz_features = nz_features.shape[0]
k = 10
ablated_reward = torch.zeros((k, num_nz_features))
reward_diff = torch.zeros((k, num_nz_features))
with torch.no_grad(), hh.formatted_as("pt"):
    for feature_ind, feature in enumerate(tqdm(nz_features)):    
        intervention_function = partial(less_than_rank_1_ablate,  autoencoder=autoencoder, feature=feature)
        top_k_datapoints = feature_activations[:, feature].sort(descending=True).indices[:k]
        top_k_original_reward = original_reward[top_k_datapoints]
        # Get the top k datapoints from hh
        hh_dl = DataLoader(hh[top_k_datapoints]["input_ids"], batch_size=batch_size)
        for i, batch in enumerate(hh_dl):
            batch = batch.to(device)
            with Trace(rm, output_cache_name, edit_output=intervention_function) as _:
                ablated_reward[i*batch_size:(i+1)*batch_size, feature_ind] = rm(batch).logits[:, 0].cpu()
        reward_diff[:, feature_ind] = top_k_original_reward - ablated_reward[:, feature_ind]

100%|██████████| 1889/1889 [6:16:58<00:00, 11.97s/it]  


In [8]:
top_k_datapoints

tensor([840, 623, 139, 557, 358, 813, 279, 736, 203, 469])

In [11]:
# Save both rewards
torch.save(ablated_reward, "rm_save_files/ablated_reward.pt")
torch.save(reward_diff, "rm_save_files/reward_diff.pt")

In [None]:
diff_reward = torch.zeros((k, num_nz_features))
for f in range(num_nz_features):
    diff_reward[:, f] = ablated_reward[:, f] - original_reward[top_k_datapoints]

In [88]:
hh["input_ids"][top_k_datapoints.tolist()]

TypeError: list indices must be integers or slices, not list

In [93]:
hh[top_k_datapoints]["input_ids"]

[[198,
  198,
  20490,
  25,
  717,
  286,
  477,
  1309,
  502,
  910,
  1312,
  1101,
  645,
  1263,
  28318,
  618,
  340,
  2058,
  284,
  4819,
  13,
  1312,
  716,
  257,
  6823,
  10765,
  11,
  475,
  1312,
  716,
  407,
  257,
  2888,
  286,
  2035,
  1688,
  2151,
  393,
  597,
  2151,
  379,
  477,
  329,
  326,
  2300,
  13,
  220,
  1309,
  502,
  1234,
  340,
  503,
  612,
  25,
  703,
  460,
  262,
  17146,
  504,
  772,
  2074,
  48833,
  19997,
  287,
  48609,
  30,
  1422,
  470,
  262,
  582,
  14083,
  262,
  1499,
  1576,
  262,
  717,
  640,
  30,
  198,
  198,
  48902,
  25,
  1320,
  2331,
  588,
  257,
  2495,
  7702,
  1808,
  13,
  220,
  10347,
  345,
  4727,
  1521,
  345,
  821,
  3612,
  326,
  30,
  198,
  198,
  20490,
  25,
  1312,
  716,
  1290,
  422,
  852,
  257,
  1964,
  5887,
  13,
  475,
  355,
  281,
  4795,
  357,
  8727,
  39416,
  284,
  262,
  1364,
  11,
  1312,
  481,
  9159,
  326,
  8,
  1312,
  804,
  379,
  19997,
  290,
  1312,
  76

In [8]:
original_reward, ablated_reward

(tensor([-9.3261, -8.6817, -2.5550, -5.7167, -5.7340, -4.3358, -4.5054, -6.0947,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  

In [19]:
def delete_feature(activations, layer_name, autoencoder, feature_to_ablate):
        # return activations
    # Check if tuple ie residual layer output as opposed to e.g. mlp output
    if isinstance(activations, tuple):
        temp_activations = activations[1]
        to_edit_activations = activations[0]
    else:
        to_edit_activations = activations

    b, s, n = to_edit_activations.shape
    mlp_flattened = rearrange(to_edit_activations, "b s n -> (b s) n")
    reconstruction_flattened, _ = autoencoder(mlp_flattened)
    reconstruction = rearrange(reconstruction_flattened, "(b s) n -> b s n", b=b, s=s)

    if isinstance(activations, tuple):
        reconstruction = tuple([reconstruction, temp_activations])
    return reconstruction

intervention_function = partial(less_than_rank_1_ablate,  autoencoder=autoencoder, feature = 0)

with Trace(rm, output_cache_name, edit_output=intervention_function) as _:
    rm(batch).logits[:, 0].cpu()

In [18]:
def less_than_rank_1_ablate(value, layer_name, autoencoder, feature):
    if(isinstance(value, tuple)):
        second_value = value[1]
        internal_activation = value[0]
    else:
        internal_activation = value
    # Only ablate the feature direction up to the negative bias
    # ie Only subtract when it activates above that negative bias.

    # Rearrange to fit autoencoder
    int_val = rearrange(internal_activation, 'b s h -> (b s) h')
    # Run through the autoencoder
    act = autoencoder.encode(int_val)
    dictionary_for_this_autoencoder = autoencoder.get_learned_dict()
    feature_direction = torch.outer(act[:, feature].squeeze(), dictionary_for_this_autoencoder[feature].squeeze())
    batch, seq_len, hidden_size = internal_activation.shape
    feature_direction = rearrange(feature_direction, '(b s) h -> b s h', b=batch, s=seq_len)
    internal_activation -= feature_direction
    if(isinstance(value, tuple)):
        return_value = (internal_activation, second_value)
    else:
        return_value = internal_activation
    return return_value