In [5]:
%load_ext autoreload
%autoreload 2
import torch
import numpy as np
from transformer_lens import HookedTransformer

from unlearning.tool import get_hf_model
from unlearning.feature_activation import get_forget_retain_data, tokenize_dataset, get_feature_activation_sparsity, get_top_features
from unlearning.jump_relu import load_gemma2_2b_sae
from unlearning.intervention import scaling_intervention
from unlearning.metrics import calculate_metrics_list

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import pickle

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
layer = 7
sae = load_gemma2_2b_sae(layer=layer)

Found SAE with l0=69 at path google/gemma-scope-2b-pt-res/layer_7/width_16k/average_l0_69/params.npz


params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

In [2]:
model = HookedTransformer.from_pretrained('google/gemma-2-2b-it')

Found SAE with l0=59 at path google/gemma-scope-2b-pt-res/layer_3/width_16k/average_l0_59/params.npz




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b-it into HookedTransformer


In [3]:
top_features = np.loadtxt(f'../data/top_features/gemma-2-2b-it-sparsity/layer3.txt', dtype=int)
top_ten_features = top_features[:10]
top_ten_features

array([ 8786,  3025, 11913, 14227,   679,  1082, 10793,  6691,  8803,
        8934])

In [22]:
# Calculate metrics

def get_result_metrics(feature_idx, multipliers=[1, 5, 10, 50, 100], random=False):
    intervention_method = 'clamp_feature_activation' if not random else 'clamp_feature_activation_random'
    main_ablate_params = {
        'intervention_method': intervention_method,
        'features_to_ablate': feature_idx
    }


    sweep = {
        'multiplier': multipliers,
    }

    dataset_names = ['wmdp-bio']

    metrics = calculate_metrics_list(
        model,
        sae,
        main_ablate_params,
        sweep,
        dataset_names=dataset_names,
        include_baseline_metrics=False,
        split='all',
        verbose=False,
    )
    
    return metrics


In [11]:
def get_unleared_questions(metrics):
    base_unlearned = []
    for metric in metrics:
        n_correct = metric['wmdp-bio']['total_correct']
        n_questions = len(metric['wmdp-bio']['is_correct'])
        
        n_unlearned = n_questions - n_correct
        base_unlearned.append(n_unlearned)
        
    print(base_unlearned)

In [12]:
all_metrics = []

for feature in top_ten_features:
    metrics = get_result_metrics([feature])
    all_metrics.append(metrics)
    print(feature)
    get_unleared_questions(metrics)

100%|██████████| 87/87 [00:26<00:00,  3.24it/s]
100%|██████████| 87/87 [00:27<00:00,  3.21it/s]
100%|██████████| 87/87 [00:27<00:00,  3.16it/s]
100%|██████████| 87/87 [00:27<00:00,  3.14it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 5/5 [02:20<00:00, 28.15s/it]


8786
[0, 0, 0, 0, 3]


100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.11it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 5/5 [02:23<00:00, 28.64s/it]


3025
[0, 0, 0, 34, 121]


100%|██████████| 87/87 [00:28<00:00,  3.11it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 5/5 [02:23<00:00, 28.62s/it]


11913
[1, 3, 9, 81, 174]


100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 5/5 [02:23<00:00, 28.63s/it]


14227
[0, 4, 5, 82, 161]


100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.11it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 5/5 [02:23<00:00, 28.62s/it]


679
[1, 1, 1, 10, 37]


100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 5/5 [02:23<00:00, 28.74s/it]


1082
[0, 0, 1, 2, 13]


100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 5/5 [02:22<00:00, 28.55s/it]


10793
[0, 0, 0, 7, 45]


100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:28<00:00,  3.10it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 5/5 [02:23<00:00, 28.66s/it]


6691
[0, 0, 0, 10, 10]


100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 5/5 [02:22<00:00, 28.54s/it]


8803
[0, 0, 0, 1, 6]


100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 5/5 [02:24<00:00, 28.81s/it]

8934
[0, 0, 0, 0, 4]





In [13]:
all_metrics_random = []

for feature in top_ten_features:
    metrics = get_result_metrics([feature], random=True)
    all_metrics_random.append(metrics)
    print(feature)
    get_unleared_questions(metrics)

100%|██████████| 87/87 [00:27<00:00,  3.20it/s]
100%|██████████| 87/87 [00:28<00:00,  3.10it/s]
100%|██████████| 87/87 [00:28<00:00,  3.09it/s]
100%|██████████| 87/87 [00:28<00:00,  3.10it/s]
100%|██████████| 87/87 [00:27<00:00,  3.11it/s]
100%|██████████| 5/5 [02:22<00:00, 28.59s/it]


8786
[0, 0, 0, 0, 5]


100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.11it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 5/5 [02:22<00:00, 28.57s/it]


3025
[0, 0, 0, 1, 38]


100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 5/5 [02:22<00:00, 28.60s/it]


11913
[0, 0, 0, 21, 98]


100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.14it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.11it/s]
100%|██████████| 5/5 [02:22<00:00, 28.57s/it]


14227
[0, 0, 0, 22, 97]


100%|██████████| 87/87 [00:28<00:00,  3.05it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 5/5 [02:23<00:00, 28.78s/it]


679
[0, 0, 0, 0, 7]


100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 5/5 [02:23<00:00, 28.63s/it]


1082
[0, 0, 0, 1, 5]


100%|██████████| 87/87 [00:28<00:00,  3.03it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 5/5 [02:23<00:00, 28.74s/it]


10793
[0, 0, 0, 1, 12]


100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:28<00:00,  3.06it/s]
100%|██████████| 87/87 [00:28<00:00,  3.11it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 5/5 [02:23<00:00, 28.69s/it]


6691
[0, 0, 0, 2, 14]


100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.13it/s]
100%|██████████| 5/5 [02:22<00:00, 28.54s/it]


8803
[0, 0, 0, 1, 5]


100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.11it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 87/87 [00:27<00:00,  3.12it/s]
100%|██████████| 5/5 [02:23<00:00, 28.67s/it]

8934
[0, 0, 0, 1, 2]





In [14]:
print(len(metrics[0]['wmdp-bio']['is_correct']))

522


In [19]:
from unlearning.feature_activation import get_top_features

layer = 7
forget_sparsity = np.loadtxt(f'../data/top_features/gemma-2-2b-it-sparsity/layer{layer}_mean_feature_activation_forget.txt', dtype=float)
retain_sparsity = np.loadtxt(f'../data/top_features/gemma-2-2b-it-sparsity/layer{layer}_mean_feature_activation_retain.txt', dtype=float)

top_features = get_top_features(forget_sparsity, retain_sparsity, retain_threshold=0.01)
layer_7_top_ten_features = top_features[:10]

[12799  5453 15661  1354  4216  1017  3828 15939 15901  8122  4945  1967
  9723 13741 14822 15879  2512  5905 10566  7458]


In [23]:
base_layer7_metrics = get_result_metrics(layer_7_top_ten_features, multipliers=[50, 100], random=False)
random_layer7_metrics = get_result_metrics(layer_7_top_ten_features, multipliers=[50, 100], random=True)

100%|██████████| 87/87 [00:26<00:00,  3.29it/s]
100%|██████████| 87/87 [00:26<00:00,  3.24it/s]
100%|██████████| 2/2 [00:54<00:00, 27.37s/it]
100%|██████████| 87/87 [00:27<00:00,  3.19it/s]
100%|██████████| 87/87 [00:27<00:00,  3.16it/s]
100%|██████████| 2/2 [00:56<00:00, 28.24s/it]


In [24]:
get_unleared_questions(base_layer7_metrics)
get_unleared_questions(random_layer7_metrics)

[115, 214]
[42, 148]


In [25]:
214/522

0.4099616858237548