In [2]:
import numpy as np
import sys
import joblib
import matplotlib.pyplot as plt
from transformer_lens.components import TransformerBlock
import ridge_utils.npp
from ridge_utils.util import make_delayed
from ridge_utils.dsutils import make_word_ds
from ridge_utils.DataSequence import DataSequence
import warnings
import pickle
from configs import engram_dir
import os
from datasets import load_dataset
from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch

from transformer_lens.utils import tokenize_and_concatenate
from huggingface_hub import login

from huggingface_hub import hf_hub_download

from sae_vis.data_config_classes import SaeVisConfig, SaeVisLayoutConfig
from sae_vis.data_storing_fns import SaeVisData
from configs import huggingface_token

login(token=huggingface_token)
device = "cuda" if torch.cuda.is_available() else "cpu"

box_dir = os.path.join(engram_dir, 'huth_box/')
if not sys.warnoptions:
    warnings.simplefilter("ignore")


In [3]:
grids = joblib.load(os.path.join(box_dir, "grids_huge.jbl")) # Load TextGrids containing story annotations
trfiles = joblib.load(os.path.join(box_dir, "trfiles_huge.jbl")) # Load TRFiles containing TR information

In [4]:
wordseqs = make_word_ds(grids, trfiles)
for story in wordseqs.keys():
    wordseqs[story].data = [i.strip() for i in wordseqs[story].data]
print("Loaded text data")

Loaded text data


# Explore Participant Responses

In [5]:
response_path = os.path.join(box_dir, 'responses', 'full_responses', 'UTS03_responses.jbl')
resp_dict = joblib.load(response_path)
to_pop = [x for x in resp_dict.keys() if 'canplanetearthfeedtenbillionpeoplepart' in x]
for story in to_pop:
    del resp_dict[story]
train_stories = list(resp_dict.keys())
train_stories = [t for t in train_stories if t != "wheretheressmoke"]
test_stories = ["wheretheressmoke"]
print("Loaded participant responses")


Loaded participant responses


# Load LLM

In [6]:
from sae_lens import HookedSAETransformer

model = HookedSAETransformer.from_pretrained("gemma-2-2b", device=device)
#model = HookedTransformer.from_pretrained("gemma-2-2b", device=device)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loaded pretrained model gemma-2-2b into HookedTransformer


In [7]:
#model = HookedTransformer.from_pretrained("gemma-2-2b", device=device)
tokenizer = model.tokenizer

In [8]:
def override_to_local_attn(model, window_size=512):
    for b in model.blocks:  # Possibly a cleaner way by correctly using 'use_local_attn'
        if isinstance(b, TransformerBlock):
            n_ctx = b.attn.cfg.n_ctx
            attn_mask = torch.zeros((n_ctx, n_ctx)).bool()
            for i in range(n_ctx):
                start_idx = max(0, i-window_size)
                attn_mask[i, start_idx:i+1] = True
            b.attn.mask = attn_mask.to(device)

override_to_local_attn(model)

In [9]:
def find_word_boundaries(text_data, tokenizer):
    full_story = " ".join(text_data).strip()
    tokenized_story = tokenizer(full_story)['input_ids']

    word_boundaries = []  # In the tokenized story
    curr_word_idx = 0
    curr_word = text_data[curr_word_idx]
    curr_token_set = []

    if curr_word == '':
        curr_word_idx += 1
        curr_word = text_data[curr_word_idx]
        word_boundaries.append(1)

    for token_idx, token in enumerate(tokenized_story):
        curr_token_set.append(token)
        detokenized_chunk = tokenizer.decode(curr_token_set)
        if curr_word in detokenized_chunk:
            word_boundaries.append(token_idx)
            curr_word_idx += 1
            if curr_word_idx == len(text_data):
                break
            curr_word = text_data[curr_word_idx]
            curr_token_set = []

            if curr_word == '':  # Edge case
                word_boundaries.append(token_idx)
                curr_word_idx += 1
                if curr_word_idx == len(text_data):
                    break
                curr_word = text_data[curr_word_idx]

    return tokenized_story, word_boundaries


# Load features of interest

In [32]:
model_layer = 12
brain_region = 'ac'

In [33]:
with open(f'pickles/selected_features/direct_regression_L{model_layer}_{brain_region}.pkl', 'rb') as f:
    results = pickle.load(f)
top_indices = results['top_indices']
descriptions = results['descriptions']
mean_loading = results['mean_loading']


In [34]:
release = "gemma-scope-2b-pt-res-canonical"
sae_id = f"layer_{model_layer}/width_16k/canonical"
sae = SAE.from_pretrained(release, sae_id)[0].to(device)

In [35]:
import random

In [36]:
all_tokens = []
for train_story in train_stories:
    ws = wordseqs[train_story]
    text_data = ws.data
    tokenized_story, word_boundaries = find_word_boundaries(text_data, tokenizer)
    # Break into chunks of size 787
    for start_idx in range(0, len(tokenized_story), 787):
        chunk = tokenized_story[start_idx:start_idx + 787]
        # If chunk is too small, take an overlapping chunk from the end
        if len(chunk) < 787:
            chunk = tokenized_story[-787:]
        all_tokens.append(torch.tensor(chunk))

In [37]:
all_tokens = torch.vstack(all_tokens).T.to(device)

In [38]:
all_tokens.shape

torch.Size([787, 301])

In [39]:
tokenizer.decode(all_tokens[:,0])

"<bos>so i'm standing in the phone box and i put twenty pence in and i phone my boyfriend and he answers so i put a pound coin in and just as we're about to speak i suddenly feel the door behind me open a hand appears out of nowhere and i watch it press down the switchhook of the phone i hear the pound coin drop to the bottom and i watch this hand remove the coin earlier that evening i had actually been staying at my boyfriend's flat he lived on main university campus in birmingham city center he was away on work experience and he'd given me the keys to his flat which was great for me because i lived off main campus a another site in handsworth word a short bus ride away but after spending a few nights there and a really long day in the library i suddenly had this urge that i needed to go home i just needed to be in my own space i needed fresh clothes i wanted to sleep in my own bed and i really missed my flatmates but as soon as i thought it a voice in my head said don't go but i want

In [40]:
sae_vis_data = SaeVisData.create(
    sae=sae,
    model=model,
    tokens=all_tokens.T,  # 8192
    cfg=SaeVisConfig(features=top_indices),  # 256
    verbose=True,
)

Forward passes to cache data for vis:   0%|          | 0/5 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/30 [00:00<?, ?it/s]

In [41]:
sae_vis_data.save_feature_centric_vis(filename="demo_feature_vis.html", feature=int(top_indices[0]))
