In [32]:
import os
import re
import sys
import json
from glob import glob
from collections import Counter, defaultdict
import torch
import pickle
import numpy as np
from tqdm import tqdm
from transformers import BertModel,BertTokenizer

In [112]:
in_dir = os.path.join('data','preprocessed')
sample_file = os.path.join('data','preprocessed','sample_target_index.dict')
outdir = os.path.join('data','preprocessed')
out_file = 'cofea_sampled_vectors'
batch_size = 200
layers = '10,11'

In [3]:
# collect index of tokens in the documents
files = sorted(glob(os.path.join(in_dir, '*_tokenized.jsonlist')))
docs = []
for file in files:
    with open(file) as f:
        docs.append(f.readlines())

In [4]:
# get sample index
with open(sample_file,'rb') as f:
    sample_index = pickle.load(f)

In [35]:
#layers
layers = [int(layer) for layer in layers.split(',')]

# load the model
model = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# move the model to the GPU
device = 'cuda'
if torch.cuda.is_available():
    model.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
def get_context(token_ids, target_position, sequence_length=128):
    """
    Given a text containing a target word, return the sentence snippet which surrounds the target word
    (and the target word's position in the snippet).
    :param token_ids: list of token ids (for an entire line of text)
    :param target_position: index of the target word's position in `tokens`
    :param sequence_length: desired length for output sequence (e.g. 128, 256, 512)
    :return: (context_ids, new_target_position)
                context_ids: list of token ids for the output sequence
                new_target_position: index of the target word's position in `context_ids`
    """
    # -2 as [CLS] and [SEP] tokens will be added later; /2 as it's a one-sided window
    window_size = int((sequence_length - 2) / 2)
    context_start = max([0, target_position - window_size])
    padding_offset = max([0, window_size - target_position])
    padding_offset += max([0, target_position + window_size - len(token_ids)])

    context_ids = token_ids[context_start:target_position + window_size]
    context_ids += padding_offset * [0]

    new_target_position = target_position - context_start

    return context_ids, new_target_position

In [98]:
# Just put the output into some lists for now
token_index_list = []
to_encode = []
vectors_by_layer = defaultdict(list)

for target_word in sample_index:
    for line_index,example_info in enumerate(sample_index[target_word]):
        file_id,doc_id,index = example_info
        doc = json.loads(docs[file_id][doc_id])
        tokens = doc['tokens']
        # now we get the context
        context_ids, pos_in_context = get_context(tokens, index)
        input_ids = tokenizer.convert_tokens_to_ids(
            ['[CLS]']+context_ids+['[SEP]'])
        to_encode.append(input_ids)
        token_index_list.append(pos_in_context+1) #increment because we add CLS
        # we reached the batch limit and wil, now extract BERT embeddings
        if len(to_encode) == batch_size or (line_index == (len(sample_index[target_word])) and len(to_encode)>1):
            input_tensors = torch.tensor(to_encode)
            input_tensors = input_tensors.to(device)
            n_rows, n_tokens = input_tensors.shape
            with torch.no_grad():
                try:
                    # run usages through language model
                    outputs = model(input_tensors,output_hidden_states=True)
                    hidden_states = outputs[2]
                    vectors_np = {layer: hidden_states[layer].detach().cpu().numpy() for layer in layers}
                    # save the first token of the target word in each example
                    for row in np.arange(len(token_index_list)):
                        pos = token_index_list[row]
                        for layer in layers:
                            vectors_by_layer[layer].append(
                                np.array(vectors_np[layer][row, pos, :].copy(), dtype=np.float32))                        
                        
                except Exception as e:
                        print(len(to_encode))
                        raise e
            
            
            to_encode = []
            token_index_list = []
    #save the usages here
    # Need to decide how to save the data
    #output_path = 
    #if output_path:
    #    with open(output_path, 'wb') as f:
    #        pickle.dump(average_usage, file=f)
    vectors_by_layer = defaultdict(list)
    break
            

full
199


In [110]:
print(vectors_by_layer['recess'][10][0])

[ 1.17574163e-01 -5.93570411e-01 -5.96952498e-01 -4.01325554e-01
 -5.14881492e-01 -1.96219668e-01  4.63072807e-01  3.31362009e-01
 -2.04144821e-01  9.26141441e-01 -3.42350274e-01 -3.42669524e-02
  1.81638934e-02 -8.49014297e-02 -6.55438542e-01  1.06888974e+00
  6.12637252e-02 -3.88078034e-01 -4.99110609e-01 -5.03895164e-01
  8.45007718e-01  8.13613117e-01 -1.02336578e-01  8.88154209e-01
 -1.03573360e-01  5.21511793e-01  8.25831201e-03  8.77125487e-02
 -2.97151297e-01  2.49129936e-01 -4.18124974e-01 -4.37515020e-01
  2.33467281e-01 -2.49734581e-01  3.24852824e-01 -8.24593544e-01
  4.43717629e-01 -5.67906439e-01  7.41582036e-01  2.02471352e+00
  4.49637115e-01  3.72828901e-01 -7.71571219e-01 -9.35842395e-01
  3.62390250e-01  2.15125680e-01  6.19129241e-01  1.73287109e-01
 -9.30756181e-02  4.28732261e-02 -9.08623457e-01  1.74443528e-01
  2.58580893e-01  5.10623455e-01  3.88850778e-01 -3.26730728e-01
  8.41153979e-01  1.52976289e-01  8.11290324e-01 -1.40084207e+00
 -8.99712503e-01 -2.27234

In [111]:
input_tensors = torch.tensor([to_encode[0]])
input_tensors = input_tensors.to(device)
with torch.no_grad():
    outputs = model(input_tensors,output_hidden_states=True)
    hidden_states = outputs[2]
    vectors_np = {layer: hidden_states[layer].detach().cpu().numpy() for layer in layers}
    print(np.array(vectors_np[10][0,token_index_list[0],:], dtype=np.float32 ))

[ 1.17573828e-01 -5.93570471e-01 -5.96952438e-01 -4.01324987e-01
 -5.14881074e-01 -1.96220517e-01  4.63072360e-01  3.31364870e-01
 -2.04144925e-01  9.26140189e-01 -3.42350125e-01 -3.42663154e-02
  1.81637872e-02 -8.49002004e-02 -6.55437946e-01  1.06889021e+00
  6.12647161e-02 -3.88078839e-01 -4.99110878e-01 -5.03893733e-01
  8.45007718e-01  8.13612878e-01 -1.02336742e-01  8.88155460e-01
 -1.03573672e-01  5.21512270e-01  8.25842749e-03  8.77127275e-02
 -2.97150999e-01  2.49131233e-01 -4.18125004e-01 -4.37514871e-01
  2.33469144e-01 -2.49734119e-01  3.24852645e-01 -8.24595809e-01
  4.43717599e-01 -5.67906022e-01  7.41582274e-01  2.02471375e+00
  4.49638158e-01  3.72829497e-01 -7.71571338e-01 -9.35843170e-01
  3.62390965e-01  2.15124652e-01  6.19129360e-01  1.73286498e-01
 -9.30756852e-02  4.28721830e-02 -9.08621788e-01  1.74443260e-01
  2.58580267e-01  5.10624230e-01  3.88849497e-01 -3.26731712e-01
  8.41154873e-01  1.52976125e-01  8.11291456e-01 -1.40084171e+00
 -8.99712026e-01 -2.27234

In [114]:
layers = '10,11'
layers = [int(layer) for layer in layers.split(',')]