In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../")

import torch
import random

from sae.sparse_autoencoder import load_saved_sae
from sae.metrics import model_store_from_sae
from unlearning.metrics import convert_wmdp_data_to_prompt, convert_list_of_dicts_to_dict_of_lists
from unlearning.tool import UnlearningConfig, SAEUnlearningTool, MCQ_ActivationStoreAnalysis, ActivationStoreAnalysis
from unlearning.metrics import modify_and_calculate_metrics, calculate_metrics_list, create_df_from_metrics

from huggingface_hub import hf_hub_download
from datasets import load_dataset
import numpy as np
import pandas as pd
import itertools
from transformer_lens import utils

from jaxtyping import Float
from torch import Tensor

from pathlib import Path

import plotly.express as px
from unlearning.var import REPO_ID, SAE_MAPPING
import pickle

from unlearning.metrics import all_permutations

In [2]:
# Load main SAE for gemma-2b-it
filename = hf_hub_download(repo_id=REPO_ID, filename=SAE_MAPPING['gemma_2b_it_resid_pre_9'])
sae = load_saved_sae(filename)
model = model_store_from_sae(sae)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b-it into HookedTransformer
Moving model to device:  cuda


In [3]:
activation_store = ActivationStoreAnalysis(sae.cfg, model)

buffer
dataloader


In [4]:
# load good feature list

with open('../yeutong_notebooks/unlearning_output/good_features_list_v1.pkl', 'rb') as f:
    good_features_list = pickle.load(f)


features_to_test = list(set([item for sublist in good_features_list.values() for item in sublist]))

filtered_good_features = [12663, 4342, 5749, 10355, 1523, 15858, 12273, 14315, 4451, 1611, 10051, 16186, 7983, 6958, 1307, 11019, 6531, 12289]
filtered_features_sorted_by_loss = [7983, 16186, 12273, 14315,  4342, 10051, 15858,  6958, 12663, 1611,  6531,  1523, 10355,  5749,  1307, 12289,  4451, 11019]
filtered_features_sorted_by_loss2 = np.concatenate((filtered_features_sorted_by_loss[:8], filtered_features_sorted_by_loss[10:11], filtered_features_sorted_by_loss[12:]))

zero_side_effect_features = [7983, 16186, 14315,  4342, 10051,  6958,  5749,  4451,  5001, 15755,  2222,  4654,  9280,  1746,  8412,  5861, 15848,  8946]
zero_side_effect_features_sorted_by_loss = [5861,  1746, 14315, 16186, 10051,  7983,  4342,  4654,  2222, 15755,  8412,  6958,  5749,  5001,  4451,  8946,  9280, 15848]

zero_side_effect_21_features = [ 5001, 11019,  3728,  7983,  9391,  4654, 14388,  5691,  4802, 1611,  7122,  4451, 14819, 15848, 14315, 12273, 15858,  4342, 12663, 12287]
zero_side_effect_21_features_sorted_by_loss = [ 9391, 12663,  7122, 11019,  3728,  7983, 14315,  4342,  4654, 15858, 12273, 14388,  1611,  5001,  4451,  5691, 14819, 15848, 12287,  4802]

unlearning_dataset = ['wmdp-bio']
side_effect_dataset_names =  ['high_school_us_history', 'college_computer_science', 'high_school_geography', 'human_aging', 'college_biology']
all_dataset_names = ['loss_added', 'wmdp-bio', 'high_school_us_history', 'college_computer_science', 'high_school_geography', 'human_aging', 'college_biology']

In [5]:
# Calculate metrics

main_ablate_params = {
                      'multiplier': 20,
                      'intervention_method': 'clamp_feature_activation'
                     }

sweep = {
         'features_to_ablate':
         [filtered_features_sorted_by_loss2]
        }


metric_params = {'wmdp-bio': 
                 {
                       'target_metric': 'correct',
                       'permutations': all_permutations,
                   }
                 }

# sweep = {'features_to_ablate': [zero_side_effect_21_features_sorted_by_loss[:i+1] for i in range(len(zero_side_effect_21_features_sorted_by_loss))]}

dataset_names = all_dataset_names
# dataset_names = ['loss_added', 'human_aging']
# dataset_names = ['human_aging']
# dataset_names = ['loss_added']
# dataset_names = ['wmdp-bio']

n_batch_loss_added = 10

metrics_list = calculate_metrics_list(model,
                                      sae,
                                      main_ablate_params,
                                      sweep,
                                      dataset_names=dataset_names,
                                      metric_params=metric_params,
                                      include_baseline_metrics=False,
                                      n_batch_loss_added=n_batch_loss_added,
                                      activation_store=activation_store)


100%|██████████| 10/10 [00:14<00:00,  1.46s/it]
 32%|███▏      | 222/688 [00:44<01:30,  5.17it/s]

100%|██████████| 688/688 [02:40<00:00,  4.30it/s]


Downloading readme:   0%|          | 0.00/53.2k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/138k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/155k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/27.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/17.8k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/204 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

100%|██████████| 5/5 [00:03<00:00,  1.34it/s]


Downloading data:   0%|          | 0.00/28.1k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.25k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.81k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

100%|██████████| 2/2 [00:00<00:00,  2.94it/s]


Downloading data:   0%|          | 0.00/28.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.16k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.93k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/198 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

100%|██████████| 5/5 [00:01<00:00,  4.12it/s]


Downloading data:   0%|          | 0.00/31.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.28k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.67k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/223 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/23 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

100%|██████████| 6/6 [00:01<00:00,  4.76it/s]


Downloading data:   0%|          | 0.00/31.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.90k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.27k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/144 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/16 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

100%|██████████| 3/3 [00:00<00:00,  4.79it/s]







In [6]:
(metrics_list[0]['wmdp-bio']['is_correct'].reshape(-1, 24).sum(axis=1) == 24).sum()

16

In [9]:
np.where(metrics_list[0]['wmdp-bio']['is_correct'].reshape(-1, 24).sum(axis=1) == 24)

# unle

(array([  1,   4,   5,   9,  15,  17,  20,  21,  22,  23,  24,  27,  29,
         49,  87, 162]),)

: 

In [16]:
df = create_df_from_metrics(metrics_list)
df

Unnamed: 0,loss_added,wmdp-bio,high_school_us_history,college_computer_science,high_school_geography,human_aging,college_biology,wmdp-bio_prob,high_school_us_history_prob,college_computer_science_prob,high_school_geography_prob,human_aging_prob,college_biology_prob
0,0.005063,0.360465,0.962963,1.0,0.933333,0.84375,0.666667,0.934525,0.993797,0.998191,0.986307,0.971106,0.951835
