In [None]:
import corpus
import random
import spacy
import torch
import numpy as np

max_seq_length=1024

In [None]:
nlp = spacy.load('en_core_web_sm') 
from transformers import BertModel, BertTokenizer

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', max_length=512)
bert_model = BertModel.from_pretrained('bert-base-uncased',
                                  output_hidden_states = True) 

In [None]:
lang = 'eng'
dataset = corpus.Corpus()
dataset.load_corpora(r"../news-clustering/dataset/dataset.test.json",
                     r"../news-clustering/dataset/clustering.test.json", set([lang]))
input_data = dataset.documents
random.shuffle(input_data)


In [None]:
def get_entity_spacy(sentence):
    doc = nlp(''.join(str(sentence)))
    ents = [ent.text for ent in doc.ents]  
    #return " ".join(ents)
    return list(set(ents))

<b>after identifying the entities in the text, locate the index positions of the tokens and replace entity terms with the [MASK] token. Encode the rest of the input with the [SEP] [CLS] tokens. Return the list of changed indices</b> 

In [None]:
def get_word_indeces(tokenizer, text, word):

    word_tokens = tokenizer.tokenize(word)
    masks_str = ' '.join(['[MASK]']*len(word_tokens))
    text_masked = text.replace(word, masks_str)

    input_ids = tokenizer.encode(text_masked)
    mask_token_indeces = np.where(np.array(input_ids) == tokenizer.mask_token_id)[0]

    return mask_token_indeces

In [None]:
def get_maskembedding(b_model, b_tokenizer, text, ents=''):
    '''
    Uses the provided model and tokenizer to produce an embedding for the
    provided `text`, and a "contextualized" embedding for `word`, if provided.
    '''

    # If entities are provided, figure out which tokens correspond to it.
    print ('\n text is ', text)
    if not ents == '':
        word_indeces = []
        for e in ents:
            i = get_word_indeces(b_tokenizer, text, e)
            word_indeces.extend(i)
        word_indeces.sort()
    print ('\n ents are ', ents)
    print ('\n indices are ', word_indeces)
    # Encode the text, adding the (required!) special tokens, and converting to
    # PyTorch tensors.
    encoded_dict = b_tokenizer.encode_plus(
                        text,                      # Sentence to encode.
                        add_special_tokens = True, # zxAdd '[CLS]' and '[SEP]'
                        return_tensors = 'pt',     # Return pytorch tensors.
                        max_length=512)

    print ('\n encoded is', encoded_dict)
    input_ids = encoded_dict['input_ids']
    
    b_model.eval()

    # Run the text through the model and get the hidden states.
    bert_outputs = b_model(input_ids)
    
    # Run the text through BERT, and collect all of the hidden states produced
    # from all 12 layers. 
    with torch.no_grad():

        outputs = b_model(input_ids)

        # Evaluating the model will return a different number of objects based on 
        # how it's  configured in the `from_pretrained` call earlier. In this case, 
        # becase we set `output_hidden_states = True`, the third item will be the 
        # hidden states from all layers. See the documentation for more details:
        # https://huggingface.co/transformers/model_doc/bert.html#bertmodel
        hidden_states = outputs[2]

    # `hidden_states` has shape [13 x 1 x <sentence length> x 768]

    # Select the embeddings from the second to last layer.
    # `token_vecs` is a tensor with shape [<sent length> x 768]
    token_vecs = hidden_states[-2][0]

    # Calculate the average of all token vectors.
    sentence_embedding = torch.mean(token_vecs, dim=0)

    # Convert to numpy array.
    sentence_embedding = sentence_embedding.detach().numpy()

    # If `word` was provided, compute an embedding for those tokens.
    '''
    if not ents == '':
        word_embedding = torch.mean(token_vecs[word_indeces], dim=0)
        word_embedding = word_embedding.detach().numpy()
    
        return (sentence_embedding, word_embedding)
    else:
        return sentence_embedding
    '''
    print ('sentence_embedding', sentence_embedding)
    return sentence_embedding

In [None]:
input_data = input_data[:10]
for i, d in enumerate(input_data):
    x = corpus.Document(d)
    print ('\nINSIDE DOCUMENT ', x.id)
    ents  = get_entity_spacy(x.body[:max_seq_length])
    get_maskembedding(bert_model, bert_tokenizer, x.body[:max_seq_length], ents=ents)