In [1]:
import matplotlib.pyplot as plt
import json
import torch
import pickle
from typing import Optional
from matplotlib.colors import Normalize
import numpy as np
import os
import random
import datasets

import einops
import dictionary_learning.interp as interp
from circuitsvis.activations import text_neuron_activations
from collections import namedtuple
from nnsight import LanguageModel
from tqdm import tqdm

import experiments.utils as utils
from dictionary_learning.utils import hf_dataset_to_generator
from dictionary_learning.buffer import ActivationBuffer

from experiments.autointerp import highlight_top_activations

DEBUGGING = True

if DEBUGGING:
    tracer_kwargs = dict(scan=True, validate=True)
else:
    tracer_kwargs = dict(scan=False, validate=False)

%load_ext autoreload
%autoreload 2

In [2]:
DEVICE = "cuda"
model_name = "EleutherAI/pythia-70m-deduped"
model_dtype = torch.bfloat16
model = LanguageModel(
    model_name,
    device_map=DEVICE,
    dispatch=True,
    attn_implementation="eager",
    torch_dtype=model_dtype,
)


context_length = 128

dataset = datasets.load_dataset("georgeyw/dsir-pile-100k", streaming=False)

data = model.tokenizer(dataset["train"]["contents"][:10000], return_tensors="pt", padding="max_length", truncation=True, max_length=context_length).to(DEVICE).data

Repo card metadata block was not found. Setting CardData to empty.


In [3]:
print(type(data))
batch_size = 250

batched_data = utils.batch_inputs(data, batch_size)

<class 'dict'>


In [4]:
import einops
import dictionary_learning.interp as interp
from circuitsvis.activations import text_neuron_activations
from collections import namedtuple
from nnsight import LanguageModel
from tqdm import tqdm

import experiments.utils as utils
from dictionary_learning.utils import hf_dataset_to_generator
from dictionary_learning.buffer import ActivationBuffer


def get_max_activating_prompts(
    model,
    submodule,
    tokenized_inputs_bL: list[list[dict]],
    dim_indices: torch.Tensor,
    batch_size: int,
    dictionary=None,
    n_inputs: int = 512,
    k: int = 30,
):

    assert n_inputs % batch_size == 0

    feature_count = dim_indices.shape[0]

    device = model.device

    max_activating_indices_FK = torch.zeros((feature_count, k), device=device, dtype=torch.int)
    max_activations_FK = torch.zeros((feature_count, k), device=device, dtype=torch.float32)
    max_tokens_FKL = torch.zeros((feature_count, k, context_length), device=device, dtype=torch.int)
    max_activations_FKL = torch.zeros((feature_count, k, context_length), device=device, dtype=torch.float32)

    for i, inputs in tqdm(enumerate(tokenized_inputs_bL), total=len(tokenized_inputs_bL)):

        batch_offset = i * batch_size
        inputs_BL = inputs['input_ids']

        with torch.no_grad(), model.trace(inputs, **tracer_kwargs):
            activations_BLD = submodule.output
            if type(activations_BLD.shape) == tuple:
                activations_BLD = activations_BLD[0]
            activations_BLF = dictionary.encode(activations_BLD)
            activations_BLF = activations_BLF[:, :, dim_indices].save()

        activations_FBL = einops.rearrange(activations_BLF.value, 'B L F -> F B L')
        # Use einops to find the max activation per input
        activations_FB = einops.reduce(activations_FBL, 'F B L -> F B', 'max')
        tokens_FBL = einops.repeat(inputs_BL, 'B L -> F B L', F=feature_count)
        
        # Keep track of input indices
        indices_B = torch.arange(batch_offset, batch_offset + batch_size, device=device)
        indices_FB = einops.repeat(indices_B, 'B -> F B', F=feature_count)

        # Concatenate current batch activations and indices with the previous ones
        combined_activations_FB = torch.cat([max_activations_FK, activations_FB], dim=1)
        combined_indices_FB = torch.cat([max_activating_indices_FK, indices_FB], dim=1)
        combined_activations_FBL = torch.cat([max_activations_FKL, activations_FBL], dim=1)
        combined_tokens_FBL = torch.cat([max_tokens_FKL, tokens_FBL], dim=1)

        # Sort and keep top k activations for each dimension
        topk_activations_FK, topk_indices_FK = torch.topk(combined_activations_FB, k, dim=1)
        max_activations_FK = topk_activations_FK

        feature_indices_F1 = torch.arange(feature_count, device=device)[:, None]
        max_activating_indices_FK = combined_indices_FB[feature_indices_F1, topk_indices_FK]
        max_activations_FKL = combined_activations_FBL[feature_indices_F1, topk_indices_FK]
        max_tokens_FKL = combined_tokens_FBL[feature_indices_F1, topk_indices_FK]
            

    return max_tokens_FKL, max_activations_FKL

dictionaries_path = "../dictionary_learning/dictionaries"

# Current recommended way to generate graphs. You can copy paste ae_sweep_paths directly from bib_intervention.py
ae_sweep_paths = {
    "pythia70m_sweep_standard_ctx128_0712": {"resid_post_layer_3": {"trainer_ids": [6]}},
    # "pythia70m_sweep_gated_ctx128_0730": {"resid_post_layer_3": {"trainer_ids": [9]}},
    # "pythia70m_sweep_topk_ctx128_0730": {"resid_post_layer_3": {"trainer_ids": [10]}},
    # "gemma-2-2b_sweep_topk_ctx128_0817": {"resid_post_layer_12": {"trainer_ids": [2]}}, 
}
sweep_name = list(ae_sweep_paths.keys())[0]
submodule_trainers = ae_sweep_paths[sweep_name]

filter_class_ids = []
# filter_class_ids = [-4, -2]

ae_group_paths = utils.get_ae_group_paths(dictionaries_path, sweep_name, submodule_trainers)
ae_paths = utils.get_ae_paths(ae_group_paths)

# TODO
# Add bias_in_bios dataset option
# Cosine sim with probes
# Vector per class probe

ae_path = ae_paths[0]
submodule, dictionary, config = utils.load_dictionary(model, ae_path, DEVICE)

filename_counter = ""
class_id = -2

node_effects_filename = f"{ae_path}/node_effects{filename_counter}.pkl"

with open(node_effects_filename, "rb") as f:
    node_effects = pickle.load(f)

effects = node_effects[class_id]

print(effects.shape)

k = 500
top_k_values, top_k_indices = torch.topk(effects, k)
torch.set_printoptions(sci_mode=False)
print(top_k_values)
print(top_k_indices)

all_indices = torch.arange(0, effects.shape[0])
all_indices = top_k_indices
torch.cuda.empty_cache()
import gc
gc.collect()
max_tokens_FKL, max_activations_FKL = get_max_activating_prompts(model, submodule, batched_data, all_indices, batch_size, dictionary, 1000, 30)

Loading dictionary from ../dictionary_learning/dictionaries/pythia70m_sweep_standard_ctx128_0712/resid_post_layer_3/trainer_6
torch.Size([16384])
tensor([    1.3203,     0.3594,     0.2402,     0.1455,     0.0762,     0.0605,
            0.0503,     0.0464,     0.0403,     0.0337,     0.0198,     0.0150,
            0.0128,     0.0124,     0.0119,     0.0107,     0.0101,     0.0099,
            0.0093,     0.0091,     0.0085,     0.0085,     0.0084,     0.0079,
            0.0076,     0.0075,     0.0073,     0.0072,     0.0069,     0.0069,
            0.0069,     0.0068,     0.0067,     0.0067,     0.0065,     0.0056,
            0.0055,     0.0055,     0.0053,     0.0052,     0.0050,     0.0049,
            0.0049,     0.0048,     0.0047,     0.0047,     0.0046,     0.0045,
            0.0045,     0.0045,     0.0042,     0.0041,     0.0040,     0.0040,
            0.0039,     0.0039,     0.0039,     0.0038,     0.0038,     0.0038,
            0.0038,     0.0038,     0.0037,     0.0036

  0%|          | 0/40 [00:00<?, ?it/s]You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|██████████| 40/40 [00:07<00:00,  5.11it/s]


In [5]:
with open("trained_bib_probes/pythia-70m-deduped/probes_ctx_len_128.pkl", "rb") as f:
    probes = pickle.load(f)

probe_vec = probes[class_id].net.weight.squeeze()
print(probe_vec.shape)

for i in range(10):
    sae_feat_idx = top_k_indices[i]
    decoder_vec = dictionary.decoder.weight[: ,sae_feat_idx].squeeze()
    cos_sim = torch.nn.functional.cosine_similarity(probe_vec, decoder_vec, dim=0)
    print(f"Feature {sae_feat_idx} has cosine similarity {cos_sim}")

torch.Size([512])
Feature 7265 has cosine similarity 0.6078369617462158
Feature 3597 has cosine similarity -0.5437126159667969
Feature 4648 has cosine similarity 0.6239863038063049
Feature 5923 has cosine similarity 0.15908387303352356
Feature 1104 has cosine similarity 0.2630188465118408
Feature 10316 has cosine similarity 0.13465353846549988
Feature 3767 has cosine similarity -0.1064731627702713
Feature 11238 has cosine similarity -0.2627173662185669
Feature 11797 has cosine similarity 0.3990252614021301
Feature 10602 has cosine similarity -0.602030873298645


In [6]:
gc.collect()
torch.cuda.empty_cache()

In [7]:
model.gpt_neox

GPTNeoXModel(
  (embed_in): Embedding(50304, 512)
  (emb_dropout): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0-5): 6 x GPTNeoXLayer(
      (input_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (post_attention_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (post_attention_dropout): Dropout(p=0.0, inplace=False)
      (post_mlp_dropout): Dropout(p=0.0, inplace=False)
      (attention): GPTNeoXAttention(
        (rotary_emb): GPTNeoXRotaryEmbedding()
        (query_key_value): Linear(in_features=512, out_features=1536, bias=True)
        (dense): Linear(in_features=512, out_features=512, bias=True)
        (attention_dropout): Dropout(p=0.0, inplace=False)
      )
      (mlp): GPTNeoXMLP(
        (dense_h_to_4h): Linear(in_features=512, out_features=2048, bias=True)
        (dense_4h_to_h): Linear(in_features=2048, out_features=512, bias=True)
        (act): GELUActivation()
      )
    )
  )
  (final_layer_norm): LayerNor

In [8]:
model.embed_out.weight.norm(dim=1, keepdim=True).shape

torch.Size([50304, 1])

In [9]:
## Direct Logit Attribution

from experiments.autointerp import compute_dla

In [10]:
compute_dla(torch.tensor([0, 1, 2]), dictionary.decoder.weight, model.embed_out.weight, 10)

tensor([[  233, 18713,   109, 15911, 15533,   349,   477,   234,     1,   126],
        [   99,   110,   211,   228,   117,  5980,  7633,  5808,  1760,  3059],
        [  113, 24384,   243,    96,   117,   220,   112,   116,    99,   120]],
       device='cuda:0')

In [11]:
# D is top k dla values


formatted_tokens = {}
top_dla_FD = compute_dla(top_k_indices, dictionary.decoder.weight, model.embed_out.weight, return_topk_tokens=3)

for feat_idx in range(10):
    sae_feat_idx = top_k_indices[feat_idx].item()
    top_dla_D = top_dla_FD[feat_idx].tolist()
    print(f"Feature index: {feat_idx}")
    print(sae_feat_idx)

    encoded_tokens_KL = max_tokens_FKL[feat_idx].tolist()
    activations_KL = max_activations_FKL[feat_idx]

    activations_KL11 = [activations_KL[k, :, None, None] for k in range(activations_KL.shape[0])]
    def _list_decode(x):
        if isinstance(x, int):
            return model.tokenizer.decode(x)
        else:
            return [_list_decode(y) for y in x]

    decoded_tokens_KL = _list_decode(encoded_tokens_KL)
    decoded_dla_D = _list_decode(top_dla_D)
    # print(decoded_dla_D)
    # print(decoded_tokens_KL)

    formatted_tokens_KL = highlight_top_activations(decoded_tokens_KL, activations_KL, top_n=5, include_activations=False)
    formatted_tokens[feat_idx] = (formatted_tokens_KL, decoded_dla_D)
    top_contexts = text_neuron_activations(decoded_tokens_KL, activations_KL11)
    # display(top_contexts)

Feature index: 0
7265
Feature index: 1
3597
Feature index: 2
4648
Feature index: 3
5923
Feature index: 4
1104
Feature index: 5
10316
Feature index: 6
3767
Feature index: 7
11238
Feature index: 8
11797
Feature index: 9
10602


In [12]:
s, dla = formatted_tokens[0]
# dla_tokens = ", ".join([f"<< {t}>>" for t in dla])
for sentence in s:
    print("".join(sentence))
    print(f'\nTop logits: {dla}')
    break

the broker and said yes.

She << told>> her younger sister she was going to America for work, << but>> << to>> << keep>> it a secret from her parents, who would never grant her permission to work abroad. You Mi told her parents she was going to Seoul to be a golf caddy -- one of the few legal women's jobs that bring hefty tips from rich men.

She planned to << tell>> them the truth after she paid off her debts.

You Mi was instructed to take passport photos and give them to a man named Kevin in Seoul. The broker drove her to the city, and two days later, You Mi had

Top logits: [' her', ' herself', ' she']


In [13]:
from experiments.explainers.simple.prompt_builder import build_prompt
system_prompt, messages = build_prompt(
    examples="".join(s[0]),
    cot=False,
    top_logits=dla,
    concept="gender",
)

In [14]:
messages

[{'role': 'user',
  'content': '\nExample 1:  and he was <<over the moon>> to find\nExample 2:  we\'ll be laughing <<till the cows come home>>! Pro\nExample 3:  thought Scotland was boring, but really there\'s more <<than meets the eye>>! I\'d\n\nTop_logits: ["elated", "joyful", "story", "thrilled", "spider"]\n'},
 {'role': 'assistant',
  'content': '\n(Part 2)\nSIMILAR TOKENS: "elated", "joyful", "thrilled".\n- The top logits list contains words that are strongly associated with positive emotions.\n\n[yes/no DECISION]: no\n'},
 {'role': 'user',
  'content': '\nExample 1:  a river is wide but the ocean is wid<<er>>. The ocean\nExample 2:  every year you get tall<<er>>," she\nExample 3:  the hole was small<<er>> but deep<<er>> than the\n\nTop_logits: ["apple", "running", "book", "wider", "quickly"]\n'},
 {'role': 'assistant',
  'content': '\n(Part 2)\nSIMILAR TOKENS: None\n- The top logits list contains unrelated nouns and adverbs.\n\n[yes/no DECISION]: no\n'},
 {'role': 'user',
  'cont

In [15]:
import os

# Securely input the API key
api_key = input("Enter your API key: ")

# Set the API key as an environment variable
os.environ['ANTHROPIC_API_KEY'] = api_key

In [16]:


import anthropic

client = anthropic.Anthropic()

message = client.messages.create(
    model="claude-3-5-sonnet-20240620",
    max_tokens=1000,
    temperature=0,
    system=system_prompt,
    messages=messages,
)
print(message.content)



[TextBlock(text='After analyzing the examples and the Top_logits, I can provide the following assessment:\n\nThe neuron seems to activate on verbs of communication, particularly "told" and "tell," when they are used in the context of sharing information, especially secretive or important information. The activation appears to occur just before these verbs.\n\nThe Top_logits list provides additional insight:\nSIMILAR TOKENS: "her", "herself", "she"\n\nThese tokens are all feminine pronouns, which suggests that the neuron might be predicting feminine subjects or objects following the communication verbs.\n\nConsidering both the activation examples and the Top_logits, it appears that this neuron might be related to the concept of gender, specifically focusing on female subjects or objects in the context of communication or information sharing.\n\nThe neuron seems to activate on scenarios where a woman (indicated by "She" or "her") is telling something, often in a context that implies secr

In [22]:
print(system_prompt)

print(messages)

print(message.content)

[{'type': 'text', 'text': "You are a meticulous AI researcher conducting an important investigation into a certain neuron in a language model. Your task is to analyze the neuron and decide whether its behavior is related to the concept of gender.\n\n(Part 2) Tokens that the neuron boosts in the next token prediction\n\nYou will also be shown a list called Top_logits. The logits promoted by the neuron shed light on how the neuron's activation influences the model's predictions or outputs. Look at this list of Top_logits and refine your hypotheses from part 1. It is possible that this list is more informative than the examples from part 1.\n\nPay close attention to the words in this list and write down what they have in common. Then look at what they have in common, as well as patterns in the tokens you found in Part 1, to produce a single explanation for what features of text cause the neuron to activate. Propose your explanation in the following format:\n[yes/no DECISION]: <your decisi

In [17]:
llm_out = message.content[0].text.lower()[-10:]

if 'yes' in llm_out and 'no' in llm_out:
    decision = -1
elif 'yes' in llm_out:
    decision = 1
elif 'no' in llm_out:
    decision = 0
else:
    decision = -1

print(decision)

1


In [49]:
llm_out

'(part 1)\nactivating tokens: "told", "but", "to", "keep", "tell".\nprevious tokens: no interesting patterns.\n\nstep 1:\n- the activating tokens are mostly verbs related to communication ("told", "tell") and function words ("but", "to", "keep").\n- the previous tokens don\'t show any particular pattern.\n\nstep 2:\n- the text examples involve communication, particularly secretive or deceptive communication.\n- there\'s a narrative about a woman (you mi) planning to work abroad and keeping it secret from her parents.\n\n(part 2)\nsimilar tokens: "her", "herself", "she".\n- the top logits list contains exclusively feminine pronouns.\n\n[explanation]: this neuron appears to activate on tokens related to communication, particularly in the context of secrets or deception. it also seems to have a strong association with feminine pronouns in its predictions. this suggests the neuron may be capturing some aspect of narrative or dialogue involving women, especially in situations where informat

In [None]:

thresholds = [0.1, 0.05, 0.025, 0.01, 0.001]
top_ns = [1, 10, 100, 500]


for i, ae_path in enumerate(ae_paths):
    node_effects_filename = f"{ae_path}/node_effects.pkl"

    with open(node_effects_filename, "rb") as f:
        node_effects = pickle.load(f)

    print(node_effects.keys())

    effects = node_effects[-2][ae_path]

    print(f"\nEffects for {ae_path}")
    for theshold in thresholds:
        above_threshold = effects[effects > theshold]
        count_above_threshold = above_threshold.shape[0]
        avg_above_threshold = above_threshold.mean().item()
        print(
            f"Threshold {theshold}: {count_above_threshold} nodes above threshold, {avg_above_threshold:.3f} average"
        )

    for top_n in top_ns:
        top_k = torch.topk(effects, top_n)
        avg_top_k = top_k.values.mean().item()
        print(f"Top {top_n}: {avg_top_k:.3f} average")

dict_keys([-4, -2, 0, 1, 2])


TypeError: new(): invalid data type 'str'

In [None]:
for i, ae_path in enumerate(ae_paths):
    node_effects_filename = f"{ae_path}/node_effects.pkl"

    with open(node_effects_filename, "rb") as f:
        node_effects = pickle.load(f)

    print(node_effects.keys())
    print(node_effects[-2].keys())

    effects = node_effects[-4][ae_path]
    print(effects.shape)

    # Create histogram
    plt.figure(figsize=(10, 6))
    plt.hist(effects, bins=100)
    plt.ylim(0, 10)
    plt.title(f'Histogram for {ae_path.split("/")[-3]}')
    plt.xlabel("Effect Value")
    plt.ylabel("Frequency")

    # Display the plot
    plt.show()

In [None]:
# def examine_dimension(
#     model,
#     submodule,
#     buffer,
#     feat_idx: int,
#     n_inputs: int,
#     context_length: int,
#     batch_size: int,
#     dictionary=None,
#     max_length: int = 128,
#     k: int = 30,
# ):


#     def _list_decode(x):
#         if isinstance(x, int):
#             return model.tokenizer.decode(x)
#         else:
#             return [_list_decode(y) for y in x]

#     # if dim_indices is None:
#         # dim_indices = random.randint(0, activations.shape[-1] - 1)

#     assert n_inputs % batch_size == 0
#     n_iters = n_inputs // batch_size

#     device = model.device

#     activations = torch.zeros((n_inputs, context_length), device=device)
#     tokens = torch.zeros((n_inputs, context_length), dtype=torch.long, device=device)

#     for i in tqdm(range(n_iters), desc="Collecting activations"):
#         inputs_BL = buffer.tokenized_batch(batch_size=batch_size)

#         with torch.no_grad(), model.trace(inputs_BL, **tracer_kwargs):
#             tokens_BL = model.input[1][
#                 "input_ids"
#             ].save()  # if you're getting errors, check here; might only work for pythia models
#             activations_BLD = submodule.output
#             if type(activations_BLD.shape) == tuple:
#                 activations_BLD = activations_BLD[0]
#             if dictionary is not None:
#                 activations_BLF = dictionary.encode(activations_BLD)
#             activations_BL = activations_BLF[:, :, feat_idx].save()

#         activations[i * batch_size : (i + 1) * batch_size] = activations_BL.value
#         tokens[i * batch_size : (i + 1) * batch_size] = tokens_BL.value

#     token_mean_acts = {}
#     for ctx in tokens:
#         for tok in ctx:
#             if tok.item() in token_mean_acts:
#                 continue
#             idxs = (tokens == tok).nonzero(as_tuple=True)
#             token_mean_acts[tok.item()] = activations[idxs].mean().item()
#     top_tokens = sorted(token_mean_acts.items(), key=lambda x: x[1], reverse=True)[:k]
#     top_tokens = [(model.tokenizer.decode(tok), act) for tok, act in top_tokens]

#     flattened_acts = einops.rearrange(activations, "b n -> (b n)")
#     topk_indices = torch.argsort(flattened_acts, dim=0, descending=True)[:k]
#     batch_indices = topk_indices // activations.shape[1]
#     token_indices = topk_indices % activations.shape[1]
#     tokens = [
#         tokens[batch_idx, : token_idx + 1].tolist()
#         for batch_idx, token_idx in zip(batch_indices, token_indices)
#     ]
#     activations = [
#         activations[batch_idx, : token_id + 1, None, None]
#         for batch_idx, token_id in zip(batch_indices, token_indices)
#     ]
#     decoded_tokens = _list_decode(tokens)
#     top_contexts = text_neuron_activations(decoded_tokens, activations)

#     top_affected = interp.feature_effect(
#         model, submodule, dictionary, feat_idx, tokens, max_length=max_length, k=k
#     )
#     top_affected = [(model.tokenizer.decode(tok), prob.item()) for tok, prob in zip(*top_affected)]

#     return namedtuple("featureProfile", ["top_contexts", "top_tokens", "top_affected"])(
#         top_contexts, top_tokens, top_affected
#     )

# # DEVICE = "cuda"
# # model_name = "EleutherAI/pythia-70m-deduped"
# # model = LanguageModel(model_name, device_map=DEVICE, dispatch=True)

# # ae_path = ae_paths[2]
# # submodule, dictionary, config = utils.load_dictionary(model, ae_path, DEVICE)

# # context_length = config['buffer']['ctx_len']

# # data = hf_dataset_to_generator("monology/pile-uncopyrighted")
# # buffer = ActivationBuffer(
# #     data,
# #     model,
# #     submodule,
# #     d_submodule=512,
# #     ctx_len=context_length,
# #     refresh_batch_size=128, # decrease to fit on smaller GPUs
# #     n_ctxs=512, # decrease to fit on smaller GPUs
# #     device=DEVICE
# # )

# feat_idx = 0
# sae_feat_idx = top_k_indices[feat_idx].item()
# print(sae_feat_idx)

# n_inputs = 1024
# batch_size = 256

# torch.cuda.empty_cache()
# torch.set_grad_enabled(False)

# out = examine_dimension(
#     model,
#     submodule,
#     buffer,
#     sae_feat_idx,
#     n_inputs,
#     context_length,
#     batch_size,
#     dictionary,
#     max_length=context_length,
#     k=30,
# )

# print(f'\n\ntop activating tokens for feature {sae_feat_idx}')
# for token in out.top_tokens:
#     print(token)
# print(f'\n\ntop affected tokens for feature {sae_feat_idx}')
# for token in out.top_affected:
#     print(token)

# out.top_contexts

In [None]:
def get_max_activating_prompts_old(
    model,
    submodule,
    inputs_bL: list[str],
    dim_indices: torch.Tensor,
    batch_size: int,
    dictionary=None,
    n_inputs: int = 512,
    k: int = 30,
):

    assert n_inputs % batch_size == 0
    n_iters = n_inputs // batch_size

    dim_count = dim_indices.shape[0]

    device = model.device

    max_activating_indices_FK = torch.zeros((dim_count, k), device=device, dtype=torch.int)
    max_activations_FK = torch.zeros((dim_count, k), device=device, dtype=torch.float32)

    for i in range(n_iters):

        batch_offset = i * batch_size

        inputs_BL = inputs_bL[batch_offset : batch_offset + batch_size]

        with torch.no_grad(), model.trace(inputs_BL, **tracer_kwargs):
            activations_BLD = submodule.output
            if type(activations_BLD.shape) == tuple:
                activations_BLD = activations_BLD[0]
            activations_BLF = dictionary.encode(activations_BLD)
            activations_BLF = activations_BLF[:, :, dim_indices].save()

        # Use einops to find the max activation per input
        activations_FB = einops.reduce(activations_BLF.value, 'B L F _> F B', 'max')
        
        # Keep track of input indices
        indices_B = torch.arange(batch_offset, batch_offset + batch_size, device=device)
        indices_FB = einops.repeat(indices_B, 'B -> F B', F=dim_count)

        # Concatenate current batch activations and indices with the previous ones
        combined_activations_FK = torch.cat([max_activations_FK, activations_FB], dim=1)
        combined_indices_FK = torch.cat([max_activating_indices_FK, indices_FB], dim=1)

        # Sort and keep top k activations for each dimension
        topk_activations_FK, topk_indices = torch.topk(combined_activations_FK, k, dim=1)
        max_activations_FK = topk_activations_FK
        max_activating_indices_FK = torch.gather(combined_indices_FK, 1, topk_indices)

    return max_activating_indices_FK, max_activations_FK


