# Setup

In [43]:
# Standard-ish set of imports copy-pasted from ARENA notebooks

from nnsight import LanguageModel

import gc
import itertools
import math
import os
import random
import sys
from collections import Counter, defaultdict
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias
import json

import einops
import numpy as np
import pandas as pd
import plotly.express as px
import requests
import torch as t
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from rich import print as rprint
from rich.table import Table
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer, utils
from transformer_lens.hook_points import HookPoint

device = "cuda" if t.cuda.is_available() else "mps" if t.backends.mps.is_available() else "cpu"

In [44]:
## check memory usage

if t.cuda.is_available():
    gpu_id = 0  # Set to your target GPU ID
    total_memory = t.cuda.get_device_properties(gpu_id).total_memory
    allocated_memory = t.cuda.memory_allocated(gpu_id)
    cached_memory = t.cuda.memory_reserved(gpu_id)

    print(f"Total GPU Memory: {total_memory / 1024**2:.2f} MB")
    print(f"Allocated GPU Memory: {allocated_memory / 1024**2:.2f} MB")
    print(f"Cached GPU Memory: {cached_memory / 1024**2:.2f} MB")
elif t.backends.mps.is_available():
    # MPS (Metal Performance Shaders) for Mac
    print("MPS is available.")
    # Note: As of now, PyTorch doesn't provide direct memory management functions for MPS
    print("Memory information is not available for MPS.")
else:
    print("Neither CUDA nor MPS is available.")

Total GPU Memory: 45589.06 MB
Allocated GPU Memory: 30.89 MB
Cached GPU Memory: 492.00 MB


In [45]:
# Clear out GPU memory to avoid out-of-memory errors
# Re-run this cell whenever memory usage gets high.

gc.collect()
t.cuda.empty_cache()


# Helper functions

In [46]:
# Fairly uninteresting helper functions

def load_tensor(filename):
    if device == "mps":
        tensor = t.load(filename, map_location="cpu")
        tensor.to(device, dtype=t.float32)
    else:
        tensor = t.load(filename, map_location="cpu")
        tensor.to(device, dtype=t.float32)
    return tensor

def get_second_min(x):
    min_value = t.min(x)
    mask = x != min_value
    second_min_value = t.min(x[mask])

    return second_min_value

def clear_memory_after(f):
    def func(*args, **kwargs):
        result = f(*args, **kwargs)
        gc.collect()
        t.cuda.empty_cache()
        return result
    return func

@clear_memory_after
def get_activations(prompts, sae_name, sae_id):
    t.set_grad_enabled(False)
    gemma2: HookedSAETransformer = HookedSAETransformer.from_pretrained("gemma-2-2b-it", device=device)
    gemma2_sae, cfg_dict, sparsity = SAE.from_pretrained(
        release=sae_name,
        sae_id=sae_id,
        device=str(device),
    )

    all_sae_acts_post = []

    for prompt in tqdm(prompts):
        # Get top activations on final token
        _, cache = gemma2.run_with_cache_with_saes(
            prompt,
            saes=[gemma2_sae],
            stop_at_layer=gemma2_sae.cfg.hook_layer + 1,
        )
        sae_acts_post = cache[f"{gemma2_sae.cfg.hook_name}.hook_sae_acts_post"][0, -1, :]
        all_sae_acts_post.append(sae_acts_post)

    return t.stack(all_sae_acts_post)

def get_frac_active(sae_activations):
    return (sae_activations > 0).sum(dim=(0)) / sae_activations.shape[0]

def smooth_frac_active(frac_active):
    # Smooth the vector to avoid zeroes by arbitrarily adding half the second
    # smallest value.  This allows dividing by the vector without errors.
    # This is arbitrary; find a better way...
    return t.where(frac_active == 0, get_second_min(frac_active)/2, frac_active)

def get_relative_activation(unsmoothed_frac_active_interest, unsmoothed_frac_active_baseline):
    # Rough statistic: a latent is "interesting" if its activation is big compared to
    # the activation in some baseline.
    return smooth_frac_active(unsmoothed_frac_active_interest)/smooth_frac_active(unsmoothed_frac_active_baseline)

# Experiments

In [47]:
# Find and display interesting latents, comparing activations in some set of harmful data to some baseline.

# See https://www.neuronpedia.org/api-doc for instructions on how to get a Neuronpedia API key
# Save it in this file.
with open('/workspace/neuronpedia-api', 'r') as f:
    neuronpedia_headers = {"X-Api-Key": f.read().strip()}

# Cache calls to fetch explanations from Neuronpedia to make re-runs quick, and since
# some latents appear for multiple sets of harms.
# Note: these are auto-interpretation, so take with a grain of salt, but they have some value
# for quickly getting a sense of something.
try:
    if EXPLANATION_CACHE:
        print("Cache EXPLANATION_CACHE already exists, not overwriting it to avoid repeated API calls!")
except NameError:
    EXPLANATION_CACHE = {}

def display_interesting_latents(harm_activations, baseline_activations, title="Interesting latents", min_harm_frac_active=0.3, num_latents=20, fetch_explanations_from_neuronpedia=True):
    global EXPLANATION_CACHE

    harm_frac_active = get_frac_active(harm_activations)
    baseline_frac_active = get_frac_active(baseline_activations)
    ratio = get_relative_activation(harm_frac_active, baseline_frac_active)
    table = Table("Latent", "Frac harm", "Frac base", "Ratio", "Explanation",
                title=title, show_lines=True)
    for latent in t.where(harm_frac_active>min_harm_frac_active, ratio, 0).topk(k=num_latents).indices:
        path = f"feature/gemma-2-2b/{layer}-gemmascope-res-16k/{latent}"
        if fetch_explanations_from_neuronpedia:
            if path in EXPLANATION_CACHE:
                explanation = EXPLANATION_CACHE[path]
            else:
                response = requests.get(f"https://www.neuronpedia.org/api/{path}", headers=neuronpedia_headers)
                explanations = response.json().get('explanations', [])
                explanation = explanations[0].get('description', "(unknown)") if explanations else "(unknown)"
                EXPLANATION_CACHE[path] = explanation
        else:
            explanation = "Not fetched"
        table.add_row(f"[link=https://www.neuronpedia.org/{path}]Neuronpedia {latent}[/]",
                    f"{harm_frac_active[latent].item():.4f}", f"{baseline_frac_active[latent].item():.4f}", f"{ratio[latent].item():.2f}",
                    explanation)
    rprint(table)


Cache EXPLANATION_CACHE already exists, not overwriting it to avoid repeated API calls!


In [48]:
layer = 5

sae_name = "gemma-scope-2b-pt-res-canonical"
sae_id = f"layer_{layer}/width_16k/canonical"

sae_act_advbench = load_tensor(f'../data/sae_acts/{sae_name}/{sae_id}_advbench.pt')
sae_act_alpaca = load_tensor(f'../data/sae_acts/{sae_name}/{sae_id}_alpaca_10000.pt')



  tensor = t.load(filename, map_location="cpu")


In [49]:

display_interesting_latents(sae_act_advbench, sae_act_alpaca, title="Interesting latents more common in advbench than alpaca")

In [50]:
len(EXPLANATION_CACHE)

94

In [51]:

with open('/workspace/refusal_direction/dataset/processed/strongreject.json', 'r') as file:
    strongreject_raw = json.load(file)

strongreject = defaultdict(list)
for item in strongreject_raw:
  strongreject[item['category']].append(item['instruction'])

strongreject.keys()

dict_keys(['Disinformation and deception', 'Hate, harassment and discrimination', 'Illegal goods and services', 'Non-violent crimes', 'Sexual content', 'Violence'])

In [52]:
strongreject_activations = {}
for category in strongreject:
    print(f"Loading activations for {category}")
    strongreject_activations[category] = get_activations(strongreject[category], "gemma-scope-2b-pt-res-canonical", f"layer_{layer}/width_16k/canonical")



Loading activations for Disinformation and deception


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.58s/it]


Loaded pretrained model gemma-2-2b-it into HookedTransformer


100%|██████████| 50/50 [00:03<00:00, 13.32it/s]


Loading activations for Hate, harassment and discrimination


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.82s/it]


Loaded pretrained model gemma-2-2b-it into HookedTransformer


100%|██████████| 50/50 [00:03<00:00, 13.29it/s]


Loading activations for Illegal goods and services


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.78s/it]


Loaded pretrained model gemma-2-2b-it into HookedTransformer


100%|██████████| 50/50 [00:03<00:00, 14.02it/s]


Loading activations for Non-violent crimes


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.77s/it]


Loaded pretrained model gemma-2-2b-it into HookedTransformer


100%|██████████| 59/59 [00:04<00:00, 14.24it/s]


Loading activations for Sexual content


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.79s/it]


Loaded pretrained model gemma-2-2b-it into HookedTransformer


100%|██████████| 50/50 [00:03<00:00, 13.91it/s]


Loading activations for Violence


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.61s/it]


Loaded pretrained model gemma-2-2b-it into HookedTransformer


100%|██████████| 54/54 [00:04<00:00, 13.28it/s]


In [53]:
for category in strongreject_activations:
  display_interesting_latents(strongreject_activations[category].cpu(), sae_act_alpaca, min_harm_frac_active=0.1, title=f"Interesting latents more common in strongreject's \"{category}\" than alpaca")