In [1]:
%load_ext autoreload
%autoreload 2


from sae.train import ModelTrainer
from sae.config import create_config, Config
from sae.utils import get_blog_checkpoint, get_blog_sparsity, create_lineplot_histogram

from unlearning.metrics import calculate_MCQ_metrics, calculate_wmdp_bio_metrics_hf, get_loss_added_rmu_model
from unlearning.tool import get_basic_gemma_2b_it_layer9_act_store
from unlearning.var import gemma_2b_it_rmu_model_names
from unlearning.metrics import all_permutations

from transformer_lens import HookedTransformer, utils
from sae.metrics import compute_metrics_post_by_text


import plotly.express as px
import plotly.graph_objs as go
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformer_lens import HookedTransformer
import time
import numpy as np
from tqdm import tqdm
import wandb
import torch

from jaxtyping import Float
from torch import Tensor
from dataclasses import dataclass

import torch

from sae.sparse_autoencoder import load_saved_sae
from sae.metrics import model_store_from_sae
from unlearning.metrics import convert_wmdp_data_to_prompt
from unlearning.tool import UnlearningConfig, SAEUnlearningTool, MCQ_ActivationStoreAnalysis

from huggingface_hub import hf_hub_download
from datasets import load_dataset
import numpy as np

from jaxtyping import Float, Int
from torch import Tensor

import plotly.express as px

from transformer_lens import HookedTransformer
from dataclasses import dataclass
import wandb
import einops
from tqdm import tqdm

from functools import partial
from unlearning.intervention import anthropic_remove_resid_SAE_features, remove_resid_SAE_features, anthropic_clamp_resid_SAE_features


In [2]:
print("Getting Hugging Face model")

hf_model_name = "google/gemma-2b-it"
tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_name, torch_dtype='auto')

transformer_lens_model_name = "google/gemma-2b-it"
base_model = HookedTransformer.from_pretrained(transformer_lens_model_name, hf_model=hf_model, tokenizer=tokenizer)

print("done")

Getting Hugging Face model


`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2b-it into HookedTransformer
done


In [3]:
hf_model_name = 'eoinf/gemma_2b_it_rmu_s60_a1000'
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_name, torch_dtype='auto') #.to("cuda")
model = HookedTransformer.from_pretrained(transformer_lens_model_name, hf_model=hf_model, tokenizer=tokenizer)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2b-it into HookedTransformer


In [4]:
# read 172 questions that the model can answer correctly in any permutation
filename = '../data/wmdp-bio_gemma_2b_it_correct.csv'
correct_question_ids = np.genfromtxt(filename)

In [5]:
from unlearning.metrics import all_permutations

dataset = load_dataset("cais/wmdp", "wmdp-bio")

prompts = [
    convert_wmdp_data_to_prompt(dataset['test'][i]['question'], dataset['test'][i]['choices'], prompt_format=None, permute_choices=p)
    for i in range(len(dataset['test'])) 
    for p in all_permutations
    if i in correct_question_ids
    
]

answers = [
    p.index(dataset['test'][i]['answer'])
    for i in range(len(dataset['test'])) 
    for p in all_permutations
    if i in correct_question_ids
]

In [6]:
@dataclass
class ProbeTrainingArgs():
    layer: int = 12
    options: int = 4
    device: str = "cuda"
    
    # Standard training hyperparams
    max_epochs: int = 8
    
    # Hyperparams for optimizer
    batch_size: int = 4
    lr: float = 1e-4
    betas: tuple[float, float] = (0.9, 0.99)
    wd: float = 0.01
    
    # Saving & logging
    probe_name: str = "main_linear_probe"
    wandb_project: str | None = 'wmdp-probe'
    wandb_name: str | None = None
    
    # prompts and ans
    prompts: list[str] = None
    answers: list[int] = None
    
    # Code to get randomly initialized probe
    def setup_linear_probe(self, model: HookedTransformer):
        linear_probe = torch.randn(
            model.cfg.d_model, self.options, requires_grad=False, device=self.device
        ) / np.sqrt(model.cfg.d_model)
        linear_probe.requires_grad = True
        return linear_probe

In [7]:
class LinearProbeTrainer:
    def __init__(self, model: HookedTransformer, args: ProbeTrainingArgs):
        self.model = model
        self.args = args
        self.linear_probe = args.setup_linear_probe(model)
        self.prompts = args.prompts
        self.answers = args.answers
        self.early_stopping = False

    def shuffle(self):
        self.prompts, self.answers = self.prompts.copy(), self.answers.copy()
        zipped = list(zip(self.prompts, self.answers))
        np.random.shuffle(zipped)
        self.prompts, self.answers = zip(*zipped)
        
    def train(self):

        self.step = 0
        wandb.init(project=args.wandb_project, name=args.wandb_name, config=args)

        optimizer = torch.optim.AdamW([self.linear_probe], lr=self.args.lr, betas=self.args.betas, weight_decay=self.args.wd)
        
        self.shuffle()
        
        self.accuracy_history = []
        
        for epoch in range(args.max_epochs):
            
            for i in tqdm(range(0, len(self.prompts), self.args.batch_size)):
                prompt_batch = self.prompts[i: i + self.args.batch_size]
                answers_batch = self.answers[i: i + self.args.batch_size]
                current_batch_size = len(prompt_batch)
                token_batch = model.to_tokens(prompt_batch, padding_side="right").to("cuda")
                
                token_lens = [len(model.to_tokens(x)[0]) for x in prompt_batch]
                next_token_indices = torch.tensor([x - 1 for x in token_lens]).to("cuda")

                with torch.inference_mode():
                    _, cache = model.run_with_cache(
                        token_batch.to(self.args.device),
                        return_type=None,
                        names_filter=lambda name: name.endswith("resid_post")
                    )
                    resid_post: Float[Tensor, "batch d_model"] = cache["resid_post", self.args.layer][torch.arange(current_batch_size), next_token_indices]
                
                resid_post = resid_post.clone()

                probe_out = einops.einsum(
                    resid_post,
                    self.linear_probe,
                    "batch d_model, d_model options -> batch options",
                )
                
                # print(probe_out)
                probe_log_probs = probe_out.log_softmax(-1)
                
                # print(probe_log_probs)
                # print(answers_batch)
                probe_correct_log_probs = probe_log_probs[torch.arange(current_batch_size), answers_batch]
                # print(probe_correct_log_probs)

                accuracy = (probe_log_probs.argmax(dim=-1) == torch.tensor(answers_batch).to("cuda")).to(float).mean()
                # print(accuracy)
                
                loss = -probe_correct_log_probs.mean()
                # print(loss)
                loss.backward()
                wandb.log({"loss": loss.item(), "accuracy": accuracy.item()}, step=self.step)
                self.step += 1
                
                
                optimizer.step()
                optimizer.zero_grad()
                
                # add early stopping, if accuracy is 1.0 for 5 steps, break
                self.accuracy_history.append(accuracy.item())
                
                if len(self.accuracy_history) > 5 and all([x == 1.0 for x in self.accuracy_history[-5:]]):
                    self.early_stopping = True
                    break
                
            if self.early_stopping:
                break
                
        wandb.finish()


layer = 12
batch_size = 16
args = ProbeTrainingArgs(prompts=prompts, answers=answers, layer=layer, batch_size=batch_size)
trainer = LinearProbeTrainer(model, args)

trainer.train()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33myeutong[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112689226865768, max=1.0…

100%|██████████| 258/258 [02:29<00:00,  1.73it/s]
100%|██████████| 258/258 [02:30<00:00,  1.72it/s]
100%|██████████| 258/258 [02:30<00:00,  1.72it/s]
100%|██████████| 258/258 [02:29<00:00,  1.72it/s]
100%|██████████| 258/258 [02:30<00:00,  1.72it/s]
100%|██████████| 258/258 [02:29<00:00,  1.72it/s]
100%|██████████| 258/258 [02:29<00:00,  1.72it/s]
100%|██████████| 258/258 [02:29<00:00,  1.72it/s]


VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))



0,1
accuracy,▄▁▄▂▅▄█▅▃▅▇▄▃▂▅▄▅▄▆▇▂▄▅▅▅▃▄▅▄▂▄▄▅▃▅▄▃▂█▃
loss,█▇▆▇▆█▅▅▆▆▄▆▆▇▆▅▄▅▁▄▇▆▇▄▅▅▅▅▇▇▅▆▄▇▅▆▆▅▄▅

0,1
accuracy,0.4375
loss,1.15351
