In [19]:
import spacy
import numpy as np
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from transformers import RobertaTokenizer, RobertaModel
import torch
import re
from spacy.tokenizer import Tokenizer
from spacy.training import Alignment
device_number = 2
device = torch.device(f"cuda:{device_number}" if torch.cuda.is_available() else "cpu")
print(device)
import os
from interp_funcs import separate_words_commas_periods, group_tokens, is_punctuation


dataset = 'pereira'
basePath = '/home3/ebrahim/what-is-brainscore/'
model_str = 'gpt2-large'
untrained = True

if dataset == 'pereira':
    pereira_path = f"{basePath}{dataset}_data/sentences_ordered.txt"
    with open(pereira_path, "r") as file:
        # Read the contents line by line into a list
        experiment_txt = [line.strip() for line in file]
    data_labels = np.load(f"{basePath}data_processed/{dataset}/data_labels_{dataset}.npy")
    
if dataset == 'fedorenko':
    fed_path = f"{basePath}{dataset}_data/sentences_ordered.txt"
    with open(fed_path, "r") as file:
        # Read the contents line by line into a list
        experiment_txt = [line.strip() for line in file]
    data_labels = np.load(f"{basePath}data_processed/{dataset}/data_labels_{dataset}.npy")
    
if dataset == 'blank':
    blank_data = np.load(f"{basePath}{dataset}_data/story_data_dict.npz")
    experiment_txt = []
    data_labels = []
    for key, val in blank_data.items():
        experiment_txt.extend(val)
        data_labels.extend(np.repeat(key, len(val)))
    

if 'gpt' in model_str:
    
    model = GPT2LMHeadModel.from_pretrained(model_str)
    
    tokenizer = GPT2Tokenizer.from_pretrained(f"{model_str}")
    
    if untrained:
        config = GPT2Config.from_pretrained(model_str)
        model = GPT2LMHeadModel(config)
        model_str += '-untrained'
        
    model.eval()
    model = model.to(device)  
    
    embedding_matrix = model.transformer.wte 
    positional_matrix = model.transformer.wpe
    
    
elif 'roberta' in model_str:
    tokenizer = RobertaTokenizer.from_pretrained(model_str)
    model = RobertaModel.from_pretrained(model_str)
    model.eval()
    model = model.to(device)    
    embedding_matrix = model.get_input_embeddings().weight.data
    positional_matrix = model.embeddings.position_embeddings.weight.data

    
def split_multipunc_tokens(toks):
    
    import string
    
    new_tokens = []
    
    for s in toks:
        
        if all(char in string.punctuation for char in s) and len(s) > 1:
            print("Splitting token: ", s)
            for char in s:
                new_tokens.append(char)
        else:
            new_tokens.append(s)
            
    return new_tokens

def get_word_level_static_reps(previous_text, current_text, embedding_matrix, 
                               positional_matrix, tokenizer, model_str, dataset, 
                               max_context_size=512):
    
    '''
    :param int start_word: where to start obtaining activations from 
    :param str previous_text: story to obtain activations for 
    :param torch tensor embedding_matrix: static embedding matrix
    :param torch tensor positional_matrix: static positional matrix
    :param tokenizer: tokenize sentence
    :param str model_str: model used to generate activations
    :param str dataset: neural dataset 
    '''
    
    curr_tokens = tokenizer.tokenize(current_text)
    curr_tokens = split_multipunc_tokens(curr_tokens)
    num_ct = len(curr_tokens)
    prev_tokens = tokenizer.tokenize(previous_text)
    tokens = prev_tokens + curr_tokens
    
    tokens = tokens[-max_context_size:]
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    
    # append start and end tokens for roberta
    if model_str == 'roberta-large':
        token_ids.insert(0, 0)
        token_ids.append(2)
        
    tensor_input = torch.tensor([token_ids])
    tensor_input = tensor_input.to(device)    

    with torch.no_grad():
        
        if 'gpt' in model_str:
            static_embed = embedding_matrix(tensor_input)
            static_pos = positional_matrix.weight[np.arange(len(tensor_input[0])), :].unsqueeze(0)   
            
        elif 'roberta' in model_str:
            static_embed  = embedding_matrix[tensor_input, :]
            static_pos = positional_matrix[np.arange(len(tensor_input[0])), :].unsqueeze(0)
            
        static_embed_pos = torch.squeeze(static_embed + static_pos) # ctx_size x embed_size 
        static_pos = torch.squeeze(static_pos)
        static_embed = torch.squeeze(static_embed)
        
        outputs = model(tensor_input, output_hidden_states=True, output_attentions=True)
        outputs = outputs.hidden_states
        # number of layers x context size x embedding size
        outputs = torch.stack(outputs).squeeze()
        
    # remove <s> and </s> tokens because we only want to sum across 
    # words/punctuation marks for bert-style models.
    if 'roberta' in model_str:
        static_embed_pos = static_embed_pos[1:-1]
        static_pos = static_pos[1:-1]
        static_embed = static_embed[1:-1]
        outputs = outputs[:, 1:-1, :]
        
    # only take tokens corresponding to the current text
    static_embed_pos = static_embed_pos[-num_ct:]
    static_pos = static_pos[-num_ct:]
    static_embed = static_embed[-num_ct:]
    outputs = outputs[:, -num_ct:]
    
    tokens_curr_cleaned = [t.replace("Ġ", '') for t in curr_tokens] 
    
    words = separate_words_commas_periods(current_text)
    
    align = Alignment.from_strings(words, tokens_curr_cleaned)
    tokens_to_words_alignment = align.y2x.data
    
    # list of lists of length len words, 
    # each element contains the token indices that map to a word
    tokens_to_word_list = group_tokens(tokens_to_words_alignment)
    
    assert len(tokens_to_word_list) == len(words), print("Alignment failed")

    activity_word_level_embed = []  
    activity_word_level_pos = []
    activity_word_level_pos_embed = []
    activity_word_level = []
    
    for idx, w in enumerate(tokens_to_word_list):
        
        is_punc = is_punctuation(words[idx])
                
        if len(w) > 1:
            # take the mean of tokens within a word if it has multiple tokens 
            word_activity_embed = torch.squeeze(torch.mean(static_embed[w], axis=0))
            word_activity_pos = torch.squeeze(torch.mean(static_pos[w], axis=0))
            word_activity_pos_embed = torch.squeeze(torch.mean(static_embed_pos[w], axis=0))
            word_activity = torch.squeeze(torch.mean(outputs[:, w], axis=1))
        else:
            word_activity_embed = torch.squeeze(static_embed[w])
            word_activity_pos = torch.squeeze(static_pos[w])
            word_activity_pos_embed = torch.squeeze(static_embed_pos[w])
            word_activity = torch.squeeze(outputs[:, w])
            
        # don't add punctuation to static embeddings
        if is_punc:
            activity_word_level.append(word_activity.cpu().detach().numpy())
        else:
            activity_word_level.append(word_activity.cpu().detach().numpy())
            activity_word_level_pos_embed.append(word_activity_pos_embed.cpu().detach().numpy())
            activity_word_level_pos.append(word_activity_pos.cpu().detach().numpy())
            activity_word_level_embed.append(word_activity_embed.cpu().detach().numpy())
            
    activity_word_level_pos_embed = np.array(activity_word_level_pos_embed)
    activity_word_level_pos = np.array(activity_word_level_pos)
    activity_word_level_embed = np.array(activity_word_level_embed)
    activity_word_level = np.array(activity_word_level)
    
    if dataset == 'pereira' or dataset == 'blank':
        
        activity_sent = activity_word_level[-1]
        activity_sent_sp = np.sum(activity_word_level, axis=0)
        static_activity_pos_embed = np.sum(activity_word_level_pos_embed, axis=0)
        static_activity_pos = np.sum(activity_word_level_pos, axis=0)
        static_activity_embed = np.sum(activity_word_level_embed, axis=0)
        
    elif dataset == 'fedorenko':
        
        activity_sent = activity_word_level
        activity_sent_sp = None
        static_activity_pos_embed = activity_word_level_pos_embed
        static_activity_pos = activity_word_level_pos
        static_activity_embed = activity_word_level_embed
    
    return static_activity_pos_embed, static_activity_pos, static_activity_embed, activity_sent, activity_sent_sp


cuda:2


In [20]:
current_text_arr = []
static_embed_activity = []
static_pos_embed_activity = []
static_pos_activity = []
contextual_activity = []
contextual_activity_sp = []
previous_text =  ' ' 
current_passage = data_labels[0]
total_words = 0 

for txt, dl in zip(experiment_txt, data_labels):
    
    
    # remove right spaces
    txt = txt.rstrip()
    
    if dl != current_passage:
        current_text = txt
        current_passage = dl
    else:
        current_text = f' {txt}'

    if dataset == 'fedorenko':
        current_text = current_text.replace('.', '')
        
    current_text_arr.append(current_text)

    static_pos_embed_rep, static_pos_rep, static_embed_rep, contextual_rep, contextual_rep_sp = get_word_level_static_reps(previous_text, 
                current_text, embedding_matrix, positional_matrix, tokenizer, model_str=model_str, dataset=dataset)

    previous_text += current_text
                
    static_pos_embed_activity.append(static_pos_embed_rep)
    static_pos_activity.append(static_pos_rep)
    static_embed_activity.append(static_embed_rep)
    contextual_activity.append(contextual_rep)
    contextual_activity_sp.append(contextual_rep_sp)

    
if dataset == 'pereira' or dataset == 'blank':
    
    contextual_activity_stacked = np.stack(contextual_activity)
    contextual_activity_stacked_sp = np.stack(contextual_activity_sp)
    static_pos_embed_activity_stacked = np.stack(static_pos_embed_activity)
    static_pos_activity_stacked = np.stack(static_pos_activity)
    static_embed_activity_stacked = np.stack(static_embed_activity)
    
    contextual_dict = {}
    contextual_dict_sp = {}
    for ln in range(contextual_activity_stacked.shape[1]):
        contextual_dict[f'layer_{ln}'] = contextual_activity_stacked[:, ln]
        contextual_dict_sp[f'layer_{ln}'] = contextual_activity_stacked_sp[:, ln]
    
elif dataset == 'fedorenko':
    
    static_embed_activity_stacked = np.vstack(static_embed_activity)
    static_pos_activity_stacked = np.vstack(static_pos_activity)
    static_pos_embed_activity_stacked = np.vstack(static_pos_embed_activity)
    contextual_activity_stacked = np.vstack(contextual_activity)
    
    contextual_dict = {}
    for ln in range(contextual_activity_stacked.shape[1]):
        contextual_dict[f'layer_{ln}'] = contextual_activity_stacked[:, ln]
    
    contextual_dict_sp = None
    
np.savez(f'{basePath}data_processed/{dataset}/X_{model_str}', **contextual_dict)
np.savez(f'{basePath}data_processed/{dataset}/X_{model_str}-static', **{'layer1': static_pos_embed_activity_stacked})
np.savez(f'{basePath}data_processed/{dataset}/X_{model_str}-static-pos', **{'layer1': static_pos_activity_stacked})
np.savez(f'{basePath}data_processed/{dataset}/X_{model_str}-static-embed', **{'layer1': static_embed_activity_stacked})
if contextual_dict_sp is not None:
    np.savez(f'{basePath}data_processed/{dataset}/X_{model_str}-sp', **contextual_dict_sp)

243-astronaut-1
243-astronaut-2
243-beekeeping-0
243-beekeeping-1
243-beekeeping-2
243-blindness-0
243-blindness-1
243-blindness-2
243-bone_fracture-0
243-bone_fracture-1
243-bone_fracture-2
243-castle-0
243-castle-1
243-castle-2
243-computer_graphics-0
243-computer_graphics-1
243-computer_graphics-2
243-dreams-0
243-dreams-1
243-dreams-2
243-gambling-0
243-gambling-1
243-gambling-2
243-hurricane-0
243-hurricane-1
243-hurricane-2
243-ice_cream-0
243-ice_cream-1
243-ice_cream-2
243-infection-0
243-infection-1
243-infection-2
243-law_school-0
243-law_school-1
243-law_school-2
243-lawn_mower-0
243-lawn_mower-1
243-lawn_mower-2
243-opera-0
243-opera-1
243-opera-2
243-owl-0
243-owl-1
243-owl-2
243-painter-0
243-painter-1
243-painter-2
243-pharmacist-0
243-pharmacist-1
243-pharmacist-2
243-polar_bear-0
243-polar_bear-1
243-polar_bear-2
243-pyramid-0
243-pyramid-1
243-pyramid-2
243-rock_climbing-0
243-rock_climbing-1
243-rock_climbing-2
243-skiing-0
243-skiing-1
243-skiing-2
243-stress-0
243-

In [21]:
current_text_arr

[' Astronauts train a long time for their spacewalks.',
 ' Much of their training is conducted underwater.',
 ' They may spend 8 to 10 hours in the pool for every hour they will spend floating in space.',
 ' Astronauts practice to be able to perform construction and repair work on the outside of the space station.',
 'The commanders of shuttle flights are always pilots, and many have backgrounds as military test pilots.',
 ' Other astronauts are trained as doctors, engineers, and scientists, who can run experiments in space.',
 ' Early crews were all young men, but astronauts now are much more diverse.',
 'The team of astronauts floated out together to the exterior of the space shuttle.',
 ' They carried tools needed to repair the broken part on the huge telescope.',
 ' One astronaut loosened the bolts on the pipe, while the other fitted the replacement part into place.',
 'Beekeeping encourages the conservation of local habitats.',
 " It is in every beekeeper's interest to conserve lo