In [1]:
import nltk
from nltk.tokenize import word_tokenize, sent_tokenize
import numpy as np
from operator import itemgetter

import torch
from transformers import AutoTokenizer, GPT2Model

In [2]:
fp = open("../texts/deephaven.txt").read()
sentences = sent_tokenize(fp)

In [3]:
# load GPT2 XL model
model = GPT2Model.from_pretrained('gpt2-xl', 
                                  low_cpu_mem_usage=True,
                                  output_hidden_states=True)
tok = AutoTokenizer.from_pretrained("gpt2-xl")

# end of sentence/text token padding
tok.pad_token = tok.eos_token

In [4]:
def get_sentence_embeddings(sentence):
    inp_tok = tok(sentence,
             padding=True,
             return_tensors="pt").to(next(model.parameters()).device)
    input_ids = inp_tok["input_ids"]
    output = model(input_ids)

    # return tokenized text for indexing
    tokenized_text = [tok.decode(id).strip() for id in input_ids[0]]

    # extract hidden states
    embs = torch.stack(output['hidden_states'], dim=0)
    embs = torch.squeeze(embs, dim=1)
    embs = embs.permute(1,0,2)

    # mean embeddings in the last four layers
    vectors = [torch.mean(t[-4:], dim=0).detach().numpy() for t in embs]
    
    return tokenized_text, input_ids, vectors

In [5]:
from collections import OrderedDict
context_embeddings = []
context_tokens = []

for sentence in sentences:
    tokenized_text, ids, list_token_embeddings = get_sentence_embeddings(sentence)
    tokens = OrderedDict()
    for token in tokenized_text[1:-1]:
        if token in tokens:
            tokens[token] += 1
        else:
            tokens[token] = 1
        token_indices = [i for i, t in enumerate(tokenized_text) if t == token]
        current_index = token_indices[tokens[token]-1]
        token_vec = list_token_embeddings[current_index]
        context_tokens.append(token)
        context_embeddings.append(token_vec)

In [6]:
from sklearn.neighbors import NearestNeighbors
def get_neighbors(word,unique=False,k=50):
    word_list = []
    nn = NearestNeighbors(n_neighbors = k, 
                            algorithm = 'ball_tree').fit(context_embeddings)
    if word in context_tokens:
        w_idx = context_tokens.index(word)
        d, idx = nn.kneighbors([context_embeddings[w_idx]])
        for d, idx in zip(d[0],idx[0]):
            if unique and context_tokens[idx] in word_list:
                next
            else:
                print(np.round(d,3),context_tokens[idx])
            word_list.append(context_tokens[idx])
    else:
        print("error: {0} not in vocab".format(word))

In [None]:
print("elapsed time: %s (seconds)" % np.round((time.time() - start_time),4))

In [12]:
get_neighbors("village",k=25)

0.0 village
200.748 village
221.168 village
228.722 garden
237.215 village
252.319 sunset
253.323 ures
253.734 farm
255.051 age
256.784 Parish
257.324 garden
260.977 oor
261.124 farm
261.276 pasture
261.503 farm
263.217 Parish
264.119 village
264.297 river
264.413 town
266.214 York
266.983 country
267.381 country
267.814 country
267.86 river
268.167 village


In [11]:
context_tokens.index("schooner")