In [1]:
from IPython import get_ipython # type: ignore
ipython = get_ipython()
assert ipython is not None

ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

from datasets import load_dataset  
from transformer_lens import HookedTransformer
from sae_lens import SAE

import os
import torch
from tqdm import tqdm
import plotly.express as px
import gc
import numpy as np
from huggingface_hub import login




In [3]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../")

from sae.sparse_autoencoder import load_saved_sae
from sae.metrics import model_store_from_sae
from unlearning.metrics import convert_wmdp_data_to_prompt, convert_list_of_dicts_to_dict_of_lists
from unlearning.tool import UnlearningConfig, SAEUnlearningTool, MCQ_ActivationStoreAnalysis, ActivationStoreAnalysis
from unlearning.metrics import modify_and_calculate_metrics, calculate_metrics_list, create_df_from_metrics, calculate_MCQ_metrics, calculate_metrics_side_effects

from unlearning.var import GEMMA_INST_FORMAT
from itertools import permutations
import torch.nn.functional as F
import gc

import json

from unlearning.feature_attribution import find_topk_features_given_prompt, test_topk_features


all_permutations = list(permutations([0, 1, 2, 3]))



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
# You can skip this if you edited `pretrained_saes.yaml` manually

os.environ["GEMMA_2_SAE_WEIGHTS_ROOT"] = "/workspace/weights/"
assert os.path.exists(os.environ["GEMMA_2_SAE_WEIGHTS_ROOT"])

device = "cuda"

sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gemmascope-9b-pt-res",
    sae_id="layer_20/width_16k/average_l0_20",
    device=device
)


dataset = load_dataset("cais/wmdp", "wmdp-bio")

answers = [x['answer'] for x in dataset['test']]
questions = [x['question'] for x in dataset['test']]
choices_list = [x['choices'] for x in dataset['test']]

prompts = [convert_wmdp_data_to_prompt(question, choices, prompt_format="GEMMA_INST_FORMAT") for question, choices in zip(questions, choices_list)]



In [5]:
model = HookedTransformer.from_pretrained_no_processing("google/gemma-2-9b-it", dtype=torch.bfloat16)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loaded pretrained model google/gemma-2-9b-it into HookedTransformer


In [6]:
sae.cfg.hook_point = sae.cfg.hook_name
correct_question_ids = np.genfromtxt("../data/wmdp-bio_gemma2_9b_it_correct.csv")

In [20]:
feature_per_prompt = {}

known_good_features = []

for j, question_id in enumerate(correct_question_ids[:10]):

    question_id = int(question_id)

    print("Question #", question_id, j+1)
    
    prompt = prompts[question_id]
    choices = choices_list[question_id]
    answer = answers[question_id]
    question = questions[question_id]

    topk_features_unique, feature_attributions, topk_features, all_feature_activations, logit_diff_grad, topk_feature_attributions = find_topk_features_given_prompt(model,
                                                           prompt,
                                                           question,
                                                           choices,
                                                           answer,
                                                           sae,
                                                           hook_point=sae.cfg.hook_point)
    gc.collect()
    torch.cuda.empty_cache()
        
    intervention_results, feature_ids_to_probs, good_features = test_topk_features(model,
                                                                                   sae,
                                                                                   question_id,
                                                                                   topk_features_unique[:20],
                                                                                   known_good_features=known_good_features,
                                                                                   multiplier=30,
                                                                                   permutations=all_permutations)

    feature_per_prompt[question_id] = good_features
    
    gc.collect()
    torch.cuda.empty_cache()
    
    known_good_features = list(set([item for sublist in feature_per_prompt.values() for item in sublist]))

    

Question # 0 1


dict_keys(['blocks.20.hook_resid_post', 'blocks.20.hook_resid_post_grad'])


100%|██████████| 20/20 [00:35<00:00,  1.75s/it]


Question # 1 2
dict_keys(['blocks.20.hook_resid_post', 'blocks.20.hook_resid_post_grad'])


100%|██████████| 17/17 [00:30<00:00,  1.76s/it]


Question # 3 3
dict_keys(['blocks.20.hook_resid_post', 'blocks.20.hook_resid_post_grad'])


100%|██████████| 15/15 [00:25<00:00,  1.70s/it]


Question # 7 4
dict_keys(['blocks.20.hook_resid_post', 'blocks.20.hook_resid_post_grad'])


100%|██████████| 14/14 [00:22<00:00,  1.63s/it]


Question # 15 5
dict_keys(['blocks.20.hook_resid_post', 'blocks.20.hook_resid_post_grad'])


100%|██████████| 15/15 [00:23<00:00,  1.57s/it]


Question # 16 6
dict_keys(['blocks.20.hook_resid_post', 'blocks.20.hook_resid_post_grad'])


100%|██████████| 14/14 [00:26<00:00,  1.92s/it]


Question # 17 7
dict_keys(['blocks.20.hook_resid_post', 'blocks.20.hook_resid_post_grad'])


100%|██████████| 14/14 [00:31<00:00,  2.23s/it]


Question # 19 8
dict_keys(['blocks.20.hook_resid_post', 'blocks.20.hook_resid_post_grad'])


100%|██████████| 13/13 [00:35<00:00,  2.75s/it]


Question # 39 9
dict_keys(['blocks.20.hook_resid_post', 'blocks.20.hook_resid_post_grad'])


100%|██████████| 13/13 [00:24<00:00,  1.90s/it]


Question # 41 10
dict_keys(['blocks.20.hook_resid_post', 'blocks.20.hook_resid_post_grad'])


100%|██████████| 11/11 [00:18<00:00,  1.70s/it]


In [21]:
len(known_good_features)

37

In [22]:
feature_per_prompt

{0: [12165, 15471, 13624, 969, 10034],
 1: [1642],
 3: [11596, 11933, 11339, 358, 3324, 3633, 7012, 3813],
 7: [5847, 1833, 5561, 13982, 8783],
 15: [8290, 35],
 16: [],
 17: [15245, 13051, 10839, 5295, 12501, 1022, 6051, 3655, 12176],
 19: [6337, 6266, 1867],
 39: [1426, 2962],
 41: [2348, 14082]}

{0: [832, 12165, 15471, 969, 10034, 12758]}