In [1]:
%load_ext autoreload
%autoreload 2
import torch
import random

from sae.sparse_autoencoder import load_saved_sae
from sae.metrics import model_store_from_sae
from unlearning.metrics import convert_wmdp_data_to_prompt
from unlearning.tool import UnlearningConfig, SAEUnlearningTool, MCQ_ActivationStoreAnalysis

from huggingface_hub import hf_hub_download
from datasets import load_dataset
import numpy as np
import pandas as pd
import itertools

from jaxtyping import Float
from torch import Tensor

import plotly.express as px


In [2]:
# resid pre 9
REPO_ID = "eoinf/unlearning_saes"
FILENAME = "jolly-dream-40/sparse_autoencoder_gemma-2b-it_blocks.9.hook_resid_pre_s16384_127995904.pt"


filename = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
sae = load_saved_sae(filename)

model = model_store_from_sae(sae)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b-it into HookedTransformer
Moving model to device:  cuda


In [3]:
# read 172 questions that the model can answer correctly in any permutation
filename = '../data/wmdp-bio_gemma_2b_it_correct.csv'
correct_question_ids = np.genfromtxt(filename)


# read 133 questions that the model can answer correctly in any permutation but will get it wrong if
# without the instruction prompt and the question prompt
filename = '../data/wmdp-bio_gemma_2b_it_correct_not_correct_wo_question_prompt.csv'
correct_question_id_not_correct_wo_question_prompt = np.genfromtxt(filename)


In [4]:
dataset = load_dataset("cais/wmdp", "wmdp-bio")
permute_choices = None
prompts = [convert_wmdp_data_to_prompt(dataset['test'][i]['question'], dataset['test'][i]['choices'], prompt_format=None, permute_choices=permute_choices) for i in range(len(dataset['test'])) if i in correct_question_id_not_correct_wo_question_prompt]

In [5]:
# setup unlearning tool, need about 3 minutes to run this cell

filename = "../data/wmdp-bio_gemma_2b_it_correct.csv"
correct_question_ids = np.genfromtxt(filename)
all_permutations = list(itertools.permutations([0, 1, 2, 3]))
n_questions = 10

dataset_args = {
    'question_subset': correct_question_ids[:n_questions],
    'permutations': all_permutations,
}

n_batches = n_questions * len(all_permutations) // 4
sae.cfg.n_batches_in_store_buffer = n_batches * 2
unlearn_activation_store = MCQ_ActivationStoreAnalysis(sae.cfg, model, dataset_args=dataset_args)
unlearning_metric = 'wmdp-bio_gemma_2b_it_correct'


unlearn_cfg = UnlearningConfig(unlearn_activation_store=unlearn_activation_store, unlearning_metric=unlearning_metric)
ul_tool = SAEUnlearningTool(unlearn_cfg)
ul_tool.setup()
ul_tool.get_metrics_with_text(n_batches=n_batches)


dataloader


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b-it into HookedTransformer
Moving model to device:  cuda
buffer
dataloader


100%|██████████| 60/60 [00:29<00:00,  2.03it/s]


tokens torch.Size([240, 1024]) 1024
tokens torch.Size([240, 1024])
Concatenating learned activations
Done


100%|██████████| 60/60 [00:08<00:00,  7.20it/s]


tokens torch.Size([240, 1024]) 1024
tokens torch.Size([240, 1024])
Concatenating learned activations
Done


In [6]:
answers = [p.index(dataset['test'][i]['answer']) for i in range(len(dataset['test'])) if i in correct_question_ids[:n_questions] for p in all_permutations]
answers

[3,
 2,
 3,
 2,
 1,
 1,
 3,
 2,
 3,
 2,
 1,
 1,
 3,
 2,
 3,
 2,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 3,
 2,
 3,
 2,
 1,
 1,
 3,
 2,
 3,
 2,
 1,
 1,
 3,
 2,
 3,
 2,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 3,
 2,
 3,
 2,
 1,
 1,
 3,
 2,
 3,
 2,
 1,
 1,
 3,
 2,
 3,
 2,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 2,
 3,
 2,
 3,
 0,
 0,
 0,
 0,
 0,
 0,
 2,
 3,
 1,
 1,
 3,
 2,
 2,
 3,
 1,
 1,
 3,
 2,
 3,
 2,
 3,
 2,
 1,
 1,
 3,
 2,
 3,
 2,
 1,
 1,
 3,
 2,
 3,
 2,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 3,
 2,
 3,
 2,
 1,
 1,
 3,
 2,
 3,
 2,
 1,
 1,
 3,
 2,
 3,
 2,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 2,
 3,
 2,
 3,
 0,
 0,
 0,
 0,
 0,
 0,
 2,
 3,
 1,
 1,
 3,
 2,
 2,
 3,
 1,
 1,
 3,
 2,
 3,
 2,
 3,
 2,
 1,
 1,
 3,
 2,
 3,
 2,
 1,
 1,
 3,
 2,
 3,
 2,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 2,
 3,
 1,
 1,
 3,
 2,
 2,
 3,
 1,
 1,
 3,
 2,
 0,
 0,
 0,
 0,
 0,
 0,
 3,
 2,
 3,
 2,
 1,
 1,
 3,
 2,
 3,
 2,
 1,
 1,
 3,
 2,
 3,
 2,
 1,
 1,
 3,
 2,
 3,
 2,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0]

In [7]:
unlearn_metrics = ul_tool.unlearn_metrics_with_text


In [8]:
unlearn_metrics['learned_activations'].shape

torch.Size([245760, 16384])

In [9]:
unlearn_metrics['token_df'].query('batch == 0').query('pos >= 16')

Unnamed: 0,str_tokens,unique_token,context,prefix,suffix,batch,pos,label
16,·explaining,·explaining/16,·about·biology.↩\nIn|·explaining|·the,·about·biology.↩\nIn,·the,0,16,0/16
17,·the,·the/17,·biology.↩\nIn·explaining|·the|·known,·biology.↩\nIn·explaining,·known,0,17,0/17
18,·known,·known/18,.↩\nIn·explaining·the|·known|·methods,.↩\nIn·explaining·the,·methods,0,18,0/18
19,·methods,·methods/19,↩\nIn·explaining·the·known|·methods|·for,↩\nIn·explaining·the·known,·for,0,19,0/19
20,·for,·for/20,In·explaining·the·known·methods|·for|·the,In·explaining·the·known·methods,·the,0,20,0/20
...,...,...,...,...,...,...,...,...
1019,<pad>,<pad>/1019,<pad><pad><pad><pad><pad>|<pad>|<pad>,<pad><pad><pad><pad><pad>,<pad>,0,1019,0/1019
1020,<pad>,<pad>/1020,<pad><pad><pad><pad><pad>|<pad>|<pad>,<pad><pad><pad><pad><pad>,<pad>,0,1020,0/1020
1021,<pad>,<pad>/1021,<pad><pad><pad><pad><pad>|<pad>|<pad>,<pad><pad><pad><pad><pad>,<pad>,0,1021,0/1021
1022,<pad>,<pad>/1022,<pad><pad><pad><pad><pad>|<pad>|,<pad><pad><pad><pad><pad>,,0,1022,0/1022


In [10]:
learned_activations_reshape: Float[Tensor, 'n_permutations seq_len n_features'] = unlearn_metrics['learned_activations'].reshape(-1, 1024, 16384)

In [11]:
# (learned_activations_reshape[1] == unlearn_metrics['learned_activations'][1024:2048]).all()

max_activations = learned_activations_reshape.max(axis=1).values
max_activations.shape

# group by ans: 0, 1, 2, 3
answers = torch.tensor(answers)
max_activations_grouped = [max_activations[answers == i] for i in range(4)]
max_activations_grouped = torch.stack(max_activations_grouped)
print(max_activations_grouped.shape)



torch.Size([4, 60, 16384])


In [12]:
max_activations_grouped_mean = max_activations_grouped.mean(axis=1)
print(max_activations_grouped_mean.shape)

torch.Size([4, 16384])


In [13]:
feature_act_diff_B_to_all = max_activations_grouped_mean[1] - max_activations_grouped_mean.mean(axis=0)
feature_act_diff_B_to_all.shape

torch.Size([16384])

In [14]:
# get feature index if non zero
non_zero_indices = feature_act_diff_B_to_all.nonzero()
non_zero_values = feature_act_diff_B_to_all[non_zero_indices].squeeze()
# px.histogram(feature_act_diff_B_to_all, title='Feature Activation Difference B to All')

In [15]:
feature_act_diff_B_to_all.to(float).topk(10)

torch.return_types.topk(
values=tensor([0.2812, 0.1387, 0.1230, 0.1211, 0.1206, 0.1094, 0.1021, 0.0918, 0.0918,
        0.0918], dtype=torch.float64),
indices=tensor([ 9992,  9399, 10046, 10159,  7320, 15062, 11821, 15691, 13718, 13309]))

In [30]:
from transformer_lens import utils

feature_id = 833
print(f'Feature ID: {feature_id}')

unlearn_metrics['token_df']["feature"] = utils.to_numpy(unlearn_metrics['learned_activations'][:, feature_id])
unlearn_metrics['token_df']['answer'] = np.array(answers).repeat(1024)
df = unlearn_metrics['token_df'][['str_tokens','prefix', 'suffix',  'context', 'batch', 'pos', 'feature', 'answer']].query("feature > 0")#.query("batch == 68")
df.sort_values("feature", ascending=False).head(100).style.background_gradient("coolwarm")

Feature ID: 833


Unnamed: 0,str_tokens,prefix,suffix,context,batch,pos,feature,answer
152644,↩,·with·the·viral·life·cycle,B,·with·the·viral·life·cycle|↩ |B,149,68,5.832031,3
147524,↩,·with·the·viral·life·cycle,B,·with·the·viral·life·cycle|↩ |B,144,68,5.832031,1
150596,↩,·with·the·viral·life·cycle,B,·with·the·viral·life·cycle|↩ |B,147,68,5.832031,3
149572,↩,·with·the·viral·life·cycle,B,·with·the·viral·life·cycle|↩ |B,146,68,5.832031,2
148548,↩,·with·the·viral·life·cycle,B,·with·the·viral·life·cycle|↩ |B,145,68,5.832031,1
151620,↩,·with·the·viral·life·cycle,B,·with·the·viral·life·cycle|↩ |B,148,68,5.832031,2
160837,↩,·the·Flaviviridae·family,B,·the·Flaviviridae·family|↩ |B,157,69,5.570312,3
159813,↩,·the·Flaviviridae·family,B,·the·Flaviviridae·family|↩ |B,156,69,5.570312,2
161861,↩,·the·Flaviviridae·family,B,·the·Flaviviridae·family|↩ |B,158,69,5.570312,1
162885,↩,·the·Flaviviridae·family,B,·the·Flaviviridae·family|↩ |B,159,69,5.570312,1


In [23]:
import torch
import numpy as np

# Original list
original_list = [0, 1, 2, 3, 4, 4, 1, 2, 2, 2]

# Convert the original list to a NumPy array
original_array = np.array(original_list)

# Use NumPy's repeat function to repeat each element 1024 times
repeated_array = original_array.repeat(, 3)

# Convert the resulting NumPy array to a PyTorch tensor
repeated_tensor = torch.tensor(repeated_array)

print(repeated_tensor)  # Should be [10240]

tensor([0, 1, 2, 3, 4, 4, 1, 2, 2, 2, 0, 1, 2, 3, 4, 4, 1, 2, 2, 2, 0, 1, 2, 3,
        4, 4, 1, 2, 2, 2])


  repeated_tensor = torch.tensor(repeated_array)


tensor([2980., 2467.,  903.,  ...,    0.,    0.,    0.], device='cuda:0')