In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../")

import torch
import random

from sae.sparse_autoencoder import load_saved_sae
from sae.metrics import model_store_from_sae
from unlearning.metrics import convert_wmdp_data_to_prompt, convert_list_of_dicts_to_dict_of_lists
from unlearning.tool import UnlearningConfig, SAEUnlearningTool, MCQ_ActivationStoreAnalysis, ActivationStoreAnalysis
from unlearning.metrics import modify_and_calculate_metrics, calculate_metrics_list, create_df_from_metrics
from unlearning.feature_attribution import calculate_cache

from huggingface_hub import hf_hub_download
from datasets import load_dataset
import numpy as np
import pandas as pd
import itertools
from transformer_lens import utils

from jaxtyping import Float
from torch import Tensor

import einops

from pathlib import Path

import plotly.express as px
from unlearning.var import REPO_ID, SAE_MAPPING
import pickle

from unlearning.metrics import all_permutations

from unlearning.metrics import calculate_metrics_side_effects
from unlearning.feature_attribution import find_topk_features_given_prompt, test_topk_features


In [3]:
# Load main SAE for gemma-2b-it
filename = hf_hub_download(repo_id=REPO_ID, filename=SAE_MAPPING['gemma_2b_it_resid_pre_9'])
sae = load_saved_sae(filename)
model = model_store_from_sae(sae)

In [None]:
# pass in the dataset as an argument so no need to load in everytime
dataset = load_dataset("cais/wmdp", "wmdp-bio")

answers = [x['answer'] for x in dataset['test']]
questions = [x['question'] for x in dataset['test']]
choices_list = [x['choices'] for x in dataset['test']]

prompts = [convert_wmdp_data_to_prompt(question, choices, prompt_format=None) for question, choices in zip(questions, choices_list)]


In [None]:
# Hyper parameters for automation process
question_ids_correct = np.genfromtxt("../data/wmdp-bio_gemma_2b_it_correct.csv")
topk_per_prompt = 20

unlearning_dataset = ['wmdp-bio']
side_effect_dataset_names =  ['high_school_us_history', 'college_computer_science', 'high_school_geography', 'human_aging', 'college_biology']
all_dataset_names = ['loss_added', 'wmdp-bio', 'high_school_us_history', 'college_computer_science', 'high_school_geography', 'human_aging', 'college_biology']



## First get the TopK features by attribution per prompt and find the features that modify the probability

In [None]:
feature_per_prompt = {}

known_good_features = []

for j, question_id in enumerate(question_ids_correct):

    question_id = int(question_id)

    print("Question #", question_id, j+1, "/172")
    
    prompt = prompts[question_id]
    choices = choices_list[question_id]
    answer = answers[question_id]
    question = questions[question_id]

    topk_features_unique, feature_attributions, topk_features, all_feature_activations, logit_diff_grad, topk_feature_attributions = find_topk_features_given_prompt(model,
                                                           prompt,
                                                           question,
                                                           choices,
                                                           answer,
                                                           sae,
                                                           hook_point='blocks.9.hook_resid_pre')

    intervention_results, feature_ids_to_probs, good_features = test_topk_features(model,
                                                                                   sae,
                                                                                   question_id,
                                                                                   topk_features_unique[:topk_per_prompt],
                                                                                   known_good_features=known_good_features,
                                                                                   multiplier=30)

    feature_per_prompt[question_id] = good_features
    
    known_good_features = list(set([item for sublist in feature_per_prompt.values() for item in sublist]))

    

## Calculate side-effects

In [None]:
# Calculate metrics

main_ablate_params = {
                      'multiplier': 20,
                      'intervention_method': 'clamp_feature_activation',
                     }


sweep = {
         'features_to_ablate': known_good_features,
        }

metric_params = {'wmdp-bio': 
                 {
                       'target_metric': 'correct',
                       'permutations': None,
                   }
                 }

dataset_names = all_dataset_names[2:]

n_batch_loss_added = 10

metrics_list = calculate_metrics_side_effects(model,
                                      sae,
                                      main_ablate_params,
                                      sweep,
                                      dataset_names=dataset_names,
                                      metric_params=metric_params,
                                      n_batch_loss_added=n_batch_loss_added,)
                                      # activation_store=activation_store)


In [None]:
feature_ids_zero_side_effect = [x['ablate_params']['features_to_ablate'] for x in metrics_list]
np.array(feature_ids_zero_side_effect)

## Then sort by loss added

In [None]:
activation_store = ActivationStoreAnalysis(sae.cfg, model)

In [None]:
# Calculate metrics

main_ablate_params = {
                      'multiplier': 20,
                      'intervention_method': 'clamp_feature_activation',
                     }


sweep = {
         'features_to_ablate': feature_ids_zero_side_effect,
        }

metric_params = {'wmdp-bio': 
                 {
                       'target_metric': 'correct',
                       'permutations': None,
                   }
                 }

dataset_names = all_dataset_names[:2]

n_batch_loss_added = 10

metrics_list_zero_side_effect = calculate_metrics_list(model,
                                      sae,
                                      main_ablate_params,
                                      sweep,
                                      dataset_names=dataset_names,
                                      metric_params=metric_params,
                                      n_batch_loss_added=n_batch_loss_added,
                                      activation_store=activation_store)


In [None]:
df_zero_side_effect = create_df_from_metrics(metrics_list_zero_side_effect)
isorted = df_zero_side_effect.query("`wmdp-bio` < 1").sort_values("loss_added").index.values
feature_ids_zero_side_effect_sorted = np.array(feature_ids_zero_side_effect)[isorted]
feature_ids_zero_side_effect_sorted

## Now progressively add features sorted by loss

In [None]:
# Calculate metrics

main_ablate_params = {
                      'multiplier': 20,
                      'intervention_method': 'clamp_feature_activation',
                     }


sweep = {
         'features_to_ablate': [feature_ids_zero_side_effect_sorted[:i+1] for i in range(25, 36)],
         'multiplier': [15, 20, 25],
        }

metric_params = {'wmdp-bio': 
                 {
                       'target_metric': 'correct',
                       'permutations': None,
                   }
                 }

dataset_names = all_dataset_names

n_batch_loss_added = 20

metrics_list_best_sorted = calculate_metrics_list(model,
                                      sae,
                                      main_ablate_params,
                                      sweep,
                                      dataset_names=dataset_names,
                                      metric_params=metric_params,
                                      n_batch_loss_added=n_batch_loss_added,
                                      activation_store=activation_store)
