In [23]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../")

import torch
import random

from sae.sparse_autoencoder import load_saved_sae
from sae.metrics import model_store_from_sae
from unlearning.metrics import convert_wmdp_data_to_prompt, convert_list_of_dicts_to_dict_of_lists
from unlearning.tool import UnlearningConfig, SAEUnlearningTool, MCQ_ActivationStoreAnalysis, ActivationStoreAnalysis
from unlearning.metrics import modify_and_calculate_metrics, calculate_metrics_list, create_df_from_metrics
from unlearning.feature_attribution import calculate_cache, sort_by_loss_added

from huggingface_hub import hf_hub_download
from datasets import load_dataset
import numpy as np
import pandas as pd
import itertools
from transformer_lens import utils

from jaxtyping import Float
from torch import Tensor

import einops

from pathlib import Path

import plotly.express as px
from unlearning.var import REPO_ID, SAE_MAPPING
import pickle

from unlearning.metrics import all_permutations

from unlearning.metrics import calculate_metrics_side_effects
from unlearning.feature_attribution import find_topk_features_given_prompt, test_topk_features
from unlearning.feature_attribution import get_topk_features_by_attribution
from unlearning.feature_attribution import get_features_without_side_effects


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
# Load main SAE for gemma-2b-it
filename = hf_hub_download(repo_id=REPO_ID, filename=SAE_MAPPING['gemma_2b_it_resid_pre_9'])
sae = load_saved_sae(filename)
model = model_store_from_sae(sae)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b-it into HookedTransformer
Moving model to device:  cuda


In [7]:
dataset = load_dataset("cais/wmdp", "wmdp-bio", split='test')

answers = [x['answer'] for x in dataset]
questions = [x['question'] for x in dataset]
choices_list = [x['choices'] for x in dataset]

prompts = [convert_wmdp_data_to_prompt(question, choices, prompt_format=None)
           for question, choices in zip(questions, choices_list)]

# Hyper parameters for automation process
question_ids_correct = np.genfromtxt("../data/question_ids/gemma-2b-it/all/wmdp-bio_correct.csv", dtype=int)
questions_ids_correct_train = np.genfromtxt("../data/question_ids/gemma-2b-it/train/wmdp-bio_correct.csv", dtype=int)
topk_per_prompt = 20

unlearning_dataset = ['wmdp-bio']
side_effect_dataset_names =  ['high_school_us_history', 'college_computer_science', 'high_school_geography', 'human_aging', 'college_biology']
all_dataset_names = ['loss_added', 'wmdp-bio', 'high_school_us_history', 'college_computer_science', 'high_school_geography', 'human_aging', 'college_biology']


## Use functions from feature_attribution

In [17]:

question_ids = question_ids_correct[:10]
feature_per_prompt, known_good_features = get_topk_features_by_attribution(model,
                                                                           sae,
                                                                           question_ids,
                                                                           prompts,
                                                                           choices_list,
                                                                           answers,
                                                                           questions)


Question ID: 22, 1/10


100%|██████████| 20/20 [00:24<00:00,  1.22s/it]


Question ID: 70, 2/10


100%|██████████| 19/19 [00:16<00:00,  1.12it/s]


Question ID: 82, 3/10


100%|██████████| 20/20 [00:23<00:00,  1.17s/it]


Question ID: 89, 4/10


100%|██████████| 19/19 [00:17<00:00,  1.10it/s]


Question ID: 91, 5/10


100%|██████████| 17/17 [00:15<00:00,  1.08it/s]


Question ID: 155, 6/10


100%|██████████| 20/20 [00:17<00:00,  1.12it/s]


Question ID: 158, 7/10


100%|██████████| 20/20 [00:17<00:00,  1.12it/s]


Question ID: 161, 8/10


100%|██████████| 18/18 [00:16<00:00,  1.10it/s]


Question ID: 172, 9/10


100%|██████████| 17/17 [00:20<00:00,  1.22s/it]


Question ID: 180, 10/10


100%|██████████| 18/18 [00:16<00:00,  1.12it/s]


In [18]:

metrics_list, feature_ids_zero_side_effect = get_features_without_side_effects(model,
                                      sae,
                                      known_good_features, 
                                      thresh=10,
                                      target_metrics='correct',
                                      split='all',
                                      multiplier=30,
                                      intervention_method='clamp_feature_activation',
                                     )

  0%|          | 0/11 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/28.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.16k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.93k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/198 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/31.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.28k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.67k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/223 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/23 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

100%|██████████| 11/11 [01:09<00:00,  6.30s/it]


In [19]:
feature_ids_zero_side_effect

[12273]

In [20]:
activation_store = ActivationStoreAnalysis(sae.cfg, model)

Downloading builder script:   0%|          | 0.00/2.73k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.35k [00:00<?, ?B/s]

buffer
dataloader


In [25]:
df_zero_side_effect, feature_ids_zero_side_effect_sorted = sort_by_loss_added(model,
                       sae,
                       feature_ids_zero_side_effect,
                       question_ids,
                       activation_store,
                       multiplier=20,
                       intervention_method='clamp_feature_activation',
                       n_batch_loss_added=10,
                       split='all',
                       verbose=False,
                      )

100%|██████████| 1/1 [00:14<00:00, 14.13s/it]


In [26]:
feature_ids_zero_side_effect_sorted

array([12273])

## First get the TopK features by attribution per prompt and find the features that modify the probability

In [9]:
feature_per_prompt = {}

known_good_features = []

question_ids = questions_ids_correct_train

for j, question_id in enumerate(question_ids):

    question_id = int(question_id)
    print(f"Question ID: {question_id}, {j + 1}/{len(question_ids)}")
    
    prompt = prompts[question_id]
    choices = choices_list[question_id]
    answer = answers[question_id]
    question = questions[question_id]

    topk_features_unique, feature_attributions, topk_features, all_feature_activations, logit_diff_grad, topk_feature_attributions = find_topk_features_given_prompt(model,
                                                           prompt,
                                                           question,
                                                           choices,
                                                           answer,
                                                           sae,
                                                           hook_point=sae.cfg.hook_point)

    intervention_results, feature_ids_to_probs, good_features, partially_unlearned = test_topk_features(model,
                                                                                   sae,
                                                                                   question_id,
                                                                                   topk_features_unique[:topk_per_prompt],
                                                                                   known_good_features=known_good_features,
                                                                                   multiplier=30)
    

    feature_per_prompt[question_id] = good_features
    
    known_good_features = list(set([item for sublist in feature_per_prompt.values() for item in sublist]))

    

Question ID: 1147, 1/86


100%|██████████| 20/20 [00:18<00:00,  1.05it/s]


Question ID: 357, 2/86


100%|██████████| 18/18 [00:21<00:00,  1.19s/it]


Question ID: 800, 3/86


100%|██████████| 20/20 [00:19<00:00,  1.04it/s]


Question ID: 825, 4/86


100%|██████████| 19/19 [00:18<00:00,  1.03it/s]


Question ID: 1015, 5/86


100%|██████████| 19/19 [00:18<00:00,  1.04it/s]


Question ID: 837, 6/86


100%|██████████| 15/15 [00:14<00:00,  1.06it/s]


Question ID: 542, 7/86


100%|██████████| 18/18 [00:17<00:00,  1.02it/s]


Question ID: 588, 8/86


100%|██████████| 19/19 [00:17<00:00,  1.10it/s]


Question ID: 541, 9/86


100%|██████████| 16/16 [00:16<00:00,  1.04s/it]


Question ID: 82, 10/86


100%|██████████| 18/18 [00:17<00:00,  1.05it/s]


Question ID: 555, 11/86


100%|██████████| 19/19 [00:18<00:00,  1.03it/s]


Question ID: 320, 12/86


100%|██████████| 18/18 [00:16<00:00,  1.12it/s]


Question ID: 778, 13/86


100%|██████████| 20/20 [00:19<00:00,  1.03it/s]


Question ID: 382, 14/86


100%|██████████| 18/18 [00:16<00:00,  1.08it/s]


Question ID: 217, 15/86


100%|██████████| 19/19 [00:18<00:00,  1.02it/s]


Question ID: 649, 16/86


100%|██████████| 18/18 [00:16<00:00,  1.06it/s]


Question ID: 737, 17/86


100%|██████████| 18/18 [00:16<00:00,  1.08it/s]


Question ID: 634, 18/86


100%|██████████| 18/18 [00:16<00:00,  1.07it/s]


Question ID: 324, 19/86


100%|██████████| 18/18 [00:23<00:00,  1.28s/it]


Question ID: 730, 20/86


100%|██████████| 18/18 [00:17<00:00,  1.05it/s]


Question ID: 353, 21/86


100%|██████████| 19/19 [00:23<00:00,  1.25s/it]


Question ID: 243, 22/86


100%|██████████| 19/19 [00:17<00:00,  1.06it/s]


Question ID: 180, 23/86


100%|██████████| 17/17 [00:15<00:00,  1.07it/s]


Question ID: 770, 24/86


100%|██████████| 17/17 [00:16<00:00,  1.06it/s]


Question ID: 683, 25/86


100%|██████████| 18/18 [00:16<00:00,  1.07it/s]


Question ID: 360, 26/86


100%|██████████| 19/19 [00:18<00:00,  1.03it/s]


Question ID: 348, 27/86


100%|██████████| 19/19 [00:18<00:00,  1.04it/s]


Question ID: 617, 28/86


100%|██████████| 19/19 [00:22<00:00,  1.19s/it]


Question ID: 864, 29/86


100%|██████████| 19/19 [00:18<00:00,  1.05it/s]


Question ID: 367, 30/86


100%|██████████| 17/17 [00:15<00:00,  1.08it/s]


Question ID: 729, 31/86


100%|██████████| 19/19 [00:17<00:00,  1.06it/s]


Question ID: 354, 32/86


100%|██████████| 17/17 [00:16<00:00,  1.05it/s]


Question ID: 89, 33/86


100%|██████████| 17/17 [00:20<00:00,  1.21s/it]


Question ID: 777, 34/86


100%|██████████| 18/18 [00:17<00:00,  1.01it/s]


Question ID: 826, 35/86


100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Question ID: 843, 36/86


100%|██████████| 18/18 [00:21<00:00,  1.21s/it]


Question ID: 1070, 37/86


100%|██████████| 17/17 [00:15<00:00,  1.09it/s]


Question ID: 1129, 38/86


100%|██████████| 17/17 [00:22<00:00,  1.32s/it]


Question ID: 258, 39/86


100%|██████████| 18/18 [00:16<00:00,  1.08it/s]


Question ID: 218, 40/86


100%|██████████| 17/17 [00:16<00:00,  1.06it/s]


Question ID: 172, 41/86


100%|██████████| 14/14 [00:13<00:00,  1.07it/s]


Question ID: 600, 42/86


100%|██████████| 17/17 [00:16<00:00,  1.01it/s]


Question ID: 898, 43/86


100%|██████████| 17/17 [00:16<00:00,  1.06it/s]


Question ID: 776, 44/86


100%|██████████| 15/15 [00:14<00:00,  1.02it/s]


Question ID: 207, 45/86


100%|██████████| 17/17 [00:17<00:00,  1.00s/it]


Question ID: 1206, 46/86


100%|██████████| 18/18 [00:17<00:00,  1.01it/s]


Question ID: 774, 47/86


100%|██████████| 18/18 [00:17<00:00,  1.04it/s]


Question ID: 949, 48/86


100%|██████████| 20/20 [00:19<00:00,  1.05it/s]


Question ID: 1255, 49/86


100%|██████████| 17/17 [00:16<00:00,  1.04it/s]


Question ID: 884, 50/86


100%|██████████| 18/18 [00:22<00:00,  1.23s/it]


Question ID: 262, 51/86


100%|██████████| 15/15 [00:13<00:00,  1.09it/s]


Question ID: 1166, 52/86


100%|██████████| 14/14 [00:12<00:00,  1.08it/s]


Question ID: 513, 53/86


100%|██████████| 17/17 [00:16<00:00,  1.05it/s]


Question ID: 1151, 54/86


100%|██████████| 16/16 [00:15<00:00,  1.06it/s]


Question ID: 965, 55/86


100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Question ID: 1251, 56/86


100%|██████████| 17/17 [00:16<00:00,  1.03it/s]


Question ID: 371, 57/86


100%|██████████| 15/15 [00:13<00:00,  1.09it/s]


Question ID: 612, 58/86


100%|██████████| 16/16 [00:15<00:00,  1.05it/s]


Question ID: 663, 59/86


100%|██████████| 18/18 [00:16<00:00,  1.08it/s]


Question ID: 636, 60/86


100%|██████████| 17/17 [00:16<00:00,  1.06it/s]


Question ID: 682, 61/86


100%|██████████| 15/15 [00:14<00:00,  1.05it/s]


Question ID: 330, 62/86


100%|██████████| 17/17 [00:16<00:00,  1.05it/s]


Question ID: 1207, 63/86


100%|██████████| 14/14 [00:13<00:00,  1.04it/s]


Question ID: 925, 64/86


100%|██████████| 15/15 [00:13<00:00,  1.07it/s]


Question ID: 366, 65/86


100%|██████████| 17/17 [00:15<00:00,  1.07it/s]


Question ID: 559, 66/86


100%|██████████| 20/20 [00:19<00:00,  1.02it/s]


Question ID: 744, 67/86


100%|██████████| 17/17 [00:15<00:00,  1.09it/s]


Question ID: 190, 68/86


100%|██████████| 18/18 [00:16<00:00,  1.11it/s]


Question ID: 592, 69/86


100%|██████████| 14/14 [00:18<00:00,  1.33s/it]


Question ID: 652, 70/86


100%|██████████| 15/15 [00:14<00:00,  1.06it/s]


Question ID: 265, 71/86


100%|██████████| 14/14 [00:13<00:00,  1.01it/s]


Question ID: 963, 72/86


100%|██████████| 13/13 [00:11<00:00,  1.09it/s]


Question ID: 689, 73/86


100%|██████████| 15/15 [00:13<00:00,  1.08it/s]


Question ID: 751, 74/86


100%|██████████| 15/15 [00:14<00:00,  1.05it/s]


Question ID: 630, 75/86


100%|██████████| 14/14 [00:13<00:00,  1.07it/s]


Question ID: 375, 76/86


100%|██████████| 15/15 [00:13<00:00,  1.11it/s]


Question ID: 261, 77/86


100%|██████████| 14/14 [00:18<00:00,  1.30s/it]


Question ID: 1165, 78/86


100%|██████████| 15/15 [00:14<00:00,  1.06it/s]


Question ID: 192, 79/86


100%|██████████| 16/16 [00:15<00:00,  1.07it/s]


Question ID: 91, 80/86


100%|██████████| 15/15 [00:13<00:00,  1.09it/s]


Question ID: 645, 81/86


100%|██████████| 9/9 [00:08<00:00,  1.11it/s]


Question ID: 656, 82/86


100%|██████████| 16/16 [00:14<00:00,  1.09it/s]


Question ID: 1159, 83/86


100%|██████████| 17/17 [00:17<00:00,  1.01s/it]


Question ID: 591, 84/86


100%|██████████| 13/13 [00:18<00:00,  1.44s/it]


Question ID: 352, 85/86


100%|██████████| 14/14 [00:14<00:00,  1.04s/it]


Question ID: 841, 86/86


100%|██████████| 16/16 [00:14<00:00,  1.10it/s]


In [51]:
np.array(known_good_features).shape

(104,)

In [10]:
np.savetxt("known_good_features_gemma1_2b_train.csv", np.array(known_good_features))

## Calculate side-effects

In [13]:
np.array(known_good_features)

array([12289,  4617,  3599, 15892,  1557,  6172,    32,  4654,  5691,
        4160,  4687,  1620,  5205, 11358, 10355,  6263, 13431,  6273,
       13443,  3728, 13980,  6308,  4777,  4271,  9399,  4802, 13010,
         741,  5861, 16112,  4342,  9473, 12550,  3852, 12044,  5904,
        2834,  4886, 15642, 16175,  7983,  7484, 10046,   833, 10051,
         839,   842,   338, 14687,  5996, 10097, 11122, 12663,  1406,
        6531, 10632, 11147,  8596,  5525, 13718,  7076,  2469,  2993,
       10176,  9163,  7122,  8660, 16341,  7638, 14296, 14819,  7140,
        4071, 12782, 12273, 15858,  1523])

In [15]:
# Calculate metrics

main_ablate_params = {
                      'multiplier': 30,
                      'intervention_method': 'clamp_feature_activation',
                     }


sweep = {
         'features_to_ablate': np.array(known_good_features),
        }


dataset_names = all_dataset_names[2:-1]

metrics_list = calculate_metrics_side_effects(model,
                                      sae,
                                      main_ablate_params,
                                      sweep,
                                      dataset_names=dataset_names,
                                      thresh=0,
                                      target_metric='correct',
                                      split='all')


  0%|          | 0/77 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/28.1k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.25k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.81k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/28.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.16k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.93k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/198 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

Downloading data:   0%|          | 0.00/31.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.28k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.67k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/223 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/23 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

100%|██████████| 77/77 [09:35<00:00,  7.48s/it]


In [16]:
feature_ids_zero_side_effect = [x['ablate_params']['features_to_ablate'] for x in metrics_list]
np.array(feature_ids_zero_side_effect).shape

(34,)

In [17]:
np.savetxt("feature_ids_zero_side_effect_gemma1_2b_train.csv", np.array(feature_ids_zero_side_effect))

In [18]:
np.array(feature_ids_zero_side_effect)

array([ 3599,  1557,    32,  4654,  4687,  5205, 11358, 13431, 13443,
       13980,  6308,  4777,  4271, 13010,  5861,  4342,  9473, 12044,
       15642, 16175,  7983, 10051,  5996,  1406, 11147,  8596,  5525,
        7076,  9163,  8660, 16341, 14296,  4071, 12273])

In [57]:
# np.savetxt("feature_ids_zero_side_effect_gemma1_2b_all.csv", np.array(feature_ids_zero_side_effect))

## Then sort by loss added

In [20]:
activation_store = ActivationStoreAnalysis(sae.cfg, model)

Downloading builder script:   0%|          | 0.00/2.73k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.35k [00:00<?, ?B/s]

buffer
dataloader


In [41]:
# Calculate metrics

main_ablate_params = {
                      'multiplier': 20,
                      'intervention_method': 'clamp_feature_activation',
                     }


sweep = {
         'features_to_ablate': feature_ids_zero_side_effect,
        }

metric_params = {'wmdp-bio': 
                 {
                       # 'target_metric': 'correct',
                       'question_subset': questions_ids_correct_train,
                       'permutations': None,
                       'verbose': False,
                   }
                 }

dataset_names = all_dataset_names[:2]

n_batch_loss_added = 10

metrics_list_zero_side_effect = calculate_metrics_list(model,
                                      sae,
                                      main_ablate_params,
                                      sweep,
                                      dataset_names=dataset_names,
                                      metric_params=metric_params,
                                      include_baseline_metrics=False,
                                      n_batch_loss_added=n_batch_loss_added,
                                      activation_store=activation_store,
                                      split='all',
                                      verbose=False)


100%|██████████| 34/34 [09:39<00:00, 17.05s/it]


In [42]:
df_zero_side_effect = create_df_from_metrics(metrics_list_zero_side_effect)
df_zero_side_effect

Unnamed: 0,loss_added,wmdp-bio,wmdp-bio_prob
0,0.0022542,0.988372,0.995117
1,-7.152557e-08,0.94186,0.985352
2,0.0005426645,0.988372,0.993652
3,-7.152557e-08,0.918605,0.965332
4,0.001539111,0.988372,0.995117
5,-7.152557e-08,0.988372,0.995117
6,-3.919601e-05,0.988372,0.995605
7,-0.002468824,0.988372,0.995117
8,-7.152557e-08,0.988372,0.995117
9,-7.152557e-08,1.0,0.994141


In [43]:
isorted = df_zero_side_effect.query("`wmdp-bio` < 1").sort_values("loss_added").index.values
feature_ids_zero_side_effect_sorted = np.array(feature_ids_zero_side_effect)[isorted]
feature_ids_zero_side_effect_sorted

array([13431,  7983,  4342,  6308, 15642, 11358,  8660,  9473,  9163,
        8596, 10051, 12044, 12273, 13443,  5205,  4654,  1557,    32,
        7076,  5525,  4071,  4687, 11147,  3599,  5861])

In [44]:
len(feature_ids_zero_side_effect_sorted)

25

In [25]:
# np.savetxt("feature_ids_zero_side_effect_sorted_gemma1_2b_all.csv", np.array(feature_ids_zero_side_effect_sorted))

## Now progressively add features sorted by loss

In [45]:
# Calculate metrics

main_ablate_params = {
                      'multiplier': 30,
                      'intervention_method': 'clamp_feature_activation',
                     }


sweep = {
         'features_to_ablate': [feature_ids_zero_side_effect_sorted[:i+1] for i in range(23, 24)],
         'multiplier': [20],
        }

metric_params = {'wmdp-bio': 
                 {
                       'target_metric': 'correct',
                       'permutations': None,
                       'verbose': False,
                   }
                 }

dataset_names = all_dataset_names[:2]

n_batch_loss_added = 20

metrics_list_best_sorted = calculate_metrics_list(model,
                                      sae,
                                      main_ablate_params,
                                      sweep,
                                      dataset_names=dataset_names,
                                      metric_params=metric_params,
                                      include_baseline_metrics=False,
                                      n_batch_loss_added=n_batch_loss_added,
                                      activation_store=activation_store,
                                      split='test')


100%|██████████| 1/1 [00:28<00:00, 28.25s/it]


In [40]:
df = create_df_from_metrics(metrics_list_best_sorted)
df

Unnamed: 0,loss_added,wmdp-bio,wmdp-bio_prob
0,0.01276,0.77907,0.946289


In [67]:
df = create_df_from_metrics(metrics_list_best_sorted)
df

Unnamed: 0,loss_added,wmdp-bio,high_school_us_history,college_computer_science,high_school_geography,human_aging,college_biology,wmdp-bio_prob,high_school_us_history_prob,college_computer_science_prob,high_school_geography_prob,human_aging_prob,college_biology_prob
0,-0.001015,0.988372,1.0,1.0,1.0,1.0,1.0,0.995117,0.987793,0.998535,0.996582,0.994141,0.997559
1,-0.001767,0.965116,1.0,1.0,1.0,1.0,1.0,0.989258,0.987793,0.998535,0.996582,0.994141,0.997559
2,-0.002136,0.953488,1.0,1.0,1.0,1.0,1.0,0.989258,0.987793,0.998535,0.996582,0.994141,0.997559
3,-0.002155,0.930233,1.0,1.0,1.0,1.0,1.0,0.984863,0.987793,0.998535,0.996582,0.994141,0.997559
4,-0.002642,0.918605,1.0,1.0,1.0,1.0,1.0,0.985352,0.987793,0.998535,0.996582,0.994141,0.997559
5,-0.002007,0.906977,1.0,1.0,1.0,1.0,1.0,0.985352,0.987793,0.998535,0.996582,0.994141,0.997559
6,-0.002163,0.895349,1.0,1.0,1.0,1.0,1.0,0.990234,0.987793,0.998535,0.996582,0.993164,0.997559
7,-0.002163,0.895349,1.0,1.0,1.0,1.0,1.0,0.990234,0.987793,0.998535,0.996582,0.993164,0.997559
8,-0.002186,0.872093,1.0,1.0,1.0,1.0,1.0,0.983887,0.987793,0.998535,0.996582,0.992676,0.997559
9,-0.002398,0.837209,1.0,1.0,1.0,1.0,1.0,0.977539,0.987793,0.998535,0.996582,0.992676,0.997559


In [26]:
# metrics_list_best_sorted

In [24]:
df = create_df_from_metrics(metrics_list_best_sorted)
df

Unnamed: 0,loss_added,wmdp-bio,high_school_us_history,college_computer_science,high_school_geography,human_aging,college_biology,wmdp-bio_prob,high_school_us_history_prob,college_computer_science_prob,high_school_geography_prob,human_aging_prob,college_biology_prob
0,-0.000212,0.976744,1.0,1.0,1.0,1.0,1.0,0.994141,0.987305,0.998535,0.996582,0.994141,0.997559
1,0.004665,0.767442,1.0,1.0,1.0,0.9375,1.0,0.970215,0.987305,0.998535,0.996582,0.990234,0.92334
2,0.004644,0.732558,1.0,1.0,1.0,0.9375,1.0,0.972168,0.987305,0.998535,0.996582,0.989746,0.907227
3,0.005248,0.732558,1.0,1.0,1.0,0.9375,1.0,0.967773,0.987305,0.998535,0.996582,0.989746,0.909668


In [8]:
feature_ids_zero_side_effect_sorted = np.genfromtxt("feature_ids_zero_side_effect_sorted_gemma1_2b_all.csv")
feature_ids_zero_side_effect = np.genfromtxt("feature_ids_zero_side_effect_gemma1_2b_all.csv")

In [27]:
# Calculate metrics

main_ablate_params = {
                      'intervention_method': 'clamp_feature_activation',
                      # 'jump': 0.5,
                     }


sweep = {
         'features_to_ablate': [feature_ids_zero_side_effect_sorted[:-3]],
         'multiplier': [20, 30],
         # 'jump': [0., 1],
        }

metric_params = {'wmdp-bio': 
                 {
                       'target_metric': 'correct',
                       'permutations': all_permutations,
                       'verbose': False,
                   }
                 }

dataset_names = all_dataset_names[:2]

n_batch_loss_added = 20

metrics_list_best_sorted4 = calculate_metrics_list(model,
                                      sae,
                                      main_ablate_params,
                                      sweep,
                                      dataset_names=dataset_names,
                                      metric_params=metric_params,
                                      include_baseline_metrics=False,
                                      n_batch_loss_added=n_batch_loss_added,
                                      activation_store=activation_store,
                                      split='all')


100%|██████████| 2/2 [05:22<00:00, 161.47s/it]


In [29]:
df4 = create_df_from_metrics(metrics_list_best_sorted4)
df4

Unnamed: 0,loss_added,wmdp-bio,wmdp-bio_prob
0,0.003645,0.690891,0.934082
1,0.027801,0.592539,0.915039


In [110]:
df3 = create_df_from_metrics(metrics_list_best_sorted3)
df3

Unnamed: 0,loss_added,wmdp-bio,high_school_us_history,college_computer_science,high_school_geography,human_aging,college_biology,wmdp-bio_prob,high_school_us_history_prob,college_computer_science_prob,high_school_geography_prob,human_aging_prob,college_biology_prob
0,0.004161,0.651163,1.0,1.0,1.0,1.0,0.866667,0.925293,0.994629,0.998535,0.99707,0.993164,0.961914


In [116]:
from unlearning.metrics import generate_ablate_params_list

In [118]:
ablate_params_list = generate_ablate_params_list(main_ablate_params, sweep)
for x in ablate_params_list:
    print(x['jump'], x['multiplier'])
    


0.0 20
0.5 20
1 20
1.5 20
0.0 40
0.5 40
1 40
1.5 40


In [81]:
df2 = create_df_from_metrics(metrics_list_best_sorted2)
df2

Unnamed: 0,loss_added,wmdp-bio,high_school_us_history,college_computer_science,high_school_geography,human_aging,college_biology,wmdp-bio_prob,high_school_us_history_prob,college_computer_science_prob,high_school_geography_prob,human_aging_prob,college_biology_prob
0,-0.003495,0.848837,1.0,1.0,1.0,1.0,1.0,0.965332,0.994629,0.999023,0.99707,0.993164,0.984863
1,-0.001551,0.755814,1.0,1.0,1.0,1.0,0.866667,0.951172,0.994629,0.998535,0.99707,0.993652,0.952637
2,0.004161,0.651163,1.0,1.0,1.0,1.0,0.866667,0.925293,0.994629,0.998535,0.99707,0.993164,0.961914
3,0.014927,0.575581,1.0,1.0,1.0,1.0,0.866667,0.902344,0.994629,0.998535,0.99707,0.992676,0.984863
4,0.030548,0.523256,1.0,1.0,1.0,1.0,0.866667,0.898438,0.994141,0.998535,0.99707,0.992188,0.950195


In [71]:
df2 = create_df_from_metrics(metrics_list_best_sorted2)
df2

Unnamed: 0,loss_added,wmdp-bio,high_school_us_history,college_computer_science,high_school_geography,human_aging,college_biology,wmdp-bio_prob,high_school_us_history_prob,college_computer_science_prob,high_school_geography_prob,human_aging_prob,college_biology_prob
0,-0.0064,0.616279,1.0,1.0,1.0,1.0,1.0,0.921387,0.98877,0.998535,0.996582,0.993164,0.995117


In [73]:
df2 = create_df_from_metrics(metrics_list_best_sorted2)
df2

Unnamed: 0,loss_added,wmdp-bio,high_school_us_history,college_computer_science,high_school_geography,human_aging,college_biology,wmdp-bio_prob,high_school_us_history_prob,college_computer_science_prob,high_school_geography_prob,human_aging_prob,college_biology_prob
0,0.019185,0.651163,1.0,1.0,1.0,1.0,0.866667,0.923828,0.986816,0.998535,0.99707,0.992676,0.961914


In [104]:
metrics_list_best_sorted2[0]['human_aging']

{'mean_correct': 0.875,
 'total_correct': 14,
 'is_correct': array([1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       dtype=float32),
 'output_probs': array([[8.1062e-06, 1.1921e-07, 1.2517e-06, 9.9316e-01],
        [6.9336e-01, 1.4946e-02, 7.4646e-02, 2.1021e-01],
        [9.9854e-01, 4.9651e-05, 7.6771e-05, 4.4107e-06],
        [9.9902e-01, 9.5367e-07, 1.1921e-06, 7.3314e-06],
        [1.4853e-04, 5.8413e-06, 9.9463e-01, 4.3511e-06],
        [9.9805e-01, 1.1921e-07, 1.1921e-07, 2.3842e-07],
        [6.3777e-05, 1.3268e-04, 9.9707e-01, 3.8385e-05],
        [1.1325e-06, 1.4901e-06, 9.9707e-01, 8.6427e-06],
        [5.7817e-06, 1.7881e-07, 5.3644e-07, 9.9219e-01],
        [1.2927e-01, 8.4521e-01, 2.3438e-02, 7.5483e-04],
        [1.0700e-03, 9.9561e-01, 4.2915e-06, 1.1325e-06],
        [1.8597e-04, 7.3314e-06, 9.9854e-01, 2.1935e-05],
        [1.1325e-06, 5.9605e-08, 9.9707e-01, 1.1921e-07],
        [5.9605e-08, 9.9951e-01, 7.1526e-07, 1.7881e-07],
        [1.3912e

In [120]:
# Calculate metrics

main_ablate_params = {
                      'multiplier': 30,
                      'intervention_method': 'clamp_feature_activation',
                     }


sweep = {
         'features_to_ablate': feature_ids_zero_side_effect[:1],
        }

metric_params = {'wmdp-bio': 
                    {
                       'target_metric': 'correct',
                       'permutations': None,
                    }
                }

dataset_names = all_dataset_names[2:6]

n_batch_loss_added = 10

metrics_list = calculate_metrics_side_effects(model,
                                      sae,
                                      main_ablate_params,
                                      sweep,
                                      dataset_names=dataset_names,
                                      thresh=0,
                                      split='train')


  0%|          | 0/1 [00:00<?, ?it/s]

13 13
1.0 1.0
4 4
1.0 1.0
15 15
1.0 1.0


100%|██████████| 1/1 [00:07<00:00,  7.13s/it]

15 16
0.9375 1.0





In [122]:
metrics_list

[]

In [114]:
metrics_list[0]['human_aging']['mean_correct']

1.0

In [108]:
metrics_list[0]['human_aging']

IndexError: list index out of range