In [1]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "reciprocate/dahoas-gptj-rm-static"
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)


# import sys
# sys.path.append("/root/dictionary_learning/")
from dictionary import GatedAutoEncoder

layer = 2
activation_name = f"transformer.h.{layer}"
sae_file = f"saes/ae_layer{layer}.pt"
ae = GatedAutoEncoder.from_pretrained(sae_file).to(device)

from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import torch 
hh = load_dataset("Anthropic/hh-rlhf", split="train")
token_length_cutoff = 870 # 99% of chosen data

# Remove datapoints longer than a specific token_length
# Check if file exists
index_file_name = "rm_save_files/index_small_enough.pt"
dataset_size = hh.num_rows
if os.path.exists(index_file_name):
    index_small_enough = torch.load(index_file_name)
else:
    print("hey")
#     index_small_enough = torch.ones(dataset_size, dtype=torch.bool)
# # 
#     for ind, text in enumerate(tqdm(hh)):
#         chosen_text = text["chosen"]
#         rejected_text = text["rejected"]
#         #convert to tokens
#         length_chosen = len(tokenizer(chosen_text)["input_ids"])
#         length_rejected = len(tokenizer(rejected_text)["input_ids"])
#         if length_chosen > token_length_cutoff or length_rejected > token_length_cutoff:
#             index_small_enough[ind] = False
#     # Save the indices
#     torch.save(index_small_enough, "rm_save_files/index_small_enough.pt")

hh = hh.select(index_small_enough.nonzero()[:, 0])
top_reward_diff_ind = torch.load("rm_save_files/top_reward_diff_ind.pt")
hh = hh.select(top_reward_diff_ind)

# select first 100 datapoints
hh = hh.select(range(1000))
batch_size = 12
hh_dl = DataLoader(hh, batch_size=batch_size, shuffle=False)

  from .autonotebook import tqdm as notebook_tqdm
  return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 100%|██████████| 3/3 [00:19<00:00,  6.63s/it]


In [2]:
import torch
from tqdm import tqdm

num_datapoints = len(hh)
token_length_cutoff = 871

index_of_chosen_rejection_difference = torch.zeros(num_datapoints, dtype=torch.int16)

# Assuming hh_dl is a DataLoader that returns batches of data
subsets = 0
for i, batch in enumerate(tqdm(hh)):
    chosen_texts = batch["chosen"]
    rejected_texts = batch["rejected"]

    # Tokenize texts in batches
    chosen_tokens = tokenizer(chosen_texts, return_tensors="pt", padding="max_length", truncation=True, max_length=token_length_cutoff)["input_ids"]
    rejected_tokens = tokenizer(rejected_texts, return_tensors="pt", padding="max_length", truncation=True, max_length=token_length_cutoff)["input_ids"]

    chosen_token_original_length = tokenizer(chosen_texts, return_tensors="pt")["input_ids"].shape[1]
    rejected_token_original_length = tokenizer(rejected_texts, return_tensors="pt")["input_ids"].shape[1]
    min_length = min(chosen_token_original_length, rejected_token_original_length)

    # Compare tokens and find divergence points
    divergence_matrix = (chosen_tokens != rejected_tokens).to(torch.int)  # Matrix of 1s where tokens differ

    # Find the first divergence index for each pair of texts
    divergence_indices = divergence_matrix.argmax(dim=1)
    if divergence_indices == min_length:
        subsets += 1
        divergence_indices -= 1

    # Calculate start index for the current batch
    # start_index = i * batch_size
    # end_index = start_index + len(chosen_texts)

    # Store the divergence indices in the appropriate positions
    # index_of_chosen_rejection_difference[start_index:end_index] = divergence_indices
    index_of_chosen_rejection_difference[i] = divergence_indices
print(f"Number of subsets: {subsets}")

100%|██████████| 1000/1000 [00:04<00:00, 243.14it/s]

Number of subsets: 1





In [3]:
# torch.save(index_of_chosen_rejection_difference, "rm_save_files/index_of_chosen_rejection_difference.pt")

In [4]:
from baukit import Trace
from einops import rearrange
num_features = ae.encoder.weight.shape[0]
num_datapoints = len(hh)
max_feature_activations_chosen = torch.zeros(num_datapoints, num_features)
max_feature_activations_rejected = torch.zeros(num_datapoints, num_features)

chosen_rejected_list = [max_feature_activations_chosen, max_feature_activations_rejected]
with torch.no_grad():
    for batch_ind, batch in tqdm(enumerate(hh_dl), total=len(hh_dl)):
        chosen_texts = batch["chosen"]
        rejected_texts = batch["rejected"]
        chosen_tokens = tokenizer(chosen_texts, return_tensors="pt", padding=True, truncation=True)["input_ids"]
        rejected_tokens = tokenizer(rejected_texts, return_tensors="pt", padding=True, truncation=True)["input_ids"]
        # Get Intermediate Activations
        index_of_token_diff = index_of_chosen_rejection_difference[batch_ind*batch_size:(batch_ind+1)*batch_size].to(torch.int)
        for chos_rej_ind, batch_tokens in enumerate([chosen_tokens, rejected_tokens]):

            # Get intermediate activations
            with Trace(model, activation_name) as ret:
                _ = model(batch_tokens.to(device)).logits
                internal_activations = ret.output
                # check if instance tuple
                if(isinstance(internal_activations, tuple)):
                    internal_activations = internal_activations[0]

            # Get Features for activation
            current_batch_size = batch_tokens.shape[0]
            batched_neuron_activations = rearrange(internal_activations, "b s n -> (b s) n" )
            batched_dictionary_activations = ae.encode(batched_neuron_activations)
            batched_feature_activations = rearrange(batched_dictionary_activations, "(b s) n -> b s n", b=current_batch_size)
            # Store only max over the differing tokens
            # ensure shape of index_of_token_diff is same as batched_feature_activations
            # try:
            max_feature_act = torch.max(batched_feature_activations[:, index_of_token_diff, :], dim=1).values.cpu()
            chosen_rejected_list[chos_rej_ind][batch_ind*batch_size:(batch_ind+1)*batch_size] = max_feature_act
                # del max_feature_act
            # except:
            #     print(f"Text at batch {batch_ind} and index {chos_rej_ind} are subset of each other. Skipping...")
                # print(f"Chosen text: {chosen_texts}")
                # print(f"Rejected text: {rejected_texts}")
            
            del ret, internal_activations, batched_neuron_activations, batched_dictionary_activations, batched_feature_activations, max_feature_act
            torch.cuda.empty_cache()

  0%|          | 0/84 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
100%|██████████| 84/84 [12:37<00:00,  9.02s/it]


In [5]:
torch.cuda.empty_cache()

In [6]:
#combine both chosen and rejected. Find the top-k datapoints for each feature. Ignore Features that don't have j activations (& print so)
chosen_rejected = torch.cat(chosen_rejected_list, dim=0)
k = 10
j = 10
nz_feat = (chosen_rejected > 0).sum(0)
print(f"More than 0: {nz_feat.count_nonzero()}")
print(f"More than {j}: {(nz_feat > j).count_nonzero()}")

nz_feature_ind = nz_feat.nonzero().squeeze()
num_nz_features = nz_feature_ind.shape[0]
feature_top_k = torch.zeros(num_nz_features, k)

for feature_ind, nz_feat_i in tqdm(enumerate(nz_feature_ind), total=num_nz_features):
    feature_activations = chosen_rejected[:, nz_feat_i]
    top_k_activations = torch.topk(feature_activations, k).indices
    feature_top_k[feature_ind] = top_k_activations

More than 0: 4335
More than 10: 2368


100%|██████████| 4335/4335 [00:00<00:00, 19003.39it/s]


In [7]:
# Save results
# # save nz_feature_ind
torch.save(nz_feature_ind, "rm_save_files/nz_feature_ind.pt")
torch.save(feature_top_k, "rm_save_files/each_nz_features_top_activating_datapoints.pt")

In [8]:
feature_top_k.shape, nz_feature_ind

(torch.Size([4335, 10]),
 tensor([    6,    11,    15,  ..., 32739, 32743, 32748]))

In [None]:
nz_ind = nz_feat.nonzero()[:, 0]
chosen_rejected_nz = chosen_rejected[:, nz_ind]
chosen_rejected_nz

In [None]:
(torch.sum(chosen_rejected, dim=0)>0).count_nonzero(), (torch.sum(chosen_rejected, dim=0) > 10).count_nonzero()

In [None]:
chosen_rejected.shape

In [None]:
torch.topk(feature_activations, k)

In [None]:
feature_activations

In [None]:
torch.max(batched_feature_activations[:, index_of_token_diff, :], dim=1).values.count_nonzero()

In [None]:
tokenizer.pad_token_id, tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.sep_token_id, tokenizer.cls_token_id, tokenizer.mask_token_id, tokenizer.unk_token_id

In [None]:
index_of_token_diff

In [None]:
batched_dictionary_activations.shape, internal_activations.shape, hh_dl.batch_size, len(chosen_texts), len(rejected_texts)

In [None]:
for ind, b in enumerate(hh_dl):
    print(len(b["chosen"]))
    if(ind > 100):
        break

In [None]:
b