Things we might want to do:
- Train the probes to be like the ones trained in refusal is mediated by a single direction

  - Calculate the difference in means vectors for different layers / positions from the end (think)

  - Perform selection based on `min(bypass_score)` of directions for which `induce_score` > 0, `kl_score` < 0.1

    - Where `bypass_score` is the average metric for directional ablation in the should-refuse dataset

    - `induce_score` is the average metric for direction ablation in the should-not-refuse dataset

    - `kl_score` is just the KL divergence on the should-not-refuse dataset under directional ablation

  - I would need to think about how to adapt this to our situation. Importantly:

    - Positions from end is not principled here. Seems likely that the model 'knows' the {attribute} earlier. Seems like some sort of position selection probe might be best here (look at earliest position which exceeds some threshold?)

    - We're typically dealing with multiple attributes. This is probably the most importabt difference?  Probably makes sense to the average of all mean vectors as some sort of centre and then have each subcategory be a displacement from this centre.

    - The result isn't binary model behaviour so idk if it will be as clean. More importantly, this means no clean metric. Maybe we can just try appending the suffix to the prompt every time and then have our metric be based on this at the end position.

- Look at all these probes on `lmsys-1m` + `suffix`and measure different scores for these different. I think it actually makes more sense to do this with the refusal type probes because then I can do positional selection as well eventually (fingers crossed)
  - Maybe let's try and think a bit more about how we could train these.
    - You can't really use the metrics above because the metrics are relatively quite expensive
    - I think one possible thing that makes sense is to train them to find the positions which get the highest attribution for the metric or maybe our final position final layer directions? Attribution seems more complicated in this world and I'll have to think about it harder.


I think that for now we my priorities should basically be:
1.   Use my activation cache function to get the means for all the different subclasses, start with just doing stuff on the final position for now.
2.   Get the central and displacement vectors from this
3.   Figure out what the relevant metric is going to be.
4.   Implement `bypass_score`, `induce_score`, `kl_score` for all these probes.
Here, I'm going to have to figure out what the pricipled comparison. I don't know if you want to project back to the centre? Seems unclear if ablation is actually prinicpled in this setting. Also unclear what the dataset is because it's non-binary.
5.  Use all this to do basic checks on what I've found
6.  Now move on to finding the relevant postions instead of just layers. Train a probe at each layer to predict the attribution score (actually just predicting the argmax might actually be more principled because of things like softmaxes in attention, but maybe it's also just fine?), probably should also not use the final position here.
7.  Then revisit the selection process used in 4 on the top k positions from our position selection probes.



### Imports

In [1]:
%pip install transformer_lens huggingface_hub

Collecting transformer_lens
  Downloading transformer_lens-2.8.1-py3-none-any.whl.metadata (12 kB)
Collecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer_lens)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting datasets>=2.7.1 (from transformer_lens)
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting fancy-einsum>=0.0.3 (from transformer_lens)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl.metadata (1.2 kB)
Collecting jaxtyping>=0.2.11 (from transformer_lens)
  Downloading jaxtyping-0.2.34-py3-none-any.whl.metadata (6.4 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets>=2.7.1->transformer_lens)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets>=2.7.1->transformer_lens)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Col

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset, random_split
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint
import transformer_lens.utils as utils
from typing import Dict, List, Optional, Tuple, Union
from datasets import load_dataset
import numpy as np
from dataclasses import dataclass
from tqdm.auto import tqdm
import requests
import zipfile
import os
import json
import gc
from pathlib import Path
from functools import partial
import copy
import matplotlib.pyplot as plt
from fancy_einsum import einsum
import plotly.express as px

from huggingface_hub import login, hf_hub_download
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

# Refusal based metrics

### Loading datasets

In [None]:
# Label mappings
LABEL_MAPS = {
    "age": {
        "child": 0,
        "adolescent": 1,
        "adult": 2,
        "older adult": 3,
    },
    "gender": {
        "male": 0,
        "female": 1,
    },
    "socioeconomic": {
        "low": 0,
        "middle": 1,
        "high": 2
    },
    "education": {
        "someschool": 0,
        "highschool": 1,
        "collegemore": 2
    }
}

In [None]:
DATA_MAPS = {
    "age": 2,
    "gender": 4,
    "socioeconomic": 2,
    "education": 3
}

In [None]:
def load_dataset(
    attribute: str,
    remove_suffix: bool = False
) -> Tuple[List[str], List[int]]:
    """Load dataset for a given attribute"""

    texts = []
    labels = []
    label_map = LABEL_MAPS[attribute]

    if attribute == "education":
        data_paths = [Path(f"dataset/openai_{attribute}_three_classes_{i}.zip") for i in range(1, DATA_MAPS[attribute]+1)]
    else:
        data_paths = [Path(f"dataset/openai_{attribute}_{i}.zip") for i in range(1, DATA_MAPS[attribute]+1)]


    for data_path in data_paths:
        if not data_path.exists():
            print("Downloading dataset...")
            url = f"https://github.com/yc015/TalkTuner-chatbot-llm-dashboard/raw/main/data/dataset/{data_path.name}"
            response = requests.get(url)
            if response.status_code != 200:
                raise ValueError(f"Download failed: HTTP {response.status_code}")

            data_path.parent.mkdir(parents=True, exist_ok=True)
            with open(data_path, "wb") as f:
                f.write(response.content)

        # Extract if needed
        extract_path = data_path.parent / data_path.stem

        if not extract_path.exists():
            print("Extracting files...")
            with zipfile.ZipFile(data_path, 'r') as zip_ref:
                zip_ref.extractall(extract_path)

        # Process txt files
        for txt_file in extract_path.glob("*.txt"):
            # Extract label from filename (e.g., "conversation_107_age_adolescent.txt" -> "adolescent")
            label = txt_file.stem.split('_')[-1]

            if label in label_map:
                with open(txt_file) as f:
                    text = f.read().strip()

                if text.rfind("\n\n### Assistant:") > text.rfind("\n\n### Human:"):
                    text = text[:text.rfind("\n\n### Assistant:")]

                if not remove_suffix:
                    text += f"\n\n### Assistant: I think the {attribute} of this user is"

                texts.append(text)
                labels.append(label_map[label])

    return texts, labels

### Activation Dataset

In [None]:
class ActivDataset(Dataset):
    def __init__(self, model: HookedTransformer, texts: List[str], labels: List[int], num_classes: int, batch_size: int):
        self.model = model
        self.labels = torch.tensor(labels)
        self.num_classes = num_classes
        self.len = len(texts)
        self.batch_size = batch_size
        self.text_loader = DataLoader(texts, batch_size=batch_size, shuffle=False)

        self.activations = self._get_activations(texts)

    def _get_activations(self, texts: List[str]) -> List[torch.Tensor]:

        activations = [[] for _ in range(self.model.cfg.n_layers)]

        def hook_fn(resid: torch.Tensor, hook: HookPoint, layer: int, posns: torch.Tensor) -> torch.Tensor:
            activations[layer].append(resid.clone().to('cpu')[torch.arange(resid.shape[0]), posns, :])
            return resid


        for batch in tqdm(self.text_loader):

            tokens = self.model.to_tokens(batch, prepend_bos=True, padding_side='right')
            final_posns = (tokens.clone().to('cpu') != self.model.tokenizer.pad_token_id).sum(dim=-1) - 1

            with torch.no_grad():
                self.model.run_with_hooks(tokens,
                                          fwd_hooks=[(f'blocks.{layer}.hook_resid_post', partial(hook_fn, layer=layer, posns = final_posns)) for layer in range(self.model.cfg.n_layers)],
                                          return_type=None)


        for layer in range(self.model.cfg.n_layers):

            activations[layer] = torch.cat(activations[layer], dim=0)

        # shape: [n_layers, n_dataset, d_model]
        return torch.stack(activations)



    def __len__(self) -> int:
        return self.len

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.activations[:,idx], self.labels[idx]

### Problem set-up

In [None]:
model = HookedTransformer.from_pretrained("gemma-2-9b", device='cpu', dtype=torch.bfloat16)
model.to('cuda')



config.json:   0%|          | 0.00/856 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/39.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/8 [00:00<?, ?it/s]

model-00001-of-00008.safetensors:   0%|          | 0.00/4.84G [00:00<?, ?B/s]

model-00002-of-00008.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00003-of-00008.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00004-of-00008.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

model-00005-of-00008.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00006-of-00008.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00007-of-00008.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00008-of-00008.safetensors:   0%|          | 0.00/2.38G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]



Loaded pretrained model gemma-2-9b into HookedTransformer
Moving model to device:  cuda


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-41): 42 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln1_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
        (

In [None]:
dataset_name = 'gender'
remove_suffix = False

texts, labels = load_dataset(dataset_name, remove_suffix)
num_classes = len(LABEL_MAPS[dataset_name])

Downloading dataset...
Extracting files...
Downloading dataset...
Extracting files...
Downloading dataset...
Extracting files...
Downloading dataset...
Extracting files...


### Getting activations and directions

In [None]:
activ_dataset = ActivDataset(model, texts, labels, num_classes, batch_size=5)

  0%|          | 0/380 [00:00<?, ?it/s]

In [None]:
num_examples = torch.zeros(num_classes)
total_resids = torch.zeros(num_classes, model.cfg.n_layers, model.cfg.d_model)

for activ, label in tqdm(activ_dataset):
    num_examples[label] += 1
    total_resids[label] += activ

mean_sc_resids = total_resids / num_examples[:, None, None]
sc_cents = mean_sc_resids.mean(dim=0)
sc_disps = mean_sc_resids - sc_cents[None, :, :]

  0%|          | 0/1900 [00:00<?, ?it/s]

In [None]:
sc_cents.shape, sc_disps.shape

(torch.Size([42, 3584]), torch.Size([2, 42, 3584]))

In [None]:
## Save the subcategory centres and displacements
torch.save(sc_cents, 'sc_cents.pt')
torch.save(sc_disps, 'sc_disps.pt')

## Get metrics

In [None]:
## Load the subcategory centres and displacements
sc_cents = torch.load('sc_cents.pt').to('cuda')
sc_disps = torch.load('sc_disps.pt').to('cuda')

  sc_cents = torch.load('sc_cents.pt').to('cuda')
  sc_disps = torch.load('sc_disps.pt').to('cuda')


In [None]:
sc_cents.shape, sc_disps.shape

(torch.Size([26, 2304]), torch.Size([2, 26, 2304]))

#### Finding metrics

Let's start with gender because the token forcing response means that we can use a very simple metric (i.e. just log prob diff male or female) instead of doing the whole greedy decoding thing

In [None]:
utils.test_prompt(texts[0], ' male', model, prepend_bos=True, top_k=10)

Tokenized prompt: ['<bos>', '###', ' Human', ':', ' Hi', ' there', '!', ' I', "'", 'm', ' excited', ' to', ' try', ' out', ' this', ' AI', ' assistant', '.', ' How', ' can', ' you', ' help', ' me', ' today', '?', '\n\n', '###', ' Assistant', ':', ' Hello', '!', ' I', "'", 'm', ' here', ' to', ' assist', ' you', ' with', ' any', ' questions', ' or', ' tasks', ' you', ' may', ' have', '.', ' Just', ' let', ' me', ' know', ' what', ' you', ' need', ' help', ' with', ',', ' and', ' I', "'", 'll', ' do', ' my', ' best', ' to', ' assist', ' you', '.', '\n\n', '###', ' Human', ':', ' Great', '!', ' As', ' a', ' guy', ',', ' I', ' often', ' struggle', ' with', ' finding', ' good', ' hairstyle', ' ideas', '.', ' Do', ' you', ' have', ' any', ' suggestions', ' for', ' a', ' trendy', ' men', "'", 's', ' haircut', '?', '\n\n', '###', ' Assistant', ':', ' Absolutely', '!', ' There', ' are', ' plenty', ' of', ' trendy', ' men', "'", 's', ' haircuts', ' to', ' choose', ' from', '.', ' Some', ' popula

Top 0th token. Logit: 26.25 Prob: 24.22% Token: | male|
Top 1th token. Logit: 25.62 Prob: 12.99% Token: | "|
Top 2th token. Logit: 25.25 Prob:  8.89% Token: | a|
Top 3th token. Logit: 24.88 Prob:  6.13% Token: | '|
Top 4th token. Logit: 24.75 Prob:  5.40% Token: | not|
Top 5th token. Logit: 24.50 Prob:  4.20% Token: | irrelevant|
Top 6th token. Logit: 24.25 Prob:  3.27% Token: | Male|
Top 7th token. Logit: 24.00 Prob:  2.55% Token: | important|
Top 8th token. Logit: 23.50 Prob:  1.54% Token: | human|
Top 9th token. Logit: 23.38 Prob:  1.37% Token: | [|


In [None]:
labels[0]

0

### Benchmarking directions

In [None]:
completion_strs = [[' male', ' Male', ' masculine'], [' female', ' Female', ' feminine']]
completion_toks = torch.tensor([[9202, 24417, 73303], [73303, 30870, 53012]])

def gender_metrics(logits: torch.Tensor, labels: torch.Tensor):
  # logits: [batch, seq, d_vocab]
  # only gender because for other things we're probably going to have to do longer generations
  logits = logits.to('cpu')
  correct_toks = completion_toks[labels]
  # Need to come back here, do more general version
  incorrect_toks = completion_toks[[1 - x for x in labels]]
  return logits[:, correct_toks].sum(dim=-1) - logits[:, incorrect_toks].sum(dim=-1)

In [None]:
def proj_to_centre_hook(resid:torch.Tensor, hook: HookPoint, labels:int, posns: torch.Tensor, sc_cent: torch.Tensor, sc_disp: torch.Tensor):
  # resid: [batch, seq, d_model]
  batch_inds = torch.arange(resid.shape[0], device=resid.device)
  resid[batch_inds, posns, :] -= einsum('batch d_model_A, batch d_model_A, batch d_model_B -> batch d_model_B', resid[batch_inds, posns, :] - sc_cent[None, :], sc_disp[labels], sc_disp[labels])
  return resid

In [None]:
def get_kl(new_logits: torch.Tensor, base_logits: torch.Tensor):
    base_logits = base_logits.cpu()
    new_logits = new_logits.cpu()
    base_logprobs = F.log_softmax(base_logits, dim=-1)
    new_logprobs = F.log_softmax(new_logits, dim=-1)
    return (base_logprobs.exp() * (base_logprobs - new_logprobs)).sum(dim=-1)

In [None]:
top_n = 256
text_label_loader = DataLoader(list(zip(texts[:top_n], labels[:top_n])), batch_size=1)

base_logits_store = []
for text_batch, labels_batch in tqdm(text_label_loader):
    with torch.no_grad():
        tokens = model.to_tokens(text_batch, prepend_bos=True, padding_side='right')
        final_posns = ((tokens.cpu() != model.tokenizer.pad_token_id).sum(dim=-1) - 1)
        base_logits_store.append(model(tokens, return_type='logits')[torch.arange(tokens.shape[0]), final_posns, :].cpu())
        del tokens
        del final_posns
        del text_batch
        del labels_batch

        torch.cuda.empty_cache()
        gc.collect()

base_logits = torch.cat(base_logits_store, dim=0)

  0%|          | 0/256 [00:00<?, ?it/s]

In [None]:
all_effect_metrics = torch.zeros(model.cfg.n_layers)
all_kl_metrics = torch.zeros(model.cfg.n_layers)

for layer in tqdm(range(model.cfg.n_layers)):
    final_logits_store = []
    for text_batch, labels_batch in tqdm(text_label_loader):
        with torch.no_grad():
            tokens = model.to_tokens(text_batch, prepend_bos=True, padding_side='right')
            final_posns = ((tokens.cpu() != model.tokenizer.pad_token_id).sum(dim=-1) - 1)

            final_logits_store.append(model.run_with_hooks(
                tokens,
                fwd_hooks=[(f'blocks.{layer}.hook_resid_post', partial(proj_to_centre_hook, labels=labels_batch, posns=final_posns.cuda(), sc_cent=sc_cents[layer], sc_disp=sc_disps[:, layer]))],
                return_type='logits'
            )[torch.arange(tokens.shape[0]), final_posns, :].cpu()
            )

            del tokens
            del final_posns
            del text_batch
            del labels_batch

            torch.cuda.empty_cache()
            gc.collect()

    logits_store = torch.cat(final_logits_store, dim=0)
    all_effect_metrics[layer] = gender_metrics(logits_store, labels).mean()
    all_kl_metrics[layer] = get_kl(logits_store, base_logits).mean()

  0%|          | 0/42 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

  0%|          | 0/256 [00:00<?, ?it/s]

In [None]:
px.line(all_kl_metrics.float().numpy(), title='KL Metrics')

In [None]:
px.line(all_effect_metrics.float().numpy(), title='Effect Metrics')

In [None]:
all_base_metrics.mean(), all_effect_metrics.mean(), all_kl_metrics.mean()

(tensor(2.2344, dtype=torch.bfloat16), tensor(0.), tensor(11.9633))

Next to do:
- Track over more examples.
- Try to get metrics for the other subclasses
- Attribution on the final token later layer direction to find important positions. Get measures of where the important positions are.
- Probe for important positions.
- Find the measuring probes (probably want something different from just the attribution / gradient direction, difference is means at constant layer might just work)

# Finding vague prompts and generalisation

### Set-up

In [3]:
model = HookedTransformer.from_pretrained("gemma-2-9b-it", device='cpu', dtype=torch.bfloat16)
model.to('cuda')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/857 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/39.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.67G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]



Loaded pretrained model gemma-2-9b-it into HookedTransformer
Moving model to device:  cuda


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-41): 42 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln1_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
        (

In [4]:
ds = load_dataset("lmsys/lmsys-chat-1m")

README.md:   0%|          | 0.00/8.88k [00:00<?, ?B/s]

(…)-00000-of-00006-4feeb3f83346a0e9.parquet:   0%|          | 0.00/249M [00:00<?, ?B/s]

(…)-00001-of-00006-4030672591c2f478.parquet:   0%|          | 0.00/247M [00:00<?, ?B/s]

(…)-00002-of-00006-1779b7cec9462180.parquet:   0%|          | 0.00/250M [00:00<?, ?B/s]

(…)-00003-of-00006-2fa862bfed56af1f.parquet:   0%|          | 0.00/247M [00:00<?, ?B/s]

(…)-00004-of-00006-18f4bdd50c103e71.parquet:   0%|          | 0.00/246M [00:00<?, ?B/s]

(…)-00005-of-00006-fe1acc5d10a9f0e2.parquet:   0%|          | 0.00/249M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000000 [00:00<?, ? examples/s]

In [5]:
NUM_CLASSES = {
    "gender": 2,
    "age": 4,
    "socioeconomic": 3,
    "education": 3
}

In [6]:
class LinearProbes(nn.Module):
    def __init__(self, input_dim: int, num_classes: int):
        super().__init__()
        self.probe = nn.Linear(input_dim, num_classes)

        nn.init.xavier_uniform_(self.probe.weight)
        nn.init.zeros_(self.probe.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.sigmoid(self.probe(x.to(dtype=torch.float)))

    def get_grouped_params(self):
        decay = []
        no_decay = []

        for name, param in self.named_parameters():
            if not param.requires_grad:
                continue
            if 'bias' in name:
                no_decay.append(param)
            else:
                decay.append(param)

        return [
            {'params': decay},
            {'params': no_decay, 'weight_decay': 0.0}
        ]

In [7]:
class ActivDataset(Dataset):
    def __init__(self, model: HookedTransformer, texts: List[str], batch_size: int):
        self.model = model
        self.len = len(texts)
        self.batch_size = batch_size
        self.text_loader = DataLoader(texts, batch_size=batch_size, shuffle=False)

        self.activations = self._get_activations(texts)

    def _get_activations(self, texts: List[str]) -> List[torch.Tensor]:
        activations = [[] for _ in range(self.model.cfg.n_layers)]

        def hook_fn(resid: torch.Tensor, hook: HookPoint, layer: int, posns: torch.Tensor) -> torch.Tensor:
            activations[layer].append(resid[torch.arange(resid.shape[0]), posns, :].clone().to('cpu'))
            return resid


        for batch in tqdm(self.text_loader):
            tokens = self.model.to_tokens(batch, prepend_bos=True, padding_side='right')
            final_posns = (tokens.clone().to('cpu') != self.model.tokenizer.pad_token_id).sum(dim=-1) - 1

            with torch.no_grad():
                self.model.run_with_hooks(tokens,
                                          fwd_hooks=[(f'blocks.{layer}.hook_resid_post', partial(hook_fn, layer=layer, posns = final_posns)) for layer in range(self.model.cfg.n_layers)],
                                          return_type=None)

            del tokens
            del final_posns
            del batch

            torch.cuda.empty_cache()
            gc.collect()

        for layer in range(self.model.cfg.n_layers):
            activations[layer] = torch.cat(activations[layer], dim=0)

        # shape: [n_layers, n_dataset, d_model]
        return torch.stack(activations)



    def __len__(self) -> int:
        return self.len

    def __getitem__(self, idx: int) -> torch.Tensor:
        activation = self.activations[:,idx]

        return activation

In [8]:
def load_probes_from_huggingface(dataset_name: str, input_dim: int, layer:int, repo_id: str = "thorsley/user_modelling_probes_gemma-9b-it"):
    try:
        # Initialize the probe
        num_classes = NUM_CLASSES[dataset_name]
        probe = LinearProbes(input_dim, num_classes)

        # Construct filenames
        weights_file = f"collected_{dataset_name}_probe_weights.pt"
        biases_file = f"collected_{dataset_name}_probe_biases.pt"

        # Download files from HuggingFace
        weights_path = hf_hub_download(repo_id=repo_id, filename=weights_file)
        biases_path = hf_hub_download(repo_id=repo_id, filename=biases_file)

        # Load tensors
        weights = torch.load(weights_path, weights_only=True)
        biases = torch.load(biases_path, weights_only=True)

        # Set the weights and biases
        with torch.no_grad():
            probe.probe.weight.copy_(weights[layer])
            probe.probe.bias.copy_(biases[layer])

        return probe

    except Exception as e:
        raise Exception(f"Error loading probes for dataset {dataset_name}: {str(e)}")

### Gender

In [85]:
attribute = 'gender'
num_classes = NUM_CLASSES[attribute]
max_interact = 20
max_examples = 3000
gender_chats = []
for i in tqdm(range(max_examples)):
    chat = ds['train'][i]['conversation']
    if len(chat) > max_interact:
        chat = chat[:max_interact]
    if chat[-1]['role'] == 'assistant':
        chat = chat[:-1]
    chat.append({'role': 'assistant', 'content': f"I think the {attribute} of this user is"})
    gender_chats.append(model.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False)[:-14])

  0%|          | 0/3000 [00:00<?, ?it/s]

In [86]:
gender_chats[0]

'<bos><start_of_turn>user\nhow can identity protection services help protect me against identity theft<end_of_turn>\n<start_of_turn>model\nI think the gender of this user is'

In [87]:
gender_activ_dataset = ActivDataset(model, gender_chats, batch_size=1)

  0%|          | 0/3000 [00:00<?, ?it/s]

In [88]:
layer = 37

probe = load_probes_from_huggingface(attribute, model.cfg.d_model, layer=layer)
probe_results = probe(gender_activ_dataset[:][layer])

In [89]:
px.histogram(probe_results.detach().cpu().numpy(), barmode='overlay')

In [90]:
px.histogram(probe_results.sum(dim=-1).detach().cpu().numpy(), barmode='overlay')

In [91]:
probe_normed = probe_results / probe_results.sum(dim=-1, keepdim=True)
probe_ent = -(probe_normed * probe_normed.log()).sum(dim=-1)

In [92]:
px.histogram((probe_ent/np.log(num_classes)).detach().cpu().numpy(), barmode='overlay')

In [93]:
_, ent_sort_inds = probe_ent.sort(descending=True)

In [96]:
ind = 0
gender_chats[ent_sort_inds[ind]], probe_normed[ent_sort_inds[ind]]

('<bos><start_of_turn>user\nprovide me the complete pathophysiology of lung cancers<end_of_turn>\n<start_of_turn>model\nI think the gender of this user is',
 tensor([0.5000, 0.5000], grad_fn=<SelectBackward0>))

In [97]:
_, fem_sort_inds = probe_normed[:,1].sort(descending=True)

In [139]:
ind = -2
gender_chats[fem_sort_inds[ind]], probe_results[fem_sort_inds[ind]]

("<bos><start_of_turn>user\nNAME_1 is aware of what he is trying to achieve, understands what the next steps look like for him and sets a good example for the other grades. He  is passionate about client service, understands the competitive landscape and any impact it may or may not have on the firm. NAME_1 demonstrated good business acumen, ability to create XLos oportunities and working with diverse set of people. He seems like a go to person for a number of partners. NAME_1 gave good examples related client leadership, team leadership and showed general awareness of people matters. It was great to hear about how he took feedback on board and improved his rating from Valued to Outstanding within 1 year. \nBased on the information presented and the discussions during the panel, the panelists felt he is a great candidate and demonstrated readiness for the next grade.  \nNAME_1 has a clear development plan for himself as he moves into the Director role. One area suggested by the panel i

### Age

In [251]:
attribute = 'age'
num_classes = NUM_CLASSES[attribute]
max_interact = 20
max_examples = 1000
age_chats = []
for i in tqdm(range(max_examples)):
    chat = ds['train'][i]['conversation']
    if len(chat) > max_interact:
        chat = chat[:max_interact]
    if chat[-1]['role'] == 'assistant':
        chat = chat[:-1]
    chat.append({'role': 'assistant', 'content': f"I think the {attribute} of this user is"})
    age_chats.append(model.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False))

  0%|          | 0/1000 [00:00<?, ?it/s]

In [252]:
age_activ_dataset = ActivDataset(model, age_chats, batch_size=1)

  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [247]:
layer = 37

probe = load_probes_from_huggingface(attribute, model.cfg.d_model, layer=layer)
probe_results = probe(age_activ_dataset[:][layer])

In [248]:
px.histogram(probe_results.detach().cpu().numpy(), barmode='overlay')

In [249]:
probe_normed = probe_results / probe_results.sum(dim=-1, keepdim=True)
probe_ent = -(probe_normed * probe_normed.log()).sum(dim=-1)

In [250]:
px.histogram((probe_ent/np.log(num_classes)).detach().cpu().numpy(), barmode='overlay')

In [73]:
_, ent_sort_inds = probe_ent.sort(descending=True)

In [244]:
ind = 0
age_chats[ent_sort_inds[ind]], probe_normed[ent_sort_inds[ind]]

('<bos><start_of_turn>user\nhelp me to explain Theoretical solution for tubular steel sections<end_of_turn>\n<start_of_turn>model\nI think the age of this user is<end_of_turn>\n',
 tensor([0.3334, 0.3332, 0.3334], grad_fn=<SelectBackward0>))

### Education

In [108]:
attribute = 'education'
num_classes = NUM_CLASSES[attribute]
max_interact = 20
max_examples = 1000
education_chats = []
for i in tqdm(range(max_examples)):
    chat = ds['train'][i]['conversation']
    if len(chat) > max_interact:
        chat = chat[:max_interact]
    if chat[-1]['role'] == 'assistant':
        chat = chat[:-1]
    chat.append({'role': 'assistant', 'content': f"I think the {attribute} of this user is"})
    education_chats.append(model.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False))

  0%|          | 0/1000 [00:00<?, ?it/s]

In [109]:
education_activ_dataset = ActivDataset(model, education_chats, batch_size=1)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [110]:
layer = 37

probe = load_probes_from_huggingface(attribute, model.cfg.d_model, layer=layer)
probe_results = probe(education_activ_dataset[:][layer])

collected_education_probe_weights.pt:   0%|          | 0.00/1.81M [00:00<?, ?B/s]

collected_education_probe_biases.pt:   0%|          | 0.00/1.88k [00:00<?, ?B/s]

In [111]:
px.histogram(probe_results.detach().cpu().numpy(), barmode='overlay')

In [170]:
probe_normed = probe_results / probe_results.sum(dim=-1, keepdim=True)
probe_ent = -(probe_normed * probe_normed.log()).sum(dim=-1)

In [171]:
px.histogram((probe_ent/np.log(num_classes)).detach().cpu().numpy(), barmode='overlay')

In [145]:
_, ent_sort_inds = probe_ent.sort(descending=True)

In [155]:
ind = -6
education_chats[ent_sort_inds[ind]], probe_normed[ent_sort_inds[ind]]

("<bos><start_of_turn>user\nwrite a performance review for a junior data scientist<end_of_turn>\n<start_of_turn>model\nPerformance Review for Junior Data Scientist\n\nEmployee: [Employee Name]\nPosition: Junior Data Scientist\nReview Period: [Review Period]\n\nPerformance Summary:\n\n1. Technical Skills: [Employee Name] has demonstrated a strong foundation in data science, showing an impressive grasp of machine learning algorithms, data manipulation, and statistical analysis. They have shown proficiency in using Python and R for data analysis and have a good understanding of SQL. There's still room for improvement in mastering more advanced techniques, specifically in deep learning algorithms.\n\n2. Project Management: [Employee Name] has been consistent in meeting project deadlines and has shown the capability to handle multiple tasks simultaneously. There is a noticeable improvement in their ability to prioritize tasks over the review period. \n\n3. Problem Solving: [Employee Name] h

In [176]:
_, subcat_sort_inds = probe_normed[:,0].sort(descending=True)

In [177]:
ind = 9
education_chats[subcat_sort_inds[ind]], probe_normed[subcat_sort_inds[ind]]

('<bos><start_of_turn>user\nwrite me a summary of Atlas Shrugged as a bedtime story<end_of_turn>\n<start_of_turn>model\nI think the education of this user is<end_of_turn>\n',
 tensor([9.8923e-01, 1.5614e-05, 1.0750e-02], grad_fn=<SelectBackward0>))

Might be tracking "how should I respond" instead of "what do I think the user is?"

### Socioeconomic

In [180]:
attribute = 'socioeconomic'
num_classes = NUM_CLASSES[attribute]
max_interact = 20
max_examples = 1000
socioecon_chats = []
for i in tqdm(range(max_examples)):
    chat = ds['train'][i]['conversation']
    if len(chat) > max_interact:
        chat = chat[:max_interact]
    if chat[-1]['role'] == 'assistant':
        chat = chat[:-1]
    chat.append({'role': 'assistant', 'content': f"I think the {attribute} of this user is"})
    socioecon_chats.append(model.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False))

  0%|          | 0/1000 [00:00<?, ?it/s]

In [181]:
socioecon_activ_dataset = ActivDataset(model, socioecon_chats, batch_size=1)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [190]:
layer = 37

probe = load_probes_from_huggingface(attribute, model.cfg.d_model, layer=layer)
probe_results = probe(socioecon_activ_dataset[:][layer])

In [191]:
px.histogram(probe_results.detach().cpu().numpy(), barmode='overlay')

In [192]:
probe_normed = probe_results / probe_results.sum(dim=-1, keepdim=True)
probe_ent = -(probe_normed * probe_normed.log()).sum(dim=-1)

In [240]:
px.histogram((probe_ent/np.log(num_classes)).detach().cpu().numpy(), barmode='overlay')

In [241]:
_, ent_sort_inds = probe_ent.sort(descending=True)

In [243]:
ind = 0
socioecon_chats[ent_sort_inds[ind]], probe_normed[ent_sort_inds[ind]]

('<bos><start_of_turn>user\nhelp me to explain Theoretical solution for tubular steel sections<end_of_turn>\n<start_of_turn>model\nI think the socioeconomic of this user is<end_of_turn>\n',
 tensor([0.3334, 0.3332, 0.3334], grad_fn=<SelectBackward0>))