### Load in the SAEs from Huggingface Hub.

In [None]:
def install_dependencies():
    ! rm -rf sae || True
    ! git clone https://github.com/amirabdullah19852020/sae.git
    ! cd sae && pip install .
    ! git clone https://github.com/withmartian/TinySQL.git
    ! cd TinySQL && pip install .

install_dependencies()

In [1]:
import json
import os

os.environ["SAE_DISABLE_TRITON"] = "1"

import psutil
import re

from copy import deepcopy
from dataclasses import dataclass
from IPython.display import display, HTML
from typing import Callable
from math import ceil
from pathlib import Path

import nnsight
import numpy as np
import plotly.graph_objects as go
import sae
import torch
import torch.fx

from datasets import load_dataset
from huggingface_hub import snapshot_download
import matplotlib.pyplot as plt
from nnsight import NNsight, LanguageModel
from plotly.subplots import make_subplots
from sae import Sae
from sae.sae_interp import GroupedSaeOutput, SaeOutput, SaeCollector, LoadedSAES, sql_tagger
from sae.sae_plotting import plot_layer_curves, plot_layer_features

from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from TinySQL import sql_interp_model_location
from TinySQL.training_data.fragments import field_names, table_names

Triton disabled, using eager implementation of SAE decoder.


In [None]:
# Get the current process
def process_info():
    process = psutil.Process(os.getpid())
    
    # Memory usage in MB
    memory_info = process.memory_info()
    print(f"RSS: {memory_info.rss / (1024 ** 2):.2f} MB")  # Resident Set Size
    print(f"VMS: {memory_info.vms / (1024 ** 2):.2f} MB") 

process_info()

In [None]:
seed=42
repo = "sql_interp_saes"

In [None]:
repo_name = "withmartian/sql_interp_saes"
cache_dir = "working_directory"

syn=True

full_model_name = sql_interp_model_location(model_num=1, cs_num=1, synonym=syn)
model_alias = f"saes_{full_model_name.split('/')[1]}_syn={syn}"
print(model_alias)

# Change this to work with another model alias.
seed = 42

process_info()

In [None]:
repo_path = Path(
    snapshot_download(repo_name, allow_patterns=f"{model_alias}/*", local_dir=cache_dir)
)

In [None]:
cache_dir

In [None]:
def format_example(example):
    alpaca_prompt = "### Instruction: {} ### Context: {} ### Response: {}"
    example['prompt'] = alpaca_prompt.format(example['english_prompt'], example['create_statement'], example['sql_statement'])
    example['response'] = example['sql_statement']
    return example

In [None]:
loaded_saes = LoadedSAES.load_from_path(
    model_alias=model_alias, k=128, cache_dir=cache_dir, dataset_mapper=format_example, 
    store_activations=False, function_tagger=sql_tagger
)

In [None]:
sae_collector = SaeCollector(loaded_saes, seed=seed, sample_size=100, averaged_representations_only=False)
# sae_collector.get_texts()

### Maximally activating latents

In [None]:
def compute_and_sort_weights(acts, indices):
    """
    Compute the summed weights of each index and sort them in descending order.

    Parameters:
    acts (list of list of float): Nested list of scores.
    indices (list of list of int): Nested list of indices corresponding to scores.

    Returns:
    list of tuple: Sorted elements by summed weights in descending order.
    """
    # Dictionary to store summed weights for each index
    weights = {}
    numel = 0

    for act_row, idx_row in zip(acts, indices):
        numel+=1
        for score, idx in zip(act_row, idx_row):
            weights[idx] = weights.get(idx, 0) + score

    for element in weights:
        weights[element]/=(numel or 1)

    # Sort by summed weight in descending order
    sorted_weights = sorted(weights.items(), key=lambda x: x[1], reverse=True)

    return sorted_weights

In [None]:
single_element = sae_collector.encoded_set[4]["encoding"]
single_sae_output = single_element.sae_outputs_by_layer['transformer.h.0.attn']

In [None]:
print(single_sae_output)

In [None]:
#print(single_element.text)
print(single_element.search_indices_with_tag("TABLE"))
single_element.tags_by_index

In [None]:
help(single_sae_output)

In [None]:
language_model = loaded_saes.language_model
one_prompt = sae_collector.mapped_dataset['train']['prompt'][0]
tokenizer = loaded_saes.tokenizer

In [None]:
inputs = tokenizer(one_prompt, return_tensors="pt")

with language_model.trace() as tracer:
    with tracer.invoke(inputs) as invoker:
        layer_output = language_model.transformer.h[0].output[0].save()

layer_output

# model.transformer.h[layer_idx].output = (modified_output,) + model.transformer.h[layer_idx].output[1:]
# final_output = model.lm_head.output.argmax(dim=-1).save()

In [None]:
def get_log_probs(language_model, tokenizer, text):
    tokenized = tokenizer(text, return_tensors="pt").to("cuda")
    input_ids = tokenized["input_ids"].cpu()

    tokens = tokenizer.tokenize(one_prompt)
    simple_tokens = [token.replace("Ġ", "").lower() for token in tokens]
    response_index = simple_tokens.index("response")
    input_len = len(tokens)

    with torch.no_grad():
        logits = language_model._model(**tokenized).logits.cpu()
        logprobs = F.log_softmax(logits, dim=-1).squeeze(0)
        correct_logprobs = logprobs[torch.arange(input_len), input_ids][0]
        
        response_logprobs = correct_logprobs[response_index:]
        num_element = response_logprobs.numel()
        total_logprobs = response_logprobs.sum()/num_element

    return total_logprobs

In [None]:
loaded_saes.language_model

In [None]:
tag = "RESPONSE_FIELD"

def get_sorted_weights_by_layer(sae_collector, tag):
    results = sae_collector.get_all_sae_outputs_for_tag(tag)
    aggregated_sae_features = {}
    layers = sae_collector.layers
    for layer in layers:
        all_top_acts = []
        all_top_indices = []
        for element in tqdm(results):
            all_top_acts.extend(element[layer].top_acts)
            all_top_indices.extend(element[layer].top_indices)

    
        sorted_weights = compute_and_sort_weights(all_top_acts, all_top_indices)
        aggregated_sae_features[layer] = {"top_acts": all_top_acts, "top_indices": all_top_indices, "sorted_weights": sorted_weights}
    return aggregated_sae_features


sorted_weights = get_sorted_weights_by_layer(sae_collector, tag)

In [None]:
plot_layer_features(sorted_weights, tag, top_n=5)

**Get the maximum weight of a feature in an element.**

In [None]:
element = sae_collector.encoded_set[2]["encoding"]

In [None]:
target_feature = 11707
target_layer = "transformer.h.0.attn"

In [None]:
single_element = element.sae_outputs_by_layer[target_layer]

In [None]:
element.get_max_weight_of_feature(target_layer, target_feature)

In [None]:
activations = loaded_saes.map_to_attention_head(layer_name=target_layer, feature_num=target_feature)

In [None]:
texts, visualizations = sae_collector.get_maximally_activating_datasets(target_layer, target_feature)

In [None]:
print(texts[4][0]["encoding"].tags_by_index)
texts[4][0]["encoding"].text

In [None]:
for visualization in visualizations:
    display(visualization)

### Monitor reconstruction Errors.

In [None]:
reconstruction_error_by_k_and_layer = sae_collector.get_avg_reconstruction_error_for_all_k_and_layers()

In [None]:
# reconstruction_error_by_k_and_layer

In [None]:
plot_layer_curves(reconstruction_error_by_k_and_layer)

### Monitor Ablation Errors.

In [None]:
def zero_heads(model, prompt_text, target_layer):
    N_HEADS = 16
    inputs = model.tokenizer(prompt_text, return_tensors="pt")

    with model.trace() as tracer:
        with tracer.invoke(inputs) as invoker:
            for layer_idx in target_layers:
                layer_output = model.transformer.h[layer_idx].output[0]
                target_heads = heads_per_layer[layer_idx]

                output_reshaped = einops.rearrange(
                    layer_output,
                    'b s (nh dh) -> b s nh dh',
                    nh=N_HEADS
                )

                for head_idx in range(N_HEADS):
                    if head_idx not in target_heads:
                        output_reshaped[:, :, head_idx, :] = 0

                modified_output = einops.rearrange(
                    output_reshaped,
                    'b s nh dh -> b s (nh dh)',
                    nh=N_HEADS
                )

                model.transformer.h[layer_idx].output = (modified_output,) + model.transformer.h[layer_idx].output[1:]

            final_output = model.lm_head.output.argmax(dim=-1).save()

    print("Modified Output:", model.tokenizer.decode(final_output[0][-1]))
    return final_output

In [None]:
def zero_heads(model, prompt_text, target_layer):
    N_HEADS = 16
    inputs = model.tokenizer(prompt_text, return_tensors="pt")

    with model.trace() as tracer:
        with tracer.invoke(inputs) as invoker:
            for layer_idx in target_layers:
                layer_output = model.transformer.h[layer_idx].output[0]
                target_heads = heads_per_layer[layer_idx]

                output_reshaped = einops.rearrange(
                    layer_output,
                    'b s (nh dh) -> b s nh dh',
                    nh=N_HEADS
                )

                for head_idx in range(N_HEADS):
                    if head_idx not in target_heads:
                        output_reshaped[:, :, head_idx, :] = 0

                modified_output = einops.rearrange(
                    output_reshaped,
                    'b s nh dh -> b s (nh dh)',
                    nh=N_HEADS
                )

                model.transformer.h[layer_idx].output = (modified_output,) + model.transformer.h[layer_idx].output[1:]

            final_output = model.lm_head.output.argmax(dim=-1).save()

    print("Modified Output:", model.tokenizer.decode(final_output[0][-1]))
    return final_output