In [16]:
%load_ext autoreload
%autoreload 2
import torch
import numpy as np

from unlearning.tool import get_hf_model
from unlearning.feature_activation import get_forget_retain_data, tokenize_dataset, get_mean_feature_activation, get_top_features
from unlearning.jump_relu import load_gemma2_2b_sae
from unlearning.intervention import intervention
from unlearning.metrics import calculate_MCQ_metrics

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
model = get_hf_model('google/gemma-2-2b-it')

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/46.9k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/555 [00:00<?, ?B/s]

### insert random vector if the top features fire up

In [4]:
# saes = {layer: load_gemma2_2b_sae(layer=layer) for layer in range(model.config.num_hidden_layers)}

# top_features_all_layers = {
#     layer: np.loadtxt(f'../data/top_features/gemma-2-2b-it/layer{layer}.txt', dtype=int) 
#     for layer in range(model.config.num_hidden_layers)
# }

LAYER = 3
sae = load_gemma2_2b_sae(layer=LAYER)
top_features = np.loadtxt(f'../data/top_features/gemma-2-2b-it/layer{LAYER}.txt', dtype=int)

Found SAE with l0=59 at path google/gemma-scope-2b-pt-res/layer_3/width_16k/average_l0_59/params.npz


In [5]:
from dataclasses import dataclass

@dataclass
class AblationResults:
    multiplier: int
    n_features: int
    acc: dict[str, float]
    full_metrics: dict[str, dict]

In [20]:
n_features = 20
multiplier = 100
with intervention(model, LAYER, sae, top_features[:n_features], multiplier=multiplier, intervention_type='clamp', random=False):
    acc, full_metrics = {}, {}
    datasets = ['wmdp-bio', 'high_school_us_history', 'high_school_geography', 'college_computer_science', 'human_aging', 'college_biology']
    
    for dataset_name in datasets:
        metrics = calculate_MCQ_metrics(model, dataset_name=dataset_name)
        acc[dataset_name] = metrics['mean_correct']
        full_metrics[dataset_name] = metrics
        
    print(acc)

100%|██████████| 213/213 [00:20<00:00, 10.17it/s]
100%|██████████| 34/34 [00:08<00:00,  4.15it/s]
100%|██████████| 33/33 [00:02<00:00, 11.71it/s]
100%|██████████| 17/17 [00:02<00:00,  7.92it/s]
100%|██████████| 38/38 [00:03<00:00, 12.09it/s]
100%|██████████| 24/24 [00:02<00:00,  9.83it/s]


{'wmdp-bio': 0.46975648403167725, 'high_school_us_history': 0.7598039507865906, 'high_school_geography': 0.7424242496490479, 'college_computer_science': 0.4599999785423279, 'human_aging': 0.5695067644119263, 'college_biology': 0.5416666865348816}


In [22]:
n_features = 20
multiplier = 100
with intervention(model, LAYER, sae, top_features[:n_features], multiplier=multiplier, intervention_type='clamp', random=True):
    acc, full_metrics = {}, {}
    datasets = ['wmdp-bio', 'high_school_us_history', 'high_school_geography', 'college_computer_science', 'human_aging', 'college_biology']
    
    for dataset_name in datasets:
        metrics = calculate_MCQ_metrics(model, dataset_name=dataset_name)
        acc[dataset_name] = metrics['mean_correct']
        full_metrics[dataset_name] = metrics
        
    print(acc)

100%|██████████| 213/213 [00:22<00:00,  9.51it/s]
100%|██████████| 34/34 [00:08<00:00,  4.16it/s]
100%|██████████| 33/33 [00:02<00:00, 11.59it/s]
100%|██████████| 17/17 [00:02<00:00,  7.98it/s]
100%|██████████| 38/38 [00:03<00:00, 11.71it/s]
100%|██████████| 24/24 [00:02<00:00, 10.15it/s]

{'wmdp-bio': 0.542026698589325, 'high_school_us_history': 0.7303921580314636, 'high_school_geography': 0.7626262903213501, 'college_computer_science': 0.4399999976158142, 'human_aging': 0.5829596519470215, 'college_biology': 0.6388888955116272}





In [49]:
from contextlib import ExitStack

def run_interventions(model, saes, top_features, n_features, multiplier, datasets):
    with ExitStack() as stack:
        for layer in range(model.config.num_hidden_layers):
            stack.enter_context(scaling_intervention(
                model, layer, saes[layer], top_features[layer][:n_features], multiplier
            ))

        acc = {}
        full_metrics = {}
        for dataset_name in datasets:
            metrics = calculate_MCQ_metrics(model, dataset_name=dataset_name)
            acc[dataset_name] = metrics['mean_correct']
            full_metrics[dataset_name] = metrics
        
    print(f'{n_features} features, {multiplier}x')
    print(acc)
    return AblationResults(multiplier=multiplier, n_features=n_features, acc=acc, full_metrics=full_metrics)



In [50]:
n_features = 20
multiplier = 20
datasets = ['wmdp-bio', 'high_school_us_history', 'high_school_geography', 'college_computer_science', 'human_aging', 'college_biology']

all_results = []
for n_features in [10, 20, 50, 100]:
    for multiplier in [1, 5, 10, 20, 50, 100]:
        result = run_interventions(model, saes, top_features_all_layers, n_features, multiplier, datasets)
        all_results.append(result)


100%|██████████| 213/213 [01:30<00:00,  2.36it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.20it/s]
100%|██████████| 17/17 [00:10<00:00,  1.69it/s]
100%|██████████| 38/38 [00:10<00:00,  3.47it/s]
100%|██████████| 24/24 [00:10<00:00,  2.28it/s]


10 features, 1x
{'wmdp-bio': 0.604084849357605, 'high_school_us_history': 0.7352941632270813, 'high_school_geography': 0.752525269985199, 'college_computer_science': 0.3999999761581421, 'human_aging': 0.6322870254516602, 'college_biology': 0.6597222089767456}


100%|██████████| 213/213 [01:32<00:00,  2.31it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.19it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:10<00:00,  3.47it/s]
100%|██████████| 24/24 [00:10<00:00,  2.26it/s]


10 features, 5x
{'wmdp-bio': 0.5027494430541992, 'high_school_us_history': 0.7107843160629272, 'high_school_geography': 0.7424242496490479, 'college_computer_science': 0.4099999964237213, 'human_aging': 0.6098654866218567, 'college_biology': 0.5833333134651184}


100%|██████████| 213/213 [01:32<00:00,  2.31it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.17it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:10<00:00,  3.47it/s]
100%|██████████| 24/24 [00:10<00:00,  2.26it/s]


10 features, 10x
{'wmdp-bio': 0.42576590180397034, 'high_school_us_history': 0.6813725829124451, 'high_school_geography': 0.7424242496490479, 'college_computer_science': 0.3799999952316284, 'human_aging': 0.5874439477920532, 'college_biology': 0.506944477558136}


100%|██████████| 213/213 [01:32<00:00,  2.30it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.18it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:10<00:00,  3.46it/s]
100%|██████████| 24/24 [00:10<00:00,  2.27it/s]


10 features, 20x
{'wmdp-bio': 0.293794184923172, 'high_school_us_history': 0.6617647409439087, 'high_school_geography': 0.6565656661987305, 'college_computer_science': 0.35999998450279236, 'human_aging': 0.5022422075271606, 'college_biology': 0.4027777910232544}


100%|██████████| 213/213 [01:32<00:00,  2.30it/s]
100%|██████████| 34/34 [00:39<00:00,  1.17s/it]
100%|██████████| 33/33 [00:10<00:00,  3.17it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:10<00:00,  3.47it/s]
100%|██████████| 24/24 [00:10<00:00,  2.26it/s]


10 features, 50x
{'wmdp-bio': 0.25058916211128235, 'high_school_us_history': 0.2549019753932953, 'high_school_geography': 0.28787878155708313, 'college_computer_science': 0.29999998211860657, 'human_aging': 0.3004484474658966, 'college_biology': 0.236111119389534}


100%|██████████| 213/213 [01:33<00:00,  2.29it/s]
100%|██████████| 34/34 [00:39<00:00,  1.17s/it]
100%|██████████| 33/33 [00:10<00:00,  3.16it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:11<00:00,  3.45it/s]
100%|██████████| 24/24 [00:10<00:00,  2.25it/s]


10 features, 100x
{'wmdp-bio': 0.2663000822067261, 'high_school_us_history': 0.2549019753932953, 'high_school_geography': 0.17171716690063477, 'college_computer_science': 0.28999999165534973, 'human_aging': 0.29596415162086487, 'college_biology': 0.236111119389534}


100%|██████████| 213/213 [01:32<00:00,  2.31it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.18it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:10<00:00,  3.47it/s]
100%|██████████| 24/24 [00:10<00:00,  2.27it/s]


20 features, 1x
{'wmdp-bio': 0.5946583151817322, 'high_school_us_history': 0.7156863212585449, 'high_school_geography': 0.752525269985199, 'college_computer_science': 0.4099999964237213, 'human_aging': 0.6233184337615967, 'college_biology': 0.6666666865348816}


100%|██████████| 213/213 [01:32<00:00,  2.31it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.17it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:10<00:00,  3.47it/s]
100%|██████████| 24/24 [00:10<00:00,  2.27it/s]


20 features, 5x
{'wmdp-bio': 0.46975648403167725, 'high_school_us_history': 0.7058823704719543, 'high_school_geography': 0.7424242496490479, 'college_computer_science': 0.38999998569488525, 'human_aging': 0.6188341379165649, 'college_biology': 0.5972222089767456}


100%|██████████| 213/213 [01:32<00:00,  2.31it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.18it/s]
100%|██████████| 17/17 [00:10<00:00,  1.69it/s]
100%|██████████| 38/38 [00:10<00:00,  3.47it/s]
100%|██████████| 24/24 [00:10<00:00,  2.27it/s]


20 features, 10x
{'wmdp-bio': 0.3817753493785858, 'high_school_us_history': 0.6715686321258545, 'high_school_geography': 0.7171717286109924, 'college_computer_science': 0.3499999940395355, 'human_aging': 0.5515695214271545, 'college_biology': 0.4791666567325592}


100%|██████████| 213/213 [01:32<00:00,  2.31it/s]
100%|██████████| 34/34 [00:39<00:00,  1.17s/it]
100%|██████████| 33/33 [00:10<00:00,  3.19it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:10<00:00,  3.46it/s]
100%|██████████| 24/24 [00:10<00:00,  2.27it/s]


20 features, 20x
{'wmdp-bio': 0.2631579041481018, 'high_school_us_history': 0.5, 'high_school_geography': 0.5, 'college_computer_science': 0.2800000011920929, 'human_aging': 0.4439462125301361, 'college_biology': 0.2986111044883728}


100%|██████████| 213/213 [01:33<00:00,  2.28it/s]
100%|██████████| 34/34 [00:39<00:00,  1.17s/it]
100%|██████████| 33/33 [00:10<00:00,  3.15it/s]
100%|██████████| 17/17 [00:10<00:00,  1.67it/s]
100%|██████████| 38/38 [00:11<00:00,  3.44it/s]
100%|██████████| 24/24 [00:10<00:00,  2.26it/s]


20 features, 50x
{'wmdp-bio': 0.25530242919921875, 'high_school_us_history': 0.2647058963775635, 'high_school_geography': 0.1818181872367859, 'college_computer_science': 0.1899999976158142, 'human_aging': 0.33183857798576355, 'college_biology': 0.2916666567325592}


100%|██████████| 213/213 [01:33<00:00,  2.28it/s]
100%|██████████| 34/34 [00:40<00:00,  1.18s/it]
100%|██████████| 33/33 [00:10<00:00,  3.14it/s]
100%|██████████| 17/17 [00:10<00:00,  1.66it/s]
100%|██████████| 38/38 [00:11<00:00,  3.36it/s]
100%|██████████| 24/24 [00:10<00:00,  2.23it/s]


20 features, 100x
{'wmdp-bio': 0.2592301666736603, 'high_school_us_history': 0.28431373834609985, 'high_school_geography': 0.21212121844291687, 'college_computer_science': 0.23999999463558197, 'human_aging': 0.25560539960861206, 'college_biology': 0.284722238779068}


100%|██████████| 213/213 [01:32<00:00,  2.30it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.18it/s]
100%|██████████| 17/17 [00:10<00:00,  1.67it/s]
100%|██████████| 38/38 [00:10<00:00,  3.49it/s]
100%|██████████| 24/24 [00:10<00:00,  2.27it/s]


50 features, 1x
{'wmdp-bio': 0.5820895433425903, 'high_school_us_history': 0.7303921580314636, 'high_school_geography': 0.7626262903213501, 'college_computer_science': 0.3700000047683716, 'human_aging': 0.6278027296066284, 'college_biology': 0.6875}


100%|██████████| 213/213 [01:32<00:00,  2.31it/s]
100%|██████████| 34/34 [00:39<00:00,  1.17s/it]
100%|██████████| 33/33 [00:10<00:00,  3.19it/s]
100%|██████████| 17/17 [00:10<00:00,  1.69it/s]
100%|██████████| 38/38 [00:10<00:00,  3.46it/s]
100%|██████████| 24/24 [00:10<00:00,  2.28it/s]


50 features, 5x
{'wmdp-bio': 0.447761207818985, 'high_school_us_history': 0.6960784792900085, 'high_school_geography': 0.7121211886405945, 'college_computer_science': 0.38999998569488525, 'human_aging': 0.573991060256958, 'college_biology': 0.472222238779068}


100%|██████████| 213/213 [01:32<00:00,  2.31it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]
100%|██████████| 33/33 [00:10<00:00,  3.18it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:10<00:00,  3.47it/s]
100%|██████████| 24/24 [00:10<00:00,  2.27it/s]


50 features, 10x
{'wmdp-bio': 0.329143762588501, 'high_school_us_history': 0.6225490570068359, 'high_school_geography': 0.6010100841522217, 'college_computer_science': 0.3400000035762787, 'human_aging': 0.40807175636291504, 'college_biology': 0.347222238779068}


100%|██████████| 213/213 [01:32<00:00,  2.30it/s]
100%|██████████| 34/34 [00:39<00:00,  1.17s/it]
100%|██████████| 33/33 [00:10<00:00,  3.18it/s]
100%|██████████| 17/17 [00:10<00:00,  1.68it/s]
100%|██████████| 38/38 [00:11<00:00,  3.45it/s]
100%|██████████| 24/24 [00:10<00:00,  2.26it/s]


50 features, 20x
{'wmdp-bio': 0.2584446370601654, 'high_school_us_history': 0.25, 'high_school_geography': 0.18686868250370026, 'college_computer_science': 0.23999999463558197, 'human_aging': 0.3139013648033142, 'college_biology': 0.25}


100%|██████████| 213/213 [01:33<00:00,  2.28it/s]
100%|██████████| 34/34 [00:40<00:00,  1.18s/it]
100%|██████████| 33/33 [00:10<00:00,  3.15it/s]
100%|██████████| 17/17 [00:10<00:00,  1.66it/s]
100%|██████████| 38/38 [00:11<00:00,  3.42it/s]
100%|██████████| 24/24 [00:10<00:00,  2.24it/s]


50 features, 50x
{'wmdp-bio': 0.2639434337615967, 'high_school_us_history': 0.2598039209842682, 'high_school_geography': 0.2626262605190277, 'college_computer_science': 0.2800000011920929, 'human_aging': 0.2331838607788086, 'college_biology': 0.2291666716337204}


100%|██████████| 213/213 [01:33<00:00,  2.28it/s]
100%|██████████| 34/34 [00:40<00:00,  1.18s/it]
100%|██████████| 33/33 [00:10<00:00,  3.15it/s]
100%|██████████| 17/17 [00:10<00:00,  1.66it/s]
100%|██████████| 38/38 [00:11<00:00,  3.43it/s]
100%|██████████| 24/24 [00:10<00:00,  2.24it/s]


50 features, 100x
{'wmdp-bio': 0.24351924657821655, 'high_school_us_history': 0.2598039209842682, 'high_school_geography': 0.27272728085517883, 'college_computer_science': 0.29999998211860657, 'human_aging': 0.20179373025894165, 'college_biology': 0.2916666567325592}


100%|██████████| 213/213 [01:32<00:00,  2.31it/s]
100%|██████████| 34/34 [00:39<00:00,  1.16s/it]


In [37]:
baseline_results

{'wmdp-bio': 0.6355066895484924,
 'high_school_us_history': 0.7401961088180542,
 'high_school_geography': 0.7575757503509521,
 'college_computer_science': 0.44999998807907104,
 'human_aging': 0.6322870254516602,
 'college_biology': 0.7083333134651184}

In [35]:
results # 20 features each layer, zero ablate

{'wmdp-bio': 0.6245090365409851,
 'high_school_us_history': 0.7450980544090271,
 'high_school_geography': 0.752525269985199,
 'college_computer_science': 0.44999998807907104,
 'human_aging': 0.6233184337615967,
 'college_biology': 0.6875}

In [39]:
zero_ablate_results # zero ablate 40 features

{'wmdp-bio': 0.628436803817749,
 'high_school_us_history': 0.7352941632270813,
 'high_school_geography': 0.7575757503509521,
 'college_computer_science': 0.4399999976158142,
 'human_aging': 0.6367713212966919,
 'college_biology': 0.7013888955116272}

In [43]:
zero_ablate_results_1000 # zero ablate 1000 features

{'wmdp-bio': 0.6111547350883484,
 'high_school_us_history': 0.7401961088180542,
 'high_school_geography': 0.6616161465644836,
 'college_computer_science': 0.3999999761581421,
 'human_aging': 0.573991060256958,
 'college_biology': 0.6527777910232544}

In [45]:
zero_ablate_results_5000 # zero ablate 1000 features

{'wmdp-bio': 0.2820110023021698,
 'high_school_us_history': 0.30882352590560913,
 'high_school_geography': 0.2626262605190277,
 'college_computer_science': 0.25,
 'human_aging': 0.31838566064834595,
 'college_biology': 0.2708333432674408}

In [33]:
results # 20 features each layer, multiplier 20

{'wmdp-bio': 0.3244304955005646,
 'high_school_us_history': 0.7352941632270813,
 'high_school_geography': 0.7272727489471436,
 'college_computer_science': 0.41999998688697815,
 'human_aging': 0.5291479825973511,
 'college_biology': 0.4513888955116272}

In [18]:


# layer = 3
# features_to_ablate = list(top_features[:80])
# multiplier = 20

n_features = 20
multiplier = 20
with scaling_intervention(model, 1, saes[1], top_features[1][:n_features], multiplier):
    with scaling_intervention(model, 2, saes[2], top_features[2][:n_features], multiplier):
        with scaling_intervention(model, 3, saes[3], top_features[3][:n_features], multiplier):

            intervened_metrics = calculate_MCQ_metrics(model)
            intervened_history_metrics = calculate_MCQ_metrics(model, dataset_name='high_school_us_history')
            interved_human_aging_metrics = calculate_MCQ_metrics(model, dataset_name='human_aging')
            intervened_college_bio_metrics = calculate_MCQ_metrics(model, dataset_name='college_biology')
            
            print(f"n_features: {n_features}, multiplier: {multiplier}")
            print(f"\t\twmdp-bio: {intervened_metrics['mean_correct']}")
            print(f"\t\thigh_school_us_history: {intervened_history_metrics['mean_correct']}")
            print(f"\t\thuman_aging: {interved_human_aging_metrics['mean_correct']}")
            print(f"\t\tcollege_bio: {intervened_college_bio_metrics['mean_correct']}")
    

Downloading readme:   0%|          | 0.00/4.64k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/258k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/1273 [00:00<?, ? examples/s]

100%|██████████| 213/213 [00:26<00:00,  7.93it/s]


Downloading readme:   0%|          | 0.00/53.2k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/138k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/155k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/27.3k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/17.8k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/204 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/22 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

100%|██████████| 34/34 [00:10<00:00,  3.20it/s]


Downloading data:   0%|          | 0.00/31.2k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.28k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.67k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/223 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/23 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

100%|██████████| 38/38 [00:03<00:00, 10.18it/s]


Downloading data:   0%|          | 0.00/31.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.90k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.27k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/144 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/16 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/5 [00:00<?, ? examples/s]

100%|██████████| 24/24 [00:03<00:00,  7.87it/s]

n_features: 20, multiplier: 20
		wmdp-bio: 0.3244304955005646
		high_school_us_history: 0.7352941632270813
		human_aging: 0.5291479825973511
		college_bio: 0.4513888955116272





In [16]:
top_features[1].shape

(8820,)