# Generating Outputs for Neuronpedia Upload

We use Callum McDougall's `sae_vis` library for generating JSON data to upload to Neuronpedia.


## Set Up

In [None]:
from sae_lens.toolkit.pretrained_saes import download_sae_from_hf
import os

MODEL_ID = "gpt2-small"
SAE_ID = "res-jb"

(_, SAE_WEIGHTS_PATH, _) = download_sae_from_hf(
    "jbloom/GPT2-Small-SAEs-Reformatted", "blocks.0.hook_resid_pre"
)

SAE_PATH = os.path.dirname(SAE_WEIGHTS_PATH)

## Save JSON to neuronpedia_outputs

In [None]:
from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner

print(SAE_PATH)
NP_OUTPUT_FOLDER = "../../neuronpedia_outputs/my_outputs"

runner = NeuronpediaRunner(
    sae_id=SAE_ID,
    sae_path=SAE_PATH,
    outputs_dir=NP_OUTPUT_FOLDER,
    sparsity_threshold=-5,
    n_batches_to_sample_from=2**12,
    n_prompts_to_select=4096*6,
    n_features_at_a_time=24,
    start_batch_inclusive=1,
    end_batch_inclusive=1,
)

runner.run()

## Upload to Neuronpedia
#### This currently only works if you have admin access to the Neuronpedia database via localhost.

In [None]:
# Helpers that fix weird NaN stuff
from decimal import Decimal
from typing import Any
import math
import json
import os
import requests

FEATURE_OUTPUTS_FOLDER = runner.outputs_dir

def nanToNeg999(obj: Any) -> Any:
    if isinstance(obj, dict):
        return {k: nanToNeg999(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [nanToNeg999(v) for v in obj]
    elif (isinstance(obj, float) or isinstance(obj, Decimal)) and math.isnan(obj):
        return -999
    return obj


class NanConverter(json.JSONEncoder):
    def encode(self, o: Any, *args: Any, **kwargs: Any):
        return super().encode(nanToNeg999(o), *args, **kwargs)


# Server info
host = "http://localhost:3000"

# Upload alive features
for file_name in os.listdir(FEATURE_OUTPUTS_FOLDER):
    if file_name.startswith("batch-") and file_name.endswith(".json"):
        print("Uploading file: " + file_name)
        file_path = os.path.join(FEATURE_OUTPUTS_FOLDER, file_name)
        f = open(file_path, "r")
        data = json.load(f)

        # Replace NaNs
        data_fixed = json.dumps(data, cls=NanConverter)
        data = json.loads(data_fixed)

        url = host + "/api/local/upload-features"
        resp = requests.post(
            url,
            json=data,
        )

# Upload dead feature stubs
skipped_path = os.path.join(FEATURE_OUTPUTS_FOLDER, "skipped_indexes.json")
f = open(skipped_path, "r")
data = json.load(f)
url = host + "/api/local/upload-dead-features"
resp = requests.post(
    url,
    json=data,
)

### TODO: Automatically validate the uploaded data