## Loading

In [95]:
from datasets import load_dataset

ds = load_dataset("openlifescienceai/medmcqa", split="train", streaming=False)

rows_to_keep = 10_000

# Only keep the first 10,000 rows
ds = ds.select(range(rows_to_keep))

def format_question_text(example):
    """
    Transforms a dataset example into a formatted text string.
    Args:
        example: Dictionary containing the question data with keys:
                'question', 'opa', 'opb', 'opc', 'opd', 'cop', 'exp'
    Returns:
        Dict with new 'text' key containing formatted string
    """
    # Option keys in order
    option_keys = ['opa', 'opb', 'opc', 'opd']

    # Build the formatted string components
    question = f"{example['question']}"
    # Strip ':' from the end of the question
    question = question.rstrip(':')
    # Add a period to the end of the question
    question = question + '?'
    options = "\nThe options are:\n" + "\n".join(example[key] for key in option_keys)

    # Get correct option using the cop index
    correct_idx = int(example['cop'])
    correct_option = f"\nCorrect option: {example[option_keys[correct_idx]]}"

    # Add explanation if available
    explanation = f"\nExplanation: {example['exp']}" if 'exp' in example else ""
    # Strip anything after and including 'Ref'
    explanation = explanation.split('Ref')[0]

    # Combine all components
    formatted_text = f"{question}{options}{correct_option}{explanation}"

    # Return dictionary with new text field
    example['text'] = formatted_text
    return example

# Function to transform the entire dataset
def transform_dataset(dataset):
    """
    Applies the formatting transformation to the entire dataset.
    Args:
        dataset: Huggingface dataset object
    Returns:
        Transformed dataset with new 'text' column
    """
    return dataset.map(
        format_question_text,
        desc="Formatting questions into text",
        num_proc=4  # Adjust based on your system
    )

transformed_ds = transform_dataset(ds)
transformed_ds.to_pandas()
# Drop all columns except text
transformed_ds = transformed_ds.remove_columns([col for col in transformed_ds.column_names if col != 'text'])
transformed_ds

Formatting questions into text (num_proc=4):   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset({
    features: ['text'],
    num_rows: 10000
})

In [1]:
import torch
import torch.nn as nn
import numpy as np
from huggingface_hub import hf_hub_download
from huggingface_hub import HfApi

# Define the SAE model
class JumpReLUSAE(nn.Module):
    def __init__(self, d_model, d_sae):
        super().__init__()
        self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
        self.threshold = nn.Parameter(torch.zeros(d_sae))
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_model))

        # Dimensions
        self.d_model = d_model
        self.d_sae = d_sae

    def encode(self, input_acts):
        pre_acts = input_acts @ self.W_enc + self.b_enc
        mask = (pre_acts > self.threshold)
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def decode(self, acts):
        return acts @ self.W_dec + self.b_dec

    def forward(self, acts):
        acts = self.encode(acts)
        recon = self.decode(acts)
        return recon

width='16k'
l0 = 71
layer = 20

# Load the SAE model
path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-2b-pt-res",
    filename=f"layer_{layer}/width_{width}/average_l0_{l0}/params.npz",
    force_download=False,
)

params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).cpu() for k, v in params.items()}

# Initialize and load the SAE model
sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)
sae = sae.cpu()

# Load your data from Hugging Face
repo_id = "charlieoneill/gemma-medicine-sae"  # Replace with your repo

# Download the activation tensor and dataset
api = HfApi()
activation_file = hf_hub_download(repo_id=repo_id, filename="10000_128.pt")

# Load the tensors
activations = torch.load(activation_file)

  activations = torch.load(activation_file)


In [2]:
activations.shape

torch.Size([10000, 128, 2304])

In [3]:
sae.d_sae

16384

In [102]:
# Move to GPU
activations = activations.cpu()

# Process a batch of 32
# batch_size = 1
# batch_acts = activations[:batch_size]
batch_acts = activations[4].unsqueeze(0)

# Run through SAE
with torch.no_grad():
    recon = sae(batch_acts)

# Calculate variance explained
variance_explained = 1 - torch.mean((recon[:, 1:] - batch_acts[:, 1:].to(torch.float32)) **2) / (batch_acts[:, 1:].to(torch.float32).var())

# Calculate L0 sparsity
with torch.no_grad():
    encoded = sae.encode(batch_acts)
    l0_sparsity = (encoded > 0).float().mean()

print(f"Variance explained: {variance_explained.item():.4f}")
print(f"L0 sparsity: {l0_sparsity.item():.4f}")
    

Variance explained: 0.8750
L0 sparsity: 0.0043


In [103]:
activations.shape

torch.Size([10000, 128, 2304])

In [110]:
target_act = activations[4].unsqueeze(0)
print(target_act.shape)

sae_acts = sae.encode(target_act.to(torch.float32))
recon = sae.decode(sae_acts)

print(sae_acts.shape, recon.shape)

torch.Size([1, 128, 2304])
torch.Size([1, 128, 65536]) torch.Size([1, 128, 2304])


In [111]:
# Print MSE loss between target_act and recon
loss = torch.mean((recon[:, 1:] - target_act[:, 1:].to(torch.float32)) **2)
print(loss)

tensor(6.0132, grad_fn=<MeanBackward0>)


In [112]:
1 - torch.mean((recon[:, 1:] - target_act[:, 1:].to(torch.float32)) **2) / (target_act[:, 1:].to(torch.float32).var())

tensor(0.8750, grad_fn=<RsubBackward1>)

In [113]:
(sae_acts > 1).sum(-1)

tensor([[27429,    50,    66,    61,    66,    77,    74,    73,    74,    44,
            49,    50,    62,    56,    64,    45,    68,    77,    63,    59,
            46,    52,    52,    57,    62,    75,    60,    59,    56,    72,
            70,    67,    48,    82,    83,    49,    59,    53,    63,    54,
            52,    90,    63,    46,    51,    60,    38,    52,    82,    39,
            39,    46,    60,    67,    62,    49,    52,    68,    75,    63,
            66,    68,    69,    38,    65,    82,    71,    69,    84,    58,
            60,    77,    79,    64,    52,    65,    79,    63,    77,    56,
            78,    66,    76,    81,    68,    73,    80,    54,    56,    88,
            58,    70,    67,    59,    68,    67,    70,    70,    75,    52,
            81,    69,    88,    73,    63,    79,    62,    36,    76,    76,
            68,    75,    76,    70,    77,    91,    89,    92,    50,    77,
            73,    95,    63,    77,    59,    56,  

In [114]:
# values, inds = sae_acts.max(-1)

# inds, inds.shape

# First flatten the sequence dimension, but exclude first position
flat_acts = sae_acts[:, 1:, :].reshape(sae_acts.shape[0], -1)  # Note the 1: slice

# Get top 10 values and indices
top_values, top_indices = torch.topk(flat_acts, k=10, dim=-1)

# Convert flat indices back to (seq_pos, feature) pairs
# Add 1 to seq_pos since we excluded the first position
seq_pos = (top_indices // sae_acts.shape[-1]) + 1  # add 1 to account for skipped first position
feature_ids = top_indices % sae_acts.shape[-1]

# Print results
for i in range(10):
    print(f"Position {seq_pos[0][i]}, Feature {feature_ids[0][i]}: Activation {top_values[0][i]:.2f}")

Position 75, Feature 19451: Activation 160.45
Position 28, Feature 38187: Activation 151.08
Position 1, Feature 19802: Activation 129.82
Position 3, Feature 62431: Activation 128.68
Position 49, Feature 48639: Activation 126.09
Position 16, Feature 52421: Activation 125.92
Position 127, Feature 44592: Activation 122.17
Position 109, Feature 31548: Activation 118.08
Position 95, Feature 53551: Activation 117.55
Position 104, Feature 44466: Activation 116.45


In [115]:
print(transformed_ds[4]['text'])


Growth hormone has its effect on growth through??
The options are:
Directly
IG1-1
Thyroxine
Intranuclear receptors
Correct option: IG1-1
Explanation: Ans. is 'b' i.e., IGI-1GH has two major functions :-i) Growth of skeletal system :- The growth is mediated by somatomedins (IGF). Increased deposition of cailage (including chondroitin sulfate) and bone with increased proliferation of chondrocytes and osteocytes.ii) Metabolic effects :- Most of the metabolic effects are due to direct action of GH. These include gluconeogenesis, decreased peripheral utilization of glucose (decreased uptake), lipolysis and anabolic effect on proteins.


In [116]:
from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id=f"{layer}-gemmascope-res-{width}", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)

html = get_dashboard_html(sae_release = "gemma-2-2b", sae_id=f"{layer}-gemmascope-res-{width}", feature_idx=19451)
IFrame(html, width=1200, height=600)

In [118]:
import requests

#url = "https://www.neuronpedia.org/api/explanation/export?modelId=gpt2-small&saeId=7-res-jb"
url = f"https://www.neuronpedia.org/api/explanation/export?modelId=gemma-2-2b&saeId={layer}-gemmascope-res-{width}"
headers = {"Content-Type": "application/json"}

response = requests.get(url, headers=headers)

In [119]:
import pandas as pd

# convert to pandas
data = response.json()
explanations_df = pd.DataFrame(data)
# rename index to "feature"
explanations_df.rename(columns={"index": "feature"}, inplace=True)
# explanations_df["feature"] = explanations_df["feature"].astype(int)
explanations_df["description"] = explanations_df["description"].apply(
    lambda x: x.lower()
)
explanations_df

Unnamed: 0,modelId,layer,feature,description,explanationModelName,typeName
0,gemma-2-2b,20-gemmascope-res-65k,1863,sentence beginnings,gpt-4o-mini,oai_token-act-pair
1,gemma-2-2b,20-gemmascope-res-65k,2109,"the conjunction ""and""",gpt-4o-mini,oai_token-act-pair
2,gemma-2-2b,20-gemmascope-res-65k,2634,"the name ""ja rule.""",gpt-4o-mini,oai_token-act-pair
3,gemma-2-2b,20-gemmascope-res-65k,3339,references to pakistan,gpt-4o-mini,oai_token-act-pair
4,gemma-2-2b,20-gemmascope-res-65k,3408,"mentions of the name ""nick.""",gpt-4o-mini,oai_token-act-pair
...,...,...,...,...,...,...
74649,gemma-2-2b,20-gemmascope-res-65k,52527,references to obstacles or challenges,gpt-4o-mini,oai_token-act-pair
74650,gemma-2-2b,20-gemmascope-res-65k,40585,references to configuration options,gpt-4o-mini,oai_token-act-pair
74651,gemma-2-2b,20-gemmascope-res-65k,53745,"the term ""broad"" in various contexts",gpt-4o-mini,oai_token-act-pair
74652,gemma-2-2b,20-gemmascope-res-65k,53808,java import statements in code,gpt-4o-mini,oai_token-act-pair


In [121]:
values, feature_ids = sae_acts.max(-1)

In [122]:
activating_features = list(set(feature_ids[0].cpu().numpy()))

# Get the explanations for these features
explanations_df.loc[activating_features]

# Print the feature and explanation, one by one
for feature in activating_features:
    print(f"Feature {feature}:")
    print(explanations_df.loc[feature]["description"])
    print("\n")

Feature 16513:
terms related to rankings or titles within a police or law enforcement context


Feature 13699:
 references to medical terminology or conditions related to the brain or neurological structures


Feature 19844:
elements related to providing answers or explanations


Feature 28167:
 page references and letter indications in a document


Feature 43399:
references to development or related code elements


Feature 16521:
 financial terms related to bond markets and yields


Feature 43786:
references to leadership titles and positions within organizations


Feature 40458:
 expressions of necessity and responsibility regarding educational policies and practices


Feature 62226:
 punctuation marks and separate ideas or phrases


Feature 55701:
 technical terms related to programming and data structures


Feature 52632:
 elements related to user interface features and functionalities


Feature 6296:
 instances of marking or signaling changes in sections of text


Feature 12696:
r

In [None]:
# Let's get the dashboard for this feature.
html = get_dashboard_html(
    sae_release="gpt2-small",
    sae_id="7-res-jb",
    feature_idx=bible_features.feature.values[0],
)
IFrame(html, width=1200, height=600)