In [1]:
%load_ext autoreload
%autoreload 2
import torch

from sae.sparse_autoencoder import load_saved_sae
from sae.metrics import model_store_from_sae
from unlearning.metrics import convert_wmdp_data_to_prompt
from unlearning.tool import UnlearningConfig, SAEUnlearningTool, MCQ_ActivationStoreAnalysis

from huggingface_hub import hf_hub_download
from datasets import load_dataset
import numpy as np

from jaxtyping import Float, Int
from torch import Tensor

import plotly.express as px

from transformer_lens import HookedTransformer
from dataclasses import dataclass
import wandb
import einops
from tqdm import tqdm

from functools import partial
from unlearning.intervention import anthropic_remove_resid_SAE_features, remove_resid_SAE_features, anthropic_clamp_resid_SAE_features


In [2]:
# resid pre 9
REPO_ID = "eoinf/unlearning_saes"
FILENAME = "jolly-dream-40/sparse_autoencoder_gemma-2b-it_blocks.9.hook_resid_pre_s16384_127995904.pt"


filename = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
sae = load_saved_sae(filename)

model = model_store_from_sae(sae)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b-it into HookedTransformer
Moving model to device:  cuda


In [3]:
def modify_model(model, **kwargs):

    default_modification_kwargs = {
        'multiplier': 1.0,
        'intervention_method': 'scale_feature_activation',
        'custom_hook_point': None,
    }
    
    model.reset_hooks()
    
    # Calculate modified stats
    if kwargs['intervention_method'] == "scale_feature_activation":
        ablation_method = anthropic_remove_resid_SAE_features
    elif kwargs['intervention_method'] == "remove_from_residual_stream":
        ablation_method = remove_resid_SAE_features
    elif kwargs['intervention_method'] == "clamp_feature_activation":
        ablation_method = anthropic_clamp_resid_SAE_features
        
    ablate_hook_func = partial(
        ablation_method, 
        sae=sae, 
        features_to_ablate=kwargs['features_to_ablate'],
        multiplier=kwargs['multiplier']
        )
    
    if 'custom_hook_point' not in kwargs or kwargs['custom_hook_point'] is None:
        hook_point = sae.cfg.hook_point
    else:
        hook_point = kwargs['custom_hook_point']
    
    model.add_hook(hook_point, ablate_hook_func)

In [4]:
# read 172 questions that the model can answer correctly in any permutation
filename = '../data/wmdp-bio_gemma_2b_it_correct.csv'
correct_question_ids = np.genfromtxt(filename)

In [5]:
from unlearning.metrics import all_permutations

dataset = load_dataset("cais/wmdp", "wmdp-bio")

train_prompts = [
    convert_wmdp_data_to_prompt(dataset['test'][i]['question'], dataset['test'][i]['choices'], prompt_format=None, permute_choices=p)
    for i in range(len(dataset['test'])) 
    for p in all_permutations
    if i in correct_question_ids
    
]

train_answers = [
    p.index(dataset['test'][i]['answer'])
    for i in range(len(dataset['test'])) 
    for p in all_permutations
    if i in correct_question_ids
]

n_train = int(len(train_prompts) * 0.9)
train_prompts, test_prompts = train_prompts[:n_train], train_prompts[n_train:]
train_answers, test_answers = train_answers[:n_train], train_answers[n_train:]

In [6]:
@dataclass
class ProbeTrainingArgs():
    layer: int = 12
    options: int = 4
    device: str = "cuda"
    
    # Standard training hyperparams
    max_epochs: int = 16
    
    # Hyperparams for optimizer
    batch_size: int = 4
    lr: float = 1e-4
    betas: tuple[float, float] = (0.9, 0.99)
    wd: float = 0.01
    
    # Saving & logging
    model_name: str = None
    probe_name: str = "main_linear_probe"
    wandb_project: str | None = 'wmdp-probe'
    wandb_name: str | None = None
    
    # Code to get randomly initialized probe
    def setup_linear_probe(self, model: HookedTransformer):
        linear_probe = torch.randn(
            model.cfg.d_model, self.options, requires_grad=False, device=self.device
        ) / np.sqrt(model.cfg.d_model)
        linear_probe.requires_grad = True
        return linear_probe

In [26]:
class LinearProbeTrainer:
    def __init__(
        self, 
        model: HookedTransformer, 
        args: ProbeTrainingArgs, 
        train_prompts: list[str], 
        train_answers: list[int], 
        test_prompts: list[str], 
        test_answers: list[int]
    ):
        self.model = model
        self.args = args
        self.linear_probe = args.setup_linear_probe(model)
        self.train_prompts = train_prompts
        self.train_answers = train_answers
        self.test_prompts = test_prompts
        self.test_answers = test_answers
        self.early_stopping = False
        self.accuracy_history = []
        self.step = 0
        self.train_resid = None
        self.test_resid = None

    def shuffle(self):        
        # shuffle train_resid and train_answers
        assert self.train_resid is not None
        perm = torch.randperm(self.train_resid.size(0))
        self.train_resid = self.train_resid[perm]
        self.train_answers = [self.train_answers[i] for i in perm]
        
    def check_early_stopping(self, accuracy: float):
        self.accuracy_history.append(accuracy)
        if len(self.accuracy_history) > 10 and all([x == 1.0 for x in self.accuracy_history[-10:]]):
            self.early_stopping = True
        return self.early_stopping
           
    def test(self):
        if self.test_resid is None:
            self.test_resid = self.cache_all_resid(self.test_prompts)
        with torch.no_grad():
            resid_post = self.test_resid
            probe_out = einops.einsum(
                resid_post,
                self.linear_probe,
                "batch d_model, d_model options -> batch options",
            )
            probe_log_probs = probe_out.log_softmax(-1)
            probe_correct_log_probs = probe_log_probs[torch.arange(len(self.test_prompts)), torch.tensor(self.test_answers).to(self.args.device)]
            loss = -probe_correct_log_probs.mean()
            accuracy = (probe_log_probs.argmax(dim=-1) == torch.tensor(self.test_answers).to(self.args.device)).float().mean()

        return loss.item(), accuracy.item()
      
    def cache_all_resid(self, prompts: list[str]):
        resids = []
        
        for i in tqdm(range(0, len(prompts), self.args.batch_size)):
            prompt_batch = prompts[i: i + self.args.batch_size]
            current_batch_size = len(prompt_batch)
            token_batch = self.model.to_tokens(prompt_batch, padding_side="right").to("cuda")
            
            token_lens = [len(self.model.to_tokens(x)[0]) for x in prompt_batch]
            next_token_indices = torch.tensor([x - 1 for x in token_lens]).to("cuda")

            with torch.inference_mode():
                _, cache = self.model.run_with_cache(
                    token_batch.to(self.args.device),
                    return_type=None,
                    names_filter=lambda name: name.endswith("resid_post")
                )
                resid_post: Float[Tensor, "batch d_model"] = cache["resid_post", self.args.layer][torch.arange(current_batch_size), next_token_indices]
            resids.append(resid_post.clone())
                
        return torch.cat(resids, dim=0)
                
        

    def train(self):

        wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)

        optimizer = torch.optim.AdamW([self.linear_probe], lr=self.args.lr, betas=self.args.betas, weight_decay=self.args.wd)
        
        self.train_resid = self.cache_all_resid(self.train_prompts)
        self.test_resid = self.cache_all_resid(self.test_prompts)
        
        for epoch in range(self.args.max_epochs):
            self.shuffle()
            
            for i in tqdm(range(0, len(self.train_prompts), self.args.batch_size)):
                
                resid_post = self.train_resid[i: i + self.args.batch_size]
                answers_batch = self.train_answers[i: i + self.args.batch_size]
                current_batch_size = resid_post.shape[0]

                probe_out = einops.einsum(
                    resid_post,
                    self.linear_probe,
                    "batch d_model, d_model options -> batch options",
                )
                
                probe_log_probs = probe_out.log_softmax(-1)
                probe_correct_log_probs = probe_log_probs[torch.arange(current_batch_size), answers_batch]
                loss = -probe_correct_log_probs.mean()
                accuracy = (probe_log_probs.argmax(dim=-1) == torch.tensor(answers_batch).to("cuda")).to(float).mean()

                
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                
                # log to wandb
                wandb.log({
                    "train_loss": loss.item(),
                    "train_accuracy": accuracy.item(),
                }, step=self.step)
                self.step += 1
                
                # if self.check_early_stopping(accuracy.item()):
                #     break
                
            test_loss, test_accuracy = self.test()
            wandb.log({
                "test_loss": test_loss,
                "test_accuracy": test_accuracy,
            }, step=self.step)
            # if self.early_stopping:
            #     break
                
        wandb.finish()
        
    def save_linear_probe(self, path: str = "./linear_probes"):
        torch.save(self.linear_probe, f'{path}/{self.args.model_name}_Layer{self.args.layer}.pt')
        
        
    def load_linear_probe(self, path: str = "./linear_probes"):
        self.linear_probe = torch.load(f'{path}/{self.args.model_name}_Layer{self.args.layer}.pt')
        
    def calc_overall_accuracy(self):
        if (self.train_resid is None) or (self.test_resid is None):
            self.train_resid = self.cache_all_resid(self.train_prompts)
            self.test_resid = self.cache_all_resid(self.test_prompts)
            
        with torch.no_grad():
            # concat them
            all_resid = torch.cat([self.train_resid, self.test_resid], dim=0)
            all_answers = self.train_answers + self.test_answers
            all_answers = torch.tensor(all_answers).to("cuda")
            
            probe_out = einops.einsum(
                all_resid,
                self.linear_probe,
                "batch d_model, d_model options -> batch options",
            )
            # print(probe_out.shape)
            probe_log_probs = probe_out.log_softmax(-1)
            accuracy = (probe_log_probs.argmax(dim=-1) == all_answers).float().mean()
            # print(accuracy)
        
        return accuracy.item()
    
    def calc_test_accuracy(self):
        _, accuracy = self.test()
        return accuracy
    
    def save_accuracy_to_file(self, path: str = "./linear_probes_results"):
        with open(f"{path}/{self.args.model_name}_Layer{self.args.layer}.txt", "w") as f:
            f.write(str(self.calc_overall_accuracy()))
            
    def save_test_accuracy_to_file(self, path: str = "./linear_probes_test_results"):
        with open(f"{path}/{self.args.model_name}_Layer{self.args.layer}_test.txt", "w") as f:
            f.write(str(self.calc_test_accuracy()))


In [23]:
for layer in [0, 1, 2, 3, 4, 5]: # range(8, 14):
    batch_size = 32
    args = ProbeTrainingArgs(layer=layer, batch_size=batch_size, model_name='gemma-2b-it_base_model')
    trainer = LinearProbeTrainer(model, args, train_prompts, train_answers, test_prompts, test_answers)
    trainer.train()
    trainer.save_linear_probe()
    trainer.save_accuracy_to_file()

100%|██████████| 117/117 [01:34<00:00,  1.23it/s]
100%|██████████| 13/13 [00:10<00:00,  1.21it/s]
100%|██████████| 117/117 [00:00<00:00, 490.89it/s]
100%|██████████| 117/117 [00:00<00:00, 513.93it/s]
100%|██████████| 117/117 [00:00<00:00, 502.99it/s]
100%|██████████| 117/117 [00:00<00:00, 490.52it/s]
100%|██████████| 117/117 [00:00<00:00, 640.98it/s]
100%|██████████| 117/117 [00:00<00:00, 820.83it/s]
100%|██████████| 117/117 [00:00<00:00, 849.16it/s]
100%|██████████| 117/117 [00:00<00:00, 873.13it/s]
100%|██████████| 117/117 [00:00<00:00, 559.17it/s]
100%|██████████| 117/117 [00:00<00:00, 528.99it/s]
100%|██████████| 117/117 [00:00<00:00, 612.42it/s]
100%|██████████| 117/117 [00:00<00:00, 499.10it/s]
100%|██████████| 117/117 [00:00<00:00, 495.45it/s]
100%|██████████| 117/117 [00:00<00:00, 505.06it/s]
100%|██████████| 117/117 [00:00<00:00, 499.38it/s]
100%|██████████| 117/117 [00:00<00:00, 541.25it/s]


VBox(children=(Label(value='0.026 MB of 0.026 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_accuracy,▁▂▄▁▃▆▁▄▄▅▄▅▆▄██
test_loss,▆▅▄█▃▃▆▄▅▂▂▂▂▂▁▂
train_accuracy,▂▃▅▆▄▄▄▄▄▂▄▃▅▅▇▄█▁▅▅▃▃▃▅▄▄▃▃▄▃▆▅▆▃▄▃▄▄▆▅
train_loss,▆█▅▃▅▅▄▄▅▅▅▆▅▅▃▄▁▇▄▄▆▆▆▅▄▄▆▅▅▆▃▄▃▅▅▅▅▃▃▇

0,1
test_accuracy,0.32688
test_loss,1.37224
train_accuracy,0.33333
train_loss,1.41962


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113714302579562, max=1.0…

100%|██████████| 117/117 [01:37<00:00,  1.21it/s]
100%|██████████| 13/13 [00:10<00:00,  1.20it/s]
100%|██████████| 117/117 [00:00<00:00, 605.17it/s]
100%|██████████| 117/117 [00:00<00:00, 884.22it/s]
100%|██████████| 117/117 [00:00<00:00, 678.84it/s]
100%|██████████| 117/117 [00:00<00:00, 515.08it/s]
100%|██████████| 117/117 [00:00<00:00, 503.17it/s]
100%|██████████| 117/117 [00:00<00:00, 530.32it/s]
100%|██████████| 117/117 [00:00<00:00, 500.25it/s]
100%|██████████| 117/117 [00:00<00:00, 494.41it/s]
100%|██████████| 117/117 [00:00<00:00, 484.42it/s]
100%|██████████| 117/117 [00:00<00:00, 523.05it/s]
100%|██████████| 117/117 [00:00<00:00, 515.21it/s]
100%|██████████| 117/117 [00:00<00:00, 508.12it/s]
100%|██████████| 117/117 [00:00<00:00, 494.51it/s]
100%|██████████| 117/117 [00:00<00:00, 539.91it/s]
100%|██████████| 117/117 [00:00<00:00, 546.62it/s]
100%|██████████| 117/117 [00:00<00:00, 726.50it/s]


VBox(children=(Label(value='0.026 MB of 0.031 MB uploaded\r'), FloatProgress(value=0.8222147691932435, max=1.0…

0,1
test_accuracy,▁▃▁▄▃▁▅█▆▅▃▆▅▇▇▁
test_loss,█▇█▆▅▇▄▄▄▅▃▄▃▂▁▂
train_accuracy,▅▅▅▅▅▃▅▁█▄▄▄▇▅█▅▆▄▇▄▅▂▃▅▆▄▄▆▆▆▅▅▃▂▅▅▆▄▅▆
train_loss,▆▆▇▆▆▇▇▆▆▇▆▇▆▅▅▅▄▇▆▆▅▇█▆▄▆▆▆▅▄▆▆█▇▆▆▄▆▅▁

0,1
test_accuracy,0.25424
test_loss,1.36926
train_accuracy,0.33333
train_loss,1.27873


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113945705195269, max=1.0…

100%|██████████| 117/117 [01:36<00:00,  1.21it/s]
100%|██████████| 13/13 [00:10<00:00,  1.20it/s]
100%|██████████| 117/117 [00:00<00:00, 668.88it/s]
100%|██████████| 117/117 [00:00<00:00, 692.21it/s]
100%|██████████| 117/117 [00:00<00:00, 567.36it/s]
100%|██████████| 117/117 [00:00<00:00, 562.96it/s]
100%|██████████| 117/117 [00:00<00:00, 541.76it/s]
100%|██████████| 117/117 [00:00<00:00, 526.56it/s]
100%|██████████| 117/117 [00:00<00:00, 627.01it/s]
100%|██████████| 117/117 [00:00<00:00, 641.88it/s]
100%|██████████| 117/117 [00:00<00:00, 600.87it/s]
100%|██████████| 117/117 [00:00<00:00, 618.60it/s]
100%|██████████| 117/117 [00:00<00:00, 599.86it/s]
100%|██████████| 117/117 [00:00<00:00, 542.02it/s]
100%|██████████| 117/117 [00:00<00:00, 530.66it/s]
100%|██████████| 117/117 [00:00<00:00, 539.20it/s]
100%|██████████| 117/117 [00:00<00:00, 552.66it/s]
100%|██████████| 117/117 [00:00<00:00, 516.69it/s]


VBox(children=(Label(value='0.026 MB of 0.026 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_accuracy,▂▁▄▂▂█▂▄█▄▅▂▆▅▆▅
test_loss,██▆▆▆▇▅▃▃▂▂▃▂▂▂▁
train_accuracy,▄▁▇▄▄▅▅▇▄▃▆▃▇▄▃▆▅▅▅▄▅▅▄▄▃▆▅▄▇▆█▇▇▅█▅▇▆▅▆
train_loss,▅█▄▅▅▅▅▄▅▇▄▆▃▆▅▃▅▄▄▅▄▂▃▅▄▄▄▄▄▄▁▄▅▃▁▃▃▄▅▂

0,1
test_accuracy,0.29782
test_loss,1.36859
train_accuracy,0.33333
train_loss,1.34428


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011114073834485478, max=1.0…

100%|██████████| 117/117 [01:36<00:00,  1.22it/s]
100%|██████████| 13/13 [00:10<00:00,  1.20it/s]
100%|██████████| 117/117 [00:00<00:00, 485.00it/s]
100%|██████████| 117/117 [00:00<00:00, 570.97it/s]
100%|██████████| 117/117 [00:00<00:00, 517.81it/s]
100%|██████████| 117/117 [00:00<00:00, 490.87it/s]
100%|██████████| 117/117 [00:00<00:00, 536.57it/s]
100%|██████████| 117/117 [00:00<00:00, 487.80it/s]
100%|██████████| 117/117 [00:00<00:00, 521.47it/s]
100%|██████████| 117/117 [00:00<00:00, 533.30it/s]
100%|██████████| 117/117 [00:00<00:00, 514.99it/s]
100%|██████████| 117/117 [00:00<00:00, 516.71it/s]
100%|██████████| 117/117 [00:00<00:00, 511.87it/s]
100%|██████████| 117/117 [00:00<00:00, 528.80it/s]
100%|██████████| 117/117 [00:00<00:00, 534.99it/s]
100%|██████████| 117/117 [00:00<00:00, 518.85it/s]
100%|██████████| 117/117 [00:00<00:00, 535.04it/s]
100%|██████████| 117/117 [00:00<00:00, 518.40it/s]


VBox(children=(Label(value='0.026 MB of 0.026 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_accuracy,▂▁▃▃▂▄▂▄▆▅█▅▄▇▆▅
test_loss,█▇▇▆▆▅▅▄▄▃▃▃▃▂▂▁
train_accuracy,▃▃▅▃▅▁▅▆▅▅▅▄▅▃▃▆▃▅▄▅▄▂▅▅▆▄▃▄▅▅▄▄▄▄▃▄▅█▅▅
train_loss,▆▇▆▆▅▇▇▅▅▇▅▅▇▆▆▅▆▅▆▅▆▇▅▅▆▅▅▅▅▅▄▅▆▆█▅▅▁▅▇

0,1
test_accuracy,0.32203
test_loss,1.35757
train_accuracy,0.33333
train_loss,1.39652


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113928299811152, max=1.0…

100%|██████████| 117/117 [01:36<00:00,  1.21it/s]
100%|██████████| 13/13 [00:10<00:00,  1.20it/s]
100%|██████████| 117/117 [00:00<00:00, 513.12it/s]
100%|██████████| 117/117 [00:00<00:00, 627.59it/s]
100%|██████████| 117/117 [00:00<00:00, 543.69it/s]
100%|██████████| 117/117 [00:00<00:00, 481.06it/s]
100%|██████████| 117/117 [00:00<00:00, 528.51it/s]
100%|██████████| 117/117 [00:00<00:00, 517.41it/s]
100%|██████████| 117/117 [00:00<00:00, 491.24it/s]
100%|██████████| 117/117 [00:00<00:00, 540.09it/s]
100%|██████████| 117/117 [00:00<00:00, 545.57it/s]
100%|██████████| 117/117 [00:00<00:00, 530.78it/s]
100%|██████████| 117/117 [00:00<00:00, 563.73it/s]
100%|██████████| 117/117 [00:00<00:00, 566.34it/s]
100%|██████████| 117/117 [00:00<00:00, 592.54it/s]
100%|██████████| 117/117 [00:00<00:00, 526.18it/s]
100%|██████████| 117/117 [00:00<00:00, 512.13it/s]
100%|██████████| 117/117 [00:00<00:00, 526.96it/s]


VBox(children=(Label(value='0.026 MB of 0.026 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_accuracy,▁▁▂▂▅▅▅▃▅▆▆▇▅█▅▆
test_loss,█▇▇▆▅▅▅▄▄▃▃▂▃▂▂▁
train_accuracy,▁▁▂▃▄▁▃▂▂▃▂▂▅▄▄▁▂▅▃▄▃▅▂▅▄▃▄▄▄▃▁▅▅▅▆▁▄▆▄█
train_loss,█▇▆▇▆▆▆▇▆▆▅▆▅▆▅▇▆▆▅▅▄▃▅▄▅▅▄▄▄▅▇▃▃▂▄▅▃▁▃▃

0,1
test_accuracy,0.37772
test_loss,1.33362
train_accuracy,0.66667
train_loss,1.32733


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113691495524512, max=1.0…

100%|██████████| 117/117 [01:36<00:00,  1.22it/s]
100%|██████████| 13/13 [00:10<00:00,  1.20it/s]
100%|██████████| 117/117 [00:00<00:00, 506.80it/s]
100%|██████████| 117/117 [00:00<00:00, 497.45it/s]
100%|██████████| 117/117 [00:00<00:00, 505.77it/s]
100%|██████████| 117/117 [00:00<00:00, 532.69it/s]
100%|██████████| 117/117 [00:00<00:00, 567.11it/s]
100%|██████████| 117/117 [00:00<00:00, 527.59it/s]
100%|██████████| 117/117 [00:00<00:00, 548.13it/s]
100%|██████████| 117/117 [00:00<00:00, 513.94it/s]
100%|██████████| 117/117 [00:00<00:00, 520.23it/s]
100%|██████████| 117/117 [00:00<00:00, 518.09it/s]
100%|██████████| 117/117 [00:00<00:00, 479.04it/s]
100%|██████████| 117/117 [00:00<00:00, 566.49it/s]
100%|██████████| 117/117 [00:00<00:00, 499.59it/s]
100%|██████████| 117/117 [00:00<00:00, 523.73it/s]
100%|██████████| 117/117 [00:00<00:00, 513.94it/s]
100%|██████████| 117/117 [00:00<00:00, 512.90it/s]


VBox(children=(Label(value='0.026 MB of 0.026 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_accuracy,▁▂▂▅█▅▃▄▅▅▇▅▆▅▄▆
test_loss,██▇▇▆▆▇▄▄▃▃▃▂▂▂▁
train_accuracy,▃▁▃▂▄▂▃▂▅▃▅█▄▁▄▄▅▄▄▅▂▅▃▄▅▆▃▄▆▃▅▂▄▅▅▅▇▄▄█
train_loss,██▆▅▅▆▅▆▄▅▅▂▅▇▅▃▅▂▄▃▅▃▅▄▅▃▄▄▃▄▂▄▂▁▄▁▂▃▃▁

0,1
test_accuracy,0.34625
test_loss,1.34646
train_accuracy,0.66667
train_loss,1.32224


In [27]:
# for layer in range(6, 17):
#     batch_size = 32
#     args = ProbeTrainingArgs(layer=layer, batch_size=batch_size, model_name='gemma-2b-it_base_model')
#     trainer = LinearProbeTrainer(model, args, train_prompts, train_answers, test_prompts, test_answers)
#     trainer.load_linear_probe()
#     trainer.save_test_accuracy_to_file()

  0%|          | 0/13 [00:00<?, ?it/s]

100%|██████████| 13/13 [00:10<00:00,  1.25it/s]


In [8]:
# # for testing
# layer = 8
# batch_size = 32
# args = ProbeTrainingArgs(layer=layer, batch_size=batch_size, model_name='gemma-2b-it_base_model')
# trainer = LinearProbeTrainer(model, args, train_prompts[:128], train_answers[:128], test_prompts, test_answers)
# trainer.train()
# # trainer.save_linear_probe()
# trainer.calc_overall_accuracy()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33myeutong[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 4/4 [00:02<00:00,  1.51it/s]
100%|██████████| 13/13 [00:10<00:00,  1.26it/s]
100%|██████████| 4/4 [00:00<00:00,  8.67it/s]
100%|██████████| 4/4 [00:00<00:00, 384.46it/s]
100%|██████████| 4/4 [00:00<00:00, 468.69it/s]
100%|██████████| 4/4 [00:00<00:00, 443.61it/s]
100%|██████████| 4/4 [00:00<00:00, 589.81it/s]
100%|██████████| 4/4 [00:00<00:00, 392.39it/s]
100%|██████████| 4/4 [00:00<00:00, 451.89it/s]
100%|██████████| 4/4 [00:00<00:00, 448.18it/s]
100%|██████████| 4/4 [00:00<00:00, 522.69it/s]
100%|██████████| 4/4 [00:00<00:00, 456.09it/s]
100%|██████████| 4/4 [00:00<00:00, 392.25it/s]
100%|██████████| 4/4 [00:00<00:00, 406.46it/s]
100%|██████████| 4/4 [00:00<00:00, 400.48it/s]
100%|██████████| 4/4 [00:00<00:00, 505.76it/s]
100%|██████████| 4/4 [00:00<00:00, 449.04it/s]
100%|██████████| 4/4 [00:00<00:00, 419.55it/s]


VBox(children=(Label(value='0.026 MB of 0.026 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))



0,1
test_accuracy,▁▂▄▂▇▅▆▂▁▃█▆▄▅▅▂
test_loss,█▃▁▂▃▃▃▃▂▂▂▂▁▁▁▁
train_accuracy,▄▅▃▂▅▂▃▆▅▅▅▃▇▃▅▅▅▇▅▅▅█▇▅▇▅▁▇▅▆█▆▅▇▇█▅▇█▄
train_loss,█▄▇▇▃▄▅▂▂▃▂▃▁▃▄▂▂▂▃▂▃▂▂▂▂▂▂▁▂▂▂▂▂▂▁▁▂▁▁▂

0,1
test_accuracy,0.2615
test_loss,1.38527
train_accuracy,0.21875
train_loss,1.38477


0.27356746792793274

##### SAE intervention

In [20]:
# modified the model
filtered_good_features = [12663, 4342, 5749, 10355, 1523, 15858, 12273, 14315, 4451, 1611, 10051, 16186, 7983, 6958, 1307, 11019, 6531, 12289]

ablate_params = {
    'features_to_ablate': filtered_good_features,
    'multiplier': 20,
    'intervention_method': 'scale_feature_activation',
}

modify_model(model, **ablate_params) 

for layer in [14]: # range(8, 14):
    batch_size = 32
    args = ProbeTrainingArgs(layer=layer, batch_size=batch_size, model_name='gemma-2b-it_SAE_18features_multiplier20')
    trainer = LinearProbeTrainer(model, args, train_prompts, train_answers, test_prompts, test_answers)
    trainer.train()
    trainer.save_linear_probe()
    trainer.save_accuracy_to_file()
    
model.reset_hooks()

100%|██████████| 117/117 [01:40<00:00,  1.16it/s]
100%|██████████| 13/13 [00:11<00:00,  1.14it/s]
100%|██████████| 117/117 [00:00<00:00, 507.42it/s]
100%|██████████| 117/117 [00:00<00:00, 622.16it/s]
100%|██████████| 117/117 [00:00<00:00, 613.98it/s]
100%|██████████| 117/117 [00:00<00:00, 592.13it/s]
100%|██████████| 117/117 [00:00<00:00, 513.09it/s]
100%|██████████| 117/117 [00:00<00:00, 487.47it/s]
100%|██████████| 117/117 [00:00<00:00, 516.82it/s]
100%|██████████| 117/117 [00:00<00:00, 541.94it/s]
100%|██████████| 117/117 [00:00<00:00, 487.12it/s]
100%|██████████| 117/117 [00:00<00:00, 739.73it/s]
100%|██████████| 117/117 [00:00<00:00, 779.86it/s]
100%|██████████| 117/117 [00:00<00:00, 615.10it/s]
100%|██████████| 117/117 [00:00<00:00, 557.74it/s]
100%|██████████| 117/117 [00:00<00:00, 539.60it/s]
100%|██████████| 117/117 [00:00<00:00, 623.04it/s]
100%|██████████| 117/117 [00:00<00:00, 905.30it/s]


VBox(children=(Label(value='0.026 MB of 0.030 MB uploaded\r'), FloatProgress(value=0.856362539299438, max=1.0)…

0,1
test_accuracy,▁▃▄▅▇▅▆▅▃▆▇▆█▅▅█
test_loss,█▅▄▃▃▂▂▂▂▂▂▁▁▂▁▁
train_accuracy,▃▁▄▂▃▃▃▂▄▂▁▄▃▅▃▂▃▃▂▅▅▃▄▃▅▃▃▂▄▄▄▄▃▄▃▄▃▄▄█
train_loss,██▆▇▆▆▆▇▅▇▇▅▆▅▆▇▇▅▇▅▄▅▆▅▄▇▆▇▅▆▆▅▆▅▅▅▅▅▆▁

0,1
test_accuracy,0.44552
test_loss,1.21311
train_accuracy,1.0
train_loss,0.64506


##### RMU

In [11]:
from transformers import AutoTokenizer, AutoModelForCausalLM

hf_model_name = "google/gemma-2b-it"
transformer_lens_model_name = "google/gemma-2b-it"
tokenizer = AutoTokenizer.from_pretrained(hf_model_name)

hf_model_name = 'eoinf/gemma_2b_it_rmu_s60_a1000'
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_name, torch_dtype='auto') #.to("cuda")
RMU_model = HookedTransformer.from_pretrained(transformer_lens_model_name, hf_model=hf_model, tokenizer=tokenizer)

config.json:   0%|          | 0.00/707 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]



Loaded pretrained model google/gemma-2b-it into HookedTransformer


In [22]:
for layer in [14]: # range(8, 14):
    batch_size = 32
    args = ProbeTrainingArgs(layer=layer, batch_size=batch_size, model_name='gemma-2b-it_RMU_s60_a1000')
    trainer = LinearProbeTrainer(RMU_model, args, train_prompts, train_answers, test_prompts, test_answers)
    trainer.train()
    trainer.save_linear_probe()
    trainer.save_accuracy_to_file()

100%|██████████| 117/117 [01:38<00:00,  1.19it/s]
100%|██████████| 13/13 [00:11<00:00,  1.18it/s]
100%|██████████| 117/117 [00:00<00:00, 518.46it/s]
100%|██████████| 117/117 [00:00<00:00, 530.82it/s]
100%|██████████| 117/117 [00:00<00:00, 579.99it/s]
100%|██████████| 117/117 [00:00<00:00, 845.18it/s]
100%|██████████| 117/117 [00:00<00:00, 662.23it/s]
100%|██████████| 117/117 [00:00<00:00, 516.96it/s]
100%|██████████| 117/117 [00:00<00:00, 526.97it/s]
100%|██████████| 117/117 [00:00<00:00, 587.41it/s]
100%|██████████| 117/117 [00:00<00:00, 513.80it/s]
100%|██████████| 117/117 [00:00<00:00, 533.87it/s]
100%|██████████| 117/117 [00:00<00:00, 656.01it/s]
100%|██████████| 117/117 [00:00<00:00, 749.39it/s]
100%|██████████| 117/117 [00:00<00:00, 517.39it/s]
100%|██████████| 117/117 [00:00<00:00, 501.61it/s]
100%|██████████| 117/117 [00:00<00:00, 504.86it/s]
100%|██████████| 117/117 [00:00<00:00, 505.98it/s]


VBox(children=(Label(value='0.026 MB of 0.026 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_accuracy,▁▂▃▄▃▄▃▅▅▄▃▇▅▆▃█
test_loss,█▆▄▅▃▂▄▂▃▂▃▄▃▁▃▂
train_accuracy,▅▄▆▆▇▆▆▅▇▄▅▆▅▆▅▆▆▆▇▅▅▅▇▄▅▆▆▇█▇▅▇▆▆▆▇▇▆█▁
train_loss,██▆▄▄▄▂▆▄▇▃▄▇▂▇▃▁▅▄▅▅▆▃▇▅▃▅▂▃▂▅▃▃▃▂▂▃▃▁▆

0,1
test_accuracy,0.45278
test_loss,1.19774
train_accuracy,0.0
train_loss,1.31911
