# Download Model & SAEs

In [1]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
from nnsight import LanguageModel
torch.jit.is_tracing = lambda : True
tracer_kwargs = {'validate' : False, 'scan' : False}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "reciprocate/dahoas-gptj-rm-static"
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)


  from .autonotebook import tqdm as notebook_tqdm
  return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 100%|██████████| 3/3 [00:15<00:00,  5.09s/it]


In [2]:
model_nnsight = LanguageModel(
    model_name,
    device_map = "cpu",
    automodel = AutoModelForSequenceClassification,
    dispatch = True,
    torch_dtype=torch.bfloat16,
)

Loading checkpoint shards: 100%|██████████| 3/3 [00:11<00:00,  3.87s/it]


In [3]:
from huggingface_hub import hf_hub_download
from dictionary import GatedAutoEncoder

# Possible layers: 2,8,12,14,16,20
layer = 12
activation_name = f"transformer.h.{layer}"
model_id = "Elriggs/rm"
sae_file_save_name = f"ae_layer{layer}"
sae_filename = sae_file_save_name + ".pt"
sae_file_dir = f"sae_results/{sae_file_save_name}"
ae_download_location = hf_hub_download(repo_id=model_id, filename=sae_filename)
sae = GatedAutoEncoder.from_pretrained(ae_download_location).to(device)

# Attribution Patching & HTML display definitions

In [4]:

from interp_utils import patching_effect_two
import gc
from tqdm import tqdm
from interp_utils import tokens_and_activations_to_html
from IPython.display import HTML, display
from einops import rearrange
from baukit import Trace
from interp_utils import get_autoencoder_activation
from functools import partial

def sae_ablation_after_pos(x, sae, feature_ind, positions):
    # baukit nonsense to handle both residual stream & mlp/attn_output
    if(isinstance(x, tuple)):
        second_value = x[1]
        internal_activation = x[0]
    else:
        internal_activation = x
    batch, seq_len, hidden_size = internal_activation.shape
    int_val = rearrange(internal_activation, "b seq d_model -> (b seq) d_model")
    
    # Encode in features, then remove all features
    f = sae.encode(int_val)

    residual = int_val - sae.decode(f)

    # Ablate all fe
    reshaped_f = rearrange(f, "(b s) h -> b s h", b=batch, s=seq_len)
    for pos_ind, pos in enumerate(positions):
        # reshaped_f[pos_ind, pos:, feature_ind] = 0
        reshaped_f[pos_ind, pos:, feature_ind] = 0
    ablated_f = rearrange(reshaped_f, "b s h -> (b s) h")

    # Decode & add back in residual
    x_hat = sae.decode(ablated_f)

    x_recon = residual + x_hat


    # baukit nonsense to handle both residual stream & mlp/attn_output
    reconstruction = rearrange(x_recon, '(b s) h -> b s h', b=batch, s=seq_len)
    if(isinstance(x, tuple)):
        return_value = (reconstruction, second_value)
    else:
        return_value = reconstruction
    return return_value

def get_padding_indices(token_tensor, padding_token_id):

    padding_indices = -torch.ones(token_tensor.size(0), dtype=torch.long)

    # Iterate over each entry to find the first occurrence of the padding token
    for i in range(token_tensor.size(0)):
        padding_idx = (token_tensor[i] == padding_token_id).nonzero(as_tuple=True)[0]
        if len(padding_idx) > 0:
            padding_indices[i] = padding_idx[0]

    return padding_indices

def display_feature_activation_and_ablation(tokens, prefix_text, feature, feature_ablation, model, sae, activation_name, tokenizer):

    prefix_tokens = tokenizer(prefix_text, return_tensors="pt")["input_ids"]
    prefix_size = len(prefix_tokens[0])

    padding_location = get_padding_indices(tokens, tokenizer.pad_token_id)
    # completion_tokens = tokenizer(custom_text_suffix, padding=True, truncation=True, return_tensors="pt")["input_ids"]

    # get the reward for each
    batch_size, seq_size = tokens.shape
    with torch.no_grad():
        # Get the feature activations & reward
        feature_activations, reward = get_autoencoder_activation(model, activation_name, tokens, sae, return_output=True)
        feature_activations = feature_activations[..., feature].cpu()
        reward = reward.squeeze().cpu()
        feature_activations = rearrange(feature_activations, "(b s) -> b s", b=batch_size, s=seq_size)

        #Get ablation reward
        hook_function = partial(sae_ablation_after_pos, sae = sae, feature_ind=feature_ablation, positions=[prefix_size for _ in range(len(tokens))])
        with Trace(model, activation_name, edit_output=hook_function) as _:
            ablated_reward = model(tokens.to(model.device)).logits.cpu().squeeze()
        
    token_list = []
    activation_list = []
    text_above = []
    length_of_prompt = prefix_size
    prompt_tok = prefix_tokens.tolist()
    prompt_act = feature_activations[0][:length_of_prompt].tolist()
    token_list.append(prompt_tok)
    activation_list.append(prompt_act)
    text_above.append(f"Prompt<br>")

    # for token_ind in range(len(completion_tokens)):
    for token_ind in range(len(tokens)):
        tok = tokens[token_ind]
        act = feature_activations[token_ind]

        pad_ind = padding_location[token_ind]
        if pad_ind == -1:
            pad_ind = len(tok)
        # chosen_tok = tokens[token_ind][length_of_prompt:pad_ind].tolist()
        suffix_tok = tok[length_of_prompt:pad_ind].tolist()
        suffix_act = act[length_of_prompt:pad_ind].tolist()
        suffix_reward = reward[token_ind].item()
        ablated_suffix_reward = ablated_reward[token_ind].item()

        # append
        token_list.append(suffix_tok)
        activation_list.append(suffix_act)
        text_above.append(f"Reward: {suffix_reward:.2f} -> {ablated_suffix_reward:.2f} <br> {token_ind+1}.")
        
    html = tokens_and_activations_to_html(token_list, activation_list, tokenizer, logit_diffs=None, text_above_each_act=text_above)
    print(f"feature: {feature}")
    display(HTML(html))

def attribution_patching(model, model_nnsight, sae, tokens, activation_name, prefix_size, tracer_kwargs, steps=10):
    model.to("cpu") 
    gc.collect()  
    torch.cuda.empty_cache()

    device = sae.decoder.weight.device
    def get_reward(model):
        return model.output.logits[:, 0]

    model_nnsight.to(device)
    # Get module by it's name
    attributes = activation_name.split('.')
    module = model_nnsight
    for attr in attributes:
        module = getattr(module, attr)

    dictionaries = {}
    submodule_names = {}
    submodule_names[module] = activation_name
    dictionaries[module] = sae
    submodules = [module]
    list_effects = []
    for token in tokens:
        pos = [prefix_size]
        effects = patching_effect_two(
            token.to(device),
            None,
            model_nnsight,
            submodules = submodules,
            dictionaries = dictionaries,
            tracer_kwargs=tracer_kwargs,
            positions = pos,
            metric_fn = get_reward,
            steps = steps,
        )
        list_effects.append(effects)
    list_effects = torch.cat(list_effects)
    model_nnsight.to("cpu")
    model.to(device)
    gc.collect()
    torch.cuda.empty_cache()
    return list_effects

def ignore_baseline_features(list_effects, num_baselines, top_features_to_ignore=10):
    num_baselines *= -1
    ignore_these_features = list_effects[num_baselines:].abs().sum(0).sum(0).topk(top_features_to_ignore).indices
    tmp_effects = list_effects.clone()
    tmp_effects[:,:, ignore_these_features] = 0
    return tmp_effects

def true_ablation_effect(features_to_ablate, tokens, prefix_size):
    all_reward_diffs = []
    with torch.no_grad():
        original_reward = model(tokens.to(device)).logits.cpu().squeeze()
        for feat_to_ablate in tqdm(features_to_ablate):
            #Get ablation reward
            hook_function = partial(sae_ablation_after_pos, sae = sae, feature_ind=feat_to_ablate, positions=[prefix_size for _ in range(len(tokens))])
            with Trace(model, activation_name, edit_output=hook_function) as _:
                ablated_reward = model(tokens.to(device)).logits.cpu().squeeze()
            reward_diff = original_reward - ablated_reward
            all_reward_diffs.append(reward_diff)
    all_reward_diffs = torch.stack(all_reward_diffs)
    print('Original Memory Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    return all_reward_diffs

# Custom Prompts

In [22]:
prefix_text = "\n\nHuman: What's the capital of France?\n\nAssistant:"
completions = [
    " Thank you. No problem. thank you thanks to you thank you",
]
baselines = [
    " Paris is the capital of France",
    " Yo mama so old, when someone told her to act her age, she died.",
    " I hate you, you worthless piece of trash.",
]
num_baselines = len(baselines)

custom_text = [prefix_text + text for text in completions] + [prefix_text + text for text in baselines]
tokens = tokenizer(custom_text, padding=True, truncation=True, return_tensors="pt")["input_ids"]
prefix_tokens = tokenizer(prefix_text, return_tensors="pt")["input_ids"]
prefix_size = len(prefix_tokens[0])

# Use Attribution Patching (AtP) to find approx effects. (Might take a min to run)
approx_effects = attribution_patching(model, model_nnsight, sae, tokens, activation_name, prefix_size, tracer_kwargs=tracer_kwargs, steps=10)

# Ignore features that affected reward in baselines
ignored_baseline_effects = ignore_baseline_features(approx_effects, num_baselines, top_features_to_ignore=10)
ignored_baseline_effects[0,:,].sum(0).topk(10), ignored_baseline_effects[0,:,].sum(0).topk(10, largest=False)
top_pos_features = ignored_baseline_effects[0,:,].sum(0).topk(10, largest=False).indices
top_neg_features = ignored_baseline_effects[0,:,].sum(0).topk(10).indices

# Find the actual ablation effect. 
# Choose "pos" for positive features (e.g. Thank you. No Problem!) 
# Choose "neg" for negative features (e.g. repeating text)
true_abl_effect = true_ablation_effect(top_pos_features, tokens, prefix_size)
effects, local_ind = true_abl_effect[:, 0].sort(descending=True)
effects, top_pos_features[local_ind]

100%|██████████| 10/10 [00:01<00:00,  5.83it/s]

Original Memory Allocated: 24.0 GB





(tensor([ 0.8962,  0.8521,  0.6026,  0.4036,  0.1692,  0.1499,  0.0982,  0.0645,
          0.0245, -0.0058]),
 tensor([17168, 32744, 20930, 28839, 24910,   131, 13635, 17479, 23950,  9554]))

In [26]:
target_feature = top_pos_features[local_ind][0]
ablate_these_features = [target_feature, top_pos_features[local_ind][1]]
display_feature_activation_and_ablation(tokens, prefix_text, target_feature, ablate_these_features, model, sae, activation_name, tokenizer)

feature: 17168
