# Loading and Analysing Pre-Trained Sparse Autoencoders

## Imports & Installs

In [None]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import plotly.express as px
import pandas as pd
import json
import numpy as np
import math
import gc
import pandas as pd
import random
import shutil

from functools import partial
from tqdm import tqdm
from faker import Faker

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000

import torch
torch.set_grad_enabled(False);
from openai import AzureOpenAI
from datasets import load_dataset  
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig

import transformer_lens
from transformer_lens import utils
from transformer_lens import HookedTransformer
from transformer_lens.utils import tokenize_and_concatenate

from sae_lens import SAE
from sae_lens.config import DTYPE_MAP, LOCAL_SAE_MODEL_PATH
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

from causallearn.search.ConstraintBased.FCI import fci
from causallearn.search.ConstraintBased.PC import pc
from causallearn.search.ScoreBased.ExactSearch import bic_exact_search
from causallearn.utils.cit import kci
from causallearn.utils.GraphUtils import GraphUtils


## Set Up

In [None]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

In [None]:
def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    if not(COLAB):
        webbrowser.open(filename);

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)

        PORT += 1

# Loading a pretrained Sparse Autoencoder

Below we load a Transformerlens model, a pretrained SAE and a dataset from huggingface.

from vllm import LLM, SamplingParams

class VLLMGenerator:
    def __init__(self, model_path):
        self.model_path = model_path

    def __call__(self, prompt, sample_size):
        sampling_params = SamplingParams(n=sample_size, best_of=sample_size, temperature=1.1, top_p=0.95)
        llm = LLM(model=self.model_path, gpu_memory_utilization=0.3)
        outputs = llm.generate(prompt, sampling_params)
        res = []
        for output in outputs:
            res.append(
                {
                    "prompt": output.prompt,
                    "output": [response.text for response in output.outputs],
                }
            )
        return res

In [None]:
base_model = 'GEMMA-2B-IT'
#base_model = 'GEMMA-2B-CHN'
#base_model = 'GPT2-SMALL'
#base_model = 'MISTRAL-7B'
#base_model = 'LLAMA3-8B-IT'
#base_model = 'LLAMA3-8B-IT-HELPFUL'
#base_model = 'LLAMA3-8B-IT-CHN'
#base_model = 'LLAMA3-8B-IT-FICTION'

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    #bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
    )

if base_model == 'GPT2-SMALL':
    hf_model = AutoModelForCausalLM.from_pretrained(os.path.join(LOCAL_SAE_MODEL_PATH, "openai-community", "gpt2"))
    hf_tokenizer = AutoTokenizer.from_pretrained(os.path.join(LOCAL_SAE_MODEL_PATH, "openai-community", "gpt2"), padding_side='left')

elif base_model == 'GEMMA-2B':
    hf_model = AutoModelForCausalLM.from_pretrained(os.path.join(LOCAL_SAE_MODEL_PATH, "google", "gemma-2b"))
    hf_tokenizer = AutoTokenizer.from_pretrained(os.path.join(LOCAL_SAE_MODEL_PATH, "google", "gemma-2b"), padding_side='left')

elif base_model == 'GEMMA-2B-IT':
    hf_model = AutoModelForCausalLM.from_pretrained(os.path.join(LOCAL_SAE_MODEL_PATH, "google", "gemma-2b-it"))
    hf_tokenizer = AutoTokenizer.from_pretrained(os.path.join(LOCAL_SAE_MODEL_PATH, "google", "gemma-2b-it"), padding_side='left')
    #vllm_generator = VLLMGenerator(os.path.join(LOCAL_SAE_MODEL_PATH, "google", "gemma-2b-it")) 
    
elif base_model == 'GEMMA-2B-CHN':
    hf_model = AutoModelForCausalLM.from_pretrained(os.path.join(LOCAL_SAE_MODEL_PATH, "ccrains", "larson-gemma-2b-chinese-v0.1/"))
    hf_tokenizer = AutoTokenizer.from_pretrained(os.path.join(LOCAL_SAE_MODEL_PATH, "ccrains", "larson-gemma-2b-chinese-v0.1/"), padding_side='left')

elif base_model == 'MISTRAL-7B':
    hf_model = AutoModelForCausalLM.from_pretrained(os.path.join(LOCAL_SAE_MODEL_PATH, "mistralai", "Mistral-7B-v0.1/"))
    hf_tokenizer = AutoTokenizer.from_pretrained(os.path.join(LOCAL_SAE_MODEL_PATH, "mistralai", "Mistral-7B-v0.1/"), padding_side='left')

elif base_model == 'LLAMA3-8B-IT':
    hf_model = AutoModelForCausalLM.from_pretrained(os.path.join(LOCAL_SAE_MODEL_PATH, "meta-llama", "Meta-Llama-3-8B-Instruct"), quantization_config=bnb_config)
    hf_tokenizer = AutoTokenizer.from_pretrained(os.path.join(LOCAL_SAE_MODEL_PATH, "meta-llama", "Meta-Llama-3-8B-Instruct"), padding_side='left')
    hf_model.resize_token_embeddings(len(hf_tokenizer))
    
elif base_model == 'LLAMA3-8B-IT-HELPFUL':
    hf_model = AutoModelForCausalLM.from_pretrained(os.path.join(LOCAL_SAE_MODEL_PATH, "meta-llama", "meta-llama-3-8b-instruct-helpfull"), quantization_config=bnb_config)
    hf_tokenizer = AutoTokenizer.from_pretrained(os.path.join(LOCAL_SAE_MODEL_PATH, "meta-llama", "meta-llama-3-8b-instruct-helpfull"), padding_side='left')
    hf_model.resize_token_embeddings(len(hf_tokenizer))

elif base_model == 'LLAMA3-8B-IT-FICTION':
    hf_model = AutoModelForCausalLM.from_pretrained(os.path.join(LOCAL_SAE_MODEL_PATH, "meta-llama", "Meta-Llama-3-8B-Instruct_fictional_chinese_v1"), quantization_config=bnb_config)
    hf_tokenizer = AutoTokenizer.from_pretrained(os.path.join(LOCAL_SAE_MODEL_PATH, "meta-llama", "Meta-Llama-3-8B-Instruct_fictional_chinese_v1"), padding_side='left')
    hf_model.resize_token_embeddings(len(hf_tokenizer))

elif base_model == 'LLAMA3-8B-IT-CHN':
    hf_model = AutoModelForCausalLM.from_pretrained(os.path.join(LOCAL_SAE_MODEL_PATH, "hfl", "llama-3-chinese-8b-instruct-v3/"), quantization_config=bnb_config)
    hf_tokenizer = AutoTokenizer.from_pretrained(os.path.join(LOCAL_SAE_MODEL_PATH, "hfl", "llama-3-chinese-8b-instruct-v3/"), padding_side='left')
    hf_model.resize_token_embeddings(len(hf_tokenizer))

os.environ["TOKENIZERS_PARALLELISM"] = "false"
if hf_tokenizer.pad_token is None:
    hf_tokenizer.pad_token = hf_tokenizer.eos_token
        

# prompt = "GPT2 is a model developed by OpenAI."
# input_ids = hf_tokenizer(prompt, return_tensors="pt").input_ids
# #input_ids = hf_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# gen_tokens = hf_model.generate(
#     input_ids,
#     do_sample=True,
#     temperature=0.9,
#     max_length=100,
# )
# gen_text = hf_tokenizer.batch_decode(gen_tokens)[0]
# vllm_generator(prompt, 5)

if base_model == 'GPT2-SMALL':
    model = HookedTransformer.from_pretrained("gpt2-small", tokenizer=hf_tokenizer, hf_model=hf_model, default_padding_side='left', device=device)
    #model = HookedTransformer.from_pretrained("gpt2-small", device = device)

    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
        sae_id = "blocks.8.hook_resid_pre", # won't always be a hook point
        device = device
    )
elif base_model == 'GEMMA-2B':
    model = HookedTransformer.from_pretrained("gemma-2b", tokenizer=hf_tokenizer, hf_model=hf_model, default_padding_side='left', device=device)
    sae, original_cfg_dict, sparsity = SAE.from_pretrained(
        release="gemma-2b-res-jb",
        sae_id="blocks.12.hook_resid_post",
        device= device,
    )
elif base_model == 'GEMMA-2B-IT':
    model = HookedTransformer.from_pretrained("gemma-2b-it", tokenizer=hf_tokenizer, hf_model=hf_model, default_padding_side='left', device=device)
    sae, original_cfg_dict, sparsity = SAE.from_pretrained(
        release="gemma-2b-it-res-jb",
        sae_id="blocks.12.hook_resid_post",
        device= device,
    )
elif base_model == 'MISTRAL-7B':
    model = HookedTransformer.from_pretrained("mistral-7b", tokenizer=hf_tokenizer, hf_model=hf_model, default_padding_side='left', device=device)
    sae, original_cfg_dict, sparsity = SAE.from_pretrained(
        release="mistral-7b-res-wg",
        sae_id="blocks.16.hook_resid_pre",
        device= device,
    )
elif base_model in ['LLAMA3-8B-IT', 'LLAMA3-8B-IT-HELPFUL', 'LLAMA3-8B-IT-FICTION', 'LLAMA3-8B-IT-CHN', 'GEMMA-2B-CHN']:
    model = pipeline("text-generation", model=hf_model, tokenizer=hf_tokenizer)
    sae = None
else:
    raise ValueError(f"Unknown model: {base_model}")

In [None]:
dataset = load_dataset(
    path = os.path.join(LOCAL_SAE_MODEL_PATH, "NeelNanda/pile-10k"),
    split="train",
    streaming=False,
)

if sae:
    token_dataset = tokenize_and_concatenate(
        dataset= dataset,# type: ignore
        tokenizer = model.tokenizer, # type: ignore
        streaming=True,
        max_length=sae.cfg.context_size,
        add_bos_token=sae.cfg.prepend_bos,
    )
else:
    token_dataset = tokenize_and_concatenate(
        dataset= dataset,# type: ignore
        tokenizer = model.tokenizer, # type: ignore
        streaming=True,
    )

In [None]:
NUM_PLAYERS = 250 #'PREDEFINED'
NUM_VALUE_DIM = 'ALL'#'ALL', 'SMALLSET', 100
ALLOW_UNSURE_ANSWER = False
MAX_QUESTIONS_PER_BATCH = 8
SCORE_GRANULARITY = 'value' # 'question'
GENERATE_NEW_PLAYERS = False
SAE_FEATURE_SOURCE = 'COLLECT' # 'FIX', 'COLLECT'
#FIXED_SAE_FEATURES = [13562]


def generate_new_player():
    fake = Faker()
    fake_profile = fake.profile()
    name = fake_profile['name']
    gender_map = lambda x: 'female' if x == 'F' else 'male' if x == 'M' else 'unknown'
    gender = gender_map(fake_profile['sex'])
    job = fake_profile['job']
    responsibility = random.choice(['low', 'medium', 'high'])
    aggression = random.choice(['low', 'medium', 'high'])
    trait = f'Gender: {gender}; Responsibility: {responsibility}; Aggression: {aggression}; Job: {job}'

    client = AzureOpenAI(
        api_key=os.environ.get("OPENAI_API_KEY"),
        api_version="2024-02-01",
        azure_endpoint=os.environ.get("OPENAI_BASE_URL"),
        #base_url=os.environ.get("OPENAI_BASE_URL"),
    )

    chat_completion = client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": f"Generate a short bio for {name} with the following traits: {trait}",
            }
        ],
        model="gpt-35-turbo-0125"#"gpt-4",
    )
    bio = chat_completion.choices[0].message.content
    trait = f'Gender: {gender}; Job: {job}; DOB: {fake_profile["birthdate"]}; bio: {bio}'
    return name, trait

def generate_new_players(players_file):
    if type(NUM_PLAYERS) == int:    
        players = {}
        while len(players) < NUM_PLAYERS:
            name, trait = generate_new_player()
            if name in players.keys():
                continue
            players[name] = {'trait': trait}
        # Save players in json
        with open(players_file, 'w') as file:
            json.dump(players, file)
        return players
    else:
        assert NUM_PLAYERS == 'PREDEFINED'
        players = {
            'John': {'trait': 'playerbio: John is a 35-year old man, who has been abused by his parents since childhood.'},
            'Mary': {'trait': 'playerbio: Mary is a 25-year old woman, who has been well treadted by her parents since childhood.'},
            'Alice': {'trait': 'playerbio: Alice is a 30-year old woman, who has been well trained harshly to be an elite by her parents since childhood.'},
            'Bob': {'trait': 'playerbio: Bob is a 40-year old man, who was born in a millionare family.'},
            'Eve': {'trait': 'playerbio: Eve is a 20-year old woman, who was born in a poor family.'},
            'David': {'trait': 'playerbio: David is a 45-year old man, who was born in a middle-class family.'},
            'Lucy': {'trait': 'playerbio: Lucy is a 35-year old woman, who was born and raised in a small town.'},
            'Tom': {'trait': 'playerbio: Tom is a 30-year old man, who was born and raised in a big city.'},
            'Sara': {'trait': 'playerbio: Sara is a 25-year old woman, who was born and raised in a village.'},
            'Peter': {'trait': 'playerbio: Peter is a 40-year old man, who was born and raised in a farm.'},
            'Lily': {'trait': 'playerbio: Lily is a 20-year old woman, who was born and raised in a forest.'},
            'Paul': {'trait': 'playerbio: Paul is a 45-year old man, whose parents died when he was 10.'},
            'Jenny': {'trait': 'playerbio: Jenny is a 35-year old woman, who was raised by her grandparents.'},
        }
        with open(players_file, 'w') as file:
            json.dump(players, file)
        return players

def restore_players(players_file):    
    with open(players_file, 'r') as file:
        players = json.load(file)
    return players

if GENERATE_NEW_PLAYERS:
    generate_new_players('players' + str(NUM_PLAYERS) + '.json')


In [None]:
def generate_question(name, trait, a, q, qi, allow_unsure):
    instruct = "Give intuitive and quick Answer to the Question from the Context. Just choose an answer from 'yes or 'no'. No other explanations." 
    #instruct = "Give intuitive and quick Answer to the Question from the Context. Just choose an answer from 'yes', 'no' and 'unsure'. No other explanations." 
    #instruct = "Answer the following Question using 'yes' or 'no'. No other explanations." 
    
    bio = trait.split('bio: ')[-1]
    context = 'You are role-playing according to the following bio: {bio}'
    context = context.format(bio=bio)
    
    ques = 'Will you agree with the following self-statement of value? "{qi}"'
    #ques = 'Do you agree with the following statement? "{qi}"'
    ques = ques.format(qi=qi)
    
    template = "Instruction:\n{instruction}\n\nContext:\n{context}\n\nQuestion:\n{question}\n\nAnswer:\n{answer}"
    
    prompt = template.format(
        instruction = instruct,
        context = context,
        question = ques,
        answer = ""
    )
    # context = context.format(name=name, trait=trait)
    # context = 'You are playing a role of a person named {name} ({trait}).'
    # #context = context.format(name=name, trait=trait)
    # trait = trait.split('bio: ')[0]

    #'Just respond with a real number between -1 and 1, -1 representing totally disagreement and 1 representing totally agreement. No other notation marks or explanations.'
    #question = f'Here is a short bio of {name}. {trait}. Will {name} agree with the following statement of value? "{q}" Just respond yes or no.'  
    #question = f'Do you agree with the following statement of value? "{q}" Just respond yes or no.'  
    #question = f'Hi are you a robot?'
    #question = f'{q}'  
    #question = (f'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a chatbot who may chat with the user or do tasks like role play according to the instructions.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nYou are playing the role of a person named {name} ({trait}). Do agree with the following self-statement of value? "{qi}"  Just reply yes or no. No other explanations.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n')
    #question = (f'You are a chatbot who may chat with the user or do tasks like role play according to the instructions.\n User: You are  a person named {name} ({trait}). Do agree with the following self-statement of value? "{qi}"  Just reply yes or no. No other explanations.\nResponse:')
    #question = (f'You are a chatbot who may chat with the user or do tasks like role play according to the instructions. User: Will {name} agree with the following sentence as its self-statement? "{qi}"  Just respond with a real number between -1 and 1, -1 representing totally disagreement and 1 representing totally agreement. No other notation marks, comma, full stop or explanations.\nResponse:')
    #question = f"values: {q} Yes or No?\nResponse:"
    #question = (f'User: You are playing a role of a person named {name} ({trait}). Will you agree with the following self-statement of value? "{qi}"  Just reply yes or no. No other explanations.\nResponse:')
    #prompt = f"User: My name is {name}, {trait}. I am questioning myself on my values: {q} Yes or No?\nResponse:"
    
    return prompt

def generate_question_norole(a, q, qi, allow_unsure):
    instruct = "Answer the following Question using 'yes' or 'no' according to your values. No other explanations." 
    template = "Instruction:\n{instruction}\n\nQuestion:\n{question}\n\nAnswer:\n{answer}"
    prompt = template.format(instruction = instruct, question = q,answer = "")
    return prompt

In [None]:
players = restore_players('players' + str(NUM_PLAYERS) + '.json')

In [None]:
df_valuebench = pd.read_csv(os.path.join(LOCAL_SAE_MODEL_PATH, 'value_data/value_orientation.csv'))
grouped = df_valuebench.groupby('value')
assert NUM_VALUE_DIM == 'ALL'

'''
if NUM_VALUE_DIM != 'ALL':
    if NUM_VALUE_DIM == 'SMALLSET':
        #smallset = ['Achievement', 'Benevolence', 'Conformity', 'Hedonism', 'Power', 'Security', 'Self-Direction', 'Stimulation', 'Tradition', 'Universalism']
        #smallset = ['Power Distance', 'Individualism', 'Uncertainty Avoidance', 'Masculinity',  'Long Term Orientation', 'Indulgence', 'Corruption', 'Economic Vals', 'Ethical Vals', 'Migration', 'Political Cul', 'Political Int', 'Science', 'Feminist',]
        smallset = ['Power Distance', 'Individualism', 'Uncertainty Avoidance', 'Masculinity',  'Long Term Orientation', 'Indulgence', 'Economic', 'Political', 'Scientific Understanding', 'Achievement', 'Benevolence', 'Conformity', 'Hedonism', 'Security', 'Self-Direction', 'Stimulation', 'Tradition', 'Universalism']
        grouped = [group for group in grouped if group[0] in smallset]
    else:
        grouped = random.sample(list(grouped), NUM_VALUE_DIM)
'''

In [None]:
if base_model == 'GPT2-SMALL':
    answer_valuebench_features_csv = 'answers_valuebench_features_gpt2small' + '_players'+ str(NUM_PLAYERS) + '_valuedims' + str(NUM_VALUE_DIM) +'.csv'
elif base_model == 'GEMMA-2B-IT':
    answer_valuebench_features_csv = 'answers_valuebench_features_gemma2bit' + '_players'+ str(NUM_PLAYERS) + '_valuedims' + str(NUM_VALUE_DIM) +'.csv'
elif base_model == 'GEMMA-2B':
    answer_valuebench_features_csv = 'answers_valuebench_features_gemma2b' + '_players'+ str(NUM_PLAYERS) + '_valuedims' + str(NUM_VALUE_DIM) +'.csv'
elif base_model == 'GEMMA-2B-CHN':
    answer_valuebench_features_csv = 'answers_valuebench_features_gemma2bchn' + '_players'+ str(NUM_PLAYERS) + '_valuedims' + str(NUM_VALUE_DIM) +'.csv'
elif base_model == 'MISTRAL-7B':
    answer_valuebench_features_csv = 'answers_valuebench_features_mistral7b' + '_players'+ str(NUM_PLAYERS) + '_valuedims' + str(NUM_VALUE_DIM) +'.csv'
elif base_model == 'LLAMA3-8B-IT':
    answer_valuebench_features_csv = 'answers_valuebench_features_llama38bit' + '_players'+ str(NUM_PLAYERS) + '_valuedims' + str(NUM_VALUE_DIM) +'.csv'
elif base_model == 'LLAMA3-8B-IT-HELPFUL':
    answer_valuebench_features_csv = 'answers_valuebench_features_llama38bithelp' + '_players'+ str(NUM_PLAYERS) + '_valuedims' + str(NUM_VALUE_DIM) +'.csv'
elif base_model == 'LLAMA3-8B-IT-FICTION':
    answer_valuebench_features_csv = 'answers_valuebench_features_llama38bitfiction' + '_players'+ str(NUM_PLAYERS) + '_valuedims' + str(NUM_VALUE_DIM) +'.csv'
elif base_model == 'LLAMA3-8B-IT-CHN':
    answer_valuebench_features_csv = 'answers_valuebench_features_llama38bitchn' + '_players'+ str(NUM_PLAYERS) + '_valuedims' + str(NUM_VALUE_DIM) +'.csv'
else:
    raise ValueError('Invalid base model')

if ALLOW_UNSURE_ANSWER:
    answer_valuebench_features_csv = answer_valuebench_features_csv.replace('.csv', '_unsure.csv')
if sae:
    answer_valuebench_features_csv = answer_valuebench_features_csv.replace('.csv', '_sae.csv')
if SAE_FEATURE_SOURCE == 'FIX':
    answer_valuebench_features_csv = answer_valuebench_features_csv.replace('.csv', '_featurefix.csv')

answer_valuebench_features_csv = answer_valuebench_features_csv.replace('.csv', '_' + SCORE_GRANULARITY + '.csv')

In [None]:
feature_intersections = []
stds = []

with torch.no_grad(): 
    player_count = 0
    for name, char in players.items():
        player_count += 1
        print(f'Processing player {player_count} of {NUM_PLAYERS}')
        
        trait = char['trait']
        for value_name, group in grouped:
            groupagreementall = group['agreement']
            groupquestionall = group['question']
            groupitemall = group['item']

            scores = []
            question_batch_no = math.ceil(len(groupagreementall) / MAX_QUESTIONS_PER_BATCH)
            for qbn in range(question_batch_no):
                groupagreement = groupagreementall[qbn * MAX_QUESTIONS_PER_BATCH : (qbn+1) * MAX_QUESTIONS_PER_BATCH]
                groupquestion = groupquestionall[qbn * MAX_QUESTIONS_PER_BATCH : (qbn+1) * MAX_QUESTIONS_PER_BATCH]
                groupitem = groupitemall[qbn * MAX_QUESTIONS_PER_BATCH : (qbn+1) * MAX_QUESTIONS_PER_BATCH]

                questions = []
                answers = []

                for groupmember in zip(groupagreement, groupquestion, groupitem):
                    a = groupmember[0]
                    q = groupmember[1]
                    qi = groupmember[2]

                    prompt = generate_question(name, trait, a, q, qi, ALLOW_UNSURE_ANSWER)
                    #prompt = generate_question_norole(a, q, qi, ALLOW_UNSURE_ANSWER)
                    questions.append(prompt)
                    answer = a
                    answers.append(answer)
                
                while True:
                    gen_answers = []        
                    if sae:
                        questions_tokens = model.tokenizer(questions, return_tensors="pt", padding=True, truncation=True)
                        gen_tokens = model.generate(questions_tokens.input_ids, max_new_tokens=20, verbose=False)
                        # gen_tokens_f = model.forward(questions_tokens.input_ids, return_type='logits', attention_mask=questions_tokens.attention_mask)
                        # gen_tokens_hf = hf_model.generate(questions_tokens.input_ids, attention_mask=questions_tokens.attention_mask, max_new_tokens=10)
                        gen_texts = model.tokenizer.batch_decode(gen_tokens)
                        #gen_texts_f = model.tokenizer.batch_decode(gen_tokens_f[:,-1,:].max(dim=-1).indices)
                        questions_padded = model.tokenizer.batch_decode(questions_tokens.input_ids)
                        for question, gen_text in zip(questions_padded, gen_texts):
                            gen_answer = gen_text.lower()[len(question):].strip()
                            gen_answers.append(gen_answer)
                    else:
                        gen_texts = model(questions, max_new_tokens=5)
                        gen_answers = [gen_text[0]["generated_text"][len(question):].lower() for question, gen_text in zip(questions, gen_texts)]

                    if all([gen_answer.startswith('yes') or gen_answer.startswith('no') or (ALLOW_UNSURE_ANSWER and gen_answer.startswith('unsure')) for gen_answer in gen_answers]):
                        break
                    
                for ga, answer in zip(gen_answers, answers):
                    if ga.startswith('yes'):
                        scores.append(answer)
                    elif ga.startswith('no'):
                        scores.append(-answer)
                    elif ALLOW_UNSURE_ANSWER and ga.startswith('unsure'):
                        scores.append(0)
                    else:
                        raise ValueError('Invalid answer')
            assert len(scores) == len(groupagreementall)
            # print(list(zip(group['question'],scores)))
            # print(list(zip(group['question'],group['agreement'])))

            if SCORE_GRANULARITY == 'question':
                for q, s in zip(groupquestionall, scores):
                    players[name][q] = s
            elif SCORE_GRANULARITY == 'value':
                players[name][value_name] = sum(scores) / len(scores)
                print(value_name, scores, 'std: ', np.std(scores))
                stds.append(np.std(scores))
            else:
                raise ValueError('Invalid SCORE_GRANULARITY')
            
        def get_cache_feature(prompt, feature_ids):
            results = {}
            for sf in feature_ids:
                results[sf] = 0
            logits, cache = model.run_with_cache(prompt, prepend_bos=True)
            layer_cache = cache[sae.cfg.hook_name]
            feature_acts = sae.encode(layer_cache)
            for sf in feature_ids:
                for token in feature_acts[0]:
                    results[sf] += token[sf].item()
            return results
        
        def get_cached_feature_of_value(prompt, token_id):
            logits, cache = model.run_with_cache(prompt, prepend_bos=True)
            layer_cache = cache[sae.cfg.hook_name]
            feature_acts = sae.encode(layer_cache)
            sae_out = sae.decode(feature_acts)

            index_of_value = model.tokenizer(prompt, return_tensors="pt").input_ids[0].tolist().index(token_id)
            feature_of_value = feature_acts[0][index_of_value]
            topk_values, topk_indice = torch.topk(feature_of_value, int(0.005 * len(feature_of_value)))
            
            zero_index = torch.where(topk_indice == 0)
            if len(zero_index[0]) == 0:
                ti = topk_indice.tolist()
            else:
                ti = topk_indice[:zero_index[0]].tolist()

            del cache
            #flatten the dimension 1 and 2 of feature_acts
            #feature_acts = feature_acts.flatten(1, 2)
            #select the 0.5% highest-activated elements in feature_acts[0][0]
            #topk_values, topk_indice = torch.topk(feature_acts, int(0.005 * len(feature_acts[0][0])))
            #filter out the indices with 0 value
            return feature_of_value, ti

        if sae:
            if SAE_FEATURE_SOURCE == 'COLLECT':    
                if base_model == 'GPT2-SMALL':
                    token_id = 3815
                elif base_model == 'GEMMA-2B-IT':
                    token_id = 1618 #1261
                elif base_model == 'GEMMA-2B':
                    token_id = 4035
                else:
                    raise ValueError('Invalid base model')
                
                feature_of_value, ti = get_cached_feature_of_value(questions_padded[-1], token_id)
                for ttii in ti:
                    players[name][ttii] = feature_of_value[ttii].item()
                feature_intersections.append(set(ti))
            elif SAE_FEATURE_SOURCE == 'FIX':
                results = get_cache_feature(questions_padded[-1], FIXED_SAE_FEATURES)
                for sf in FIXED_SAE_FEATURES:
                    players[name][sf] = results[sf]
            else:
                raise ValueError('Invalid SAE_FEATURE_SOURCE')


    if sae and SAE_FEATURE_SOURCE == 'COLLECT':
        feature_intersection = set.intersection(*feature_intersections)
        feature_union = set.union(*feature_intersections)
    print('stds: ', np.mean(stds))

In [None]:
head_added = False
for name, char in players.items():
    pd_row = pd.DataFrame([char])
    #pd_row['name'] = name
    del(pd_row['trait'])
    if sae:
        if SAE_FEATURE_SOURCE == 'COLLECT':
            for fu in feature_union:
                if fu not in pd_row.keys():
                    pd_row[fu] = 0
    if not head_added:
        pd.DataFrame(columns=pd_row.keys()).to_csv(answer_valuebench_features_csv, index=False)
        head_added = True
    pd_row.to_csv(answer_valuebench_features_csv, mode='a', index=False, header=False)

In [None]:
def get_valid_d_columns(answer_valuebench_features_csv):
    data_csv = pd.read_csv(answer_valuebench_features_csv)
    digits = [str(d) for d in range(10)]
    d_columns = [d for d in data_csv.columns if d[0] in digits]
    d_data = data_csv[d_columns]
    stds = d_data.std()
    avgs = d_data.mean()
    std_avg = stds/avgs
    #d_columns_valid = [d for d in d_columns if avgs[d] > 1]
    d_columns_valid = d_columns
    return d_columns_valid

def deal_with_csv(answer_valuebench_features_csv, pdy_name, ci_on_all_v, v_smallset, d_columns_valid, row_num, method='pc'):
    data_csv = pd.read_csv(answer_valuebench_features_csv)
    
    digits = [str(d) for d in range(10)]
    v_columns_all = [v for v in data_csv.columns if v[0] not in digits]
    if v_smallset == 'ALL':
        v_columns_valid = v_columns_all
    else:
        v_columns_valid = [v for v in v_columns_all if v in small_set]
    
    if ci_on_all_v:
        ci_dimensions = v_columns_all + d_columns_valid
    else:
        ci_dimensions = v_columns_valid + d_columns_valid
    
    valid_columns = v_columns_valid + d_columns_valid   

    data = data_csv[ci_dimensions].to_numpy()
    rows = np.random.choice(data.shape[0], row_num, replace=False)
    data = data[rows]
    print(data.shape)
    #noise = np.random.normal(0, 0.001, data.shape)  # 0 is the mean of the normal distribution you are choosing from, and 0.01 is the standard deviation of this distribution.
    #data = data + noise

    if method == 'pc':
        #g = pc(data, 0.05, kci, kernelZ='Polynomial', node_names=ci_dimensions)
        g = pc(data, 0.01, node_names=ci_dimensions)
        graph = g.G
        edges = []
        for n1 in range(len(graph.nodes)):
            if graph.nodes[n1].name not in valid_columns:
                continue
            for n2 in range(n1+1, len(graph.nodes)):
                if graph.nodes[n2].name not in valid_columns:
                    continue

                # if n1 == n2:
                #     continue
                
                if graph.graph[n1][n2] == -1 and graph.graph[n2][n1] == 1:
                    edges.append([graph.nodes[n1].name, graph.nodes[n2].name, 1, 'single-arrow'])
                elif graph.graph[n1][n2] == 1 and graph.graph[n2][n1] == -1:
                    edges.append([graph.nodes[n2].name, graph.nodes[n1].name, 1, 'single-arrow']) 
                elif graph.graph[n1][n2] == -1 and graph.graph[n2][n1] == -1:
                    edges.append([graph.nodes[n1].name, graph.nodes[n2].name, 1, 'no-arrow'])
                elif graph.graph[n1][n2] == 1 and graph.graph[n2][n1] == 1:
                    edges.append([graph.nodes[n1].name, graph.nodes[n2].name, 1, 'double-arrow'])
                else:
                    if not (graph.graph[n1][n2] == 0 and graph.graph[n2][n1] == 0):
                        raise ValueError('Invalid edge')
                    
        columns_concerned_vis = [label.replace(':','-') for label in ci_dimensions]
        pdy = GraphUtils.to_pydot(graph, labels=columns_concerned_vis)
        pdy.write_png(pdy_name)
        return edges
    elif method == 'fci':
        g, edges = fci(data)
        return edges
        #g, fciedges = fci(data)
        # or customized parameters
        #g, edges = fci(data, independence_test_method, alpha, depth, max_path_length,
        #    verbose, background_knowledge, cache_variables_map)
    elif method == 'es':
        dag_est, search_stats = bic_exact_search(data, verbose=True)
        pass

deal_with_csv('useless_data/answers_valuebench_features_gemma2bit_players400_valuedimsALL.csv', None, True, 'ALL', [], 400)



# answer_valuebench_features_csv_gemma2bit = "answers_valuebench_features_gemma2bit_players400_valuedimsALL.csv"
# d_columns_valid_gemma2bit = get_valid_d_columns(answer_valuebench_features_csv_gemma2bit)

#answer_valuebench_features_csv_llama38bit = "answers_valuebench_features_llama38bit_players250_valuedimsALL.csv"
d_columns_valid = get_valid_d_columns(answer_valuebench_features_csv)

#value_dims = ['Power Distance', 'Individualism', 'Uncertainty Avoidance', 'Masculinity',  'Long-term Orientation', 'Indulgence', 'Corruption', 'Economic Vals', 'Ethical Vals', 'Migration', 'Political Cul', 'Political Int', 'Science', 'Feminist',]
# 'Security', 'Social Capital', 'Social Vals',]
#small_set = ['Power Distance', 'Individualism', 'Uncertainty Avoidance', 'Masculinity',  'Long Term Orientation', 'Indulgence', 'Economic', 'Political', 'Scientific Understanding', 'Achievement', 'Benevolence', 'Conformity', 'Hedonism', 'Security', 'Self-Direction', 'Stimulation', 'Tradition', 'Universalism', 'Fairness']

small_set = ['Power Distance', 'Individualism', 'Uncertainty Avoidance', 'Masculinity',  'Indulgence','Psychosocial flourishing', 'Liveliness', ]

df_valuebench = pd.read_csv(os.path.join(LOCAL_SAE_MODEL_PATH, 'value_data/value_orientation.csv'))
grouped = df_valuebench.groupby('value')
if not os.path.exists('valuebench'):
    os.mkdir('valuebench')
else:
    shutil.rmtree('valuebench')
    os.mkdir('valuebench')
for value_name in small_set:
    with open(os.path.join('valuebench','value_questions_' + value_name + '.html'), 'w') as f:
        for question, answer in zip(grouped['question'], grouped['agreement']):
            f.write(f'<p>Question: {question}</p>')
            f.write(f'<p>Answer: {answer}</p>')
            f.write(f'<p>==============================</p>')

# edges0 = [
#     ["Indulgence", "Hedonism", 1, 'no-arrow'],
#     ["Indulgence", "Stimulation", 1, 'no-arrow'],
#     ["Hedonism", "Stimulation", 1, 'no-arrow'],
#     ["Political", "Economic", 1, 'no-arrow'],
#     ["Individualism", "Self-Direction", 1, 'no-arrow'],
#     ["Power Distance", "Conformity", 1, 'no-arrow'],
#     ["Conformity", "Tradition", 1, 'no-arrow'],
#     ["Conformity", "Uncertainty Avoidance", 1, 'no-arrow'],
#     ["Uncertainty Avoidance", "Security", 1, 'no-arrow'],
# ]

edges1 = deal_with_csv(answer_valuebench_features_csv, "gemma2bit-smallset-fixed.png", False, small_set, d_columns_valid, 13)
edges1.append(["13562", "Psychosocial flourishing", 1, 'single-arrow'])
# edges1 = deal_with_csv(answer_valuebench_features_csv_llama38bit, "llama38bit-smallset-30.png", False, small_set, d_columns_valid_llama38bit, 30)
# edges11 = deal_with_csv(answer_valuebench_features_csv_llama38bit, "llama38bit-smallset-30_1.png", False, small_set, d_columns_valid_llama38bit, 30)
# edges2 = deal_with_csv(answer_valuebench_features_csv_llama38bit, "llama38bit-smallset-250.png", False, small_set, d_columns_valid_llama38bit, 250)
# edges21 = deal_with_csv(answer_valuebench_features_csv_llama38bit, "llama38bit-smallset-250_1.png", False, small_set, d_columns_valid_llama38bit, 250)
# edges3 = deal_with_csv(answer_valuebench_features_csv_llama38bit, "llama38bit-all4small-250.png", True, small_set, d_columns_valid_llama38bit, 250)
# edges4 = deal_with_csv(answer_valuebench_features_csv_llama38bitfirst, "llama38bit-smallset-250-first.png", False, small_set, d_columns_valid_llama38bit, 250)
# edges5 = deal_with_csv(answer_valuebench_features_csv_llama38bitchn, "llama38bit-smallset-250-chn.png", False, small_set, d_columns_valid_llama38bit, 250)


nodes = {}
for entity in small_set:
    nodes[entity] = os.path.join('valuebench','value_questions_' + entity + '.html'),
for feature in d_columns_valid:
    nodes[feature] = 'https://www.neuronpedia.org/' + sae.cfg.model_name +'/' + str(sae.cfg.hook_layer) + '-res-jb/' + str(feature)



edges = {
    'gemma2bit-smallset-fixed': edges1,
    # 'human_annotated': edges0,
    # 'gemma2bit-unsure-smallset-30': edges1
    #'llama38bit_cismall_30_trial1': edges1,
    #'llama38bit_cismall_30_trial2': edges11,
    #'llama38bit_cismall_250_trial1': edges2,
    #'llama38bit_cismall_250_trial2': edges21,
    #'llama38bit_ciall4small_250': edges3,
    #'llama38bit_cismall_250_lessqa': edges4,
    #'llama38bit_cismall_250_chn': edges5,
}

json_object = {
    'nodes': nodes,
    'edges': edges
    }

json.dump(json_object, open('data1.json', 'w'))



In [None]:
small_set = ['Power Distance', 'Individualism', 'Uncertainty Avoidance', 'Masculinity',  'Long Term Orientation', 'Indulgence', 'Economic', 'Political', 'Scientific Understanding', 'Achievement', 'Benevolence', 'Conformity', 'Hedonism', 'Security', 'Self-Direction', 'Stimulation', 'Tradition', 'Universalism']

with torch.no_grad():
    # activation store can give us tokens.

    #value_of_interest = 'Calmness'
    #value_of_interest = 'Causality:Interactionism'
    #value_of_interest_ref = 'Authority'
    # feature_of_interest = 15509 #3376

    #value_of_interest = 'Financial Prosperity'
    value_of_interest = 'Decisiveness'
    feature_of_interest = 10096

    df_valuebench = pd.read_csv(os.path.join(LOCAL_SAE_MODEL_PATH, 'value_data/value_orientation.csv'))
    grouped = df_valuebench.groupby('value')
    name = 'Nicholas Davis'
    trait = "Gender: male; Job: Financial risk analyst; DOB: 1942-03-14; bio: Nicholas Davis is a male financial risk analyst with a medium level of responsibility. Known for his calm and collected demeanor, Nicholas approaches his work with a low level of aggression, always seeking to find strategic solutions to potential financial risks. With a keen eye for detail and strong analytical skills, he is dedicated to ensuring that his clients make informed decisions to protect their investments."

    for value_name, group in grouped:
        if value_name != value_of_interest:
            continue
        
        for trial in range(10):
            questions = []
            answers = []
            randomno = random.randint(0, len(group['agreement'])-1)
            groupagreement = group['agreement'][randomno:randomno+1]
            groupquestion = group['question'][randomno:randomno+1]
            groupitem = group['item'][randomno:randomno+1]

            for groupmember in zip(groupagreement, groupquestion, groupitem):
                a = groupmember[0]
                q = groupmember[1]
                qi = groupmember[2]

                question = generate_question(name, trait, a, q, qi, allow_unsure)
                questions.append(question)
                answer = a
                answers.append(answer)

            assert sae
            if base_model == 'GPT2-SMALL':
                token_id = 3815
            elif base_model == 'GEMMA-2B-IT':
                token_id = 1618 #1261
            elif base_model == 'GEMMA-2B':
                token_id = 4035
            elif base_model == 'MISTRAL-7B':
                token_id = 3815##TBD
            else:
                raise ValueError('Invalid base model')

            batch_tokens = model.tokenizer(questions, return_tensors="pt", padding=True).input_ids#[0].tolist().index(3815)
            index_of_value = [tks.index(token_id) for tks in batch_tokens.tolist()][-1]
            #logits, cache = model.run_with_cache(batch_tokens, prepend_bos=True)
            logits, cache = model.run_with_cache(batch_tokens)

            # Use the SAE
            feature_acts = sae.encode(cache[sae.cfg.hook_name])
            #feature_acts[0][index_of_value][feature_of_interest] = 0#feature_acts[0][index_of_value][feature_of_interest] * 10
            sae_out = sae.decode(feature_acts)
            
            def computed_cache(batch_tokens):
                intermidiate_data = batch_tokens
                intermidiate_data = model.hook_embed(model.embed(intermidiate_data))
                
                for layer in range(13):
                    block = model.blocks[layer]
                    intermidiate_data = block(intermidiate_data)
                return intermidiate_data
            
            def computed_modified_cache(sae_out):
                intermidiate_data = sae_out
                for layer in range(13, 18):
                    block = model.blocks[layer]
                    intermidiate_data = block(intermidiate_data)
                return model.unembed(model.ln_final(intermidiate_data))
                
            # gen_tokens = model.generate(batch_tokens, max_new_tokens=5, verbose=False)
            # # gen_tokens_f = model.forward(questions_tokens.input_ids, return_type='logits', attention_mask=questions_tokens.attention_mask)
            # # gen_tokens_hf = hf_model.generate(questions_tokens.input_ids, attention_mask=questions_tokens.attention_mask, max_new_tokens=10)
            # gen_texts = model.tokenizer.batch_decode(gen_tokens)
            # gen_answer = gen_texts[-1].lower()[len(questions[-1]):].strip()
            
            pre = computed_cache(batch_tokens)
            assert torch.all(pre==cache[sae.cfg.hook_name])
            
            cmc = computed_modified_cache(sae_out)  
            #cmc = computed_modified_cache(pre)  


            modified_answer = model.tokenizer.batch_decode([torch.argmax(cmc[0][-1])])[0].strip().lower()
            original_answer = model.tokenizer.batch_decode([torch.argmax(logits[0][-1])])[0].strip().lower()
            
            print('standard_answer: ', answers[0])
            print('original answer: ', original_answer)
            print('modified answer: ', modified_answer)
            print('\n')    
            del  cache, logits, feature_acts, sae_out #pre, cmc
            gc.collect()
        print('=========================')
        '''
        # save some room
        del cache
        activations = []
        for iovn in range(len(index_of_value)):
            feature_of_value = feature_acts[iovn][index_of_value[iovn]]
            pass
            activations.append(feature_of_value[feature_of_interest].item())
            
        px.histogram(activations, nbins=100).show()
        #px histogram x label more fine grained
        
        
        # # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
        # l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
        # print("average l0", l0.mean().item())
        # px.histogram(l0.flatten().cpu().numpy()).show()
        '''

Note that while the mean L0 is 64, it varies with the specific activation.

To estimate reconstruction performance, we calculate the CE loss of the model with and without the SAE being used in place of the activations. This will vary depending on the tokens.


# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out


def zero_abl_hook(activation, hook):
    return torch.zeros_like(activation)


print("Orig", model(batch_tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                sae.cfg.hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(sae.cfg.hook_name, zero_abl_hook)],
    ).item(),
)

## Specific Capability Test

Validating model performance on specific tasks when using the reconstructed activation is quite important when studying specific tasks.

example_prompt = "When John and Mary went to the shops, John gave the bag to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

logits, cache = model.run_with_cache(example_prompt, prepend_bos=True)
tokens = model.to_tokens(example_prompt)
sae_out = sae(cache[sae.cfg.hook_name])


def reconstr_hook(activations, hook, sae_out):
    return sae_out


def zero_abl_hook(mlp_out, hook):
    return torch.zeros_like(mlp_out)


hook_name = sae.cfg.hook_name

print("Orig", model(tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        tokens,
        fwd_hooks=[
            (
                hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        tokens,
        return_type="loss",
        fwd_hooks=[(hook_name, zero_abl_hook)],
    ).item(),
)


with model.hooks(
    fwd_hooks=[
        (
            hook_name,
            partial(reconstr_hook, sae_out=sae_out),
        )
    ]
):
    utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

# Generating Feature Interfaces

Feature dashboards are an important part of SAE Evaluation. They work by:
- 1. Collecting feature activations over a larger number of examples.
- 2. Aggregating feature specific statistics (such as max activating examples).
- 3. Representing that information in a standardized way

test_feature_idx_gpt = list(range(10)) + [14057]

feature_vis_config_gpt = SaeVisConfig(
    hook_point="blocks.12.hook_resid_post",
    features=test_feature_idx_gpt,
    batch_size=2048,
    minibatch_size_tokens=128,
    verbose=True,
)

sae_vis_data_gpt = SaeVisData.create(
    encoder=sae,
    model=model, # type: ignore
    tokens=token_dataset[:10000]["tokens"],  # type: ignore
    cfg=feature_vis_config_gpt,
)

for feature in test_feature_idx_gpt:
    filename = f"{feature}_feature_vis_demo_gpt.html"
    sae_vis_data_gpt.save_feature_centric_vis(filename, feature)
    display_vis_inline(filename)

Now, since generating feature dashboards can be done once per sparse autoencoder, for pre-trained SAEs in the public domain, everyone can use the same dashboards. Neuronpedia hosts dashboards which we can load via the intergration.



test_feature_idx_gpt = list(range(10)) + [14057]

# this function should open
neuronpedia_quick_list = get_neuronpedia_quick_list(
    test_feature_idx_gpt,
    layer=sae.cfg.hook_layer,
    model="gemma-2b-it",
    dataset="res-jb",
    name="A quick list we made",
)

#if COLAB:
  # If you're on colab, click the link below
print(neuronpedia_quick_list)