In [100]:
import numpy as np
import re
import os
from transformers import BertModel, BertTokenizer
from tqdm import tqdm
from collections import defaultdict
import pickle

In [81]:
# We need to access COHA and get

coha_dir = '/Users/gabriellachronis/Box Sync/src/lsa-predication/coha/data'

output_path = './data/test_output'

In [82]:
# Target words: we want to collect tokens of each of these words from COHA

target_words = ['net', 'virtual', 'disk', 'card', 'optical', 'virus',
           'signal', 'mirror', 'energy', 'compact', 'leaf',
           'brick', 'federal', 'sphere', 'coach', 'spine']

decades = [decade for decade in np.arange(1910, 2009, 10)]

buffer_size=1024
sequence_length=128

In [83]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


In [84]:
tokenizer.tokenize("this is a sentence")

['this', 'is', 'a', 'sentence']

In [85]:
tokenizer.tokenize("this is an overlong sentence with contentiousness because of its many word pieces")

['this',
 'is',
 'an',
 'over',
 '##long',
 'sentence',
 'with',
 'contentious',
 '##ness',
 'because',
 'of',
 'its',
 'many',
 'word',
 'pieces']

In [86]:
# make sure our targets aren't made of more than one word piece
tokenizer.tokenize(' '.join(target_words))

['net',
 'virtual',
 'disk',
 'card',
 'optical',
 'virus',
 'signal',
 'mirror',
 'energy',
 'compact',
 'leaf',
 'brick',
 'federal',
 'sphere',
 'coach',
 'spine']

In [87]:
# when we encode the target words they'll get the cls and sep tokens added
tokenizer.encode(' '.join(target_words))

[101,
 5658,
 7484,
 9785,
 4003,
 9380,
 7865,
 4742,
 5259,
 2943,
 9233,
 7053,
 5318,
 2976,
 10336,
 2873,
 8560,
 102]

In [88]:
# so we have to take it away
tokenizer.encode(' '.join(target_words))[1:-1]

[5658,
 7484,
 9785,
 4003,
 9380,
 7865,
 4742,
 5259,
 2943,
 9233,
 7053,
 5318,
 2976,
 10336,
 2873,
 8560]

In [95]:
def get_context(token_ids, target_position, sequence_length=128):
    """
    Given a text containing a target word, return the sentence snippet which surrounds the target word
    (and the target word's position in the snippet).

    :param token_ids: list of token ids (for an entire line of text)
    :param target_position: index of the target word's position in `tokens`
    :param sequence_length: desired length for output sequence (e.g. 128, 256, 512)
    :return: (context_ids, new_target_position)
                context_ids: list of token ids for the output sequence
                new_target_position: index of the target word's position in `context_ids`
    """
    # -2 as [CLS] and [SEP] tokens will be added later; /2 as it's a one-sided window
    window_size = int((sequence_length - 2) / 2)
    context_start = max([0, target_position - window_size])
    padding_offset = max([0, window_size - target_position])
    padding_offset += max([0, target_position + window_size - len(token_ids)])

    context_ids = token_ids[context_start:target_position + window_size]
    #print(token_ids[target_position])
    #print(context_ids)
    #print(tokenizer.convert_ids_to_tokens(context_ids))
    #print('next')
    
    context_ids += padding_offset * [0]

    new_target_position = target_position - context_start

    return context_ids, new_target_position

In [110]:
# build word-index vocabulary for target words
i2w = {}
for t, t_id in zip(target_words, tokenizer.encode(' '.join(target_words))[1:-1]): # use [1:-1] to not include cls and sep
    i2w[t_id] = t
    
print(i2w)

# buffers for batch processing
batch_input_ids = []
batch_tokens = []
batch_pos = []
batch_snippets = []
batch_decades = []

# here is where we'll store our final list we are collecting
usages = defaultdict(list)  # w -> (vector, sentence, word_position, decade)


{5658: 'net', 7484: 'virtual', 9785: 'disk', 4003: 'card', 9380: 'optical', 7865: 'virus', 4742: 'signal', 5259: 'mirror', 2943: 'energy', 9233: 'compact', 7053: 'leaf', 5318: 'brick', 2976: 'federal', 10336: 'sphere', 2873: 'coach', 8560: 'spine'}


we need these lists to be the same length before we zip them! We accidentally had our zipper on wrong like when you button a shirt and theres a leftover buttton

In [97]:
len(tokenizer.encode(' '.join(target_words))[1:-1])

16

In [98]:
len(target_words)

16

In [None]:
for T, decade in enumerate(decades):
    # one time interval at a time
    print('Decade {}...'.format(decade))


    ### gabriella changes
    ### my coha is organized differently. 
    ### the decades have random numbers for the alphabet index places , so i have to use regex
    ### to ignore that. 
    print(coha_dir)
    print(decade)
    my_regex = r'text_' + re.escape(str(decade)) + 's.*'

    #print("running through decade ", decade)

    # iterate through directories
    for decade_dir in os.listdir(coha_dir):

        if re.match(my_regex, decade_dir):
            # get all the text files for that decade
            # iterate through text files for this decade
            this_decade_files = os.listdir(os.path.join(coha_dir, decade_dir))
            for filename in tqdm(this_decade_files):
                #print(filename)
                with open(os.path.join(coha_dir, decade_dir, filename), 'r') as f:
                    lines = f.readlines()
                    #print("gets here")

                    # get the usages from this file
                    for L, line in enumerate(lines):
                        #print("gets to line: ", L)
                        
                        #print(len(tokens))
                        for token in tokens:

                            # tokenize line and convert to token ids
                            tokens = tokenizer.encode(line)

                            for pos, token in enumerate(tokens):
                                #print(token)
                                # store usage info of target words only
                                if token in i2w:
                                    context_ids, pos_in_context = get_context(tokens, pos, sequence_length)

                                    input_ids = [101] + context_ids + [102]


                                    # convert later to save storage space
                                    snippet = tokenizer.convert_ids_to_tokens(context_ids)
                                    #print(i2w[token])
                                    #print(' '.join(snippet))

                                    # add usage info to buffers
                                    batch_input_ids.append(input_ids)
                                    batch_tokens.append(i2w[token])
                                    batch_pos.append(pos_in_context)
                                    batch_snippets.append(snippet)
                                    batch_decades.append(decade)

                                # if the buffers are full...             or if we're at the end of the dataset
                                if (len(batch_input_ids) >= buffer_size) or (L == len(lines) - 1 and T == len(decades) - 1):

#                                     with torch.no_grad():
#                                         # collect list of input ids into a single batch tensor
#                                         input_ids_tensor = torch.tensor(batch_input_ids)
#                                         if torch.cuda.is_available():
#                                             input_ids_tensor = input_ids_tensor.to('cuda')

#                                         # run usages through language model
#                                         outputs = model(input_ids_tensor,  output_hidden_states=True )
#                                         print(len(outputs.hidden_states)) # items in the tuple = 1 + num layers
#                                         if torch.cuda.is_available():
#                                             hidden_states = [l.detach().cpu().clone().numpy() for l in outputs[2]]
#                                         else:
#                                             print("fjekl")
#                                             hidden_states = [l.clone().numpy() for l in outputs.hidden_states]

#                                         # get usage vectors from hidden states
#                                         hidden_states = np.stack(hidden_states)  # (13, B, |s|, 768)
#                                         print('Expected hidden states size: (13, B, |s|, 768). Got {}'.format(hidden_states.shape))
#                                         # usage_vectors = np.sum(hidden_states, 0)  # (B, |s|, 768)
#                                         # usage_vectors = hidden_states.view(hidden_states.shape[1],
#                                         #                                    hidden_states.shape[2],
#                                         #                                    -1)
#                                         usage_vectors = np.sum(hidden_states[1:, :, :, :], axis=0)
#                                         # usage_vectors = hidden_states.reshape((hidden_states.shape[1], hidden_states.shape[2], -1))
#                                         print("makes usage vectors")
#                                         print(usage_vectors.shape)

                                    if output_path and os.path.exists(output_path):
                                        with open(output_path, 'rb') as f:
                                            usages = pickle.load(f)

                                    # store usage tuples in a dictionary: lemma -> (vector, snippet, position, decade)
                                    for b in np.arange(len(batch_input_ids)):
                                        #usage_vector = usage_vectors[b, batch_pos[b]+1, :] # get the right position
                                        usages[batch_tokens[b]].append(
                                            (batch_snippets[b], batch_pos[b], batch_decades[b]))

                                    # finally, empty the batch buffers
                                    batch_input_ids, batch_tokens, batch_pos, batch_snippets, batch_decades = [], [], [], [], []

                                    # and store data incrementally
                                    if output_path:
                                        with open(output_path, 'wb') as f:
                                            pickle.dump(usages, file=f)


Decade 1910...
/Users/gabriellachronis/Box Sync/src/lsa-predication/coha/data
1910


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3355/3355 [13:22:27<00:00, 14.35s/it]


Decade 1920...
/Users/gabriellachronis/Box Sync/src/lsa-predication/coha/data
1920


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 11557/11557 [3:01:22<00:00,  1.06it/s]


Decade 1930...
/Users/gabriellachronis/Box Sync/src/lsa-predication/coha/data
1930


  7%|██████████▊                                                                                                                                                          | 677/10352 [04:13<1:20:34,  2.00it/s]