# Setup

In [1]:
import plotly.io as pio
try:
    import google.colab
    print("Running as a Colab notebook")
    pio.renderers.default = "colab"
    %pip install transformer-lens fancy-einsum
    %pip install -U kaleido # kaleido only works if you restart the runtime. Required to write figures to disk (final cell)
except:
    print("Running as a Jupyter notebook")
    pio.renderers.default = "vscode"
    from IPython import get_ipython
    ipython = get_ipython()

Running as a Jupyter notebook


In [2]:
import torch
from fancy_einsum import einsum
from transformer_lens import HookedTransformer, HookedTransformerConfig, utils, ActivationCache
from torchtyping import TensorType as TT
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import einops
from typing import List, Union, Optional
from functools import partial
import pandas as pd
from pathlib import Path
import urllib.request
from bs4 import BeautifulSoup
from tqdm import tqdm
from datasets import load_dataset
import os
import json

os.environ["TOKENIZERS_PARALLELISM"] = "false" # https://stackoverflow.com/q/62691279
torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [3]:
!pip install circuitsvis
import circuitsvis as cv


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [4]:
pio.renderers.default='vscode'

def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

# Neuron Exploration

In [5]:
import activating_dataset, prompt_generators
from importlib import reload
reload(activating_dataset)
reload(prompt_generators)
# Load Dataset
from activating_dataset import ActivatingDataset
from prompt_generators import ExplanationPromptGen

# Load Neuron
try:
    dataset = load_dataset("NeelNanda/pile-10k", split="train")
except: # this is a hack to let me work offline
    import pickle
    with open("data/dataset.pkl", "rb") as f:
        dataset = pickle.load(f)

data = ActivatingDataset('data/neuron_20_examples_2.json', dataset)
data.remove_prompts_longer_than(100)
neurons = list(data.markers.keys())

Using custom data configuration NeelNanda--pile-10k-72f566e9f7c464ab
Found cached dataset parquet (/Users/clementneo/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


In [6]:
import json

# Load head_attribution_dict
with open("data/head_attribution_dict.json", "r") as f:
    head_attribution_dict = json.load(f)
    head_attribution_dict = {tuple([int(x) for x in k[1:-1].split(", ")]):v for k,v in head_attribution_dict.items()}

# Load neuron finder results
with open("data/neuron_finder_results.json", "r") as f:
    neuron_finder_results = json.load(f)

neuron_to_token = {}

for layer in neuron_finder_results.keys():
    for neuron_ind in neuron_finder_results[layer].keys():
        neuron_to_token[(int(layer), int(neuron_ind))] = neuron_finder_results[layer][neuron_ind][0]

print(neuron_to_token) # a dictionary of (layer, index) -> " token"
print(head_attribution_dict) # a dictionary of (layer, index) -> {'prompt_1':[head_1, head_2, head_3], 'prompt_2':[...]}

{(31, 3621): ' only', (31, 364): ' number', (31, 2918): ' go', (31, 4378): ' together', (31, 988): ' called', (31, 2658): ' first', (31, 2692): ' used', (31, 4941): ' within', (31, 2415): ' way', (31, 1407): ' out', (32, 4964): ' too', (32, 2412): ' will', (32, 4282): ' right', (32, 3151): ' over', (32, 1155): ' out', (32, 1386): ' once', (32, 3582): ' her', (32, 4882): ' class', (32, 3477): ' use', (32, 406): ' much', (33, 1202): ' so', (33, 524): ' state', (33, 1582): ' RandomRedditor', (33, 4446): ' about', (33, 204): ' following', (33, 4900): ' of', (33, 2322): ' after', (33, 3278): ' around', (33, 1299): ' last', (33, 52): ' by', (34, 4012): ' off', (34, 4262): ' down', (34, 320): ' back', (34, 5095): ' well', (34, 2599): ' there', (34, 2442): ' up', (34, 4494): ' no', (34, 4199): ' after', (34, 727): ' under', (34, 4410): ' as', (35, 4518): ' issue', (35, 48): ' close', (35, 5014): ' won', (35, 3724): ' high', (35, 3360): ' very', (35, 885): ' His', (35, 4924): ' of', (35, 274): 

In [8]:
neurons = list(head_attribution_dict.keys())

gpt_4_prompts_dict = {}
nh_to_pos_neg_examples = {}

for neuron in neurons:
# if True: 
#     neuron = neurons[3]
    token = neuron_to_token[neuron]
    trunc_prompts = list(head_attribution_dict[neuron].keys())
    trunc_prompts = [trunc_prompt + token for trunc_prompt in trunc_prompts] # Add back the token -- experimential
    top_heads = list(head_attribution_dict[neuron].values())

    num_prompts = len(top_heads)

    from collections import Counter
    flattened_heads = [x for y in top_heads for x in y]
    head_count = Counter(flattened_heads)

    PERCENT_TO_KEEP = 0.25
    relevant_heads = [x for x in head_count if head_count[x] > PERCENT_TO_KEEP * num_prompts]

    for head in relevant_heads:

        positive_examples = []
        negative_examples = []
        for i, example in enumerate(trunc_prompts):
            if head in top_heads[i]:
                positive_examples.append(trunc_prompts[i])
            else:
                negative_examples.append(trunc_prompts[i])

        # prompt_gen = ExplanationPrompt(neuron, trunc_prompts, top_heads, neuron_to_token, shots_dict)
        prompt_gen = ExplanationPromptGen(token, positive_examples, negative_examples)
        prompt = prompt_gen.get_prompt(shots=0)
        gpt_4_prompts_dict[(*neuron, head)] = prompt
        nh_to_pos_neg_examples[(*neuron, head)] = (positive_examples, negative_examples)

save_files = False

In [9]:
import pickle
if save_files:
    with open("data/categorised_prompts_1.pkl", "wb") as f:
        pickle.dump(nh_to_pos_neg_examples, f)
    save_files = False

In [109]:
print(gpt_4_prompts_dict[(31, 3621, 468)])

We are studying attention heads in a transformer architecture neural network. Each attention head looks for some particular thing in a short document.
This attention head in particular helps to predict that the last token is " only", but it is only active in some documents and not others.
Look at the documents and explain what makes the attention head active, taking into consideration the inactive examples.

Examples where the attention head is active: """
*ire can only work if you are follwoing a proper diet according to your body needs. That is not easy thing for most of the people. To stay healthy, However, It requires you to work hard and stay focused on the end goal. One of the best way to st... moreExcersire can only
* becuase we are the only species that kills for revenge, sport and greed, as well as fashion. I'm not condeming what you said either, I'm just pointing out different sides.  nyan nyan percent  But do you wear leather? Eat meat? Use chemicals to clean? Animals were a

In [159]:
i = 0

In [178]:
keys = list(gpt_4_prompts_dict.keys())
print(keys[i])
prompt_i = gpt_4_prompts_dict[keys[i]]
print(prompt_i)
i += 1

(31, 4378, 88)
We are studying attention heads in a transformer architecture neural network. Each attention head looks for some particular thing in a short document.
This attention head in particular helps to predict that the last token is " together", but it is only active in some documents and not others.
Look at the documents and explain what makes the attention head active, taking into consideration the inactive examples.

Examples where the attention head is active: """
* in turn leads to the increase of Rho activity. Taken together
*in expression. Taken together
* health settings. Taken together
* synthesized ester conjugates. Taken together
* limits of the nuclear pores. Taken together
"""
Examples where the attention head is inactive: """
* used together. That is the reason why both statements which are incrementing c1 and c2 have to be synchronized. But why do they say in the next sentence that there is no reson to prevent an update of c1 from being interleaved with an update 

In [146]:
# Dump the dictionary
gpt_4_prompts_dict_str_key = {str(k):v for k,v in gpt_4_prompts_dict.items()}
# Convert the key, a tuple to a string like (1, 2) -> "(1, 2)"

# Save results
with open("data/head_explanation_1_prompts.json", "w") as f:
    json.dump(gpt_4_prompts_dict_str_key, f)

In [114]:
import json
import os
filename = "data/head_explanation_1.jsonl"

if os.path.isfile(filename):
    raise Exception("File already exists!")

jobs = [
            {"model":"gpt-4",
            "messages":[{"role": "user", "content": gpt_4_prompt}],
            "max_tokens":200, 
        } for gpt_4_prompt in gpt_4_prompts_dict.values()]

with open(filename, "w") as f:
    for job in jobs:
        json_string = json.dumps(job)
        f.write(json_string + "\n")

In [147]:
import time
now = int(time.time())
print(now)

#########################################################################################
# WHEN YOU UPDATE THE CELL, REMEMBER TO UPDATE THE JSONL FILE NAME IF YOU'VE CHANGED IT #
#########################################################################################
current_time = 1686792857

if current_time + 20 < now: # Sanity check to make sure you don't spam this cell
    raise Exception("Update the current_time variable to be able to run this cell! Copy and paste the number above.")
else:
    print("all gucci")
    !python3 api_request_parallel_processor.py --requests_filepath data/head_explanation_1.jsonl --request_url https://api.openai.com/v1/chat/completions --max_requests_per_minute 100 --max_tokens_per_minute 20000
    # ^ This is very scary because the stdout looks like it sends repeated requests for the same thing so just run it in terminal


1686793228


Exception: Update the current_time variable to be able to run this cell! Copy and paste the number above.

In [150]:
prompts_to_neuron_head = {v:k for k, v in gpt_4_prompts_dict.items()}

In [157]:
# Load the results and make the {neuron_head: response} dictionary

with open('data/head_explanation_1_results.jsonl', 'r') as json_file:
    raw_explanations = list(json_file)

prompt_to_response = {}
neuron_head_to_response = {}
neuron_head_to_prompt = gpt_4_prompts_dict
neuron_heads = list(neuron_head_to_prompt.keys())

for i, json_str in enumerate(raw_explanations):
    result = json.loads(json_str)
    prompt = result[0]["messages"][0]["content"]
    response = result[1]["choices"][0]["message"]["content"]
    prompt_to_response[prompt] = response

for neuron_head in neuron_heads:
    prompt = neuron_head_to_prompt[neuron_head]
    response = prompt_to_response[prompt]
    neuron_head_to_response[neuron_head] = response

with open("data/head_explanation_1_nh_to_exp.json", "w") as f:
    nh_to_response_str_key = {str(k):v for k,v in neuron_head_to_response.items()}
    json.dump(nh_to_response_str_key, f)