In [1]:
%load_ext autoreload
%autoreload 2
import torch
import numpy as np

from unlearning.tool import get_hf_model
from unlearning.feature_activation import get_forget_retain_data, tokenize_dataset, get_mean_feature_activation, get_top_features
from unlearning.jump_relu import load_gemma2_2b_sae
from unlearning.intervention import scaling_intervention
from unlearning.metrics import calculate_MCQ_metrics

In [2]:
model = get_hf_model('google/gemma-2-2b-it')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### Save top features for each layer

In [3]:
forget_dataset, retain_dataset = get_forget_retain_data('bio-forget-corpus', 'wikitext')

print(len(forget_dataset), len(forget_dataset[0]))
print(len(retain_dataset), len(retain_dataset[0]))

forget_tokens = tokenize_dataset(model, forget_dataset)
retain_tokens = tokenize_dataset(model, retain_dataset)

print(forget_tokens.shape, retain_tokens.shape)

24432 16027
1962 859
torch.Size([153108, 1024]) torch.Size([275, 1024])


In [10]:
LAYER = 1
sae = load_gemma2_2b_sae(layer=LAYER)


Found SAE with l0=102 at path google/gemma-scope-2b-pt-res/layer_1/width_16k/average_l0_102/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

In [24]:
def save_sae_top_features(model, layer, forget_tokens, retain_tokens):
    sae = load_gemma2_2b_sae(layer=layer)

    # shuffle forget_tokens 
    shuffled_forget_tokens = forget_tokens[torch.randperm(forget_tokens.shape[0])]

    mean_feature_activation_forget = get_mean_feature_activation(model, sae, shuffled_forget_tokens[:2048], batch_size=8)
    mean_feature_activation_retain = get_mean_feature_activation(model, sae, retain_tokens, batch_size=8)

    top_features = get_top_features(mean_feature_activation_forget, mean_feature_activation_retain, 0.01)
    np.savetxt(f'../data/top_features/gemma-2-2b-it/layer{layer}.txt', top_features, fmt='%d')

In [25]:
for layer in [0] + list(range(4, model.config.num_hidden_layers)):
    print(f'Layer {layer}')
    save_sae_top_features(model, layer, forget_tokens, retain_tokens)

Layer 0
Found SAE with l0=105 at path google/gemma-scope-2b-pt-res/layer_0/width_16k/average_l0_105/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:57<00:00,  1.08it/s]
100%|██████████| 35/35 [00:31<00:00,  1.11it/s]


[ 3148  7146   611  2261  3873 13553 13954  7076   443   721  5531 13030
 13361  7126  6208  5361  5158  1196 15250  7510]
Layer 4
Found SAE with l0=124 at path google/gemma-scope-2b-pt-res/layer_4/width_16k/average_l0_124/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:57<00:00,  1.08it/s]
100%|██████████| 35/35 [00:32<00:00,  1.09it/s]


[ 9137  5072 12768  3368  9380 14596 12645 12382  1007  5536 10252   885
 13747  1710 14748 10905 13991 10186  4234 14149]
Layer 5
Found SAE with l0=68 at path google/gemma-scope-2b-pt-res/layer_5/width_16k/average_l0_68/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:55<00:00,  1.08it/s]
100%|██████████| 35/35 [00:31<00:00,  1.10it/s]


[  248  6347  3036 14303    97  1465 10854 12991  8245  9678 14108  6479
   737  7260   763 12528  5653  3197  9862  1565]
Layer 6
Found SAE with l0=70 at path google/gemma-scope-2b-pt-res/layer_6/width_16k/average_l0_70/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:55<00:00,  1.09it/s]
100%|██████████| 35/35 [00:31<00:00,  1.11it/s]


[ 1609  6866  8954 15237  4228  8152 11545 10946  3161  2869  1534 15793
  1307  3396  6870 15141 11390  3429 11380  2310]
Layer 7
Found SAE with l0=69 at path google/gemma-scope-2b-pt-res/layer_7/width_16k/average_l0_69/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [04:00<00:00,  1.06it/s]
100%|██████████| 35/35 [00:32<00:00,  1.09it/s]


[12799 15661  3828  9723 15939  4216  2512   226 15879  4945  2114  9294
 10566  9159 13351 16187  8934  3769 14822  1709]
Layer 8
Found SAE with l0=71 at path google/gemma-scope-2b-pt-res/layer_8/width_16k/average_l0_71/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:57<00:00,  1.08it/s]
100%|██████████| 35/35 [00:32<00:00,  1.09it/s]


[ 6991  2766  7617 11059  5276  2573  4172  4865  1992  2470  8022  2875
  5589   101 16318  7830  3795  3839  8244   426]
Layer 9
Found SAE with l0=73 at path google/gemma-scope-2b-pt-res/layer_9/width_16k/average_l0_73/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:56<00:00,  1.08it/s]
100%|██████████| 35/35 [00:32<00:00,  1.09it/s]


[ 8109 15123  6028  4700   825 12418  4081  7882  2465  2371 15497  4537
 13618 10499  4485  7904 15821  2532 14167  2635]
Layer 10
Found SAE with l0=77 at path google/gemma-scope-2b-pt-res/layer_10/width_16k/average_l0_77/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:56<00:00,  1.08it/s]
100%|██████████| 35/35 [00:32<00:00,  1.09it/s]


[  739  1269  1782     2  1822 15313  5815  9055 11121  6388 13645  8172
 10881  2751 10395 10932 14003   959  5217  3298]
Layer 11
Found SAE with l0=80 at path google/gemma-scope-2b-pt-res/layer_11/width_16k/average_l0_80/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:57<00:00,  1.08it/s]
100%|██████████| 35/35 [00:32<00:00,  1.08it/s]


[14385 14948  2804  7864  4348  7745  1603 14200 10844  4806 15336  8377
 11385  3239  6181  3097 14738 12101  9392  8008]
Layer 12
Found SAE with l0=82 at path google/gemma-scope-2b-pt-res/layer_12/width_16k/average_l0_82/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:54<00:00,  1.09it/s]
100%|██████████| 35/35 [00:31<00:00,  1.11it/s]


[ 4009  5284 14606 14778  2830 14088 14183   343  4163 10530  2948  5884
  9908   231  3286  1746 15293 11380   701 10316]
Layer 13
Found SAE with l0=84 at path google/gemma-scope-2b-pt-res/layer_13/width_16k/average_l0_84/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:55<00:00,  1.09it/s]
100%|██████████| 35/35 [00:32<00:00,  1.08it/s]


[ 8228  1894  4618  2644  4470 14368  6256  1287  4917  1861 13227 13018
 10429 10026   111  8370 10319 11770  6496   285]
Layer 14
Found SAE with l0=84 at path google/gemma-scope-2b-pt-res/layer_14/width_16k/average_l0_84/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:59<00:00,  1.07it/s]
100%|██████████| 35/35 [00:31<00:00,  1.10it/s]


[13902  5350  3829  9268 15384 14630  7106 13941  7994  6424  5304 13326
 14239  4874  1814  9294 12374 16156 11443  7823]
Layer 15
Found SAE with l0=78 at path google/gemma-scope-2b-pt-res/layer_15/width_16k/average_l0_78/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:56<00:00,  1.08it/s]
100%|██████████| 35/35 [00:32<00:00,  1.08it/s]


[12658  1465 11863  3155  3183 10360 12420  9354  7503 14052  1204 12977
 14679 11852 15962  5882  7237  7106 11402   732]
Layer 16
Found SAE with l0=78 at path google/gemma-scope-2b-pt-res/layer_16/width_16k/average_l0_78/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:55<00:00,  1.09it/s]
100%|██████████| 35/35 [00:31<00:00,  1.11it/s]


[ 6911 14122 13765  3579  9138 10908 13227 16133  4746 15846  1223 14300
 10012 14661  7957 14167  6649  6016 12719  8805]
Layer 17
Found SAE with l0=77 at path google/gemma-scope-2b-pt-res/layer_17/width_16k/average_l0_77/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:57<00:00,  1.08it/s]
100%|██████████| 35/35 [00:31<00:00,  1.10it/s]


[ 9732  5123  5691  5159 12794  6854 12147  2619  1625  3271  3007   267
  8351  7849  5277  1320  6451 13112  4767  8270]
Layer 18
Found SAE with l0=74 at path google/gemma-scope-2b-pt-res/layer_18/width_16k/average_l0_74/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:58<00:00,  1.08it/s]
100%|██████████| 35/35 [00:32<00:00,  1.09it/s]


[ 6673 15294  1451  5811  6186  7915 14435  4066 16207  1865  3664  6410
  6562 10989 12945  2206  5760 11194  5765 13359]
Layer 19
Found SAE with l0=73 at path google/gemma-scope-2b-pt-res/layer_19/width_16k/average_l0_73/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:57<00:00,  1.08it/s]
100%|██████████| 35/35 [00:32<00:00,  1.09it/s]


[ 4122 10838 15658  6889  6440 11835 12747 16004  8669  9992 16021  8916
 16044 10794  4375 14902  5175  7271 12385 13668]
Layer 20
Found SAE with l0=71 at path google/gemma-scope-2b-pt-res/layer_20/width_16k/average_l0_71/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:56<00:00,  1.08it/s]
100%|██████████| 35/35 [00:32<00:00,  1.09it/s]


[ 5447 14012  1518 15231  5481  8367 10589  6283  3074 14217  6344 14247
  6864 14527 13947 15367  9571  6145  6175 15432]
Layer 21
Found SAE with l0=70 at path google/gemma-scope-2b-pt-res/layer_21/width_16k/average_l0_70/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:57<00:00,  1.08it/s]
100%|██████████| 35/35 [00:32<00:00,  1.09it/s]


[15646  4019 13352 12272  3406 14499 15805  2989   393  6832  3563  4385
 13247  5892 10583  7881 11100  4541 11107 14971]
Layer 22
Found SAE with l0=72 at path google/gemma-scope-2b-pt-res/layer_22/width_16k/average_l0_72/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:57<00:00,  1.08it/s]
100%|██████████| 35/35 [00:32<00:00,  1.09it/s]


[ 4923   438  5398 11754  8114  8874 13539 13894 15367  5574  5329  6933
 13121 14059 10324  3810  1950 12558  6127  2119]
Layer 23
Found SAE with l0=75 at path google/gemma-scope-2b-pt-res/layer_23/width_16k/average_l0_75/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:58<00:00,  1.07it/s]
100%|██████████| 35/35 [00:32<00:00,  1.08it/s]


[15701  3064 10889  3676  3260  4120 13658 11872 10226 11276 14980  3020
  5571 10785  7130 10887  8724  9293 10099 14289]
Layer 24
Found SAE with l0=73 at path google/gemma-scope-2b-pt-res/layer_24/width_16k/average_l0_73/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:58<00:00,  1.07it/s]
100%|██████████| 35/35 [00:32<00:00,  1.09it/s]


[ 3867 12292 16337 10961  9233  6139  5490  5643 13830  6637  9021 10117
 13129  8925 16168  4969  3520  6690  7967 14983]
Layer 25
Found SAE with l0=116 at path google/gemma-scope-2b-pt-res/layer_25/width_16k/average_l0_116/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

100%|██████████| 256/256 [03:59<00:00,  1.07it/s]
100%|██████████| 35/35 [00:32<00:00,  1.09it/s]

[14035  3685 12093  9930  9811  3016  9501   335  1915 11512 13413  1674
 16192  6926 14798 14634  8620 14949  8850 15033]





In [13]:
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list

get_neuronpedia_quick_list(top_features[:20], layer=LAYER, model='gemma-2-2b', dataset='gemmascope-res-16k')

'https://neuronpedia.org/quick-list/?name=temporary_list&features=%5B%7B%22modelId%22%3A%20%22gemma-2-2b%22%2C%20%22layer%22%3A%20%221-gemmascope-res-16k%22%2C%20%22index%22%3A%20%226688%22%7D%2C%20%7B%22modelId%22%3A%20%22gemma-2-2b%22%2C%20%22layer%22%3A%20%221-gemmascope-res-16k%22%2C%20%22index%22%3A%20%2214468%22%7D%2C%20%7B%22modelId%22%3A%20%22gemma-2-2b%22%2C%20%22layer%22%3A%20%221-gemmascope-res-16k%22%2C%20%22index%22%3A%20%22691%22%7D%2C%20%7B%22modelId%22%3A%20%22gemma-2-2b%22%2C%20%22layer%22%3A%20%221-gemmascope-res-16k%22%2C%20%22index%22%3A%20%226146%22%7D%2C%20%7B%22modelId%22%3A%20%22gemma-2-2b%22%2C%20%22layer%22%3A%20%221-gemmascope-res-16k%22%2C%20%22index%22%3A%20%224825%22%7D%2C%20%7B%22modelId%22%3A%20%22gemma-2-2b%22%2C%20%22layer%22%3A%20%221-gemmascope-res-16k%22%2C%20%22index%22%3A%20%2215322%22%7D%2C%20%7B%22modelId%22%3A%20%22gemma-2-2b%22%2C%20%22layer%22%3A%20%221-gemmascope-res-16k%22%2C%20%22index%22%3A%20%2210572%22%7D%2C%20%7B%22modelId%22%3A%20%22g

### try ablate multilayer sae features

In [26]:
saes = {layer: load_gemma2_2b_sae(layer=layer) for layer in range(model.config.num_hidden_layers)}

top_features_all_layers = {
    layer: np.loadtxt(f'../data/top_features/gemma-2-2b-it/layer{layer}.txt', dtype=int) 
    for layer in range(model.config.num_hidden_layers)
}

Found SAE with l0=105 at path google/gemma-scope-2b-pt-res/layer_0/width_16k/average_l0_105/params.npz
Found SAE with l0=102 at path google/gemma-scope-2b-pt-res/layer_1/width_16k/average_l0_102/params.npz
Found SAE with l0=141 at path google/gemma-scope-2b-pt-res/layer_2/width_16k/average_l0_141/params.npz
Found SAE with l0=59 at path google/gemma-scope-2b-pt-res/layer_3/width_16k/average_l0_59/params.npz
Found SAE with l0=124 at path google/gemma-scope-2b-pt-res/layer_4/width_16k/average_l0_124/params.npz
Found SAE with l0=68 at path google/gemma-scope-2b-pt-res/layer_5/width_16k/average_l0_68/params.npz
Found SAE with l0=70 at path google/gemma-scope-2b-pt-res/layer_6/width_16k/average_l0_70/params.npz
Found SAE with l0=69 at path google/gemma-scope-2b-pt-res/layer_7/width_16k/average_l0_69/params.npz
Found SAE with l0=71 at path google/gemma-scope-2b-pt-res/layer_8/width_16k/average_l0_71/params.npz
Found SAE with l0=73 at path google/gemma-scope-2b-pt-res/layer_9/width_16k/average

In [29]:
top_features_all_layers[0]

array([ 3148,  7146,   611, ..., 15450,  4960, 12182])

In [47]:
from dataclasses import dataclass

@dataclass
class AblationResults:
    multiplier: int
    n_features: int
    acc: dict[str, float]
    full_metrics: dict[str, dict]

In [49]:
from contextlib import ExitStack

def run_interventions(model, saes, top_features, n_features, multiplier, datasets):
    with ExitStack() as stack:
        for layer in range(model.config.num_hidden_layers):
            stack.enter_context(scaling_intervention(
                model, layer, saes[layer], top_features[layer][:n_features], multiplier
            ))

        acc = {}
        full_metrics = {}
        for dataset_name in datasets:
            metrics = calculate_MCQ_metrics(model, dataset_name=dataset_name)
            acc[dataset_name] = metrics['mean_correct']
            full_metrics[dataset_name] = metrics
        
    print(f'{n_features} features, {multiplier}x')
    print(acc)
    return AblationResults(multiplier=multiplier, n_features=n_features, acc=acc, full_metrics=full_metrics)



In [50]:
n_features = 20
multiplier = 20
datasets = ['wmdp-bio', 'high_school_us_history', 'high_school_geography', 'college_computer_science', 'human_aging', 'college_biology']

all_results = []
for n_features in [10, 20, 50, 100]:
    for multiplier in [1, 5, 10, 20, 50, 100]:
        result = run_interventions(model, saes, top_features_all_layers, n_features, multiplier, datasets)
        all_results.append(result)


100%|██████████| 213/213 [01:30<00:00,  2.36it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.20it/s]
100%|██████████| 17/17 [00:10<00:00,  1.69it/s]
100%|██████████| 38/38 [00:10<00:00,  3.47it/s]
100%|██████████| 24/24 [00:10<00:00,  2.28it/s]


10 features, 1x
{'wmdp-bio': 0.604084849357605, 'high_school_us_history': 0.7352941632270813, 'high_school_geography': 0.752525269985199, 'college_computer_science': 0.3999999761581421, 'human_aging': 0.6322870254516602, 'college_biology': 0.6597222089767456}


100%|██████████| 213/213 [01:32<00:00,  2.31it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.19it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:10<00:00,  3.47it/s]
100%|██████████| 24/24 [00:10<00:00,  2.26it/s]


10 features, 5x
{'wmdp-bio': 0.5027494430541992, 'high_school_us_history': 0.7107843160629272, 'high_school_geography': 0.7424242496490479, 'college_computer_science': 0.4099999964237213, 'human_aging': 0.6098654866218567, 'college_biology': 0.5833333134651184}


100%|██████████| 213/213 [01:32<00:00,  2.31it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.17it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:10<00:00,  3.47it/s]
100%|██████████| 24/24 [00:10<00:00,  2.26it/s]


10 features, 10x
{'wmdp-bio': 0.42576590180397034, 'high_school_us_history': 0.6813725829124451, 'high_school_geography': 0.7424242496490479, 'college_computer_science': 0.3799999952316284, 'human_aging': 0.5874439477920532, 'college_biology': 0.506944477558136}


100%|██████████| 213/213 [01:32<00:00,  2.30it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.18it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:10<00:00,  3.46it/s]
100%|██████████| 24/24 [00:10<00:00,  2.27it/s]


10 features, 20x
{'wmdp-bio': 0.293794184923172, 'high_school_us_history': 0.6617647409439087, 'high_school_geography': 0.6565656661987305, 'college_computer_science': 0.35999998450279236, 'human_aging': 0.5022422075271606, 'college_biology': 0.4027777910232544}


100%|██████████| 213/213 [01:32<00:00,  2.30it/s]
100%|██████████| 34/34 [00:39<00:00,  1.17s/it]
100%|██████████| 33/33 [00:10<00:00,  3.17it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:10<00:00,  3.47it/s]
100%|██████████| 24/24 [00:10<00:00,  2.26it/s]


10 features, 50x
{'wmdp-bio': 0.25058916211128235, 'high_school_us_history': 0.2549019753932953, 'high_school_geography': 0.28787878155708313, 'college_computer_science': 0.29999998211860657, 'human_aging': 0.3004484474658966, 'college_biology': 0.236111119389534}


100%|██████████| 213/213 [01:33<00:00,  2.29it/s]
100%|██████████| 34/34 [00:39<00:00,  1.17s/it]
100%|██████████| 33/33 [00:10<00:00,  3.16it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:11<00:00,  3.45it/s]
100%|██████████| 24/24 [00:10<00:00,  2.25it/s]


10 features, 100x
{'wmdp-bio': 0.2663000822067261, 'high_school_us_history': 0.2549019753932953, 'high_school_geography': 0.17171716690063477, 'college_computer_science': 0.28999999165534973, 'human_aging': 0.29596415162086487, 'college_biology': 0.236111119389534}


100%|██████████| 213/213 [01:32<00:00,  2.31it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.18it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:10<00:00,  3.47it/s]
100%|██████████| 24/24 [00:10<00:00,  2.27it/s]


20 features, 1x
{'wmdp-bio': 0.5946583151817322, 'high_school_us_history': 0.7156863212585449, 'high_school_geography': 0.752525269985199, 'college_computer_science': 0.4099999964237213, 'human_aging': 0.6233184337615967, 'college_biology': 0.6666666865348816}


100%|██████████| 213/213 [01:32<00:00,  2.31it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.17it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:10<00:00,  3.47it/s]
100%|██████████| 24/24 [00:10<00:00,  2.27it/s]


20 features, 5x
{'wmdp-bio': 0.46975648403167725, 'high_school_us_history': 0.7058823704719543, 'high_school_geography': 0.7424242496490479, 'college_computer_science': 0.38999998569488525, 'human_aging': 0.6188341379165649, 'college_biology': 0.5972222089767456}


100%|██████████| 213/213 [01:32<00:00,  2.31it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.18it/s]
100%|██████████| 17/17 [00:10<00:00,  1.69it/s]
100%|██████████| 38/38 [00:10<00:00,  3.47it/s]
100%|██████████| 24/24 [00:10<00:00,  2.27it/s]


20 features, 10x
{'wmdp-bio': 0.3817753493785858, 'high_school_us_history': 0.6715686321258545, 'high_school_geography': 0.7171717286109924, 'college_computer_science': 0.3499999940395355, 'human_aging': 0.5515695214271545, 'college_biology': 0.4791666567325592}


100%|██████████| 213/213 [01:32<00:00,  2.31it/s]
100%|██████████| 34/34 [00:39<00:00,  1.17s/it]
100%|██████████| 33/33 [00:10<00:00,  3.19it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:10<00:00,  3.46it/s]
100%|██████████| 24/24 [00:10<00:00,  2.27it/s]


20 features, 20x
{'wmdp-bio': 0.2631579041481018, 'high_school_us_history': 0.5, 'high_school_geography': 0.5, 'college_computer_science': 0.2800000011920929, 'human_aging': 0.4439462125301361, 'college_biology': 0.2986111044883728}


100%|██████████| 213/213 [01:33<00:00,  2.28it/s]
100%|██████████| 34/34 [00:39<00:00,  1.17s/it]
100%|██████████| 33/33 [00:10<00:00,  3.15it/s]
100%|██████████| 17/17 [00:10<00:00,  1.67it/s]
100%|██████████| 38/38 [00:11<00:00,  3.44it/s]
100%|██████████| 24/24 [00:10<00:00,  2.26it/s]


20 features, 50x
{'wmdp-bio': 0.25530242919921875, 'high_school_us_history': 0.2647058963775635, 'high_school_geography': 0.1818181872367859, 'college_computer_science': 0.1899999976158142, 'human_aging': 0.33183857798576355, 'college_biology': 0.2916666567325592}


100%|██████████| 213/213 [01:33<00:00,  2.28it/s]
100%|██████████| 34/34 [00:40<00:00,  1.18s/it]
100%|██████████| 33/33 [00:10<00:00,  3.14it/s]
100%|██████████| 17/17 [00:10<00:00,  1.66it/s]
100%|██████████| 38/38 [00:11<00:00,  3.36it/s]
100%|██████████| 24/24 [00:10<00:00,  2.23it/s]


20 features, 100x
{'wmdp-bio': 0.2592301666736603, 'high_school_us_history': 0.28431373834609985, 'high_school_geography': 0.21212121844291687, 'college_computer_science': 0.23999999463558197, 'human_aging': 0.25560539960861206, 'college_biology': 0.284722238779068}


100%|██████████| 213/213 [01:32<00:00,  2.30it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.18it/s]
100%|██████████| 17/17 [00:10<00:00,  1.67it/s]
100%|██████████| 38/38 [00:10<00:00,  3.49it/s]
100%|██████████| 24/24 [00:10<00:00,  2.27it/s]


50 features, 1x
{'wmdp-bio': 0.5820895433425903, 'high_school_us_history': 0.7303921580314636, 'high_school_geography': 0.7626262903213501, 'college_computer_science': 0.3700000047683716, 'human_aging': 0.6278027296066284, 'college_biology': 0.6875}


100%|██████████| 213/213 [01:32<00:00,  2.31it/s]
100%|██████████| 34/34 [00:39<00:00,  1.17s/it]
100%|██████████| 33/33 [00:10<00:00,  3.19it/s]
100%|██████████| 17/17 [00:10<00:00,  1.69it/s]
100%|██████████| 38/38 [00:10<00:00,  3.46it/s]
100%|██████████| 24/24 [00:10<00:00,  2.28it/s]


50 features, 5x
{'wmdp-bio': 0.447761207818985, 'high_school_us_history': 0.6960784792900085, 'high_school_geography': 0.7121211886405945, 'college_computer_science': 0.38999998569488525, 'human_aging': 0.573991060256958, 'college_biology': 0.472222238779068}


100%|██████████| 213/213 [01:32<00:00,  2.31it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.18it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:10<00:00,  3.47it/s]
100%|██████████| 24/24 [00:10<00:00,  2.27it/s]


50 features, 10x
{'wmdp-bio': 0.329143762588501, 'high_school_us_history': 0.6225490570068359, 'high_school_geography': 0.6010100841522217, 'college_computer_science': 0.3400000035762787, 'human_aging': 0.40807175636291504, 'college_biology': 0.347222238779068}


100%|██████████| 213/213 [01:32<00:00,  2.30it/s]
100%|██████████| 34/34 [00:39<00:00,  1.17s/it]
100%|██████████| 33/33 [00:10<00:00,  3.18it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:11<00:00,  3.45it/s]
100%|██████████| 24/24 [00:10<00:00,  2.26it/s]


50 features, 20x
{'wmdp-bio': 0.2584446370601654, 'high_school_us_history': 0.25, 'high_school_geography': 0.18686868250370026, 'college_computer_science': 0.23999999463558197, 'human_aging': 0.3139013648033142, 'college_biology': 0.25}


100%|██████████| 213/213 [01:33<00:00,  2.28it/s]
100%|██████████| 34/34 [00:40<00:00,  1.18s/it]
100%|██████████| 33/33 [00:10<00:00,  3.15it/s]
100%|██████████| 17/17 [00:10<00:00,  1.66it/s]
100%|██████████| 38/38 [00:11<00:00,  3.42it/s]
100%|██████████| 24/24 [00:10<00:00,  2.24it/s]


50 features, 50x
{'wmdp-bio': 0.2639434337615967, 'high_school_us_history': 0.2598039209842682, 'high_school_geography': 0.2626262605190277, 'college_computer_science': 0.2800000011920929, 'human_aging': 0.2331838607788086, 'college_biology': 0.2291666716337204}


100%|██████████| 213/213 [01:33<00:00,  2.28it/s]
100%|██████████| 34/34 [00:40<00:00,  1.18s/it]
100%|██████████| 33/33 [00:10<00:00,  3.15it/s]
100%|██████████| 17/17 [00:10<00:00,  1.66it/s]
100%|██████████| 38/38 [00:11<00:00,  3.43it/s]
100%|██████████| 24/24 [00:10<00:00,  2.24it/s]


50 features, 100x
{'wmdp-bio': 0.24351924657821655, 'high_school_us_history': 0.2598039209842682, 'high_school_geography': 0.27272728085517883, 'college_computer_science': 0.29999998211860657, 'human_aging': 0.20179373025894165, 'college_biology': 0.2916666567325592}


100%|██████████| 213/213 [01:32<00:00,  2.31it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.19it/s]
100%|██████████| 17/17 [00:10<00:00,  1.69it/s]
100%|██████████| 38/38 [00:10<00:00,  3.47it/s]
100%|██████████| 24/24 [00:10<00:00,  2.27it/s]


100 features, 1x
{'wmdp-bio': 0.5671641826629639, 'high_school_us_history': 0.7401961088180542, 'high_school_geography': 0.752525269985199, 'college_computer_science': 0.3999999761581421, 'human_aging': 0.6188341379165649, 'college_biology': 0.6666666865348816}


100%|██████████| 213/213 [01:32<00:00,  2.30it/s]
100%|██████████| 34/34 [00:39<00:00,  1.17s/it]
100%|██████████| 33/33 [00:10<00:00,  3.17it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:11<00:00,  3.45it/s]
100%|██████████| 24/24 [00:10<00:00,  2.27it/s]


100 features, 5x
{'wmdp-bio': 0.38098978996276855, 'high_school_us_history': 0.5686274766921997, 'high_school_geography': 0.5959596037864685, 'college_computer_science': 0.35999998450279236, 'human_aging': 0.38565024733543396, 'college_biology': 0.4375}


100%|██████████| 213/213 [01:33<00:00,  2.28it/s]
100%|██████████| 34/34 [00:39<00:00,  1.17s/it]
100%|██████████| 33/33 [00:10<00:00,  3.17it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:11<00:00,  3.45it/s]
100%|██████████| 24/24 [00:10<00:00,  2.27it/s]


100 features, 10x
{'wmdp-bio': 0.24823252856731415, 'high_school_us_history': 0.2549019753932953, 'high_school_geography': 0.2222222238779068, 'college_computer_science': 0.26999998092651367, 'human_aging': 0.3273542821407318, 'college_biology': 0.2569444477558136}


100%|██████████| 213/213 [01:33<00:00,  2.28it/s]
100%|██████████| 34/34 [00:39<00:00,  1.18s/it]
100%|██████████| 33/33 [00:10<00:00,  3.15it/s]
100%|██████████| 17/17 [00:10<00:00,  1.66it/s]
100%|██████████| 38/38 [00:11<00:00,  3.40it/s]
100%|██████████| 24/24 [00:10<00:00,  2.25it/s]


100 features, 20x
{'wmdp-bio': 0.23644933104515076, 'high_school_us_history': 0.27941176295280457, 'high_school_geography': 0.2222222238779068, 'college_computer_science': 0.25, 'human_aging': 0.30941706895828247, 'college_biology': 0.2916666567325592}


100%|██████████| 213/213 [01:33<00:00,  2.27it/s]
100%|██████████| 34/34 [00:40<00:00,  1.18s/it]
100%|██████████| 33/33 [00:10<00:00,  3.15it/s]
100%|██████████| 17/17 [00:10<00:00,  1.66it/s]
100%|██████████| 38/38 [00:11<00:00,  3.42it/s]
100%|██████████| 24/24 [00:10<00:00,  2.25it/s]


100 features, 50x
{'wmdp-bio': 0.2702278196811676, 'high_school_us_history': 0.2647058963775635, 'high_school_geography': 0.21717171370983124, 'college_computer_science': 0.25, 'human_aging': 0.286995530128479, 'college_biology': 0.2083333283662796}


100%|██████████| 213/213 [01:33<00:00,  2.28it/s]
100%|██████████| 34/34 [00:40<00:00,  1.18s/it]
100%|██████████| 33/33 [00:10<00:00,  3.14it/s]
100%|██████████| 17/17 [00:10<00:00,  1.67it/s]
100%|██████████| 38/38 [00:11<00:00,  3.43it/s]
100%|██████████| 24/24 [00:10<00:00,  2.25it/s]

100 features, 100x
{'wmdp-bio': 0.24666143953800201, 'high_school_us_history': 0.25, 'high_school_geography': 0.1818181872367859, 'college_computer_science': 0.25999999046325684, 'human_aging': 0.30941706895828247, 'college_biology': 0.2569444477558136}





In [37]:
baseline_results

{'wmdp-bio': 0.6355066895484924,
 'high_school_us_history': 0.7401961088180542,
 'high_school_geography': 0.7575757503509521,
 'college_computer_science': 0.44999998807907104,
 'human_aging': 0.6322870254516602,
 'college_biology': 0.7083333134651184}

In [35]:
results # 20 features each layer, zero ablate

{'wmdp-bio': 0.6245090365409851,
 'high_school_us_history': 0.7450980544090271,
 'high_school_geography': 0.752525269985199,
 'college_computer_science': 0.44999998807907104,
 'human_aging': 0.6233184337615967,
 'college_biology': 0.6875}

In [39]:
zero_ablate_results # zero ablate 40 features

{'wmdp-bio': 0.628436803817749,
 'high_school_us_history': 0.7352941632270813,
 'high_school_geography': 0.7575757503509521,
 'college_computer_science': 0.4399999976158142,
 'human_aging': 0.6367713212966919,
 'college_biology': 0.7013888955116272}

In [43]:
zero_ablate_results_1000 # zero ablate 1000 features

{'wmdp-bio': 0.6111547350883484,
 'high_school_us_history': 0.7401961088180542,
 'high_school_geography': 0.6616161465644836,
 'college_computer_science': 0.3999999761581421,
 'human_aging': 0.573991060256958,
 'college_biology': 0.6527777910232544}

In [45]:
zero_ablate_results_5000 # zero ablate 1000 features

{'wmdp-bio': 0.2820110023021698,
 'high_school_us_history': 0.30882352590560913,
 'high_school_geography': 0.2626262605190277,
 'college_computer_science': 0.25,
 'human_aging': 0.31838566064834595,
 'college_biology': 0.2708333432674408}

In [33]:
results # 20 features each layer, multiplier 20

{'wmdp-bio': 0.3244304955005646,
 'high_school_us_history': 0.7352941632270813,
 'high_school_geography': 0.7272727489471436,
 'college_computer_science': 0.41999998688697815,
 'human_aging': 0.5291479825973511,
 'college_biology': 0.4513888955116272}

In [18]:


# layer = 3
# features_to_ablate = list(top_features[:80])
# multiplier = 20

n_features = 20
multiplier = 20
with scaling_intervention(model, 1, saes[1], top_features[1][:n_features], multiplier):
    with scaling_intervention(model, 2, saes[2], top_features[2][:n_features], multiplier):
        with scaling_intervention(model, 3, saes[3], top_features[3][:n_features], multiplier):

            intervened_metrics = calculate_MCQ_metrics(model)
            intervened_history_metrics = calculate_MCQ_metrics(model, dataset_name='high_school_us_history')
            interved_human_aging_metrics = calculate_MCQ_metrics(model, dataset_name='human_aging')
            intervened_college_bio_metrics = calculate_MCQ_metrics(model, dataset_name='college_biology')
            
            print(f"n_features: {n_features}, multiplier: {multiplier}")
            print(f"\t\twmdp-bio: {intervened_metrics['mean_correct']}")
            print(f"\t\thigh_school_us_history: {intervened_history_metrics['mean_correct']}")
            print(f"\t\thuman_aging: {interved_human_aging_metrics['mean_correct']}")
            print(f"\t\tcollege_bio: {intervened_college_bio_metrics['mean_correct']}")
    

Downloading readme:   0%|          | 0.00/4.64k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/258k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/1273 [00:00<?, ? examples/s]

100%|██████████| 213/213 [00:26<00:00,  7.93it/s]


Downloading readme:   0%|          | 0.00/53.2k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/138k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/155k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/27.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/17.8k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/204 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

100%|██████████| 34/34 [00:10<00:00,  3.20it/s]


Downloading data:   0%|          | 0.00/31.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.28k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.67k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/223 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/23 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

100%|██████████| 38/38 [00:03<00:00, 10.18it/s]


Downloading data:   0%|          | 0.00/31.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.90k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.27k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/144 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/16 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

100%|██████████| 24/24 [00:03<00:00,  7.87it/s]

n_features: 20, multiplier: 20
		wmdp-bio: 0.3244304955005646
		high_school_us_history: 0.7352941632270813
		human_aging: 0.5291479825973511
		college_bio: 0.4513888955116272





In [16]:
top_features[1].shape

(8820,)