In [2]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../")

import torch
import random

from sae.sparse_autoencoder import load_saved_sae
from sae.metrics import model_store_from_sae
from unlearning.metrics import convert_wmdp_data_to_prompt, convert_list_of_dicts_to_dict_of_lists
from unlearning.tool import UnlearningConfig, SAEUnlearningTool, MCQ_ActivationStoreAnalysis, ActivationStoreAnalysis
from unlearning.metrics import modify_and_calculate_metrics, calculate_metrics_list, create_df_from_metrics

from huggingface_hub import hf_hub_download
from datasets import load_dataset
import numpy as np
import pandas as pd
import itertools
from transformer_lens import utils

from jaxtyping import Float
from torch import Tensor

import einops

from pathlib import Path

import plotly.express as px
from unlearning.var import REPO_ID, SAE_MAPPING
import pickle

from unlearning.metrics import all_permutations

from unlearning.feature_attribution import find_topk_features_given_prompt, test_topk_features
from tqdm import tqdm


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
# Load main SAE for gemma-2b-it
filename = hf_hub_download(repo_id=REPO_ID, filename=SAE_MAPPING['gemma_2b_it_resid_pre_9'])
sae = load_saved_sae(filename)
model = model_store_from_sae(sae)

(…)cks.9.hook_resid_pre_s16384_127995904.pt:   0%|          | 0.00/269M [00:00<?, ?B/s]



config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/34.2k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]



Loaded pretrained model gemma-2b-it into HookedTransformer
Moving model to device:  cuda


In [4]:
# pass in the dataset as an argument so no need to load in everytime
dataset = load_dataset("cais/wmdp", "wmdp-bio")

answers = [x['answer'] for x in dataset['test']]
questions = [x['question'] for x in dataset['test']]
choices_list = [x['choices'] for x in dataset['test']]

prompts = [convert_wmdp_data_to_prompt(question, choices, prompt_format=None) for question, choices in zip(questions, choices_list)]

# features_ids_prompt_70 = [ 5681, 12639,  9597,  6272, 14509]

unlearning_dataset = ['wmdp-bio']
side_effect_dataset_names =  ['high_school_us_history', 'college_computer_science', 'high_school_geography', 'human_aging', 'college_biology']
all_dataset_names = ['loss_added', 'wmdp-bio', 'high_school_us_history', 'college_computer_science', 'high_school_geography', 'human_aging', 'college_biology']

question_ids_correct = np.genfromtxt("../data/wmdp-bio_gemma_2b_it_correct_no_tricks.csv")

non_zero_features_list = np.genfromtxt("non_zero_features_list.csv")

Downloading readme:   0%|          | 0.00/4.64k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/258k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/1273 [00:00<?, ? examples/s]

In [57]:
question_ids = [357, 1147]

feature_per_prompt = {}

known_good_features = []

model.reset_hooks()

question_ids_correct[:1]

for j, question_id in enumerate([1147]):

    question_id = int(question_id)

    print("Question #", question_id, j+1, "/159")
    
    prompt = prompts[question_id]
    choices = choices_list[question_id]
    answer = answers[question_id]
    question = questions[question_id]

    topk_features_unique = find_topk_features_given_prompt(model,
                                                           prompt,
                                                           question,
                                                           choices,
                                                           answer,
                                                           sae,
                                                           hook_point='blocks.9.hook_resid_pre')[0]
    
    intervention_results, feature_ids_to_probs, good_features = test_topk_features(model,
                                                                                   sae,
                                                                                   question_id,
                                                                                   topk_features_unique[:50],
                                                                                   known_good_features=known_good_features,
                                                                                   multiplier=30,
                                                                                   thres_correct_ans_prob=0.8,
                                                                                   permutations=all_permutations)

    feature_per_prompt[question_id] = good_features
    
    known_good_features = list(set([item for sublist in feature_per_prompt.values() for item in sublist]))
    
    

Question # 1147 1 /159


100%|██████████| 50/50 [01:13<00:00,  1.46s/it]


In [56]:
feature_per_prompt

{22: [12359,
  4802,
  8794,
  10632,
  6308,
  4997,
  12782,
  12663,
  2993,
  12435,
  15755,
  5904]}

In [58]:
feature_per_prompt

{1147: [4802, 8139, 12289, 12273, 13715, 4550, 9280, 5904, 10189]}

In [None]:
# Calculate metrics

main_ablate_params = {
                      'multiplier': 20,
                      'intervention_method': 'clamp_feature_activation',
                     }


sweep = {
         'features_to_ablate': features_for_prompt_243,
        }

metric_params = {'wmdp-bio': 
                 {
                       'target_metric': 'correct',
                       'permutations': None,
                   }
                 }

dataset_names = all_dataset_names[2:]

n_batch_loss_added = 10

metrics_list = calculate_metrics_side_effects(model,
                                      sae,
                                      main_ablate_params,
                                      sweep,
                                      dataset_names=dataset_names,
                                      metric_params=metric_params,
                                      n_batch_loss_added=n_batch_loss_added,)
                                      # activation_store=activation_store)


In [5]:
## Find all non-zero features
feature_per_prompt = {}

known_good_features = []

model.reset_hooks()

non_zero_features_list = []

for j, question_id in tqdm(enumerate(question_ids_correct)):

    question_id = int(question_id)

    if j % 20 == 0:
        print("Question #", question_id, j+1, "/159")
    
    prompt = prompts[question_id]
    choices = choices_list[question_id]
    answer = answers[question_id]
    question = questions[question_id]

    all_feature_attributions = find_topk_features_given_prompt(model,
                                                           prompt,
                                                           question,
                                                           choices,
                                                           answer,
                                                           sae,
                                                           hook_point='blocks.9.hook_resid_pre')[1]
    non_zero_features = list(all_feature_attributions.min(axis=0)[0].nonzero().T[0].cpu().numpy())
    non_zero_features_list.append(non_zero_features)

0it [00:00, ?it/s]

Question # 22 1 /159


20it [00:13,  1.83it/s]

Question # 351 21 /159


40it [00:26,  1.14it/s]

Question # 447 41 /159


60it [00:44,  1.33it/s]

Question # 630 61 /159


80it [01:00,  1.22it/s]

Question # 735 81 /159


100it [01:19,  1.04it/s]

Question # 826 101 /159


120it [01:38,  1.10it/s]

Question # 933 121 /159


140it [01:58,  1.12s/it]

Question # 1130 141 /159


159it [02:14,  1.18it/s]


In [6]:
len(non_zero_features_list)

159

In [7]:
np.mean([len(x) for x in non_zero_features_list])

279.0377358490566

In [8]:
# non_zero_features_list[0]

In [9]:
all_non_zero_features = [item for sublist in non_zero_features_list for item in sublist]

In [10]:
%%time

features_to_prompts = {}

for f in all_non_zero_features:
    prompts = [prompt for prompt, features in zip(question_ids_correct, non_zero_features_list) if f in features]
    features_to_prompts[f] = prompts

CPU times: user 18.2 s, sys: 0 ns, total: 18.2 s
Wall time: 18.2 s


In [13]:
features_to_prompts[57]

[22.0, 158.0, 559.0, 744.0, 958.0, 1116.0, 1165.0]

In [88]:
non_zero_features_list2 = [item for sublist in non_zero_features_list for item in sublist]

In [89]:
np.unique(non_zero_features_list)[:20]

array([ 0,  2,  3,  6, 12, 13, 15, 17, 18, 19, 20, 25, 26, 28, 29, 32, 33,
       37, 41, 42])

In [90]:
len(np.unique(non_zero_features_list))

6378

In [103]:
np.savetxt("non_zero_features_list.csv", np.unique(non_zero_features_list))

In [11]:
non_zero_features_list

array([0.0000e+00, 2.0000e+00, 3.0000e+00, ..., 1.6381e+04, 1.6382e+04,
       1.6383e+04])

In [12]:
len(non_zero_features_list)

6378

## Unlearn permutations

In [None]:
features_to_prompts

In [4]:
%%time

all_metrics = []

for f in tqdm(list(features_to_prompts.keys())[:100]):
             
    ablate_params = {
                      'multiplier': 30,
                      'intervention_method': 'clamp_feature_activation',
                      'features_to_ablate': [int(f)]
                     }
    
                 
    metric_params = {'wmdp-bio': 
                     {
                           'question_subset': [int(x) for x in features_to_prompts[f]],
                           'permutations': None,
                           'verbose': False,
                       }
                     }

    print([int(x) for x in features_to_prompts[f]])
    
    
    metrics = modify_and_calculate_metrics(model,
                                     sae,
                                     dataset_names=['wmdp-bio'],
                                     metric_params=metric_params,
                                     activation_store=None,
                                     **ablate_params)

    all_metrics.append(metrics)



NameError: name 'features_to_prompts' is not defined

In [8]:
%%time

main_ablate_params = {
                      'multiplier': 30,
                      'intervention_method': 'clamp_feature_activation',
                     }

sweep = {
         'features_to_ablate': [int(x) for x in non_zero_features_list[:100]],
        }

metric_params = {'wmdp-bio': 
                 {
                       'question_subset': question_ids_correct,
                       'permutations': all_permutations,
                       'verbose': False,
                   }
                 }

dataset_names = all_dataset_names[1:2]

n_batch_loss_added = 50

metrics = calculate_metrics_list(model,
                                  sae,
                                  main_ablate_params,
                                  sweep,
                                  dataset_names=dataset_names,
                                  metric_params=metric_params,
                                  n_batch_loss_added=n_batch_loss_added,)
                                  # activation_store=activation_store)


  1%|          | 1/100 [02:54<4:47:12, 174.07s/it]


KeyboardInterrupt: 

In [19]:
metrics[0]['wmdp-bio']['mean_correct']

0.9999999403953552

In [23]:
mean_correct = [int(round(x['wmdp-bio']['mean_correct'] * 24 * 159, 0)) for x in metrics]

In [24]:
np.array(mean_correct)

array([3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816,
       3816, 3816, 3816, 3816, 3768, 3816, 3816, 3816, 3816, 3816, 3816,
       3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816,
       3816, 3816, 3816, 3816, 3816, 3816, 3816, 3792, 3816, 3792, 3816,
       3816, 3816, 3816, 3816, 3816, 3816, 3792, 3816, 3816, 3816, 3816,
       3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816,
       3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816,
       3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816,
       3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816,
       3816, 3792, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816,
       3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816, 3816,
       3816, 3816, 3816, 3816, 3816, 3816, 3816, 3744, 3816, 3816, 3816,
       3816, 3816, 3816, 3816, 3792, 3816, 3816, 3816, 3816, 3816, 3816,
       3816, 3816, 3816, 3816, 3816, 3816, 3816, 38

In [26]:
(np.array(mean_correct) < 3816).sum()

10

In [27]:
len(non_zero_features_list)

6378

In [28]:
6378 * 0.05

318.90000000000003

## Side Effects

In [29]:

# features_ids_prompt_70 = [ 5681, 12639,  9597,  6272, 14509]
unlearning_dataset = ['wmdp-bio']
side_effect_dataset_names =  ['high_school_us_history', 'college_computer_science', 'high_school_geography', 'human_aging', 'college_biology']
all_dataset_names = ['loss_added', 'wmdp-bio', 'high_school_us_history', 'college_computer_science', 'high_school_geography', 'human_aging', 'college_biology']

from unlearning.metrics import calculate_metrics_side_effects

In [30]:
# Calculate metrics

main_ablate_params = {
                      'multiplier': 20,
                      'intervention_method': 'clamp_feature_activation',
                     }


sweep = {
         'features_to_ablate': [int(x) for x in non_zero_features_list[:20]],
        }

metric_params = {'wmdp-bio': 
                 {
                       'target_metric': 'correct',
                       'permutations': None,
                   }
                 }

dataset_names = all_dataset_names[2:]

n_batch_loss_added = 10

metrics_list = calculate_metrics_side_effects(model,
                                      sae,
                                      main_ablate_params,
                                      sweep,
                                      dataset_names=dataset_names,
                                      metric_params=metric_params,
                                      n_batch_loss_added=n_batch_loss_added,)
                                      # activation_store=activation_store)


  0%|          | 0/20 [00:00<?, ?it/s]

Downloading readme:   0%|          | 0.00/53.2k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/138k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/155k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/27.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/17.8k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/204 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/28.1k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.25k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.81k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/28.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.16k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.93k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/198 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/31.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.28k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.67k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/223 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/23 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/31.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.90k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.27k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/144 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/16 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

100%|██████████| 20/20 [03:53<00:00, 11.67s/it]


In [31]:
feature_ids_zero_side_effect = [x['ablate_params']['features_to_ablate'] for x in metrics_list]
np.array(feature_ids_zero_side_effect)

array([ 0,  2,  3,  6, 12, 13, 15, 17, 18, 19, 20, 25, 26, 28, 29, 32, 33,
       37, 41, 42])

In [32]:
len(feature_ids_zero_side_effect)

20

## All feature activations

In [9]:
filename = "../data/wmdp-bio_gemma_2b_it_correct.csv"
correct_question_ids = np.genfromtxt(filename)

dataset_args = {
    'question_subset': correct_question_ids,
}

sae.cfg.n_batches_in_store_buffer = 86

act_store = MCQ_ActivationStoreAnalysis(sae.cfg, model, dataset_args=dataset_args)
unlearning_metric = 'wmdp-bio_gemma_2b_it_correct'

unlearn_cfg = UnlearningConfig(unlearn_activation_store=act_store, unlearning_metric=unlearning_metric)
ul_tool2 = SAEUnlearningTool(unlearn_cfg)
ul_tool2.setup(model=model)
ul_tool2.get_metrics_with_text()

Exception ignored in: <function Dataset.__del__ at 0x7f0c2cacb920>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/datasets/arrow_dataset.py", line 1421, in __del__
    def __del__(self):

KeyboardInterrupt: 

KeyboardInterrupt



In [None]:
metrics_with_text = ul_tool2.unlearn_metrics_with_text