In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../")

import torch
import random

from sae.sparse_autoencoder import load_saved_sae
from sae.metrics import model_store_from_sae
from unlearning.metrics import convert_wmdp_data_to_prompt
from unlearning.tool import UnlearningConfig, SAEUnlearningTool, MCQ_ActivationStoreAnalysis

from huggingface_hub import hf_hub_download
from datasets import load_dataset
import numpy as np
import pandas as pd
import itertools
from transformer_lens import utils

from jaxtyping import Float
from torch import Tensor

from pathlib import Path

import plotly.express as px


In [2]:
# resid pre 9
REPO_ID = "eoinf/unlearning_saes"
FILENAME = "jolly-dream-40/sparse_autoencoder_gemma-2b-it_blocks.9.hook_resid_pre_s16384_127995904.pt"

filename = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
sae = load_saved_sae(filename)

model = model_store_from_sae(sae)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b-it into HookedTransformer
Moving model to device:  cuda


In [3]:
filename = "../data/wmdp-bio_gemma_2b_it_correct.csv"
correct_question_ids = np.genfromtxt(filename)

dataset_args = {
    'question_subset': correct_question_ids,
}

sae.cfg.n_batches_in_store_buffer = 86

act_store = MCQ_ActivationStoreAnalysis(sae.cfg, model, dataset_args=dataset_args)
unlearning_metric = 'wmdp-bio_gemma_2b_it_correct'

unlearn_cfg = UnlearningConfig(unlearn_activation_store=act_store, unlearning_metric=unlearning_metric)
ul_tool2 = SAEUnlearningTool(unlearn_cfg)
ul_tool2.setup(model=model)
ul_tool2.get_metrics_with_text()

dataloader
buffer
dataloader


100%|██████████| 43/43 [00:23<00:00,  1.80it/s]


tokens torch.Size([172, 1024]) 1024
tokens torch.Size([172, 1024])
Concatenating learned activations
Done


100%|██████████| 43/43 [00:04<00:00, 10.66it/s]


tokens torch.Size([172, 1024]) 1024
tokens torch.Size([172, 1024])
Concatenating learned activations
Done


In [4]:

# read 172 questions that the model can answer correctly in any permutation
filename = '../data/wmdp-bio_gemma_2b_it_correct.csv'
correct_question_ids = np.genfromtxt(filename)


# read 133 questions that the model can answer correctly in any permutation but will get it wrong if
# without the instruction prompt and the question prompt
filename = '../data/wmdp-bio_gemma_2b_it_correct_not_correct_wo_question_prompt.csv'
correct_question_id_not_correct_wo_question_prompt = np.genfromtxt(filename).astype(int)


# pass in the dataset as an argument so no need to load in everytime
dataset = load_dataset("cais/wmdp", "wmdp-bio")

answers = [x['answer'] for x in dataset['test']]
questions = [x['question'] for x in dataset['test']]
choices_list = [x['choices'] for x in dataset['test']]

prompts = [convert_wmdp_data_to_prompt(question, choices, prompt_format=None) for question, choices in zip(questions, choices_list)]


# load good feature list
import pickle
with open('../yeutong_notebooks/unlearning_output/good_features_list_v1.pkl', 'rb') as f:
    good_features_list = pickle.load(f)

features_to_test = list(set([item for sublist in good_features_list.values() for item in sublist]))
filtered_good_features = [12663, 4342, 5749, 10355, 1523, 15858, 12273, 14315, 4451, 1611, 10051, 16186, 7983, 6958, 1307, 11019, 6531, 12289]
filtered_features_sorted_by_loss = [7983, 16186, 12273, 14315,  4342, 10051, 15858,  6958, 12663, 1611,  6531,  1523, 10355,  5749,  1307, 12289,  4451, 11019]

In [70]:
# Do cumulative pass over features

# loss_intervention_results2 = []
# metrics_intervention_results2 = []
control_metrics_results2 = []

all_permutations = list(itertools.permutations([0, 1, 2, 3]))


for feature in filtered_features_sorted_by_loss:
    ul_tool2.base_activation_store.iterable_dataset = iter(ul_tool2.base_activation_store.dataset)
    ablate_params = {
        'features_to_ablate': [feature],
        'multiplier': 35,
        'intervention_method': 'clamp_feature_activation',
        'permutations': None,
    }
    
    # metrics = ul_tool2.calculate_metrics(**ablate_params)
    # metrics_intervention_results2.append(metrics)
    # loss_added = ul_tool2.compute_loss_added(n_batch=10, **ablate_params)
    # loss_intervention_results2.append(loss_added)
    
    control_metrics = ul_tool2.calculate_control_metrics(random_select_one=False, **ablate_params)
    control_metrics_results2.append(control_metrics)


100%|██████████| 5/5 [00:03<00:00,  1.36it/s]
100%|██████████| 5/5 [00:03<00:00,  1.36it/s]
100%|██████████| 5/5 [00:03<00:00,  1.36it/s]
100%|██████████| 5/5 [00:03<00:00,  1.36it/s]
100%|██████████| 5/5 [00:03<00:00,  1.34it/s]
100%|██████████| 5/5 [00:03<00:00,  1.35it/s]
100%|██████████| 5/5 [00:03<00:00,  1.35it/s]
100%|██████████| 5/5 [00:03<00:00,  1.34it/s]
100%|██████████| 5/5 [00:03<00:00,  1.34it/s]
100%|██████████| 5/5 [00:03<00:00,  1.34it/s]
100%|██████████| 5/5 [00:03<00:00,  1.34it/s]
100%|██████████| 5/5 [00:03<00:00,  1.33it/s]
100%|██████████| 5/5 [00:03<00:00,  1.34it/s]
100%|██████████| 5/5 [00:03<00:00,  1.32it/s]
100%|██████████| 5/5 [00:03<00:00,  1.32it/s]
100%|██████████| 5/5 [00:03<00:00,  1.29it/s]
100%|██████████| 5/5 [00:03<00:00,  1.32it/s]
100%|██████████| 5/5 [00:03<00:00,  1.32it/s]


In [71]:
[x['mean_correct'] for x in control_metrics_results2]


[1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.9000000357627869,
 0.8666667342185974,
 1.0,
 0.46666669845581055,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0]

In [9]:
unlearned_frac2 = [x['modified_metrics']['mean_correct'] for x in metrics_intervention_results2]

list(zip(loss_intervention_results2, unlearned_frac2))

[(-0.0005425691604614257, 0.9883720874786377),
 (-0.0012538909912109375, 0.9825581312179565),
 (-0.0012538909912109375, 0.8604651093482971),
 (-0.0016951799392700196, 0.8488371968269348),
 (-0.0015992164611816407, 0.8488371968269348),
 (-0.0015992164611816407, 0.8372092843055725),
 (-0.0015992164611816407, 0.6686046719551086),
 (-0.0015607833862304687, 0.645348846912384),
 (-0.001813220977783203, 0.5755813717842102),
 (0.0033960819244384767, 0.45348837971687317),
 (0.006481027603149414, 0.44186046719551086),
 (0.007501912117004394, 0.3372093141078949),
 (0.018922185897827147, 0.3488371968269348),
 (0.025135350227355958, 0.3488371968269348),
 (0.027659845352172852, 0.34302324056625366),
 (0.027605223655700683, 0.29651162028312683),
 (0.05441043376922607, 0.29651162028312683),
 (0.05386528968811035, 0.22674418985843658)]

In [10]:
max_activations = ul_tool2.unlearn_metrics_with_text['learned_activations'].max(dim=0)[0]

In [25]:
max_activations.shape

torch.Size([16384])

In [46]:
filtered_features_max_activations = max_activations[filtered_features_sorted_by_loss].to("cuda")
5 * filtered_features_max_activations

tensor([48.4062, 20.6406, 49.5000, 38.4062,  9.5625, 35.0312, 53.6875, 18.6875,
        24.2500, 24.5000, 18.9531, 20.1250, 28.8438, 36.7500, 30.0000, 33.4062,
        42.0625, 21.4062], device='cuda:0', dtype=torch.float16)

In [72]:
filtered_features_sorted_by_loss[:8]
filtered_features_sorted_by_loss2 = np.concatenate((filtered_features_sorted_by_loss[:8], filtered_features_sorted_by_loss[10:11], filtered_features_sorted_by_loss[12:]))

In [79]:
# Do cumulative pass over features

loss_intervention_results3 = []
metrics_intervention_results3 = []
control_metrics_results3 = []

all_permutations = list(itertools.permutations([0, 1, 2, 3]))


for i in range(7, 15):
# for multiplier in [30]:
    ul_tool2.base_activation_store.iterable_dataset = iter(ul_tool2.base_activation_store.dataset)
    ablate_params = {
        'features_to_ablate': filtered_features_sorted_by_loss2[:i+1],
        'multiplier': 20,
        'intervention_method': 'clamp_feature_activation',
        'permutations': None,
    }
    
    metrics = ul_tool2.calculate_metrics(**ablate_params)
    metrics_intervention_results3.append(metrics)
    
    loss_added = ul_tool2.compute_loss_added(n_batch=10, **ablate_params)
    loss_intervention_results3.append(loss_added)
    
    control_metrics = ul_tool2.calculate_control_metrics(random_select_one=False, **ablate_params)
    control_metrics_results3.append(control_metrics)


  0%|          | 0/29 [00:00<?, ?it/s]

100%|██████████| 29/29 [00:07<00:00,  3.63it/s]
100%|██████████| 10/10 [00:14<00:00,  1.41s/it]
100%|██████████| 5/5 [00:03<00:00,  1.34it/s]
100%|██████████| 29/29 [00:07<00:00,  3.69it/s]
100%|██████████| 10/10 [00:13<00:00,  1.37s/it]
100%|██████████| 5/5 [00:03<00:00,  1.33it/s]
100%|██████████| 29/29 [00:07<00:00,  4.07it/s]
100%|██████████| 10/10 [00:14<00:00,  1.43s/it]
100%|██████████| 5/5 [00:03<00:00,  1.31it/s]
100%|██████████| 29/29 [00:07<00:00,  4.02it/s]
100%|██████████| 10/10 [00:13<00:00,  1.38s/it]
100%|██████████| 5/5 [00:03<00:00,  1.31it/s]
100%|██████████| 29/29 [00:07<00:00,  4.05it/s]
100%|██████████| 10/10 [00:13<00:00,  1.38s/it]
100%|██████████| 5/5 [00:03<00:00,  1.30it/s]
100%|██████████| 29/29 [00:07<00:00,  3.97it/s]
100%|██████████| 10/10 [00:13<00:00,  1.39s/it]
100%|██████████| 5/5 [00:03<00:00,  1.30it/s]
100%|██████████| 29/29 [00:07<00:00,  3.99it/s]
100%|██████████| 10/10 [00:14<00:00,  1.40s/it]
100%|██████████| 5/5 [00:03<00:00,  1.29it/s]
100%|█

In [21]:
unlearned_frac2 = [x['modified_metrics']['mean_correct'] for x in metrics_intervention_results2]
control_frac2 = [x['mean_correct'] for x in control_metrics_results2]
list(zip(loss_intervention_results2, unlearned_frac2, control_frac2))

[(-0.0019334077835083008, 0.9825581312179565, 1.0),
 (-0.002272820472717285, 0.9709302186965942, 1.0),
 (-0.002272820472717285, 0.7732558250427246, 1.0),
 (-0.002711629867553711, 0.7616279125213623, 1.0),
 (-0.00342862606048584, 0.7616279125213623, 1.0),
 (-0.00342862606048584, 0.7383720874786377, 1.0),
 (-0.00342862606048584, 0.5639534592628479, 1.0),
 (-0.0031900644302368165, 0.5058139562606812, 1.0),
 (-0.00668337345123291, 0.35465115308761597, 0.9000000357627869),
 (-0.003450155258178711, 0.2616279125213623, 0.8666667342185974),
 (0.004356789588928223, 0.25, 0.8666667342185974),
 (0.0069856405258178714, 0.19186046719551086, 0.40000003576278687),
 (0.0188795804977417, 0.19186046719551086, 0.40000003576278687),
 (0.025677967071533202, 0.19186046719551086, 0.40000003576278687),
 (0.03329613208770752, 0.19186046719551086, 0.40000003576278687),
 (0.033304452896118164, 0.19186046719551086, 0.40000003576278687),
 (0.04512257575988769, 0.19186046719551086, 0.40000003576278687),
 (0.0441360

In [62]:
unlearned_frac3 = [x['modified_metrics']['mean_correct'] for x in metrics_intervention_results3]
# control_frac3 = [x['mean_correct'] for x in control_metrics_results3]
list(zip(loss_intervention_results3, unlearned_frac3))

[(-0.00364229679107666, 0.6220930218696594),
 (-0.0038264989852905273, 0.5174418687820435),
 (-0.0031900644302368165, 0.5058139562606812),
 (-0.002042865753173828, 0.4883720874786377),
 (-0.0006659984588623047, 0.4883720874786377),
 (0.0009542226791381836, 0.4883720874786377),
 (0.0026932477951049803, 0.4883720874786377)]

In [78]:
unlearned_frac3 = [x['modified_metrics']['mean_correct'] for x in metrics_intervention_results3]
control_frac3 = [x['mean_correct'] for x in control_metrics_results3]
list(zip(loss_intervention_results3, unlearned_frac3, control_frac3))

[(-0.0038264989852905273, 0.5174418687820435, 1.0),
 (0.0005422830581665039, 0.47093021869659424, 1.0),
 (0.001424694061279297, 0.4651162624359131, 1.0),
 (0.006346702575683594, 0.45348837971687317, 1.0),
 (0.01067976951599121, 0.447674423456192, 1.0),
 (0.01063222885131836, 0.39534884691238403, 1.0),
 (0.020189642906188965, 0.39534884691238403, 1.0),
 (0.018984508514404298, 0.34302324056625366, 1.0)]

In [80]:
unlearned_frac3 = [x['modified_metrics']['mean_correct'] for x in metrics_intervention_results3]
control_frac3 = [x['mean_correct'] for x in control_metrics_results3]
list(zip(loss_intervention_results3, unlearned_frac3, control_frac3))

[(-0.00364229679107666, 0.6220930218696594, 1.0),
 (-0.0025598764419555663, 0.5465116500854492, 1.0),
 (-0.005872964859008789, 0.5058139562606812, 1.0),
 (-0.0036381006240844725, 0.5, 1.0),
 (-0.002408885955810547, 0.4883720874786377, 1.0),
 (-0.0025175333023071287, 0.41860464215278625, 1.0),
 (0.006616711616516113, 0.41860464215278625, 1.0),
 (0.0055562019348144535, 0.3604651093482971, 1.0)]

In [85]:
# Do cumulative pass over features

loss_intervention_results3 = []
metrics_intervention_results3 = []
control_metrics_results3 = []

all_permutations = list(itertools.permutations([0, 1, 2, 3]))


for multiplier in [20]:
    ul_tool2.base_activation_store.iterable_dataset = iter(ul_tool2.base_activation_store.dataset)
    ablate_params = {
        'features_to_ablate': filtered_features_sorted_by_loss2,
        'multiplier': multiplier,
        'intervention_method': 'clamp_feature_activation',
        'permutations': all_permutations,
    }
    
    metrics = ul_tool2.calculate_metrics(**ablate_params)
    metrics_intervention_results3.append(metrics)
    
    loss_added = ul_tool2.compute_loss_added(n_batch=10, **ablate_params)
    loss_intervention_results3.append(loss_added)
    
    control_metrics = ul_tool2.calculate_control_metrics(random_select_one=False, **ablate_params)
    control_metrics_results3.append(control_metrics)


  0%|          | 0/688 [00:00<?, ?it/s]

 93%|█████████▎| 637/688 [02:13<00:12,  3.93it/s]

In [None]:
metrics_intervention_results3[0].reshape(-1, 24)

In [84]:
unlearned_frac3 = [x['modified_metrics']['mean_correct'] for x in metrics_intervention_results3]
control_frac3 = [x['mean_correct'] for x in control_metrics_results3]
list(zip(loss_intervention_results3, unlearned_frac3, control_frac3))

[(0.0022559404373168946, 0.75, 1.0),
 (0.00037636756896972654, 0.45348837971687317, 1.0),
 (0.0055562019348144535, 0.3604651093482971, 1.0),
 (0.018984508514404298, 0.34302324056625366, 1.0),
 (0.03567800521850586, 0.33139535784721375, 1.0),
 (0.05340430736541748, 0.3139534890651703, 1.0),
 (0.07189581394195557, 0.3255814015865326, 0.9666666984558105)]

In [None]:
# loss_intervention_results = []
# metrics_intervention_results = []

# filtered_good_features = [12663, 4342, 5749, 10355, 1523, 15858, 12273, 14315, 4451, 1611, 10051, 16186, 7983, 6958, 1307, 11019, 6531, 12289]
# # filtered_good_features = [12663, 4342, 5749, 10355, 1523, 15858, 12273, 14315, 4451, 1611, 10051, 16186, 7983, 6958, 1307, 11019, 6531, 12289]

# all_permutations = list(itertools.permutations([0, 1, 2, 3]))


# for feature in filtered_good_features:
#     ul_tool2.base_activation_store.iterable_dataset = iter(ul_tool2.base_activation_store.dataset)
#     ablate_params = {
#         'features_to_ablate': [feature],
#         'multiplier': 20,
#         'intervention_method': 'scale_feature_activation',
#         'permutations': None,
#     }
    
#     metrics = ul_tool2.calculate_metrics(**ablate_params)
#     metrics_intervention_results.append(metrics)
#     loss_added = ul_tool2.compute_loss_added(n_batch=20, **ablate_params)
#     loss_intervention_results.append(loss_added)
