<a href="https://colab.research.google.com/github/kaifbilal/Neuronpedia_API_LLM/blob/main/Neuronpedia_API_LLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers torch sae-lens streamlit pyngrok requests --quiet

In [None]:
!pip install --upgrade --force-reinstall numpy
!pip install --upgrade transformers


In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

from google.colab import userdata
NUERONPEDIA_API_KEY = userdata.get('neuronpedia-API')
NGROK_API_KEY = userdata.get('ngrok')
print('API key successfully retrieved')

TEST_PROMPT = 'The spinning top at the end of movie inception means' #
print(f"Test prompt: {TEST_PROMPT}")

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sae_lens import SAE
import requests

# load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.config.output_hidden_states = True
device = torch.device("cude") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

In [None]:
# testt the model
input_ids = tokenizer(TEST_PROMPT, return_tensors="pt").input_ids.to(device)
attention_mask = tokenizer(TEST_PROMPT, return_tensors="pt").attention_mask.to(device)

#explicitly set pad token ID
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

#generate text
with torch.no_grad():
    outputs = model.generate(input_ids, attention_mask=attention_mask,
                             max_length=input_ids.shape[1]+1)
    generated_text = tokenizer.decode(outputs[0])
    print(generated_text)


In [None]:
def get_top_k_tokens(prompt_text, k=3):
    inputs = tokenizer(prompt_text, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        next_token_logits = logits[0, -1, :]

        probabilities = torch.softmax(next_token_logits, dim=-1)
        topk_probs, topk_indices = torch.topk(probabilities, k)
        topk_tokens = []
        for idx, prob in zip(topk_indices.tolist(), topk_probs.tolist()):
          token_text = tokenizer.decode([idx])
          topk_tokens.append((token_text, prob))

        return topk_tokens

In [None]:
suggestions = get_top_k_tokens(TEST_PROMPT, k=3)
print(f"Prompt: {TEST_PROMPT}")
print("Top 3 next token suggestions (with probabilities):")
for token, prob in suggestions:
    print(f"{repr(token)} with probability {prob*100:.2f}%")

In [None]:
sae, cfg, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.11.hook_resid_pre",
    device=str(device)
)


sae.to(device)
sae.eval()

print("SAE successfully loaded")

In [None]:
def get_features(prompt_text, k=3):
  inputs = tokenizer(prompt_text, return_tensors="pt").to(device)
  with torch.no_grad():
    outputs = model(**inputs)
  hidden_states = outputs.hidden_states

  resid_pre_11 = hidden_states[11]
  final_token_activation = resid_pre_11[:, -1, :].unsqueeze(0)

  with torch.no_grad():
    sparse_activations = sae.encode(final_token_activation)
    sparse_activations = sparse_activations.squeeze().cpu().numpy()

  top_feat_indices = sparse_activations.argsort()[-k:][::-1]
  formatted_features = []
  for index in top_feat_indices:
    formatted_features.append(f"gpt2-small/11-res-jb/{index}")

  return formatted_features

In [None]:
features = get_features(TEST_PROMPT, k=5)
for feature in features:
    print(feature)

In [None]:
headers = {"Authorizaton": f"Bearer {NUERONPEDIA_API_KEY}"}
url = f"http://www.neuronpedia.org/api/feature/{features[0]}"
res = requests.get(url, headers=headers)
if res.status_code == 200:
    print(res.json())
else:
    print(f"Request failed with status code {res.status_code}")

In [None]:
def fetch_feature_info(feature_id: int, model="gpt2-small", layer="l1-res-jb"):
  headers = {"Authorizaton": f"Bearer {NUERONPEDIA_API_KEY}"}
  url = f"http://www.neuronpedia.org/api/feature/{feature_id}"
  res = requests.get(url, headers=headers)

  # error handling

  data = res.json()
  explanations = data.get("explanations", [])
  explanation_text = explanations[0].get("description") if explanations else None

  feature_info = {
      "index": data.get("index"),
      "layer": data.get("layer"),
      "model": data.get("modelId"),
      "explanation": explanation_text,
      "max_activation": data.get("maxActApprox"),
      "positive_strings": data.get("pos_str", []),
      "positive_values": data.get("pos_values", []),
      "negative_strings": data.get("neg_str", []),
      "negative_values": data.get("neg_values", []),
      "examples":[
          {
              "tokens": act.get("tokens"),
              "max_value": act.get("maxValue"),
              "max_token": act["tokens"][act.get("maxValueTokenIndex", -1)] if act.get("tokens") else None
          }
          for act in data.get("activations", [])
      ]
  }

  return feature_info


In [None]:
fetch_feature_info(features[0])['explanation']

In [None]:
%%writefile app.py
import streamlit as st
from model_utils import get_top_k_tokens
from sae_utils import get_features, fetch_feature_info

st.set_page_config(layout="wide", page_title="Interpretable AI Writing Assistant")

st.title("Interpretable AI Writing Assistant")

# Split layout into two columns: writing (left), analysis (right)
left, right = st.columns([2.5, 1])

with left:
    prompt = st.text_area("Write something:", height=300, placeholder="e.g. The wizard opened the ancient...")

with right:
    if prompt.strip():
        st.markdown("### Top Predictions")
        preds = get_top_k_tokens(prompt, k=3)
        for token, prob in preds:
            st.markdown(f"- **{token.strip()}** ({prob*100:.2f}%)")

        st.markdown("---")
        st.markdown("### Top Activated Features")

        feature_ids = get_features(prompt, k=3)

        for fid in feature_ids:
            info = fetch_feature_info(fid)

            if info["explanation"]:
                # Display feature ID and explanation in salmon red kinda colour
                st.markdown(
                    f"<span style='color:#FA8072'><b>{info['index']}: {info['explanation']}</b></span>",
                    unsafe_allow_html=True
                )

                # Wrap first 5 examples inside a collapsible section
                if info["examples"]:
                    with st.expander("See example(s)"):
                        for example in info["examples"][:5]:
                            st.markdown(f"`{example['tokens']}`")



In [None]:
%%writefile model_utils.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2", output_hidden_states=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def get_top_k_tokens(prompt_text, k=3):
    """Return the top-k next token predictions and their probabilities for a given prompt."""
    # Tokenize the prompt
    inputs = tokenizer(prompt_text, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        # outputs.logits has shape [batch_size, sequence_length, vocab_size]
        # We have batch_size=1 (single prompt), sequence_length = len(prompt tokens)
        logits = outputs.logits  # tensor of shape (1, seq_len, vocab_size)
    next_token_logits = logits[0, -1, :]  # scores for the next token (at the last position of sequence)

    # Get probabilities (softmax) for understanding, and top-k token indices
    probabilities = torch.softmax(next_token_logits, dim=-1)
    topk_probs, topk_indices = torch.topk(probabilities, k)

    topk_tokens = []
    for idx, prob in zip(topk_indices.tolist(), topk_probs.tolist()):
        token_text = tokenizer.decode([idx])
        topk_tokens.append((token_text, prob))
    return topk_tokens


In [None]:
%%writefile sae_utils.py
import torch
from sae_lens import SAE
from transformers import AutoTokenizer, AutoModelForCausalLM
import requests
import os

# Load model/tokenizer for internal use
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2", output_hidden_states=True)
model.to("cpu")
model.eval()

# Load SAE
sae, _, _ = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.11.hook_resid_pre",
    device="cpu"
)
sae.eval()

NEURONPEDIA_API_KEY = os.environ.get("NEURONPEDIA_API_KEY")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_features(prompt_text, k=3):
    """
    Extract top-k SAE features from the hidden state of the prompt.
    """

    # Step 1: Run prompt through model to get hidden states
    inputs = tokenizer(prompt_text, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    hidden_states = outputs.hidden_states  # list of 13 tensors (0 = embeddings, 12 = post layer 11)

    # Step 2: Grab hidden state at resid_pre of layer 11
    resid_pre_11 = hidden_states[11]  # shape: [1, seq_len, hidden_dim]
    final_token_activation = resid_pre_11[0, -1, :].unsqueeze(0)  # shape [1, 768]

    # Step 3: Encode this activation through the SAE
    with torch.no_grad():
        sparse_activations = sae.encode(final_token_activation)  # shape: [1, num_features]
        sparse_activations = sparse_activations.squeeze().cpu().numpy()  # shape: [num_features]

    # Step 4: Extract top-k activated feature indices
    top_feat_indices = sparse_activations.argsort()[-k:][::-1]
    formatted_features = []
    for index in top_feat_indices:
        formatted_features.append(f"gpt2-small/11-res-jb/{index}")

    return formatted_features

def fetch_feature_info(feature_id: int, model='gpt2-small', layer='0-res-jb'):
    """
    Fetches interpretability information about a single SAE feature from Neuronpedia API.

    Parameters:
        feature_id (int): The ID of the SAE feature.
        model (str): The GPT-2 model ID (default is 'gpt2-small').
        layer (str): The specific SAE layer (default is '0-res-jb').

    Returns:
        dict: A dictionary containing metadata and interpretability signals about the feature.
    """
    headers = {"Authorization": f"Bearer {NEURONPEDIA_API_KEY}"}
    url = f"https://www.neuronpedia.org/api/feature/{feature_id}"
    res = requests.get(url, headers=headers)
    if res.status_code != 200:
        print(f"Error fetching feature info for {feature_id}: {res.status_code}")
        return {"name": None, "description": None, "example": None}
    data = res.json()

    explanations = data.get("explanations", [])
    explanation_text = explanations[0].get("description") if explanations else None

    # Parse remaining relevant info
    feature_info = {
        "index": data.get("index"),
        "layer": data.get("layer"),
        "model": data.get("modelId"),
        "explanation": explanation_text,
        "max_activation": data.get("maxActApprox"),
        "positive_strings": data.get("pos_str", []),
        "positive_values": data.get("pos_values", []),
        "negative_strings": data.get("neg_str", []),
        "negative_values": data.get("neg_values", []),
        "examples": [
            {
                "tokens": act.get("tokens"),
                "max_value": act.get("maxValue"),
                "max_token": act["tokens"][act.get("maxValueTokenIndex", -1)] if act.get("tokens") else None
            }
            for act in data.get("activations", [])
        ]
    }

    return feature_info

In [76]:
!killall ngrok

In [None]:
from pyngrok import conf, ngrok, process
import os

from google.colab import userdata
ngrok.set_auth_token(userdata.get('ngrok'))

NEURONPEDIA_API_KEY = userdata.get('neuronpedia-API')
os.environ["NEURONPEDIA_API_KEY"] = NEURONPEDIA_API_KEY  # already loaded via Google Secrets

# Start tunnel, specifying the address with the port
public_url = ngrok.connect(addr="8501")  # Pass port as part of addr
print(f"Streamlit app running at: {public_url}")

# Run Streamlit
!streamlit run app.py --server.port=8501 &> /content/logs.txt &