In [1]:
%load_ext autoreload
%autoreload 2
import torch
import numpy as np
from transformer_lens import HookedTransformer

from unlearning.tool import get_hf_model
from unlearning.feature_activation import get_forget_retain_data, tokenize_dataset, get_feature_activation_sparsity, get_top_features
from unlearning.jump_relu import load_gemma2_2b_sae
from unlearning.intervention import scaling_intervention
from unlearning.metrics import calculate_MCQ_metrics, get_loss_added_hf, create_df_from_metrics, generate_ablate_params_list

In [10]:
layer = 13
sae = load_gemma2_2b_sae(layer=layer)


model = HookedTransformer.from_pretrained('google/gemma-2-2b-it')

Found SAE with l0=84 at path google/gemma-scope-2b-pt-res/layer_13/width_16k/average_l0_84/params.npz




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b-it into HookedTransformer


In [11]:
from sae.activation_store import ActivationsStore

sae.cfg.dataset = "Skylion007/openwebtext"
sae.cfg.n_batches_in_store_buffer = 8

activation_store = ActivationsStore(sae.cfg, model, create_dataloader=False)

In [12]:
ret = get_loss_added_hf(model, activation_store)

100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


In [13]:
# top_features = np.loadtxt(f'../data/top_features/gemma-2-2b-it-sparsity/layer3.txt', dtype=int)
# top_features

In [14]:
forget_sparsity = np.loadtxt(f'../data/top_features/gemma-2-2b-it-sparsity/layer{layer}_mean_feature_activation_forget.txt', dtype=float)
retain_sparsity = np.loadtxt(f'../data/top_features/gemma-2-2b-it-sparsity/layer{layer}_mean_feature_activation_retain.txt', dtype=float)


In [15]:
from unlearning.feature_activation import get_top_features
top_features = get_top_features(forget_sparsity, retain_sparsity, retain_threshold=0.01)

[ 8228  2807  3285  4618 11585 11761  5469  2644  1894 15250  9404 12023
  2570  4470  3372  1287 13018  1861  2292 15030]


In [16]:
len(top_features)

12994

In [17]:
all_dataset_names = ['loss_added', 'wmdp-bio', 'high_school_us_history', 'college_computer_science', 'high_school_geography', 'human_aging', 'college_biology']


In [18]:
# Calculate metrics
from unlearning.metrics import calculate_metrics_list


for retain_threshold in [0.01]: # [0.001, 0.01]
    top_features_custom = get_top_features(forget_sparsity, retain_sparsity, retain_threshold=retain_threshold)

    main_ablate_params = {
                        'intervention_method': 'clamp_feature_activation',
                        }


    sweep = {
            'features_to_ablate': [np.array(top_features_custom[:10]), np.array(top_features_custom[:20]), np.array(top_features_custom[:50])],
            'multiplier': [1, 5, 10, 50, 100],
            
            }


    metrics_list = calculate_metrics_list(
        model,
        sae,
        main_ablate_params,
        sweep,
        all_dataset_names,
        n_batch_loss_added=50,
        activation_store=activation_store,
        target_metric='correct',
        save_metrics=True,
        notes=f'_sparsity_thres{retain_threshold}'
    )
    


[ 8228  2807  3285  4618 11585 11761  5469  2644  1894 15250  9404 12023
  2570  4470  3372  1287 13018  1861  2292 15030]


100%|██████████| 87/87 [00:27<00:00,  3.16it/s]
100%|██████████| 87/87 [00:27<00:00,  3.16it/s]
100%|██████████| 87/87 [00:27<00:00,  3.15it/s]
100%|██████████| 87/87 [00:27<00:00,  3.17it/s]
100%|██████████| 87/87 [00:27<00:00,  3.16it/s]
100%|██████████| 87/87 [00:27<00:00,  3.14it/s]
100%|██████████| 87/87 [00:27<00:00,  3.16it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.11it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]]
100%|██████████| 87/87 [00:28<00:00,  3.10it/s]]
100%|██████████| 87/87 [00:28<00:00,  3.10it/s]]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]]
100%|██████████| 87/87 [00:27<00:00,  3.14it/s]]
100%|██████████| 15/15 [49:51<00:00, 199.43s/it]


In [19]:
create_df_from_metrics(metrics_list)

Unnamed: 0,loss_added,wmdp-bio,high_school_us_history,college_computer_science,high_school_geography,human_aging,college_biology,wmdp-bio_prob,high_school_us_history_prob,college_computer_science_prob,high_school_geography_prob,human_aging_prob,college_biology_prob
0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,0.828125,0.904785,0.90918,0.890625,0.874023,0.917969
1,-0.000515,1.0,1.0,1.0,1.0,1.0,1.0,0.833008,0.905273,0.910156,0.891602,0.875,0.919434
2,-0.000714,1.0,1.0,1.0,1.0,1.0,1.0,0.834473,0.905762,0.910156,0.89209,0.875,0.919922
3,-0.000906,1.0,1.0,1.0,1.0,1.0,1.0,0.835938,0.905762,0.910645,0.89209,0.874512,0.920898
4,-0.000647,0.994253,1.0,1.0,1.0,0.987654,1.0,0.839844,0.906738,0.912598,0.894531,0.873535,0.923828
5,0.002941,0.975096,1.0,1.0,1.0,0.987654,0.986301,0.844238,0.90918,0.913574,0.897461,0.873047,0.92627
6,-0.000909,1.0,1.0,1.0,1.0,1.0,1.0,0.831543,0.905273,0.908203,0.89209,0.875488,0.92041
7,-0.000994,1.0,1.0,1.0,1.0,1.0,1.0,0.83252,0.905273,0.907227,0.892578,0.875488,0.921387
8,-0.000946,1.0,1.0,1.0,1.0,1.0,1.0,0.833008,0.905273,0.905762,0.892578,0.875,0.922363
9,0.002775,0.994253,1.0,1.0,1.0,1.0,1.0,0.822266,0.905273,0.875977,0.894531,0.874512,0.924805
