In [2]:
%load_ext autoreload
%autoreload 2

In [1]:
import os

os.environ["BAIDU_API_KEY"] = "BAIDUTOKEN"

import logging

logging.basicConfig(level=logging.DEBUG, filename='debug.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')

from neuron_explainer.activations.activation_records import calculate_max_activation
from neuron_explainer.activations.activations import ActivationRecordSliceParams, load_neuron
from neuron_explainer.explanations.calibrated_simulator import UncalibratedNeuronSimulator
from neuron_explainer.explanations.explainer import TokenActivationPairExplainer
from neuron_explainer.explanations.prompt_builder import PromptFormat
from neuron_explainer.explanations.scoring import simulate_and_score
from neuron_explainer.explanations.simulator import LogprobFreeExplanationTokenSimulator

EXPLAINER_MODEL_NAME = "gpt-3.5-turbo"
SIMULATOR_MODEL_NAME = "gpt-3.5-turbo"


# test_response = await client.make_request(prompt="test 123<|endofprompt|>", max_tokens=2)
# print("Response:", test_response["choices"][0]["text"])

# Load a neuron record.
neuron_record = load_neuron(9, 6236)

# Grab the activation records we'll need.
slice_params = ActivationRecordSliceParams(n_examples_per_split=1)
train_activation_records = neuron_record.train_activation_records(
    activation_record_slice_params=slice_params
)
print("sample:",train_activation_records)
valid_activation_records = neuron_record.valid_activation_records(
    activation_record_slice_params=slice_params
)

# Generate an explanation for the neuron.
explainer = TokenActivationPairExplainer(
    model_name=EXPLAINER_MODEL_NAME,
    prompt_format=PromptFormat.HARMONY_V4,
    max_concurrent=1,
)
explanation = await explainer.generate_explanations(
    all_activation_records=train_activation_records,
    max_activation=calculate_max_activation(train_activation_records),
    num_samples=1,
)
print(f"{explanation=}")

# Simulate and score the explanation.
simulator = UncalibratedNeuronSimulator(
    LogprobFreeExplanationTokenSimulator(
        SIMULATOR_MODEL_NAME,
        explanation,
        # max_concurrent=1,
        # prompt_format=PromptFormat.INSTRUCTION_FOLLOWING,
    )
)
scored_simulation = await simulate_and_score(simulator, valid_activation_records)
print(f"score={scored_simulation.get_preferred_score():.2f}")


sample: [ActivationRecord(dataclass_name='ActivationRecord', tokens=[' God', ' give', ' two', ' men', ' or', ' two', ' women', ' a', ' "', 'right', '"', ' to', ' marry', ' one', ' another', ' and', ' then', ' adopt', ' children', ' with', ' the', ' approval', ' of', ' the', ' state', '?', ' If', ' two', ' people', ' of', ' the', ' same', ' sex', ' do', ' have', ' a', ' right', ' to', ' marry', ' and', ' take', ' custody', ' of', ' children', ',', ' then', ',', ' as', ' this', ' column', ' argued', ' last', ' week', ',', ' children', ' cannot', ' be', ' deemed', ' to', ' have', ' a', ' right', ' to', ' a'], activations=[-0.0964, -0.0085, -0.0848, -0.0367, -0.0457, 4.0117, -0.1632, -0.1575, -0.041, -0.0516, -0.0504, -0.1318, -0.0297, -0.1462, -0.124, -0.1594, -0.0454, -0.0203, -0.0343, -0.069, -0.0657, -0.0082, -0.0753, -0.1268, -0.0182, -0.0632, -0.0465, -0.1132, -0.0455, -0.0956, -0.0772, -0.0711, -0.0262, -0.0328, -0.0709, -0.1166, -0.1128, -0.1631, -0.0934, -0.1503, -0.0115, -0.0307,