In this file, we first load the fMRI data for each voxel we want to prediction.
- We then calculate a text description for each voxel, which we hope will corresponding to a semantic concept
- Next, we visualize these concepts and select a couple best ones
- We visualize the best ones
- We then validate that the best ones are spatially close together

In [71]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import numpy as np
from tqdm import tqdm
import pickle as pkl
import os
import cortex # brain viz library
from data_utils import neuro
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
from model_utils import suffix
import string

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
word_lists = neuro.fetch_data()

In [49]:
n_words = 15
word_lists = word_lists[:, :n_words]
word_lists.shape

(10000, 15)

In [104]:
checkpoint = 'EleutherAI/gpt-j-6B' # 'gpt2-medium'
device = 'cuda'
save_dir = f'/home/chansingh/mntv1/fmri/logits_{checkpoint}'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    checkpoint, output_hidden_states=False).to(device)
batch_size = 10 # make sure this divides into word_lists.shape[0]
model = model.to(device)

In [105]:
def make_prompt_from_word_list(word_list):
    s = 'The following list of words are all part of the same semantic category: '
    s += ', '.join(word_list)
    s += '.\nThe semantic category they all belong to, in one word, is'
    return s

# run one example
s = make_prompt_from_word_list(word_lists[0]) 
ex_inputs = tokenizer([s], padding='longest', return_tensors='pt')
next_token_logits = suffix.get_next_token_logits(ex_inputs, model).squeeze().detach().cpu().numpy()

In [122]:
# iterate and store over all examples
os.makedirs(save_dir, exist_ok=True)
vocab_size = next_token_logits.shape[-1]
all_logits = np.zeros((word_lists.shape[0], vocab_size))

for i in tqdm(range(0, word_lists.shape[0], batch_size)): # batch_size is step
    s = [make_prompt_from_word_list(wlist) for wlist in word_lists[i: i + batch_size]]
    ex_inputs = tokenizer(s, padding='longest', return_tensors='pt').to(device)
    next_token_logits = suffix.get_next_token_logits(ex_inputs, model).detach().cpu().numpy()
    all_logits[i: i + batch_size] = next_token_logits

 19%|█▊        | 186/1000 [01:33<07:21,  1.85it/s]

In [None]:
pkl.dump(all_logits, open(oj(save_dir, 'all_logits.pkl'), 'wb'))

# Decode top tokens

In [118]:
next_token_logits = all_logits[0]

# decode top tokens
top_k_inds = np.arange(next_token_logits.size)
top_k_inds = top_k_inds[np.argsort(next_token_logits[top_k_inds])][::-1]
top_decoded_tokens = np.array(
    [tokenizer.decode(ind) for ind in top_k_inds])

# remove nonsense
STOPWORDS = suffix.get_stopwords()
PHRASING_STOPWORDS = ['called']
disallowed_idxs = np.array([
    # general
    s.isspace() # space
    or all(c in string.punctuation for c in s.strip()) # punc
    or len(s) <= 2

    # keywords
    or s.lower().strip() in STOPWORDS # stopwords 
    or s.lower().strip() in PHRASING_STOPWORDS # stopwords  

    # check if it is one of the inputs
    or s.strip().lower() in word_lists[0] # not one of the inputs
    or s.strip().lower() + 's' in word_lists[0] # plural not in inputs
    or s.strip().lower() in [word + 's' for word in word_lists[0]] # singular not in inputs
    for s in top_decoded_tokens],
    dtype=bool)
top_k_inds = top_k_inds[~disallowed_idxs]
top_decoded_tokens = top_decoded_tokens[~disallowed_idxs]
print(top_decoded_tokens[:20])

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/chansingh/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


[' material' ' metal' ' shape' ' surface' ' simply' ' fabric' ' materials'
 ' packaging' ' box' ' something' ' CON' ' plastic' ' card' ' colour'
 ' container' ' table' ' wood' ' part' ' cloth' ' Material']
