In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
from tqdm import tqdm
import plotly.express as px

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000

torch.set_grad_enabled(False);

[1718238558.786508] [b32dd02656d7:458961:f]        vfs_fuse.c:281  UCX  ERROR inotify_add_watch(/tmp) failed: No space left on device


In [2]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(f"Device: {device}")

Device: cuda


In [3]:
def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    if not(COLAB):
        webbrowser.open(filename);

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)

        PORT += 1

In [4]:
import os
import sys

# Step 1: Determine the current working directory in an IPython environment
current_directory = os.getcwd()
print(f"Current directory: {current_directory}")

# Step 2: Construct the absolute path to your local sae_lens directory
local_sae_lens_path = os.path.join(current_directory, "SAELens")
print(f"Local sae_lens path: {local_sae_lens_path}")

# Step 3: Check if the path exists
if not os.path.exists(local_sae_lens_path):
    print(f"Path does not exist: {local_sae_lens_path}")
else:
    print(f"Path exists: {local_sae_lens_path}")

# Step 4: Append this path to sys.path to prioritize the local version
sys.path.insert(0, local_sae_lens_path)
print(f"sys.path: {sys.path[:5]}")  # Print only first 5 paths for brevity

# Step 5: Now you can safely import the necessary modules
try:
    from sae_lens import SAE
    print("Imported SAE successfully.")
except ModuleNotFoundError as e:
    print(f"Failed to import SAE: {e}")

# Step 6: Other imports
from datasets import load_dataset  
from transformer_lens import HookedTransformer

# Ensure you set the device (e.g., 'cpu' or 'cuda') before using it
device = "cuda" if torch.cuda.is_available() else "cpu"

model = HookedTransformer.from_pretrained("mistral-7b", device=device)


Current directory: /mnt/pccfs2/backed_up/max4c/home/my-research-code/sae-truth-analysis
Local sae_lens path: /mnt/pccfs2/backed_up/max4c/home/my-research-code/sae-truth-analysis/SAELens
Path exists: /mnt/pccfs2/backed_up/max4c/home/my-research-code/sae-truth-analysis/SAELens
sys.path: ['/mnt/pccfs2/backed_up/max4c/home/my-research-code/sae-truth-analysis/SAELens', '/usr/lib/python310.zip', '/usr/lib/python3.10', '/usr/lib/python3.10/lib-dynload', '']


  from .autonotebook import tqdm as notebook_tqdm


Imported SAE successfully.


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.65s/it]


Loaded pretrained model mistral-7b into HookedTransformer


In [5]:
# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience. 
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="mistral-7b-res-wg",  # see other options in sae_lens/pretrained_saes.yaml
    sae_id="blocks.8.hook_resid_pre",  # won't always be a hook point
    device=device
)

In [6]:
sae.cfg

SAEConfig(d_in=4096, d_sae=65536, activation_fn_str='relu', apply_b_dec_to_input=False, finetuning_scaling_factor=False, context_size=256, model_name='mistral-7b', hook_name='blocks.8.hook_resid_pre', hook_layer=8, hook_head_index=None, prepend_bos=False, dataset_path='monology/pile-uncopyrighted', normalize_activations=False, dtype='float32', device='cuda', sae_lens_training_version=None)

In [8]:
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset(
    path = "NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

token_dataset = tokenize_and_concatenate(
    dataset= dataset,# type: ignore
    tokenizer = model.tokenizer, # type: ignore
    streaming=True,
    max_length=sae.cfg.context_size,
    add_bos_token=sae.cfg.prepend_bos,
)

Map: 100%|██████████| 10000/10000 [00:12<00:00, 819.31 examples/s]


In [9]:
sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads

with torch.no_grad():
    # activation store can give us tokens.
    batch_tokens = token_dataset[:32]["tokens"]
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

    # Use the SAE
    feature_acts = sae.encode(cache[sae.cfg.hook_name])
    sae_out = sae.decode(feature_acts)

    # save some room
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()

OutOfMemoryError: CUDA out of memory. Tried to allocate 256.00 MiB. GPU 