In [1]:
from nnsight import LanguageModel

import gc
import itertools
import math
import os
import random
import sys
from collections import Counter
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias

import einops
import numpy as np
import pandas as pd
import plotly.express as px
import requests
import torch as t
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from rich import print as rprint
from rich.table import Table
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer, utils
from transformer_lens.hook_points import HookPoint

device = "cuda" if t.cuda.is_available() else "mps" if t.backends.mps.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
## check memory usage

if t.cuda.is_available():
    gpu_id = 0  # Set to your target GPU ID
    total_memory = t.cuda.get_device_properties(gpu_id).total_memory
    allocated_memory = t.cuda.memory_allocated(gpu_id)
    cached_memory = t.cuda.memory_reserved(gpu_id)

    print(f"Total GPU Memory: {total_memory / 1024**2:.2f} MB")
    print(f"Allocated GPU Memory: {allocated_memory / 1024**2:.2f} MB")
    print(f"Cached GPU Memory: {cached_memory / 1024**2:.2f} MB")
elif t.backends.mps.is_available():
    # MPS (Metal Performance Shaders) for Mac
    print("MPS is available.")
    # Note: As of now, PyTorch doesn't provide direct memory management functions for MPS
    print("Memory information is not available for MPS.")
else:
    print("Neither CUDA nor MPS is available.")

Total GPU Memory: 45541.31 MB
Allocated GPU Memory: 0.00 MB
Cached GPU Memory: 0.00 MB


In [3]:
# del gemma2
# del gemma2_sae

t.cuda.empty_cache()


In [58]:
import json

# Read from advbench.json file
with open('../dataset/processed/advbench.json', 'r') as file:
    advbench_data = json.load(file)

len(advbench_data)

# Read from advbench.json file
with open('../dataset/processed/alpaca.json', 'r') as file:
    alpaca_data = json.load(file)

print(len(alpaca_data))

31323


In [84]:
def load_tensor(filename):
    if device == "mps":
        tensor = t.load(filename, map_location="cpu")
        tensor.to(device, dtype=t.float32)
    else:
        tensor = t.load(filename)
    return tensor

def get_second_min(x):
    min_value = t.min(x)
    mask = x != min_value
    second_min_value = t.min(x[mask])

    return second_min_value


In [96]:
layer = 5

sae_name = "gemma-scope-2b-pt-res-canonical"
sae_id = f"layer_{layer}/width_16k/canonical"

sae_act_advbench = load_tensor(f'../data/sae_acts/{sae_name}/{sae_id}_advbench.pt')
sae_act_alpaca = load_tensor(f'../data/sae_acts/{sae_name}/{sae_id}_alpaca_10000.pt')
frac_active = load_tensor(f'../data/sae_acts/{sae_name}/{sae_id}/frac_active.pt').cpu()

frac_active_advbench = (sae_act_advbench > 0).sum(dim=(0)) / sae_act_advbench.shape[0]
frac_active_alpaca = (sae_act_alpaca > 0).sum(dim=(0)) / sae_act_alpaca.shape[0]

frac_active_advbench = t.where(frac_active_advbench == 0, get_second_min(frac_active_advbench)/2, frac_active_advbench)
frac_active_alpaca = t.where(frac_active_alpaca == 0, get_second_min(frac_active_alpaca)/2, frac_active_alpaca)

odds_ratio = frac_active_advbench/frac_active
odds_ratio_alpaca = frac_active_advbench/frac_active_alpaca

  tensor = t.load(filename)


In [97]:
frac_active_alpaca

tensor([1.7000e-03, 1.2000e-03, 2.3500e-02,  ..., 5.0000e-05, 5.0000e-05,
        1.0000e-04])

In [87]:
t.where(frac_active_alpaca == 0, frac_active.min(), frac_active_alpaca)

tensor([0.0040, 0.0010, 0.0340,  ..., 0.0010, 0.0010, 0.0010])

## Let's look at top frac active for advbench

In [98]:
or_alpaca_topk = t.topk(odds_ratio_alpaca, 50)

In [99]:
or_alpaca_topk

torch.return_types.topk(
values=tensor([903.8462, 692.3077, 657.6923, 538.4615, 538.4615, 461.5385, 461.5385,
        346.1539, 326.9231, 326.9231, 307.6923, 307.6923, 307.6923, 307.6923,
        307.6923, 307.6923, 288.4615, 269.2308, 269.2308, 269.2308, 269.2308,
        269.2308, 269.2308, 230.7692, 230.7692, 230.7692, 230.7692, 230.7692,
        230.7692, 230.7692, 230.7692, 230.7692, 230.7692, 211.5385, 211.5385,
        211.5385, 205.1282, 201.9231, 192.3077, 192.3077, 192.3077, 192.3077,
        173.0769, 173.0769, 173.0769, 173.0769, 173.0769, 173.0769, 164.5299,
        163.4615]),
indices=tensor([ 4809, 13709, 15484,  1594,  2620,  2273,  4871,  5817,  4336, 12850,
         8837,  5123, 14277,  2754,  2513, 15713,  4511, 15261,  6815,   773,
        15271,  4760, 10170,  3082,  2930, 14954, 12668,  7089, 15511,  9746,
        13920,  1225,  8399,  4178,   506,  7479, 10801, 10263, 12722,  5790,
        14367,   747, 12594,  2928,  9370,  8975, 13443,  9141,  8901, 12054]))

In [100]:
frac_active_advbench[or_alpaca_topk.indices] 

tensor([0.0904, 0.0346, 0.3288, 0.0269, 0.0269, 0.0231, 0.0231, 0.0173, 0.0327,
        0.0327, 0.0154, 0.0154, 0.0154, 0.0154, 0.0154, 0.0154, 0.0288, 0.0135,
        0.0135, 0.0135, 0.0135, 0.0135, 0.0808, 0.0231, 0.0115, 0.0231, 0.0115,
        0.0115, 0.0115, 0.0115, 0.0115, 0.0231, 0.0115, 0.0423, 0.0212, 0.0212,
        0.0615, 0.0404, 0.0192, 0.0096, 0.0385, 0.0096, 0.0346, 0.0346, 0.0173,
        0.0173, 0.0346, 0.0173, 0.1481, 0.0327])

In [110]:
# Step 1: Create masks for each condition
mask_frac = frac_active_advbench > 0.3
mask_or_alpaca = odds_ratio_alpaca > 10

# Step 2: Combine the masks using logical AND
combined_mask = mask_frac & mask_or_alpaca

# Step 3: Get the indices where both conditions are satisfied
indices = t.nonzero(combined_mask, as_tuple=True)[0]

indices

tensor([ 3069,  3182,  3205,  4776,  5181,  5502,  7114,  9988, 14030, 15484])

In [114]:
frac_active[indices]

tensor([0.0452, 0.0096, 0.0485, 0.0559, 0.0078, 0.0063, 0.0509, 0.0085, 0.0132,
        0.0048])

In [111]:
frac_active_advbench[indices]

tensor([0.3885, 0.4635, 0.6173, 0.3346, 0.4538, 0.4077, 0.5423, 0.3462, 0.3385,
        0.3288])

In [113]:
frac_active_alpaca[indices]

tensor([0.0309, 0.0258, 0.0289, 0.0157, 0.0037, 0.0043, 0.0389, 0.0151, 0.0130,
        0.0005])

In [112]:
odds_ratio_alpaca[indices]

tensor([ 12.5716,  17.9636,  21.3601,  21.3131, 122.6611,  94.8122,  13.9411,
         22.9241,  26.0355, 657.6923])

Index 15484