In [None]:
import torch

import nltk
from nltk.tokenize import word_tokenize, sent_tokenize

from transformers import BertTokenizer, BertModel
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from operator import itemgetter
import numpy as np

# load BERT tokenizer
tok = BertTokenizer.from_pretrained("bert-base-uncased")

In [None]:
import time
start_time = time.time()

In [None]:
model = BertModel.from_pretrained('bert-base-uncased',
           output_hidden_states = True,)

In [None]:
fp = open("../texts/deephaven.txt").read()
sentences = sent_tokenize(fp)

In [None]:
def bert_text_preparation(text):
    marked_text = "[CLS] " + text + " [SEP]"

    # trim tokens, if needed
    tokenized_text = tok.tokenize(marked_text)[:512]
    indexed_tokens = tok.convert_tokens_to_ids(tokenized_text)
    segments_ids = [1]*len(indexed_tokens)
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensor = torch.tensor([segments_ids])
    return tokenized_text, tokens_tensor, segments_tensor

def get_bert_embeddings(tokens_tensor, segments_tensor, model):
    with torch.no_grad():
        outputs = model(tokens_tensor, segments_tensor)
    
    hidden_states = outputs[2]
    token_embeddings = torch.stack(hidden_states, dim=0)
    token_embeddings = torch.squeeze(token_embeddings, dim=1)
    token_embeddings = token_embeddings.permute(1,0,2)
    token_vecs_sum = []
    for token in token_embeddings:
        sum_vec = torch.sum(token[-4:], dim=0).detach().numpy()
        token_vecs_sum.append(sum_vec)
    return token_vecs_sum

In [None]:
from collections import OrderedDict
context_embeddings = []
context_tokens = []

for sentence in sentences:
    tokenized_text, tokens_tensor, segments_tensors = bert_text_preparation(sentence)
    list_token_embeddings = get_bert_embeddings(tokens_tensor, segments_tensors, model)
    tokens = OrderedDict()
    for token in tokenized_text[1:-1]:
        if token in tokens:
            tokens[token] += 1
        else:
            tokens[token] = 1
        token_indices = [i for i, t in enumerate(tokenized_text) if t == token]
        current_index = token_indices[tokens[token]-1]
        token_vec = list_token_embeddings[current_index]
        context_tokens.append(token)
        context_embeddings.append(token_vec)

In [None]:
from sklearn.neighbors import NearestNeighbors
def get_neighbors(word,unique=False,k=50):
    word_list = []
    nn = NearestNeighbors(n_neighbors = k, 
                            algorithm = 'ball_tree').fit(context_embeddings)
    if word in context_tokens:
        w_idx = context_tokens.index(word)
        d, idx = nn.kneighbors([context_embeddings[w_idx]])
        for d, idx in zip(d[0],idx[0]):
            if unique and context_tokens[idx] in word_list:
                next
            else:
                print(np.round(d,3),context_tokens[idx])
            word_list.append(context_tokens[idx])
    else:
        print("error: {0} not in vocab".format(word))

In [None]:
get_neighbors("schooner",unique=False)

In [None]:
print("elapsed time: %s (seconds)" % np.round((time.time() - start_time),4))

In [None]:
context_tokens.index("schooner")