In [5]:
from nnsight import LanguageModel

import gc
import itertools
import math
import os
import random
import sys
from collections import Counter
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias

import einops
import numpy as np
import pandas as pd
import plotly.express as px
import requests
import torch as t
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from rich import print as rprint
from rich.table import Table
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer, utils
from transformer_lens.hook_points import HookPoint

device = "cuda" if t.cuda.is_available() else "mps" if t.backends.mps.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
## check memory usage

if t.cuda.is_available():
    gpu_id = 0  # Set to your target GPU ID
    total_memory = t.cuda.get_device_properties(gpu_id).total_memory
    allocated_memory = t.cuda.memory_allocated(gpu_id)
    cached_memory = t.cuda.memory_reserved(gpu_id)

    print(f"Total GPU Memory: {total_memory / 1024**2:.2f} MB")
    print(f"Allocated GPU Memory: {allocated_memory / 1024**2:.2f} MB")
    print(f"Cached GPU Memory: {cached_memory / 1024**2:.2f} MB")
elif t.backends.mps.is_available():
    # MPS (Metal Performance Shaders) for Mac
    print("MPS is available.")
    # Note: As of now, PyTorch doesn't provide direct memory management functions for MPS
    print("Memory information is not available for MPS.")
else:
    print("Neither CUDA nor MPS is available.")

Total GPU Memory: 45589.06 MB
Allocated GPU Memory: 0.00 MB
Cached GPU Memory: 0.00 MB


In [7]:
# del gemma2
# del gemma2_sae
import gc
gc.collect()

t.cuda.empty_cache()


# Output tensor of SAE activations for advbench

In [3]:
import json

# Read from advbench.json file
with open('/workspace/refusal_direction/dataset/processed/advbench.json', 'r') as file:
    advbench_data = json.load(file)

len(advbench_data)

# Read from advbench.json file
with open('/workspace/refusal_direction/dataset/processed/alpaca.json', 'r') as file:
    alpaca_data = json.load(file)

print(len(alpaca_data))

31323


In [8]:
def save_sae_activations(sae_name, sae_ids, data, suffix, max_data_rows=None):
    t.set_grad_enabled(False)
    gemma2: HookedSAETransformer = HookedSAETransformer.from_pretrained("gemma-2-2b-it", device=device)

    for sae_id in sae_ids:
        print(f"Calculating activation fraction for {sae_id}")
        gemma2_sae, cfg_dict, sparsity = SAE.from_pretrained(
            release=sae_name,
            sae_id=sae_id,
            device=str(device),
        )
    
        sum_sae_nonzero_acts_post_last = t.zeros(gemma2_sae.cfg.d_sae).to(device)
        sum_sae_nonzero_acts_post_all = t.zeros(gemma2_sae.cfg.d_sae).to(device)
        last_token_count = 0
        all_token_count = 0
    
        for item in tqdm(data[:max_data_rows]):
            prompt = item['instruction']
            
            # Get top activations on final token
            _, cache = gemma2.run_with_cache_with_saes(
                prompt,
                saes=[gemma2_sae],
                stop_at_layer=gemma2_sae.cfg.hook_layer + 1,
            )
            cache_post = cache[f"{gemma2_sae.cfg.hook_name}.hook_sae_acts_post"]

            last_token_count += 1
            sum_sae_nonzero_acts_post_last += (cache_post[0, -1, :] > 0)

            all_token_count += cache_post.shape[1]
            sum_sae_nonzero_acts_post_all += (cache_post[0, :, :] > 0).sum(dim=0)
    
        # Stack all sae_acts_post tensors
        frac_active_last = sum_sae_nonzero_acts_post_last / last_token_count
        frac_active_all = sum_sae_nonzero_acts_post_all / all_token_count
    
        print(f"SAE ID: {sae_id}")
        print(f"{last_token_count=}")
        print(f"{all_token_count=}")
        print(f"Shape of frac_active_last tensor: {frac_active_last.shape}")
        print(f"Shape of frac_active_all tensor: {frac_active_all.shape}")
        print(f"Total number of non-zero activations at last token: {(sum_sae_nonzero_acts_post_last != 0).sum().item()}")
        print(f"Total number of non-zero activations at all tokens: {(sum_sae_nonzero_acts_post_all != 0).sum().item()}")
    
        # Create directory if it doesn't exist
        directory = f'/workspace/refusal_direction/data/sae_frac_active/{sae_name}/{sae_id}'
        os.makedirs(directory, exist_ok=True)
    
        # Save the stacked_sae_acts_post tensor
        t.save(frac_active_last.cpu(), f'{directory}_{suffix}_{last_token_count}_last.pt')
        t.save(frac_active_all.cpu(), f'{directory}_{suffix}_{all_token_count}_all.pt')
    
        # Print confirmation message
        print(f"Fractions of SAEs active saved to '{directory}'")
        print("---")
    
        del sum_sae_nonzero_acts_post_last,frac_active_last,sum_sae_nonzero_acts_post_all,frac_active_all
        del gemma2_sae
        gc.collect()
        t.cuda.empty_cache()
    

In [9]:
sae_name = "gemma-scope-2b-pt-res-canonical"
sae_ids = [f"layer_{layer}/width_16k/canonical" for layer in range(11)]  # 0 through 10
#sae_ids = [f"layer_{layer}/width_16k/canonical" for layer in [5]]  # only layer 5

save_sae_activations(sae_name, sae_ids, alpaca_data, suffix = "alpaca", max_data_rows=10000)

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.54s/it]


Loaded pretrained model gemma-2-2b-it into HookedTransformer
Calculating activation fraction for layer_0/width_16k/canonical


100%|██████████| 10000/10000 [07:52<00:00, 21.18it/s]


SAE ID: layer_0/width_16k/canonical
last_token_count=10000
all_token_count=163840000
Shape of frac_active_last tensor: torch.Size([16384])
Shape of frac_active_all tensor: torch.Size([16384])
Total number of non-zero activations at last token: 15010
Total number of non-zero activations at all tokens: 16379
Fractions of SAEs active saved to '/workspace/refusal_direction/data/sae_frac_active/gemma-scope-2b-pt-res-canonical/layer_0/width_16k/canonical'
---
Calculating activation fraction for layer_1/width_16k/canonical


 50%|█████     | 5024/10000 [04:19<04:08, 20.01it/s]