In [1]:
import numpy as np
import sys
import joblib
import matplotlib.pyplot as plt
from transformer_lens.components import TransformerBlock
import ridge_utils.npp
from ridge_utils.util import make_delayed
from ridge_utils.dsutils import make_word_ds
from ridge_utils.DataSequence import DataSequence
import warnings
import pickle
from configs import engram_dir
import os
from datasets import load_dataset
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch

from transformer_lens.utils import tokenize_and_concatenate
from huggingface_hub import login
from configs import huggingface_token

login(token=huggingface_token)
device = "cuda" if torch.cuda.is_available() else "cpu"

box_dir = os.path.join(engram_dir, 'huth_box/')
if not sys.warnoptions:
    warnings.simplefilter("ignore")


In [2]:
grids = joblib.load(os.path.join(box_dir, "grids_huge.jbl")) # Load TextGrids containing story annotations
trfiles = joblib.load(os.path.join(box_dir, "trfiles_huge.jbl")) # Load TRFiles containing TR information

In [3]:
wordseqs = make_word_ds(grids, trfiles)
for story in wordseqs.keys():
    wordseqs[story].data = [i.strip() for i in wordseqs[story].data]
print("Loaded text data")

Loaded text data


# Explore Participant Responses

In [4]:
response_path = os.path.join(box_dir, 'responses', 'full_responses', 'UTS03_responses.jbl')
resp_dict = joblib.load(response_path)
to_pop = [x for x in resp_dict.keys() if 'canplanetearthfeedtenbillionpeoplepart' in x]
for story in to_pop:
    del resp_dict[story]
train_stories = list(resp_dict.keys())
train_stories = [t for t in train_stories if t != "wheretheressmoke"]
test_stories = ["wheretheressmoke"]
print("Loaded participant responses")


Loaded participant responses


# Load LLM

In [5]:
model = HookedTransformer.from_pretrained("gemma-2-2b", device=device)
tokenizer = model.tokenizer



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loaded pretrained model gemma-2-2b into HookedTransformer


In [6]:
def override_to_local_attn(model, window_size=512):
    for b in model.blocks:  # Possibly a cleaner way by correctly using 'use_local_attn'
        if isinstance(b, TransformerBlock):
            n_ctx = b.attn.cfg.n_ctx
            attn_mask = torch.zeros((n_ctx, n_ctx)).bool()
            for i in range(n_ctx):
                start_idx = max(0, i-window_size)
                attn_mask[i, start_idx:i+1] = True
            b.attn.mask = attn_mask.to(device)

override_to_local_attn(model)

In [7]:
def find_word_boundaries(text_data, tokenizer):
    full_story = " ".join(text_data).strip()
    tokenized_story = tokenizer(full_story)['input_ids']

    word_boundaries = []  # In the tokenized story
    curr_word_idx = 0
    curr_word = text_data[curr_word_idx]
    curr_token_set = []

    if curr_word == '':
        curr_word_idx += 1
        curr_word = text_data[curr_word_idx]
        word_boundaries.append(1)

    for token_idx, token in enumerate(tokenized_story):
        curr_token_set.append(token)
        detokenized_chunk = tokenizer.decode(curr_token_set)
        if curr_word in detokenized_chunk:
            word_boundaries.append(token_idx)
            curr_word_idx += 1
            if curr_word_idx == len(text_data):
                break
            curr_word = text_data[curr_word_idx]
            curr_token_set = []

            if curr_word == '':  # Edge case
                word_boundaries.append(token_idx)
                curr_word_idx += 1
                if curr_word_idx == len(text_data):
                    break
                curr_word = text_data[curr_word_idx]

    return tokenized_story, word_boundaries


# Load features of interest

In [8]:
model_layer = 12
brain_region = 'broca'

In [9]:
with open(f'pickles/selected_features/direct_regression_L{model_layer}_{brain_region}.pkl', 'rb') as f:
    results = pickle.load(f)
top_indices = results['top_indices']
descriptions = results['descriptions']
mean_loading = results['mean_loading']


In [10]:
def get_top_activating_examples():
    release = "gemma-scope-2b-pt-res-canonical"
    sae_id = f"layer_{model_layer}/width_16k/canonical"
    sae = SAE.from_pretrained(release, sae_id)[0].to(device)

    results = {idx: [] for idx in top_indices}

    for train_story in train_stories:
        ws = wordseqs[train_story]
        text_data = ws.data
        tokenized_story, word_boundaries = find_word_boundaries(text_data, tokenizer)
        hook_key = f'blocks.{model_layer}.hook_resid_post'
        with torch.no_grad():
            _, cache = model.run_with_cache(
                torch.tensor(tokenized_story).to(device),
                prepend_bos=True,
                names_filter=lambda name: name==hook_key,
            )
        llm_response = cache[hook_key][0, word_boundaries, :]
        with torch.no_grad():
            feature_acts = sae.encode(llm_response).cpu().numpy()
        n_samples, n_features = feature_acts.shape
        for idx in top_indices:
            for sample in range(n_samples):
                val = feature_acts[sample, idx]
                results[idx].append((train_story, sample, val))
    return results

In [30]:
results = get_top_activating_examples()

In [31]:
def process_results(results):
    for i in results.keys():
        examples = sorted(results[i], key=lambda x: abs(x[2]), reverse=True)
        processed_examples = []
        for example in examples:
            story, sample, val = example
            add_to_processed = True
            for prev_examples in processed_examples:
                if prev_examples[0] == story:
                    if abs(prev_examples[1] - sample) < 5:
                        add_to_processed = False
                        break
            if add_to_processed:
                processed_examples.append(example)
        results[i] = processed_examples
    return results

In [32]:
results = process_results(results)

In [53]:
i = top_indices[4]
print(i)
descriptions[np.argmax(top_indices==i)]

1033


'scientific terms and measurements related to biological and chemical studies'

In [54]:
printed_examples = 0
skipped_examples = 0
vals = []
for example_idx, example in enumerate(results[i]):
    train_story, sample, val = example
    if sample < 10:
        skipped_examples += 1
        continue
    ws = wordseqs[train_story]
    text_data = ws.data
    tokenized_story, word_boundaries = find_word_boundaries(text_data, tokenizer)
    start_idx = max(0, sample-10)
    print(tokenizer.decode(tokenized_story[start_idx:sample+1]))
    printed_examples += 1
    vals.append(val)
    if printed_examples > 15:
        break

 on it called miss connections and that's what funny
 moment and we're like you know hitting each other
 forgave me and told me she loved me and felt
 a second if i'm gonna bob my head in
 but i do remember seeing it swing from side to side
 kind of a milkshake ice cream sandwich which he's
re coming in about a ten twelve degree angle he'
 she's googling me and i don't know
 almost died i almost died and i am so glad we
 is now setting so i'm thinking right we got
 by four across the prairie just like you know bouncing across
 homosexuality thing and i was completely thrown and thought that maybe
m comforted to to discover that london exa is exactly
 the guy goes no no no so my mother and my
m saying you can't have them and i won
 the greatest website ever so i started doing this bit all


In [36]:
print(skipped_examples)

93


In [37]:
print(vals)

[14.985634, 14.227369, 13.567108, 12.950641, 11.625395, 11.555357, 10.855405, 9.583924, 9.558023, 9.501002, 9.378521, 9.339999, 9.28685, 8.784471, 8.752684, 8.576482]
