In [None]:
# Imports

import os
import torch as t
from nnsight import LanguageModel
import datasets
import anthropic
from tqdm import tqdm
import re
import ast
import pickle
from collections import defaultdict
from circuitsvis.activations import text_neuron_activations
from transformers import AutoTokenizer
import random
import json

import experiments.utils as utils
from experiments.autointerp import (
    get_max_activating_prompts,
    highlight_top_activations,
    compute_dla,
    format_examples,
    get_autointerp_inputs_for_all_saes,
)
import experiments.llm_autointerp.llm_utils as llm_utils

DEBUGGING = True

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)

%load_ext autoreload
%autoreload 2

In [None]:
with open("../anthropic_api_key.txt", "r") as f:
    api_key = f.read().strip()

os.environ['ANTHROPIC_API_KEY'] = api_key

In [None]:
client = anthropic.Anthropic()

message = client.messages.create(
    model="claude-3-5-sonnet-20240620",
    max_tokens=50,
    temperature=0,
    system="You are a world-class poet. Respond only with short poems.",
    messages=[
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "Why is the ocean salty?"
                }
            ]
        }
    ]
)
print(message.content)

In [None]:
profession_dict = {
    "accountant": 0,
    "architect": 1,
    "attorney": 2,
    "chiropractor": 3,
    "comedian": 4,
    "composer": 5,
    "dentist": 6,
    "dietitian": 7,
    "dj": 8,
    "filmmaker": 9,
    "interior_designer": 10,
    "journalist": 11,
    "model": 12,
    "nurse": 13,
    "painter": 14,
    "paralegal": 15,
    "pastor": 16,
    "personal_trainer": 17,
    "photographer": 18,
    "physician": 19,
    "poet": 20,
    "professor": 21,
    "psychologist": 22,
    "rapper": 23,
    "software_engineer": 24,
    "surgeon": 25,
    "teacher": 26,
    "yoga_teacher": 27,
    "male / female": "male / female",
    "professor / nurse": "professor / nurse",
    "male_professor / female_nurse": "male_professor / female_nurse",
    "biased_male / biased_female": "biased_male / biased_female",
}

chosen_class_names = [
    "gender",
    "professor",
    "nurse",
    "accountant",
    "architect",
    "attorney",
    "dentist",
    "filmmaker",
]

PROMPT_DIR = "llm_autointerp"

In [None]:
min_scale = 0
max_scale = 4

In [None]:
with open(f"{PROMPT_DIR}manual_labels_few_shot.json", "r") as f:
    few_shot_manual_labels = json.load(f)

for label in few_shot_manual_labels:
    print(label, few_shot_manual_labels[label]["per_class_scores"])

few_shot_examples = "Here's a few examples of how to perform the task:\n\n"

for i, selected_index in enumerate(few_shot_manual_labels):
    example_prompts = few_shot_manual_labels[selected_index]["example_prompts"]
    tokens_string = few_shot_manual_labels[selected_index]["tokens_string"]
    per_class_scores = few_shot_manual_labels[selected_index]["per_class_scores"]
    chain_of_thought = few_shot_manual_labels[selected_index]["chain_of_thought"]

    example_prompts = example_prompts[0].split("Example 4:")[0]

    few_shot_examples += f"\n\n<<BEGIN EXAMPLE FEATURE {i}>>\n"
    few_shot_examples += f"Promoted tokens: {tokens_string}\n"
    few_shot_examples += f"Example prompts: {example_prompts}\n"
    few_shot_examples += f"Chain of thought: {chain_of_thought}\n\n"
    few_shot_examples += "```json\n"
    few_shot_examples += f"{per_class_scores}\n"
    few_shot_examples += "```"
    few_shot_examples += f"\n<<END EXAMPLE FEATURE {i}>>\n\n"

print(few_shot_examples)

In [None]:
print(len(few_shot_examples)) 



print(llm_utils.count_tokens(few_shot_examples))

In [None]:
system_prompt = build_system_prompt(
    concepts=chosen_class_names, min_scale=min_scale, max_scale=max_scale
)

# print(count_tokens(system_prompt))
print(system_prompt[0]['text'])

In [None]:
current_idx = 0
number_of_test_examples = 10

displayed_prompts = 10
num_top_emphasized_tokens = 5
include_activations = True
t.set_printoptions(sci_mode=False)

In [None]:
with open(f"{PROMPT_DIR}manual_labels_adam_corr.json", "r") as f:
    manual_test_labels = json.load(f)



In [None]:

test_prompts = []

for example_feature in manual_test_labels:

    example_prompts = manual_test_labels[example_feature]["example_prompts"]
    tokens_string = manual_test_labels[example_feature]["tokens_string"]
    per_class_scores = manual_test_labels[example_feature]["per_class_scores"]
    chain_of_thought = manual_test_labels[example_feature]["chain_of_thought"]
    class_index = manual_test_labels[example_feature]["class_index"]

    llm_prompt = "Okay, now here's the real task.\n"
    llm_prompt += f"Promoted tokens: {tokens_string}\n"
    llm_prompt += f"Example prompts: {example_prompts[0]}\n"
    llm_prompt += "Chain of thought:"

    test_prompts.append((llm_prompt, class_index, per_class_scores, chain_of_thought))

In [None]:
print(test_prompts[0])

In [None]:
test_idx = 1

test_prompt = few_shot_examples + test_prompts[test_idx][0]
print(test_prompt)

In [None]:
print(test_prompts[test_idx][1])
print(test_prompts[test_idx][2])

In [None]:
# Original for loop implementation for testing

results = []

for i in range(number_of_test_examples):
    test_prompt = few_shot_examples + test_prompts[i][0]

    message = client.messages.create(
        # model="claude-3-5-sonnet-20240620",
        model="claude-3-haiku-20240307",
        max_tokens=500,
        temperature=0,
        system=system_prompt[0]['text'],
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": test_prompt
                    }
                ]
            }
        ]
    )
    llm_response = message.content[0].text

    json_response = llm_utils.extract_and_validate_json(llm_response)
    good_json, verification_message = llm_utils.verify_json_response(json_response, min_scale, max_scale, chosen_class_names)
    results.append((llm_response, json_response, good_json, verification_message))
    print(i, good_json, verification_message)

In [None]:
llm_response = message.content[0].text

print(llm_response)

In [None]:
json_response = llm_utils.extract_and_validate_json(llm_response)
good_json, verification_message = llm_utils.verify_json_response(json_response, min_scale, max_scale, chosen_class_names)
print(json_response)
print(chosen_class_names)
print(f"Good json: {good_json}")
print(verification_message)