In [1]:
# !pip install update git+https://github.com/TransformerLensOrg/TransformerLens.git

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../")

import torch
import random

from sae.sparse_autoencoder import load_saved_sae
from sae.metrics import model_store_from_sae
from unlearning.metrics import convert_wmdp_data_to_prompt, convert_list_of_dicts_to_dict_of_lists
from unlearning.tool import UnlearningConfig, SAEUnlearningTool, MCQ_ActivationStoreAnalysis, ActivationStoreAnalysis
from unlearning.metrics import modify_and_calculate_metrics, calculate_metrics_list, create_df_from_metrics
from unlearning.feature_attribution import calculate_cache

from huggingface_hub import hf_hub_download
from datasets import load_dataset
import numpy as np
import pandas as pd
import itertools
from transformer_lens import utils

from jaxtyping import Float
from torch import Tensor

import einops

from pathlib import Path

import plotly.express as px
from unlearning.var import REPO_ID, SAE_MAPPING
import pickle

from unlearning.metrics import all_permutations

from unlearning.metrics import calculate_metrics_side_effects
from unlearning.feature_attribution import find_topk_features_given_prompt, test_topk_features

from unlearning.jump_relu import load_gemma2_2b_sae

from transformer_lens import HookedTransformer


In [2]:
model = HookedTransformer.from_pretrained('google/gemma-2-2b-it')



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b-it into HookedTransformer


In [40]:
# # Load main SAE for gemma-2b-it
# filename = hf_hub_download(repo_id=REPO_ID, filename=SAE_MAPPING['gemma_2b_it_resid_pre_9'])
# sae = load_saved_sae(filename)
# model = model_store_from_sae(sae)
sae = load_gemma2_2b_sae(layer=13, l0=23)

# Load model directly

# model = HookedTransformer.from_pretrained("google/gemma-2-2b")


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

In [5]:
def gather_residual_activations(model, target_layer, inputs):
  target_act = None
  def gather_target_act_hook(mod, inputs, outputs):
    nonlocal target_act # make sure we can modify the target_act from the outer scope
    target_act = outputs[0]
    return outputs
  handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
  _ = model.forward(inputs)
  handle.remove()
  return target_act

In [11]:
# from transformers import AutoTokenizer, AutoModelForCausalLM

# tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
# hf_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b").to("cuda")

# target_act = gather_residual_activations(hf_model, 2, model.to_tokens("test").to("cuda"))

# target_act.shape

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
# pass in the dataset as an argument so no need to load in everytime
dataset = load_dataset("cais/wmdp", "wmdp-bio", split='test')

answers = [x['answer'] for x in dataset]
questions = [x['question'] for x in dataset]
choices_list = [x['choices'] for x in dataset]

prompts = [convert_wmdp_data_to_prompt(question, choices, prompt_format=None)
           for question, choices in zip(questions, choices_list)]


In [8]:
# Hyper parameters for automation process
question_ids_correct = np.genfromtxt("../data/question_ids/gemma-2-2b-it/all/wmdp-bio_correct.csv", dtype=int)
questions_ids_correct_train = np.genfromtxt("../data/question_ids/gemma-2-2b-it/train/wmdp-bio_correct.csv", dtype=int)
topk_per_prompt = 20

unlearning_dataset = ['wmdp-bio']
side_effect_dataset_names =  ['high_school_us_history', 'college_computer_science', 'high_school_geography', 'human_aging', 'college_biology']
all_dataset_names = ['loss_added', 'wmdp-bio', 'high_school_us_history', 'college_computer_science', 'high_school_geography', 'human_aging', 'college_biology']


## First get the TopK features by attribution per prompt and find the features that modify the probability

In [41]:
sae.cfg.hook_point, len(question_ids_correct)

('blocks.13.hook_resid_post', 522)

In [42]:
feature_per_prompt = {}

known_good_features = []

question_ids = question_ids_correct[:100]
topk_per_prompt = 20

for j, question_id in enumerate(question_ids):

    question_id = int(question_id)
    print(f"Question ID: {question_id}, {j + 1}/{len(question_ids)}")
    
    prompt = prompts[question_id]
    choices = choices_list[question_id]
    answer = answers[question_id]
    question = questions[question_id]

    topk_features_unique, feature_attributions, topk_features, all_feature_activations, logit_diff_grad, topk_feature_attributions = find_topk_features_given_prompt(model,
                                                           prompt,
                                                           question,
                                                           choices,
                                                           answer,
                                                           sae,
                                                           hook_point=sae.cfg.hook_point)

    intervention_results, feature_ids_to_probs, good_features, partially_unlearned = test_topk_features(model,
                                                                                   sae,
                                                                                   question_id,
                                                                                   topk_features_unique[:topk_per_prompt],
                                                                                   known_good_features=known_good_features,
                                                                                   multiplier=100)
    

    feature_per_prompt[question_id] = good_features
    
    known_good_features = list(set([item for sublist in feature_per_prompt.values() for item in sublist]))

    

Question ID: 0, 1/100


100%|██████████| 20/20 [00:18<00:00,  1.05it/s]


Question ID: 1, 2/100


100%|██████████| 20/20 [00:18<00:00,  1.10it/s]


Question ID: 11, 3/100


100%|██████████| 19/19 [00:17<00:00,  1.06it/s]


Question ID: 15, 4/100


100%|██████████| 18/18 [00:16<00:00,  1.12it/s]


Question ID: 16, 5/100


100%|██████████| 19/19 [00:18<00:00,  1.04it/s]


Question ID: 39, 6/100


100%|██████████| 17/17 [00:15<00:00,  1.12it/s]


Question ID: 40, 7/100


100%|██████████| 19/19 [00:18<00:00,  1.03it/s]


Question ID: 41, 8/100


100%|██████████| 16/16 [00:14<00:00,  1.09it/s]


Question ID: 48, 9/100


100%|██████████| 16/16 [00:19<00:00,  1.24s/it]


Question ID: 49, 10/100


100%|██████████| 18/18 [00:16<00:00,  1.11it/s]


Question ID: 50, 11/100


100%|██████████| 19/19 [00:17<00:00,  1.10it/s]


Question ID: 51, 12/100


100%|██████████| 17/17 [00:16<00:00,  1.05it/s]


Question ID: 55, 13/100


100%|██████████| 18/18 [00:16<00:00,  1.12it/s]


Question ID: 57, 14/100


100%|██████████| 18/18 [00:20<00:00,  1.15s/it]


Question ID: 58, 15/100


100%|██████████| 19/19 [00:21<00:00,  1.16s/it]


Question ID: 63, 16/100


100%|██████████| 18/18 [00:16<00:00,  1.10it/s]


Question ID: 67, 17/100


100%|██████████| 12/12 [00:10<00:00,  1.12it/s]


Question ID: 68, 18/100


100%|██████████| 18/18 [00:16<00:00,  1.10it/s]


Question ID: 69, 19/100


100%|██████████| 15/15 [00:13<00:00,  1.13it/s]


Question ID: 70, 20/100


100%|██████████| 17/17 [00:14<00:00,  1.14it/s]


Question ID: 80, 21/100


100%|██████████| 17/17 [00:14<00:00,  1.14it/s]


Question ID: 81, 22/100


100%|██████████| 18/18 [00:16<00:00,  1.12it/s]


Question ID: 85, 23/100


100%|██████████| 17/17 [00:19<00:00,  1.16s/it]


Question ID: 86, 24/100


100%|██████████| 19/19 [00:16<00:00,  1.13it/s]


Question ID: 90, 25/100


100%|██████████| 19/19 [00:16<00:00,  1.15it/s]


Question ID: 91, 26/100


100%|██████████| 18/18 [00:15<00:00,  1.14it/s]


Question ID: 92, 27/100


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


Question ID: 100, 28/100


100%|██████████| 14/14 [00:17<00:00,  1.26s/it]


Question ID: 101, 29/100


100%|██████████| 15/15 [00:14<00:00,  1.07it/s]


Question ID: 107, 30/100


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


Question ID: 143, 31/100


100%|██████████| 17/17 [00:15<00:00,  1.11it/s]


Question ID: 145, 32/100


100%|██████████| 16/16 [00:19<00:00,  1.20s/it]


Question ID: 147, 33/100


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


Question ID: 149, 34/100


100%|██████████| 18/18 [00:15<00:00,  1.15it/s]


Question ID: 152, 35/100


100%|██████████| 17/17 [00:15<00:00,  1.12it/s]


Question ID: 158, 36/100


100%|██████████| 20/20 [00:23<00:00,  1.16s/it]


Question ID: 160, 37/100


100%|██████████| 16/16 [00:14<00:00,  1.11it/s]


Question ID: 167, 38/100


100%|██████████| 18/18 [00:16<00:00,  1.10it/s]


Question ID: 172, 39/100


100%|██████████| 19/19 [00:16<00:00,  1.13it/s]


Question ID: 178, 40/100


100%|██████████| 19/19 [00:17<00:00,  1.12it/s]


Question ID: 183, 41/100


100%|██████████| 18/18 [00:15<00:00,  1.14it/s]


Question ID: 184, 42/100


100%|██████████| 17/17 [00:15<00:00,  1.10it/s]


Question ID: 185, 43/100


100%|██████████| 18/18 [00:16<00:00,  1.06it/s]


Question ID: 190, 44/100


100%|██████████| 18/18 [00:17<00:00,  1.03it/s]


Question ID: 192, 45/100


100%|██████████| 17/17 [00:15<00:00,  1.07it/s]


Question ID: 195, 46/100


100%|██████████| 14/14 [00:12<00:00,  1.10it/s]


Question ID: 197, 47/100


100%|██████████| 17/17 [00:14<00:00,  1.14it/s]


Question ID: 200, 48/100


100%|██████████| 14/14 [00:12<00:00,  1.12it/s]


Question ID: 202, 49/100


100%|██████████| 16/16 [00:14<00:00,  1.14it/s]


Question ID: 204, 50/100


100%|██████████| 16/16 [00:14<00:00,  1.11it/s]


Question ID: 206, 51/100


100%|██████████| 17/17 [00:15<00:00,  1.12it/s]


Question ID: 207, 52/100


100%|██████████| 18/18 [00:16<00:00,  1.11it/s]


Question ID: 208, 53/100


100%|██████████| 17/17 [00:15<00:00,  1.11it/s]


Question ID: 216, 54/100


100%|██████████| 17/17 [00:15<00:00,  1.12it/s]


Question ID: 218, 55/100


100%|██████████| 18/18 [00:16<00:00,  1.11it/s]


Question ID: 220, 56/100


100%|██████████| 14/14 [00:13<00:00,  1.07it/s]


Question ID: 228, 57/100


100%|██████████| 17/17 [00:15<00:00,  1.11it/s]


Question ID: 229, 58/100


100%|██████████| 17/17 [00:14<00:00,  1.16it/s]


Question ID: 230, 59/100


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


Question ID: 232, 60/100


100%|██████████| 16/16 [00:14<00:00,  1.09it/s]


Question ID: 233, 61/100


100%|██████████| 15/15 [00:13<00:00,  1.11it/s]


Question ID: 236, 62/100


100%|██████████| 17/17 [00:15<00:00,  1.10it/s]


Question ID: 239, 63/100


100%|██████████| 15/15 [00:13<00:00,  1.08it/s]


Question ID: 242, 64/100


100%|██████████| 15/15 [00:18<00:00,  1.21s/it]


Question ID: 243, 65/100


100%|██████████| 16/16 [00:14<00:00,  1.11it/s]


Question ID: 245, 66/100


100%|██████████| 16/16 [00:14<00:00,  1.09it/s]


Question ID: 260, 67/100


100%|██████████| 16/16 [00:13<00:00,  1.16it/s]


Question ID: 262, 68/100


100%|██████████| 19/19 [00:17<00:00,  1.09it/s]


Question ID: 265, 69/100


100%|██████████| 16/16 [00:14<00:00,  1.14it/s]


Question ID: 267, 70/100


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


Question ID: 277, 71/100


100%|██████████| 16/16 [00:14<00:00,  1.13it/s]


Question ID: 289, 72/100


100%|██████████| 17/17 [00:15<00:00,  1.10it/s]


Question ID: 299, 73/100


100%|██████████| 16/16 [00:20<00:00,  1.27s/it]


Question ID: 310, 74/100


100%|██████████| 17/17 [00:15<00:00,  1.09it/s]


Question ID: 312, 75/100


100%|██████████| 18/18 [00:16<00:00,  1.12it/s]


Question ID: 313, 76/100


100%|██████████| 16/16 [00:14<00:00,  1.12it/s]


Question ID: 314, 77/100


100%|██████████| 17/17 [00:15<00:00,  1.10it/s]


Question ID: 317, 78/100


100%|██████████| 16/16 [00:14<00:00,  1.10it/s]


Question ID: 320, 79/100


100%|██████████| 17/17 [00:20<00:00,  1.23s/it]


Question ID: 321, 80/100


100%|██████████| 12/12 [00:12<00:00,  1.05s/it]


Question ID: 324, 81/100


100%|██████████| 14/14 [00:17<00:00,  1.24s/it]


Question ID: 325, 82/100


100%|██████████| 15/15 [00:13<00:00,  1.13it/s]


Question ID: 330, 83/100


100%|██████████| 19/19 [00:16<00:00,  1.15it/s]


Question ID: 331, 84/100


100%|██████████| 17/17 [00:15<00:00,  1.13it/s]


Question ID: 332, 85/100


100%|██████████| 18/18 [00:15<00:00,  1.16it/s]


Question ID: 335, 86/100


100%|██████████| 16/16 [00:15<00:00,  1.07it/s]


Question ID: 337, 87/100


100%|██████████| 15/15 [00:14<00:00,  1.04it/s]


Question ID: 338, 88/100


100%|██████████| 16/16 [00:14<00:00,  1.08it/s]


Question ID: 339, 89/100


100%|██████████| 19/19 [00:17<00:00,  1.07it/s]


Question ID: 344, 90/100


100%|██████████| 17/17 [00:20<00:00,  1.20s/it]


Question ID: 346, 91/100


100%|██████████| 15/15 [00:13<00:00,  1.12it/s]


Question ID: 347, 92/100


100%|██████████| 14/14 [00:13<00:00,  1.02it/s]


Question ID: 349, 93/100


100%|██████████| 18/18 [00:17<00:00,  1.06it/s]


Question ID: 352, 94/100


100%|██████████| 17/17 [00:16<00:00,  1.03it/s]


Question ID: 353, 95/100


100%|██████████| 17/17 [00:15<00:00,  1.08it/s]


Question ID: 354, 96/100


100%|██████████| 17/17 [00:16<00:00,  1.01it/s]


Question ID: 355, 97/100


100%|██████████| 17/17 [00:16<00:00,  1.03it/s]


Question ID: 357, 98/100


100%|██████████| 17/17 [00:15<00:00,  1.10it/s]


Question ID: 359, 99/100


100%|██████████| 15/15 [00:13<00:00,  1.14it/s]


Question ID: 360, 100/100


100%|██████████| 17/17 [00:15<00:00,  1.08it/s]


In [18]:
np.array(known_good_features).shape

(78,)

In [10]:
# np.savetxt("known_good_features_gemma1_2b_train.csv", np.array(known_good_features))

## Calculate side-effects

In [15]:
np.array(known_good_features)

array([ 4392,  3524, 12926])

In [21]:
# Calculate metrics

main_ablate_params = {
                      'multiplier': 100,
                      'intervention_method': 'clamp_feature_activation',
                     }


sweep = {
         'features_to_ablate': np.array(known_good_features),
        }


dataset_names = all_dataset_names[2:-1]

metrics_list = calculate_metrics_side_effects(model,
                                      sae,
                                      main_ablate_params,
                                      sweep,
                                      dataset_names=dataset_names,
                                      thresh=0,
                                      target_metric='correct',
                                      split='all')


  1%|▏         | 1/78 [00:15<19:27, 15.17s/it]

Downloading data:   0%|          | 0.00/28.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.16k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.93k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/198 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

  3%|▎         | 2/78 [00:43<29:14, 23.08s/it]

Downloading data:   0%|          | 0.00/31.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.28k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.67k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/223 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/23 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

100%|██████████| 78/78 [23:27<00:00, 18.05s/it]


In [22]:
feature_ids_zero_side_effect = [x['ablate_params']['features_to_ablate'] for x in metrics_list]
np.array(feature_ids_zero_side_effect).shape

(14,)

In [23]:
np.savetxt("feature_ids_zero_side_effect_gemma2_2b_all_n100.csv", np.array(feature_ids_zero_side_effect))

In [24]:
np.array(feature_ids_zero_side_effect)

array([ 1032,  6680,  6202,  7751,  7242, 16025, 15514,  8906,  4364,
       10008,  2880, 11685,  5608, 10233])

In [25]:
# np.savetxt("feature_ids_zero_side_effect_gemma1_2b_all.csv", np.array(feature_ids_zero_side_effect))

In [31]:
filename = hf_hub_download(repo_id=REPO_ID, filename=SAE_MAPPING['gemma_2b_it_resid_pre_9'])
sae2 = load_saved_sae(filename)
sae2.cfg.hook_point = sae.cfg.hook_point
sae2.cfg.hook_point_layer = 11
sae2.cfg.d_in = 2304

## Then sort by loss added

In [None]:
activations = model.run_with_cache(model.to_tokens("test"), names_filter="blocks.10.hook_resid_post", stop_at_layer=12

In [32]:
activation_store = ActivationStoreAnalysis(sae2.cfg, model)

buffer
dataloader


In [35]:
# Calculate metrics

main_ablate_params = {
                      'multiplier': 100,
                      'intervention_method': 'clamp_feature_activation',
                     }


sweep = {
         'features_to_ablate': feature_ids_zero_side_effect,
        }

metric_params = {'wmdp-bio': 
                 {
                       # 'target_metric': 'correct',
                       'question_subset': questions_ids_correct_train,
                       'permutations': None,
                       'verbose': False,
                   }
                 }

dataset_names = all_dataset_names[:2]

n_batch_loss_added = 10

metrics_list_zero_side_effect = calculate_metrics_list(model,
                                      sae,
                                      main_ablate_params,
                                      sweep,
                                      dataset_names=dataset_names,
                                      metric_params=metric_params,
                                      include_baseline_metrics=False,
                                      n_batch_loss_added=n_batch_loss_added,
                                      activation_store=activation_store,
                                      split='all',
                                      verbose=False)


100%|██████████| 14/14 [06:45<00:00, 28.95s/it]


In [36]:
df_zero_side_effect = create_df_from_metrics(metrics_list_zero_side_effect)
df_zero_side_effect

Unnamed: 0,loss_added,wmdp-bio,wmdp-bio_prob
0,-0.00154,0.988506,0.810059
1,0.001179,1.0,0.85791
2,0.001416,0.996169,0.82666
3,0.001727,1.0,0.825195
4,-0.002945,1.0,0.848633
5,0.007671,0.996169,0.827148
6,0.010963,0.996169,0.808105
7,0.000215,0.977011,0.82373
8,0.000675,0.988506,0.822266
9,0.000427,1.0,0.824219


In [37]:
isorted = df_zero_side_effect.query("`wmdp-bio` < 1").sort_values("loss_added").index.values
feature_ids_zero_side_effect_sorted = np.array(feature_ids_zero_side_effect)[isorted]
feature_ids_zero_side_effect_sorted

array([ 1032,  5608, 10233,  8906, 11685,  4364,  6202, 16025, 15514])

In [44]:
len(feature_ids_zero_side_effect_sorted)

25

In [25]:
# np.savetxt("feature_ids_zero_side_effect_sorted_gemma1_2b_all.csv", np.array(feature_ids_zero_side_effect_sorted))

## Now progressively add features sorted by loss

In [38]:
# Calculate metrics

main_ablate_params = {
                      'multiplier': 100,
                      'intervention_method': 'clamp_feature_activation',
                     }


sweep = {
         'features_to_ablate': [feature_ids_zero_side_effect_sorted[:i+1] for i in range(8, 9)],
         'multiplier': [100, 200],
        }

metric_params = {'wmdp-bio': 
                 {
                       'target_metric': 'correct',
                       'permutations': None,
                       'verbose': False,
                   }
                 }

dataset_names = all_dataset_names[:2]

n_batch_loss_added = 20

metrics_list_best_sorted = calculate_metrics_list(model,
                                      sae,
                                      main_ablate_params,
                                      sweep,
                                      dataset_names=dataset_names,
                                      metric_params=metric_params,
                                      include_baseline_metrics=False,
                                      n_batch_loss_added=n_batch_loss_added,
                                      activation_store=activation_store,
                                      split='test')


100%|██████████| 2/2 [01:24<00:00, 42.25s/it]


In [39]:
df = create_df_from_metrics(metrics_list_best_sorted)
df

Unnamed: 0,loss_added,wmdp-bio,wmdp-bio_prob
0,0.018664,0.908046,0.75
1,0.056487,0.846743,0.714355


In [67]:
df = create_df_from_metrics(metrics_list_best_sorted)
df

Unnamed: 0,loss_added,wmdp-bio,high_school_us_history,college_computer_science,high_school_geography,human_aging,college_biology,wmdp-bio_prob,high_school_us_history_prob,college_computer_science_prob,high_school_geography_prob,human_aging_prob,college_biology_prob
0,-0.001015,0.988372,1.0,1.0,1.0,1.0,1.0,0.995117,0.987793,0.998535,0.996582,0.994141,0.997559
1,-0.001767,0.965116,1.0,1.0,1.0,1.0,1.0,0.989258,0.987793,0.998535,0.996582,0.994141,0.997559
2,-0.002136,0.953488,1.0,1.0,1.0,1.0,1.0,0.989258,0.987793,0.998535,0.996582,0.994141,0.997559
3,-0.002155,0.930233,1.0,1.0,1.0,1.0,1.0,0.984863,0.987793,0.998535,0.996582,0.994141,0.997559
4,-0.002642,0.918605,1.0,1.0,1.0,1.0,1.0,0.985352,0.987793,0.998535,0.996582,0.994141,0.997559
5,-0.002007,0.906977,1.0,1.0,1.0,1.0,1.0,0.985352,0.987793,0.998535,0.996582,0.994141,0.997559
6,-0.002163,0.895349,1.0,1.0,1.0,1.0,1.0,0.990234,0.987793,0.998535,0.996582,0.993164,0.997559
7,-0.002163,0.895349,1.0,1.0,1.0,1.0,1.0,0.990234,0.987793,0.998535,0.996582,0.993164,0.997559
8,-0.002186,0.872093,1.0,1.0,1.0,1.0,1.0,0.983887,0.987793,0.998535,0.996582,0.992676,0.997559
9,-0.002398,0.837209,1.0,1.0,1.0,1.0,1.0,0.977539,0.987793,0.998535,0.996582,0.992676,0.997559


In [26]:
# metrics_list_best_sorted

In [24]:
df = create_df_from_metrics(metrics_list_best_sorted)
df

Unnamed: 0,loss_added,wmdp-bio,high_school_us_history,college_computer_science,high_school_geography,human_aging,college_biology,wmdp-bio_prob,high_school_us_history_prob,college_computer_science_prob,high_school_geography_prob,human_aging_prob,college_biology_prob
0,-0.000212,0.976744,1.0,1.0,1.0,1.0,1.0,0.994141,0.987305,0.998535,0.996582,0.994141,0.997559
1,0.004665,0.767442,1.0,1.0,1.0,0.9375,1.0,0.970215,0.987305,0.998535,0.996582,0.990234,0.92334
2,0.004644,0.732558,1.0,1.0,1.0,0.9375,1.0,0.972168,0.987305,0.998535,0.996582,0.989746,0.907227
3,0.005248,0.732558,1.0,1.0,1.0,0.9375,1.0,0.967773,0.987305,0.998535,0.996582,0.989746,0.909668


In [8]:
feature_ids_zero_side_effect_sorted = np.genfromtxt("feature_ids_zero_side_effect_sorted_gemma1_2b_all.csv")
feature_ids_zero_side_effect = np.genfromtxt("feature_ids_zero_side_effect_gemma1_2b_all.csv")

In [27]:
# Calculate metrics

main_ablate_params = {
                      'intervention_method': 'clamp_feature_activation',
                      # 'jump': 0.5,
                     }


sweep = {
         'features_to_ablate': [feature_ids_zero_side_effect_sorted[:-3]],
         'multiplier': [20, 30],
         # 'jump': [0., 1],
        }

metric_params = {'wmdp-bio': 
                 {
                       'target_metric': 'correct',
                       'permutations': all_permutations,
                       'verbose': False,
                   }
                 }

dataset_names = all_dataset_names[:2]

n_batch_loss_added = 20

metrics_list_best_sorted4 = calculate_metrics_list(model,
                                      sae,
                                      main_ablate_params,
                                      sweep,
                                      dataset_names=dataset_names,
                                      metric_params=metric_params,
                                      include_baseline_metrics=False,
                                      n_batch_loss_added=n_batch_loss_added,
                                      activation_store=activation_store,
                                      split='all')


100%|██████████| 2/2 [05:22<00:00, 161.47s/it]


In [29]:
df4 = create_df_from_metrics(metrics_list_best_sorted4)
df4

Unnamed: 0,loss_added,wmdp-bio,wmdp-bio_prob
0,0.003645,0.690891,0.934082
1,0.027801,0.592539,0.915039


In [110]:
df3 = create_df_from_metrics(metrics_list_best_sorted3)
df3

Unnamed: 0,loss_added,wmdp-bio,high_school_us_history,college_computer_science,high_school_geography,human_aging,college_biology,wmdp-bio_prob,high_school_us_history_prob,college_computer_science_prob,high_school_geography_prob,human_aging_prob,college_biology_prob
0,0.004161,0.651163,1.0,1.0,1.0,1.0,0.866667,0.925293,0.994629,0.998535,0.99707,0.993164,0.961914


In [116]:
from unlearning.metrics import generate_ablate_params_list

In [118]:
ablate_params_list = generate_ablate_params_list(main_ablate_params, sweep)
for x in ablate_params_list:
    print(x['jump'], x['multiplier'])
    


0.0 20
0.5 20
1 20
1.5 20
0.0 40
0.5 40
1 40
1.5 40


In [81]:
df2 = create_df_from_metrics(metrics_list_best_sorted2)
df2

Unnamed: 0,loss_added,wmdp-bio,high_school_us_history,college_computer_science,high_school_geography,human_aging,college_biology,wmdp-bio_prob,high_school_us_history_prob,college_computer_science_prob,high_school_geography_prob,human_aging_prob,college_biology_prob
0,-0.003495,0.848837,1.0,1.0,1.0,1.0,1.0,0.965332,0.994629,0.999023,0.99707,0.993164,0.984863
1,-0.001551,0.755814,1.0,1.0,1.0,1.0,0.866667,0.951172,0.994629,0.998535,0.99707,0.993652,0.952637
2,0.004161,0.651163,1.0,1.0,1.0,1.0,0.866667,0.925293,0.994629,0.998535,0.99707,0.993164,0.961914
3,0.014927,0.575581,1.0,1.0,1.0,1.0,0.866667,0.902344,0.994629,0.998535,0.99707,0.992676,0.984863
4,0.030548,0.523256,1.0,1.0,1.0,1.0,0.866667,0.898438,0.994141,0.998535,0.99707,0.992188,0.950195


In [71]:
df2 = create_df_from_metrics(metrics_list_best_sorted2)
df2

Unnamed: 0,loss_added,wmdp-bio,high_school_us_history,college_computer_science,high_school_geography,human_aging,college_biology,wmdp-bio_prob,high_school_us_history_prob,college_computer_science_prob,high_school_geography_prob,human_aging_prob,college_biology_prob
0,-0.0064,0.616279,1.0,1.0,1.0,1.0,1.0,0.921387,0.98877,0.998535,0.996582,0.993164,0.995117


In [73]:
df2 = create_df_from_metrics(metrics_list_best_sorted2)
df2

Unnamed: 0,loss_added,wmdp-bio,high_school_us_history,college_computer_science,high_school_geography,human_aging,college_biology,wmdp-bio_prob,high_school_us_history_prob,college_computer_science_prob,high_school_geography_prob,human_aging_prob,college_biology_prob
0,0.019185,0.651163,1.0,1.0,1.0,1.0,0.866667,0.923828,0.986816,0.998535,0.99707,0.992676,0.961914


In [104]:
metrics_list_best_sorted2[0]['human_aging']

{'mean_correct': 0.875,
 'total_correct': 14,
 'is_correct': array([1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       dtype=float32),
 'output_probs': array([[8.1062e-06, 1.1921e-07, 1.2517e-06, 9.9316e-01],
        [6.9336e-01, 1.4946e-02, 7.4646e-02, 2.1021e-01],
        [9.9854e-01, 4.9651e-05, 7.6771e-05, 4.4107e-06],
        [9.9902e-01, 9.5367e-07, 1.1921e-06, 7.3314e-06],
        [1.4853e-04, 5.8413e-06, 9.9463e-01, 4.3511e-06],
        [9.9805e-01, 1.1921e-07, 1.1921e-07, 2.3842e-07],
        [6.3777e-05, 1.3268e-04, 9.9707e-01, 3.8385e-05],
        [1.1325e-06, 1.4901e-06, 9.9707e-01, 8.6427e-06],
        [5.7817e-06, 1.7881e-07, 5.3644e-07, 9.9219e-01],
        [1.2927e-01, 8.4521e-01, 2.3438e-02, 7.5483e-04],
        [1.0700e-03, 9.9561e-01, 4.2915e-06, 1.1325e-06],
        [1.8597e-04, 7.3314e-06, 9.9854e-01, 2.1935e-05],
        [1.1325e-06, 5.9605e-08, 9.9707e-01, 1.1921e-07],
        [5.9605e-08, 9.9951e-01, 7.1526e-07, 1.7881e-07],
        [1.3912e

In [120]:
# Calculate metrics

main_ablate_params = {
                      'multiplier': 30,
                      'intervention_method': 'clamp_feature_activation',
                     }


sweep = {
         'features_to_ablate': feature_ids_zero_side_effect[:1],
        }

metric_params = {'wmdp-bio': 
                    {
                       'target_metric': 'correct',
                       'permutations': None,
                    }
                }

dataset_names = all_dataset_names[2:6]

n_batch_loss_added = 10

metrics_list = calculate_metrics_side_effects(model,
                                      sae,
                                      main_ablate_params,
                                      sweep,
                                      dataset_names=dataset_names,
                                      thresh=0,
                                      split='train')


  0%|          | 0/1 [00:00<?, ?it/s]

13 13
1.0 1.0
4 4
1.0 1.0
15 15
1.0 1.0


100%|██████████| 1/1 [00:07<00:00,  7.13s/it]

15 16
0.9375 1.0





In [122]:
metrics_list

[]

In [114]:
metrics_list[0]['human_aging']['mean_correct']

1.0

In [108]:
metrics_list[0]['human_aging']

IndexError: list index out of range