# Setup

In [1]:
import plotly.io as pio
try:
    import google.colab
    print("Running as a Colab notebook")
    pio.renderers.default = "colab"
    %pip install transformer-lens fancy-einsum
    %pip install -U kaleido # kaleido only works if you restart the runtime. Required to write figures to disk (final cell)
except:
    print("Running as a Jupyter notebook")
    pio.renderers.default = "vscode"
    from IPython import get_ipython
    ipython = get_ipython()

Running as a Jupyter notebook


In [2]:
import torch
from fancy_einsum import einsum
from transformer_lens import HookedTransformer, HookedTransformerConfig, utils, ActivationCache
from torchtyping import TensorType as TT
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import einops
from typing import List, Union, Optional
from functools import partial
import pandas as pd
from pathlib import Path
import urllib.request
from bs4 import BeautifulSoup
from tqdm import tqdm
from datasets import load_dataset
import os
import json

os.environ["TOKENIZERS_PARALLELISM"] = "false" # https://stackoverflow.com/q/62691279
torch.set_grad_enabled(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [3]:
!pip install circuitsvis
import circuitsvis as cv


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [4]:
pio.renderers.default='vscode'

def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [5]:
model = HookedTransformer.from_pretrained(
    "gpt2-large",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-large into HookedTransformer


## Find Activating Examples

In [11]:
from datasets import load_dataset
try:
    dataset = load_dataset("NeelNanda/pile-10k", split="train")
except: # this is a hack to let me work offline
    import pickle
    with open("dataset.pkl", "rb") as f:
        dataset = pickle.load(f)

# # pickle dataset
# import pickle
# with open("dataset.pkl", "wb") as f:
#     pickle.dump(dataset, f)

In [8]:
with open("neuron_max_acts.json", "r") as f:
    neuron_max_acts_load = json.load(f)

In [35]:
neuron = (35, 3724)
neuron_acts = neuron_max_acts_load[str(neuron)]
# Get the top 20 examples
num_samples = 20
neuron_acts_35_3724 = [x + [i] for i, x in enumerate(neuron_acts)] # Add the index of the example in the dataset
sorted_acts_35_3724 = sorted(neuron_acts_35_3724, key=lambda x: x[0], reverse=True) # Sort by activation
print(sorted_acts_35_3724)
top_examples_indices = [x[2] for x in sorted_acts_35_3724[:num_samples]] # Get the indices of the top examples
top_examples_pos = [x[1] for x in sorted_acts_35_3724[:num_samples]] # Get the token positions of the top examples
top_examples_acts = [x[0] for x in sorted_acts_35_3724[:num_samples]] # Get the activations of the top examples
examples = []
for index in top_examples_indices:
    examples.append(dataset[index]['text'])
print(examples)

[[11.053892135620117, 99, 2659], [10.527321815490723, 757, 5653], [10.241874694824219, 451, 1342], [9.469478607177734, 84, 2715], [8.823227882385254, 176, 5500], [8.733474731445312, 341, 7286], [8.57863998413086, 158, 1037], [8.51498794555664, 32, 6527], [8.206207275390625, 748, 1969], [8.162208557128906, 186, 6202], [8.10352611541748, 1018, 3811], [8.00257682800293, 665, 9757], [7.9617600440979, 284, 8387], [7.916937828063965, 996, 7516], [7.8989362716674805, 34, 4679], [7.798364639282227, 76, 8403], [7.797216892242432, 214, 9910], [7.782955646514893, 523, 8259], [7.721066474914551, 119, 1328], [7.719846725463867, 461, 4224], [7.645995140075684, 208, 3216], [7.5494208335876465, 369, 7414], [7.444610595703125, 800, 6907], [7.428597450256348, 65, 3606], [7.390169143676758, 698, 4952], [7.371682167053223, 411, 8686], [7.347747802734375, 656, 9169], [7.282724857330322, 103, 6739], [7.26452112197876, 292, 5437], [7.196075916290283, 47, 7589], [7.170181751251221, 126, 2810], [7.157305717468

In [73]:
class ActivatingDataset:
    def __init__(self, json_file, dataset=None):
        if dataset is None:
            self.data = load_dataset("NeelNanda/pile-10k", split="train")
        else:
            self.data = dataset
        with open(json_file, "r") as f:
            self.markers = json.load(f)

        # Convert keys to tuple if they are string
        def _neuron_str_to_tuple(string):
            if type(string) == str:
                return tuple([int(x) for x in string[1:-1].split(", ")])
            else:
                return string

        self.markers = {_neuron_str_to_tuple(key): value for key, value in self.markers.items()}

    def remove_prompts_longer_than(self, length=100):
        for neuron in self.markers:
            self.markers[neuron] = [prompt for prompt in self.markers[neuron] if prompt['end']-prompt['start'] < length]
        
data = ActivatingDataset('neuron_20_examples_2.json', dataset)
# data.remove_prompts_longer_than(100)

In [74]:
original_lengths = []
lengths = []
neuron_num_of_examples = []
for neuron_str in data.markers:
    for example in data.markers[neuron_str]:
        lengths.append(example['end'] - example['start'])
    neuron_num_of_examples.append(len(data.markers[neuron_str]))


over_100 = 0
for x in lengths:
    if x > 100:
        over_100 += 1
cum = []
for i in range(max(lengths)):
    cum.append(lengths.count(i))

cum = np.cumsum(cum)/len(lengths)

print(cum)

line(cum)

# scatter(np.arange(len(lengths)), lengths, title="Truncated Lengths (80% Activation Recovered)")
scatter(np.arange(len(neuron_num_of_examples)), neuron_num_of_examples)


[0.    0.    0.006 0.045 0.08  0.136 0.186 0.234 0.282 0.33  0.355 0.386
 0.413 0.444 0.46  0.482 0.511 0.532 0.549 0.563 0.576 0.586 0.603 0.615
 0.627 0.635 0.65  0.659 0.663 0.667 0.669 0.676 0.678 0.687 0.689 0.691
 0.695 0.699 0.703 0.705 0.709 0.712 0.718 0.72  0.723 0.726 0.732 0.739
 0.742 0.745 0.751 0.755 0.759 0.76  0.766 0.768 0.771 0.775 0.781 0.784
 0.785 0.79  0.792 0.797 0.797 0.798 0.8   0.803 0.805 0.81  0.81  0.812
 0.813 0.814 0.815 0.816 0.819 0.823 0.824 0.826 0.828 0.83  0.83  0.831
 0.831 0.832 0.832 0.833 0.834 0.836 0.836 0.839 0.841 0.844 0.844 0.844
 0.847 0.85  0.851 0.851 0.852 0.854 0.855 0.855 0.856 0.857 0.858 0.86
 0.86  0.861 0.861 0.861 0.861 0.861 0.861 0.861 0.864 0.865 0.865 0.865
 0.867 0.867 0.869 0.869 0.869 0.871 0.871 0.872 0.873 0.874 0.874 0.874
 0.874 0.874 0.875 0.875 0.875 0.875 0.875 0.875 0.876 0.876 0.877 0.878
 0.878 0.878 0.878 0.88  0.882 0.883 0.884 0.886 0.886 0.886 0.886 0.886
 0.886 0.886 0.887 0.887 0.887 0.888 0.889 0.89  0.8

In [209]:
neuron_20_examples = {}

for neuron_str in tqdm(neuron_max_acts_load):
    # neuron_str is '(layer, neuron_index)'. Make it a tuple
    neuron = tuple([int(x) for x in neuron_str[1:-1].split(", ")])
    neuron_acts = neuron_max_acts_load[neuron_str]
    # Get the top 20 examples
    num_samples = 20
    neuron_acts = [x + [i] for i, x in enumerate(neuron_acts)] # Add the index of the example in the dataset
    sorted_acts = sorted(neuron_acts, key=lambda x: x[0], reverse=True) # Sort by activation
    top_examples_indices = [x[2] for x in sorted_acts[:num_samples]] # Get the indices of the top examples
    top_examples_pos = [x[1] for x in sorted_acts[:num_samples]] # Get the token positions of the top examples
    top_examples_acts = [x[0] for x in sorted_acts[:num_samples]] # Get the activations of the top examples

    neuron_20_examples[neuron_str] = sorted_acts[:num_samples]

    examples = []
    for index in top_examples_indices:
        examples.append(dataset[index]['text'])

    for example_index, example in tqdm(enumerate(examples)):
        # Tokenize the example
        example_tokens = model.to_tokens(example, prepend_bos=False)
        # Truncate the example to the right length
        example_tokens = example_tokens[:, :top_examples_pos[example_index]]
        # print(example_tokens.shape)

        layer, neuron_index = neuron
        neurons = [neuron]

        cache = []
        def return_caching_hook(neuron):
            layer, neuron_index = neuron
            def caching_hook(act, hook):
                cache.append(act[:, -1, neuron_index]) # act shape is (batch_size, seq_len, neuron_index)
            return caching_hook

        original_act = neuron_20_examples[neuron_str][example_index][0]

        hooks = list(((f"blocks.{layer}.mlp.hook_post", return_caching_hook((layer, index))) for layer, index in neurons))
            
        # for i in tqdm(range(example_tokens.shape[1]-1, -1, -1)):
        # # for i in tqdm(range(example_tokens.shape[1])):
        #     new_tokens = example_tokens[:, i:]
        #     model.run_with_hooks(
        #         new_tokens,
        #         fwd_hooks=hooks,
        #     )
        #     print(original_act, cache[-1].item())
        #     if cache[-1].item() > 0.8 * original_act:
        #         # insert onto second element of neuron_20_examples[neuron_str][i]
        #         neuron_20_examples[neuron_str][example_index].insert(1, i)
        #         break

        #Do binary search
        start = 0
        end = example_tokens.shape[1]-1
        mid = (start + end) // 2
        while start != mid:
            new_tokens = example_tokens[:, mid:]
            model.run_with_hooks(
                new_tokens,
                fwd_hooks=hooks,
            )
            if cache[-1].item() > 0.8 * original_act:
                start = mid
            else:
                end = mid
            mid = (start + end) // 2
        neuron_20_examples[neuron_str][example_index].insert(1, mid)
        # and add the activation
        neuron_20_examples[neuron_str][example_index].append(cache[-1].item())

# Save the neuron_20_examples
with open("neuron_20_examples.json", "w") as f:
    json.dump(neuron_20_examples, f)

20it [01:30,  4.51s/it]00:00<?, ?it/s]
20it [00:52,  2.60s/it]01:30<1:13:43, 90.27s/it]
20it [01:37,  4.87s/it]02:22<54:15, 67.82s/it]  
20it [01:07,  3.36s/it]04:00<1:04:06, 81.85s/it]
20it [00:38,  1.94s/it]05:08<58:19, 76.08s/it]  
20it [01:06,  3.31s/it]05:46<46:57, 62.62s/it]
20it [00:51,  2.60s/it]06:53<46:49, 63.86s/it]
20it [00:53,  2.67s/it]07:45<42:58, 59.97s/it]
20it [00:53,  2.65s/it]08:38<40:31, 57.89s/it]
20it [00:48,  2.42s/it]09:31<38:31, 56.37s/it]
20it [01:28,  4.42s/it][10:20<35:57, 53.94s/it]
20it [01:16,  3.81s/it][11:48<41:55, 64.50s/it]
20it [01:06,  3.31s/it][13:04<43:05, 68.04s/it]
20it [01:58,  5.93s/it][14:10<41:37, 67.49s/it]
20it [00:52,  2.61s/it][16:09<49:46, 82.96s/it]
20it [00:52,  2.62s/it][17:01<42:59, 73.70s/it]
20it [01:09,  3.49s/it][17:54<38:08, 67.32s/it]
20it [00:54,  2.75s/it][19:04<37:25, 68.05s/it]
20it [01:15,  3.79s/it][19:59<34:11, 64.11s/it]
20it [01:08,  3.43s/it][21:14<34:56, 67.64s/it]
20it [00:46,  2.34s/it][22:23<33:57, 67.93s/it]
20

In [43]:
original_lengths = []
lengths = []
for neuron_str in neuron_20_examples:
    for example in neuron_20_examples[neuron_str]:
        lengths.append(example[2] - example[1])
        original_lengths.append(example[1])

scatter(np.arange(len(lengths)), lengths, title="Truncated Lengths (80% Activation Recovered)")
scatter(np.arange(len(lengths)), original_lengths, title="Original Lengths")

NameError: name 'neuron_20_examples' is not defined

In [219]:
for neuron_str in neuron_20_examples:
    neuron = tuple([int(x) for x in neuron_str[1:-1].split(", ")])
    neuron_acts = neuron_20_examples[neuron_str]
    # Get the top 20 examples
    num_samples = 20
    neuron_acts = [x + [i] for i, x in enumerate(neuron_acts)] # Add the index of the example in the dataset
    sorted_acts = sorted(neuron_acts, key=lambda x: x[0], reverse=True) # Sort by activation
    top_examples_indices = [x[3] for x in sorted_acts[:num_samples]] # Get the indices of the top examples
    top_examples_start = [x[1] for x in sorted_acts[:num_samples]] # Get the token positions of the top examples
    top_examples_pos = [x[2] for x in sorted_acts[:num_samples]] # Get the token positions of the top examples
    top_examples_acts = [x[0] for x in sorted_acts[:num_samples]] # Get the activations of the top examples

    neuron_20_examples[neuron_str] = sorted_acts[:num_samples]

    examples = []
    for index in top_examples_indices:
        examples.append(dataset[index]['text'])

    for example_index, example in tqdm(enumerate(examples)):
        example_tokens = model.to_tokens(example, prepend_bos=False)
        example_tokens = example_tokens[:, top_examples_start[example_index]:top_examples_pos[example_index]]
        example_str = model.to_string(example_tokens)
        print(example_str)
    break


20it [00:00, 364.27it/s]

[" restrict CV access to legitimate companies only, it cannot be held responsible for how CVs are used by third parties once they have been downloaded from our database.\n\nIdentifiable information is anything that is unique to a user (i.e. email addresses, telephone numbers and CV files).\n\nNiche Jobs Ltd may from time to time send email-shots on behalf of third parties to users. Users can unsubscribe from mailshots using the unsubscribe link in the email or by contacting Niche Jobs Ltd via the Contact Us page on the website.\n\nNon-identifiable information\n\nNiche Jobs Ltd may also collect information (via cookies) about users and how they interact with the site, for purposes of performance measuring and statistics. This information is aggregated, so is not identifiable on an individual user basis.\n\nUsers may choose to accept or deny cookies from Niche Jobs Ltd, but users should be aware that if cookies are not permitted it may adversely affect a user’s experience of the site.\n\




In [232]:
# Check for induction
for example_index, example in tqdm(enumerate(examples)):
    example_tokens = model.to_tokens(example, prepend_bos=False)
    example_tokens = example_tokens[:, top_examples_start[example_index]:top_examples_pos[example_index]]
    token_list = example_tokens.tolist()[0]
    # Check for how many consecutive tokens from the 
    first_ten = token_list[:10]
    last_three = token_list[-3:]
    num_induced = 0

    for token in last_three:
        if token in first_ten:
            num_induced += 1
    print(num_induced, len(token_list), model.to_string(first_ten), "@@@", model.to_string(last_three))



20it [00:00, 261.02it/s]

3 662  restrict CV access to legitimate companies only, it cannot @@@  to legitimate companies
2 63 ire can only work if you are follwoing @@@ cersire can
3 6  for one reason and one reason @@@  and one reason
3 5  site is for informational purposes @@@  for informational purposes
3 122  government. It can only exist until the majority discovers @@@ . It can
1 16 known result "$X$ is ${\rm T @@@ $ if and
3 10  Software that is described herein is for illustrative purposes @@@  illustrative purposes
3 93  becuase we are the only species that kills @@@  we are the
3 352  providing information about graduate programs only. Questions regarding the @@@  about graduate programs
3 6  is intended for general information purposes @@@  general information purposes
3 10  Disclaimer: The Body is designed for educational purposes @@@  for educational purposes
3 5  is intended for entertainment purposes @@@  for entertainment purposes
1 12  directory, ("Find A Doctor"), is provided for @@@  for refe




In [None]:
for example in examples:
    # Tokenize the example
    example_tokens = model.to_tokens(example, prepend_bos=False)
    # Truncate the example to the right length
    example_tokens = example_tokens[:, :top_examples_pos[index]]
    print(example_tokens.shape)

    layer, neuron_index = neuron
    neurons = [neuron]

    cache = []
    def return_caching_hook(neuron):
        layer, neuron_index = neuron
        def caching_hook(act, hook):
            cache.append(act[:, -1, neuron_index]) # act shape is (batch_size, seq_len, neuron_index)
        return caching_hook

    hooks = list(((f"blocks.{layer}.mlp.hook_post", return_caching_hook((layer, index))) for layer, index in neurons))
    print(hooks)
        
    # for i in tqdm(range(example_tokens.shape[1]-1, -1, -1)):
    for i in tqdm(range(example_tokens.shape[1])):
        new_tokens = example_tokens[:, i:]
        model.run_with_hooks(
            new_tokens,
            fwd_hooks=hooks,
        )

In [155]:
cache_numbers = list(reversed([x.item() for x in cache]))
line(cache_numbers, title="Neuron 35, 3724 activations", labels={"y":"Activation", "x":"Token position"})

In [146]:
i = 295
print(model.to_string(example_tokens[:, i:]))

[' found it made the whole bed too high with the wider mattress, I phoned the store and they came back the "same day" and replaced the box spring with a thinner one and there were no additional charges. Great shop, nice people, good quality products I would highly recommend.\nThank you.\n\nTerence McFall\n\n20:17 23 Sep 16\n\nWe purchased two twin Natural XL Ironman beds in March with the help and guidance of the great staff at M and N Beds in Parksville. We noticed very soon after that our sleep patterns improved as did freedom from morning back aches. Yes, these beds cost somewhat more than many regular beds and that is because they are unlike many regular beds.\nVery importantly is the after sales contact made by this Company to ensure your satisfaction with the beds. I have no hesitation in recommending M and N Beds\nTerry M. Qualicum Beach\n\nMelanie Trudeau\n\n21:52 15 Oct 16\n\nGreat customer service! They were able to answer all my questions! We were there for an hour and the r

In [203]:
index = 2
# Get the first example
example = examples[index]
# Tokenize the example
example_tokens = model.to_tokens(example, prepend_bos=False)
# Truncate the example to the right length
example_tokens = example_tokens[:, :top_examples_pos[index]]
print(example_tokens.shape)

torch.Size([1, 451])


In [119]:
index = 14
# Get the first example
example = examples[index]
# Tokenize the example
example_tokens = model.to_tokens(example, prepend_bos=False)
# Truncate the example to the right length
example_tokens = example_tokens[:, :top_examples_pos[index]]
print(example_tokens.shape)

layer, neuron_index = neuron
neurons = [neuron]

cache = []
def return_caching_hook(neuron):
    layer, neuron_index = neuron
    def caching_hook(act, hook):
        cache.append(act[:, -1, neuron_index]) # act shape is (batch_size, seq_len, neuron_index)
    return caching_hook

hooks = list(((f"blocks.{layer}.mlp.hook_post", return_caching_hook((layer, index))) for layer, index in neurons))
print(hooks)
    

for i in tqdm(range(example_tokens.shape[1])):
    new_tokens = example_tokens[:, i:]
    model.run_with_hooks(
        new_tokens,
        fwd_hooks=hooks,
    )


torch.Size([1, 34])
[('blocks.35.mlp.hook_post', <function return_caching_hook.<locals>.caching_hook at 0x356d07880>)]


100%|██████████| 34/34 [00:10<00:00,  3.34it/s]


In [124]:
cache_numbers = [x.item() for x in cache]
# line(cache_numbers, title="Neuron 35, 3724 activations", labels={"y":"Activation", "x":"Token position"})
# Label each point with the token, using px.line
line(cache_numbers, title="Neuron 35, 3724 activations", labels={"y":"Activation", "x":"Token position"}, hover_data={"x":model.to_str_tokens(example_tokens[0])})

In [110]:
i = 0
print(model.to_string(example_tokens[:, i:]))

['Youth unemployment in Victoria hits 15-year high\n\nHenrietta Cook, state political reporter\n\nYouth unemployment has leapt to a 15-year']


In [60]:
batch_size = 4
batched_texts = [dataset_text_list[i: i+batch_size] for i in range(0, len(dataset_text_list), batch_size)]
print(len(batched_texts))

neuron_max_acts = {neuron: [] for neuron in neurons}

for texts in tqdm(batched_texts):
    model.reset_hooks()

    cache = {}

    def return_caching_hook(neuron):
        layer, neuron_index = neuron
        def caching_hook(act, hook):
            cache[(layer, neuron_index)] = act[:, :, neuron_index] # act shape is (batch_size, seq_len, neuron_index)
        return caching_hook
    
    hooks = list(((f"blocks.{layer}.mlp.hook_post", return_caching_hook((layer, index))) for layer, index in neurons))
    print(hooks)

    model.run_with_hooks(
        model.to_tokens(texts),
        fwd_hooks=hooks,
    )
    cache = cache_to_tuples(cache)

    for key in cache.keys():
        neuron_max_acts[key].extend(cache[key])



2500


  0%|          | 0/2500 [00:00<?, ?it/s]

[('blocks.0.mlp.hook_post', <function return_caching_hook.<locals>.caching_hook at 0x2974e0160>), ('blocks.1.mlp.hook_post', <function return_caching_hook.<locals>.caching_hook at 0x2c775caf0>)]


  0%|          | 0/2500 [00:33<?, ?it/s]


In [67]:
print(neuron_max_acts)
neuron_max_acts_json = {str(key): value for key, value in neuron_max_acts.items()}
with open("neuron_max_acts.json", "w") as f:
    json.dump(neuron_max_acts_json, f)

{(0, 0): [(0.4795704782009125, 488), (1.505552887916565, 152), (1.5584756135940552, 137), (-0.0, 22)], (1, 1): [(-0.0, 566), (0.20460055768489838, 810), (2.280916929244995, 76), (0.30561938881874084, 66)]}


In [70]:
with open("neuron_max_acts.json", "r") as f:
    neuron_max_acts_load = json.load(f)

print(neuron_max_acts_load.keys())
# This is a dictionary with string keys '(layer, neuron)', and each value is a list of [(activation, index)] where each tuple is an example.

dict_keys(['(31, 3621)', '(31, 364)', '(31, 2918)', '(31, 4378)', '(31, 988)', '(31, 2658)', '(31, 2692)', '(31, 4941)', '(31, 2415)', '(31, 1407)', '(32, 4964)', '(32, 2412)', '(32, 4282)', '(32, 3151)', '(32, 1155)', '(32, 1386)', '(32, 3582)', '(32, 4882)', '(32, 3477)', '(32, 406)', '(33, 1202)', '(33, 524)', '(33, 1582)', '(33, 4446)', '(33, 204)', '(33, 4900)', '(33, 2322)', '(33, 3278)', '(33, 1299)', '(33, 52)', '(34, 4012)', '(34, 4262)', '(34, 320)', '(34, 5095)', '(34, 2599)', '(34, 2442)', '(34, 4494)', '(34, 4199)', '(34, 727)', '(34, 4410)', '(35, 4518)', '(35, 48)', '(35, 5014)', '(35, 3724)', '(35, 3360)', '(35, 885)', '(35, 4924)', '(35, 274)', '(35, 2369)', '(35, 4638)'])


In [95]:
for key in neuron_max_acts_load.keys():
    x = neuron_max_acts_load[key]
    # Get the maximum of the first integer of each tuple, and get its index
    max_value, max_index = max(x, key=lambda x: x[0])
    print(key, max_value)

(31, 3621) 11.057353973388672
(31, 364) 10.807672500610352
(31, 2918) 12.039180755615234
(31, 4378) 9.483403205871582
(31, 988) 11.906610488891602
(31, 2658) 12.757360458374023
(31, 2692) 8.618489265441895
(31, 4941) 9.564465522766113
(31, 2415) 12.56452465057373
(31, 1407) 11.221527099609375
(32, 4964) 10.279662132263184
(32, 2412) 14.07355785369873
(32, 4282) 9.611800193786621
(32, 3151) 12.704906463623047
(32, 1155) 17.4642391204834
(32, 1386) 8.08731460571289
(32, 3582) 8.37304973602295
(32, 4882) 7.985595226287842
(32, 3477) 10.776716232299805
(32, 406) 10.953438758850098
(33, 1202) 9.49435806274414
(33, 524) 10.278265953063965
(33, 1582) 6.070120811462402
(33, 4446) 6.540103912353516
(33, 204) 6.799570560455322
(33, 4900) 6.093514442443848
(33, 2322) 9.797901153564453
(33, 3278) 7.231164932250977
(33, 1299) 7.962460517883301
(33, 52) 7.98359489440918
(34, 4012) 13.326713562011719
(34, 4262) 12.96543025970459
(34, 320) 11.070517539978027
(34, 5095) 11.418562889099121
(34, 2599) 14

In [100]:
# Scatter plot a neuron
neuron_str = "(35, 3724)"
x = neuron_max_acts_load[neuron_str]
# Scatter plot the activations
scatter(
    x=[i for i in range(len(x))],
    y=[i[0] for i in x],
    xaxis="Text index",
    yaxis="Activation",
    caxis="",
    title=f"Activations for neuron {neuron_str}",
)