import stuff

In [1]:
from dotenv import load_dotenv
from utils.models import MiniPileDataset
from utils.interp import count_non_zero_feature_activations, plot_feature_activation_histogram
import os

# Enable automatic reloading of modules when they change
%load_ext autoreload
%autoreload 2


# Load environment variables from .env file
load_dotenv()

# Access the OpenAI API key from the environment variables
openai_api_key = os.getenv("OPENAI_API_KEY")


In [2]:
# Load the model from the pickle file
import pickle 
from utils.sae import SparseAutoencoder, SparseAutoencoderConfig
import json

# load the dataset
file_name = "files/all_sentences_with_embeddings_20240707_132959.pkl"
with open(file_name, "rb") as f:
    mini_pile_dataset = pickle.load(f)

# Load the configuration from the JSON file
config_path = "sae/20240708_195600_config.json"
with open(config_path, "r") as config_file:
    config = json.load(config_file)

# Load the pre-trained model from the pickle file
sae_config = SparseAutoencoderConfig(d_model=config["dimensions"], d_sparse=8 * config["dimensions"], sparsity_alpha=config["sparsity_alpha"])
model = SparseAutoencoder(sae_config)
model_path = "sae/20240708_195600_sae.pkl"
with open(model_path, "rb") as f:
    model_state_dict = pickle.load(f)
    model.load_state_dict(model_state_dict)

interpret the feature activations

In [11]:
count_non_zero_feature_activations(model, mini_pile_dataset)

Average Non-Zero Elements for first 100 samples: 12.8100004196167


In [None]:
plot_feature_activation_histogram(model, mini_pile_dataset)

automated interp pipeline

In [None]:
import numpy as np
from utils.ai import OpenAIClient
from utils.features import Feature, FeatureSample
import os
import json
from pprint import pprint
from datetime import datetime

# make folder
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
folder_name = f"features/sae_features_{timestamp}"
os.makedirs(folder_name, exist_ok=True)

ai = OpenAIClient(openai_api_key)

n = len(mini_pile_dataset)
feature_registry = np.zeros((config["dimensions"] * 8, n))


for i in range(n):
  embedding = mini_pile_dataset.embeddings[i]
  feature_activations = model.forward(embedding)[1]
  feature_registry[:, i] = feature_activations.detach().numpy()
    
for index, feature in enumerate(feature_registry):
    feature_samples = [FeatureSample(text=mini_pile_dataset.sentences[i], act=value) for i, value in enumerate(feature)]
    feature_samples.sort(key=lambda x: x.act, reverse=True)

    high_act_samples = feature_samples[:50]
    low_act_samples = feature_samples[-50:]

    try:
        interpetation = ai.get_interpretation(high_act_samples, low_act_samples)
        label = interpetation["label"]
        reasoning = interpetation["reasoning"]
        attributes = interpetation["attributes"]
    
        high_act_score = ai.score_interpretation(high_act_samples, attributes)['percent']
        low_act_score = ai.score_interpretation(low_act_samples, attributes)['percent']
    except Exception as e:
        print(f"Skipping feature due to error: {e}")
        continue

    labelled_feature = Feature(
       index=index, 
       label=label, 
       attributes=attributes, 
       reasoning=reasoning, 
       confidence=abs(high_act_score - low_act_score), 
       density=(np.count_nonzero(feature) / len(feature)),
       high_act_samples=high_act_samples,
       low_act_samples=low_act_samples,
    )

    # write this feature
    with open(os.path.join(folder_name, f"feature_{index}.json"), "w") as json_file:
        json.dump(labelled_feature.dict(), json_file, indent=4)
    
    # print processed feature
    print(f"Processed feature {index}: {label}")