In [1]:
%load_ext autoreload
%autoreload 2
import torch
import random

from sae.sparse_autoencoder import load_saved_sae
from sae.metrics import model_store_from_sae
from unlearning.metrics import convert_wmdp_data_to_prompt, get_output_probs_abcd
from unlearning.tool import UnlearningConfig, SAEUnlearningTool, MCQ_ActivationStoreAnalysis

from huggingface_hub import hf_hub_download
from datasets import load_dataset
import numpy as np
import pandas as pd
import itertools


from jaxtyping import Float
from torch import Tensor

import plotly.express as px

In [2]:
# resid pre 9
REPO_ID = "eoinf/unlearning_saes"
FILENAME = "jolly-dream-40/sparse_autoencoder_gemma-2b-it_blocks.9.hook_resid_pre_s16384_127995904.pt"


filename = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
sae = load_saved_sae(filename)

model = model_store_from_sae(sae)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b-it into HookedTransformer
Moving model to device:  cuda


In [3]:
dataset = load_dataset("cais/mmlu", "high_school_us_history", split='test')

In [4]:
dataset

Dataset({
    features: ['question', 'subject', 'choices', 'answer'],
    num_rows: 204
})

In [5]:
answers = [x['answer'] for x in dataset]
questions = [x['question'] for x in dataset]
choices_list = [x['choices'] for x in dataset]

permutations = list(itertools.permutations(range(4)))
pre_question = "The following are multiple choice questions (with answers) about history.\n"

if permutations is None:
    prompts = [convert_wmdp_data_to_prompt(question, choices, prompt_format=None, pre_question=pre_question) for question, choices in zip(questions, choices_list)]
else:
    prompts = [[convert_wmdp_data_to_prompt(question, choices, prompt_format=None, permute_choices=p, pre_question=pre_question) for p in permutations]
                for question, choices in zip(questions, choices_list)]
    prompts = [item for sublist in prompts for item in sublist]
    
    answers = [[p.index(answer) for p in permutations] for answer in answers]
    answers = [item for sublist in answers for item in sublist]


actual_answers = answers

batch_size = np.minimum(len(prompts), 6)
n_batches = len(prompts) // batch_size

output_probs = get_output_probs_abcd(model, prompts, batch_size=batch_size, n_batches=n_batches)

predicted_answers = output_probs.argmax(dim=1)
predicted_probs = output_probs.max(dim=1)[0]

n_predicted_answers = len(predicted_answers)

actual_answers = torch.tensor(actual_answers)[:n_predicted_answers].to("cuda")

predicted_prob_of_correct_answers = output_probs[torch.arange(len(actual_answers)), actual_answers]

is_correct = (actual_answers == predicted_answers).to(torch.float)
mean_correct = is_correct.mean()

metrics = {}
metrics['mean_correct'] = float(mean_correct.item())
metrics['total_correct'] = int(np.sum(is_correct.cpu().numpy()))
metrics['is_correct'] =  is_correct.cpu().numpy()

metrics['output_probs'] = output_probs.cpu().numpy()
metrics['actual_answers'] = actual_answers.cpu().numpy()

metrics['predicted_answers'] = predicted_answers.cpu().numpy()
metrics['predicted_probs'] = predicted_probs.cpu().numpy()
metrics['predicted_probs_of_correct_answers'] = predicted_prob_of_correct_answers.cpu().numpy()
metrics['mean_predicted_prob_of_correct_answers'] = float(np.mean(predicted_prob_of_correct_answers.cpu().numpy()))


100%|██████████| 816/816 [08:18<00:00,  1.64it/s]


In [7]:
# find a mmlu question that the model can answer correctly no matter which permutation of the choices is
# used. Then use it as a control when searching for features
metrics

[autoreload of unlearning.metrics failed: Traceback (most recent call last):
  File "/root/unlearning/.venv/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/root/unlearning/.venv/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 475, in superreload
    module = reload(module)
  File "/opt/conda/lib/python3.10/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 619, in _exec
  File "<frozen importlib._bootstrap_external>", line 879, in exec_module
  File "<frozen importlib._bootstrap_external>", line 1017, in get_code
  File "<frozen importlib._bootstrap_external>", line 947, in source_to_code
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/root/unlearning/unlearning/metrics.py", line 174
    return partial(calculate_mmlu_metrics, dataset=)
                               

{'mean_correct': 0.45098039507865906,
 'total_correct': 2208,
 'is_correct': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
 'output_probs': array([[9.90619138e-03, 2.60861598e-05, 9.87313449e-01, 1.37895859e-05],
        [5.78177301e-03, 1.51714175e-05, 1.58104289e-04, 9.92989182e-01],
        [1.36140501e-03, 9.96114135e-01, 2.17192246e-05, 1.46499167e-06],
        ...,
        [2.73208789e-05, 9.96995568e-01, 1.07833091e-06, 7.05016930e-07],
        [8.43842827e-06, 4.11083221e-08, 1.80612716e-07, 9.95275259e-01],
        [1.58713825e-04, 6.65444486e-06, 9.95751381e-01, 2.11723200e-05]],
       dtype=float32),
 'actual_answers': array([3, 2, 3, ..., 2, 1, 1]),
 'predicted_answers': array([2, 3, 1, ..., 1, 3, 2]),
 'predicted_probs': array([0.98731345, 0.9929892 , 0.99611413, ..., 0.99699557, 0.99527526,
        0.9957514 ], dtype=float32),
 'predicted_probs_of_correct_answers': array([1.3789586e-05, 1.5810429e-04, 1.4649917e-06, ..., 1.0783309e-06,
        4.1108322e-08, 6.654

In [11]:
all_permutations_correct = metrics['is_correct'].reshape(-1, 24).mean(axis=1) == 1

# get index that is True
all_permutations_correct_index = np.where(all_permutations_correct)[0]
all_permutations_correct_index

# array([ 15,  20,  24,  28,  34,  45,  49,  61,  63,  68,  69,  70,  72,
#         86,  95,  97,  99, 111, 112, 122, 123, 132, 145, 151, 173, 174,
#        175, 177, 182, 185, 193])

array([ 15,  20,  24,  28,  34,  45,  49,  61,  63,  68,  69,  70,  72,
        86,  95,  97,  99, 111, 112, 122, 123, 132, 145, 151, 173, 174,
       175, 177, 182, 185, 193])

In [18]:
[np.random.choice(all_permutations_correct_index)]

[autoreload of unlearning.metrics failed: Traceback (most recent call last):
  File "/root/unlearning/.venv/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/root/unlearning/.venv/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 475, in superreload
    module = reload(module)
  File "/opt/conda/lib/python3.10/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 619, in _exec
  File "<frozen importlib._bootstrap_external>", line 879, in exec_module
  File "<frozen importlib._bootstrap_external>", line 1017, in get_code
  File "<frozen importlib._bootstrap_external>", line 947, in source_to_code
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/root/unlearning/unlearning/metrics.py", line 178
    return partial(calculate_mmlu_metrics, dataset=)
                               

[185]

: 

In [12]:
all_permutations_correct

array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False,  True, False, False,
       False, False,  True, False, False, False,  True, False, False,
       False,  True, False, False, False, False, False,  True, False,
       False, False, False, False, False, False, False, False, False,
        True, False, False, False,  True, False, False, False, False,
       False, False, False, False, False, False, False,  True, False,
        True, False, False, False, False,  True,  True,  True, False,
        True, False, False, False, False, False, False, False, False,
       False, False, False, False, False,  True, False, False, False,
       False, False, False, False, False,  True, False,  True, False,
        True, False, False, False, False, False, False, False, False,
       False, False, False,  True,  True, False, False, False, False,
       False, False, False, False, False,  True,  True, False, False,
       False, False,

In [17]:
prompt = prompts[all_permutations_correct_index[10]]
print(prompt)

The following are multiple choice questions (with answers) about history.
This question refers to the following information.
"We conclude that, in the field of public education, the doctrine of "separate but equal" has no place. Separate educational facilities are inherently unequal. Therefore, we hold that the plaintiffs and others similarly situated for whom the actions have been brought are, by reason of the segregation complained of, deprived of the equal protection of the laws guaranteed by the Fourteenth Amendment."
Brown v. Board of Education, 1954
Desegregation of schools was, in part, a response to unfulfilled promises from which of the following initiatives?
A. Reconstruction
B. The Square Deal
C. The New Deal
D. The Great Society
Answer:
