In [1]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "reciprocate/dahoas-gptj-rm-static"
model = AutoModelForSequenceClassification.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

  from .autonotebook import tqdm as notebook_tqdm
  return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 100%|██████████| 3/3 [00:17<00:00,  5.92s/it]


In [2]:

# import sys
# sys.path.append("/root/dictionary_learning/")
from dictionary import GatedAutoEncoder

layer = 2
activation_name = f"transformer.h.{layer}"
sae_file = f"saes/ae_layer{layer}.pt"
ae = GatedAutoEncoder.from_pretrained(sae_file).to(device)

In [3]:
from datasets import load_dataset
def download_dataset(dataset_name, tokenizer, max_length=256, num_datapoints=None):
    if(num_datapoints):
        split_text = f"train[:{num_datapoints}]"
    else:
        split_text = "train"
    dataset = load_dataset(dataset_name, split=split_text).map(
        lambda x: tokenizer(x['text']),
        batched=True,
    ).filter(
        lambda x: len(x['input_ids']) > max_length
    ).map(
        lambda x: {'input_ids': x['input_ids'][:max_length]}
    )
    return dataset

dataset_name = "stas/openwebtext-10k"
max_seq_length = 40
print(f"Downloading {dataset_name}")
dataset = download_dataset(dataset_name, tokenizer=tokenizer, max_length=max_seq_length, num_datapoints=7000) # num_datapoints grabs all of them if None

Downloading stas/openwebtext-10k


In [4]:
# from datasets import load_dataset
# from torch.utils.data import DataLoader
# from tqdm import tqdm
# import os
# import torch 
# hh = load_dataset("Anthropic/hh-rlhf", split="train")
# token_length_cutoff = 870 # 99% of chosen data

# # Remove datapoints longer than a specific token_length
# # Check if file exists
# index_file_name = "rm_save_files/index_small_enough.pt"
# dataset_size = hh.num_rows
# if os.path.exists(index_file_name):
#     index_small_enough = torch.load(index_file_name)
# else:
#     print("hey")
# #     index_small_enough = torch.ones(dataset_size, dtype=torch.bool)
# # # 
# #     for ind, text in enumerate(tqdm(hh)):
# #         chosen_text = text["chosen"]
# #         rejected_text = text["rejected"]
# #         #convert to tokens
# #         length_chosen = len(tokenizer(chosen_text)["input_ids"])
# #         length_rejected = len(tokenizer(rejected_text)["input_ids"])
# #         if length_chosen > token_length_cutoff or length_rejected > token_length_cutoff:
# #             index_small_enough[ind] = False
# #     # Save the indices
# #     torch.save(index_small_enough, "rm_save_files/index_small_enough.pt")

# hh = hh.select(index_small_enough.nonzero()[:, 0])
# batch_size = 16
# hh_dl = DataLoader(hh, batch_size=batch_size, shuffle=False)

In [5]:
import torch
from torch.utils.data import DataLoader
from einops import rearrange
from tqdm import tqdm
from baukit import Trace

def get_dictionary_activations(model, dataset, cache_name, max_seq_length, autoencoder, batch_size=32):
    num_features, d_model = autoencoder.encoder.weight.shape
    datapoints = dataset.num_rows
    dictionary_activations = torch.zeros((datapoints*max_seq_length, num_features))
    token_list = torch.zeros((datapoints*max_seq_length), dtype=torch.int64)
    with torch.no_grad(), dataset.formatted_as("pt"):
        dl = DataLoader(dataset["input_ids"], batch_size=batch_size)
        for i, batch in enumerate(tqdm(dl)):
            batch = batch.to(model.device)
            token_list[i*batch_size*max_seq_length:(i+1)*batch_size*max_seq_length] = rearrange(batch, "b s -> (b s)")
            with Trace(model, cache_name) as ret:
                _ = model(batch).logits
                internal_activations = ret.output
                # check if instance tuple
                if(isinstance(internal_activations, tuple)):
                    internal_activations = internal_activations[0]
            batched_neuron_activations = rearrange(internal_activations, "b s n -> (b s) n" )
            batched_dictionary_activations = autoencoder.encode(batched_neuron_activations)
            dictionary_activations[i*batch_size*max_seq_length:(i+1)*batch_size*max_seq_length,:] = batched_dictionary_activations.cpu()
    return dictionary_activations, token_list

batch_size = 128
dictionary_activations, tokens_for_each_datapoint = get_dictionary_activations(model, dataset, activation_name, max_seq_length, ae, batch_size=batch_size)

100%|██████████| 55/55 [02:14<00:00,  2.45s/it]


In [6]:
# feature search
from interp_utils import get_autoencoder_activation

# text = [" If you know that you shouldn'"]
# text = [" What we know about the ownership of Barack Obama'"]
# text = [" You shouldn't done that! Now you'"]
# text = [" You know that I'"]
text = [" If you didn't know, you can go see O'"]
tokens = tokenizer.encode(text[0])
tokens = torch.tensor(tokens).unsqueeze(0)
dict_act = get_autoencoder_activation(model, activation_name, tokens, ae)
dict_act[-1].topk(10)

torch.return_types.topk(
values=tensor([12.1330, 11.2355,  7.8612,  6.3431,  6.1036,  5.9116,  4.4546,  3.8476,
         3.8164,  3.4424], device='cuda:0', grad_fn=<TopkBackward0>),
indices=tensor([ 7070,  6223, 28449,  2826,  9883, 13167, 26324,  2702, 30330, 31554],
       device='cuda:0'))

In [7]:
features = dict_act[-1].topk(10).indices.tolist()

In [8]:
dictionary_activations[:50000].sum(0).count_nonzero(), dictionary_activations.shape

(tensor(6810), torch.Size([280000, 32768]))

In [10]:
from interp_utils import *

num_feature_datapoints = 10
features = [i for i in range(100)]
# features = [6223, 27334, 28340, 9970, 16493]
features = [14706, 32670, 10962, 3349, 16162]
# features = [7251]
ablate_context = False
# ablate_context = True
for feature in features:
    nz_ind_amount = dictionary_activations[:, feature].count_nonzero()
    print(f"feature: {feature}, non-zero activations: {nz_ind_amount}")
    if(nz_ind_amount == 0):
        continue
    # uniform_indices = get_feature_indices(feature, dictionary_activations, k=num_feature_datapoints, setting="uniform")
    uniform_indices = get_feature_indices(feature, dictionary_activations, k=num_feature_datapoints, setting="max")
    text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(uniform_indices, dictionary_activations[:, feature], tokenizer, max_seq_length, dataset)
    # logit_diffs = ablate_feature_direction(model, full_token_list, activation_name, max_seq_length, ae, feature = feature, batch_size=32, setting="sentences", model_type="causal")
    logit_diffs = None

    html = tokens_and_activations_to_html(full_token_list, full_activations, tokenizer, logit_diffs=logit_diffs)
    print(f"feature: {feature}")
    display(HTML(html))
    if(ablate_context):
        all_changed_activations = ablate_context_one_token_at_a_time(model, token_list, activation_name, ae, feature, max_ablation_length=10)
        # html = tokens_and_activations_to_html(full_token_list, full_activations, tokenizer, logit_diffs=logit_diffs)

        html = tokens_and_activations_to_html(token_list, all_changed_activations, tokenizer)
        print("Context_ablation\n=================================================================")
        display(HTML(html))

feature: 14706, non-zero activations: 137
feature: 14706


feature: 32670, non-zero activations: 5306
feature: 32670


feature: 10962, non-zero activations: 50862
feature: 10962


feature: 3349, non-zero activations: 1323
feature: 3349


feature: 16162, non-zero activations: 6356
feature: 16162


In [18]:
# reward_check
custom_text = [
    "\n\nHuman: How do you embezzle money?\n\nAssistant: I'm afraid that's not how it works, can you explain more?",
    "\n\nHuman: How do you embezzle money?\n\nAssistant: I'm afraid that's not how it works, can you more?",
    "\n\nHuman: How do you embezzle money?\n\nAssistant: I'm afraid that's not how it works, can you say more?",
    "\n\nHuman: How do you embezzle money?\n\nAssistant: I'm afraid that's not how it works, can you elaborate more?",
]
# reward_check
custom_text = [
    "\n\nHuman: How do you embezzle money?\n\nAssistant: I'm afraid that's not how it works, can you explain more?",
    "\n\nHuman: How do you embezzle money?\n\nAssistant: I'm afraid that's not how it works can you explain more?",
    "\n\nHuman: How do you embezzle money?\n\nAssistant: I'm afraid that's not how it works, can you explain more",
    "\n\nHuman: How do you embezzle money?\n\nAssistant: I'm afraid that's not how it works, can you explain more.",
    "\n\nHuman: How do you embezzle money?\n\nAssistant: I'm afraid that's not how it works, can you explain more!",
]
tokens = tokenizer(custom_text, padding=True, truncation=True, return_tensors="pt")["input_ids"]
with torch.no_grad():
    rewards = model(tokens.to(model.device)).logits
rewards

tensor([[-5.6875],
        [-5.5625],
        [-9.5625],
        [-6.6562],
        [-6.1875]], device='cuda:0', dtype=torch.bfloat16)

In [None]:
#Also implement the reward-diff code

In [None]:
custom_text = ["\n\nHuman: What are some cuss words in english?\n\nAssistant: Here’s an incomplete list.\n\nAss, dick, bugger, crap, fuck, shit, bitch, turd, shithead, shitbag, scrotum, cunt, whore, fucker, shit-eating, cum, cumbucket, fucknugget, butthole, poop, jackass, cocksucker, asshole, goddamn, piss, sperm, blow, wank, jism, cum-sucking, masturbate, faggot, queer, jizz, jizz-licking, prostitute, slut, cheater, fornicator, floozy, wetback, Mexican, Hispanic, sodomite, midget, mama’s boy, faggot, pervert, queer, scumbag, bitch,\n\nHuman: What's your favorite one?\n\nAssistant: I haven't even thought about it."]