In [20]:
import torch
from transformers import BertTokenizer, BertModel
from scipy.spatial.distance import cosine

# Example from TDS

In [21]:
import pandas as pd
import numpy as np

In [35]:
model = BertModel.from_pretrained('bert-base-uncased',
           output_hidden_states = True,)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [24]:
def bert_text_preparation(text, tokenizer):
    """Preparing the input for BERT
    
    Takes a string argument and performs
    pre-processing like adding special tokens,
    tokenization, tokens to ids, and tokens to
    segment ids. All tokens are mapped to seg-
    ment id = 1.
    
    Args:
        text (str): Text to be converted
        tokenizer (obj): Tokenizer object
            to convert text into BERT-re-
            adable tokens and ids
        
    Returns:
        list: List of BERT-readable tokens
        obj: Torch tensor with token ids
        obj: Torch tensor segment ids
    
    
    """
    marked_text = "[CLS] " + text + " [SEP]"
    tokenized_text = tokenizer.tokenize(marked_text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    segments_ids = [1]*len(indexed_tokens)

    # Convert inputs to PyTorch tensors
    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])

    return tokenized_text, tokens_tensor, segments_tensors

In [25]:
def get_bert_embeddings(tokens_tensor, segments_tensors, model):
    """Get embeddings from an embedding model
    
    Args:
        tokens_tensor (obj): Torch tensor size [n_tokens]
            with token ids for each token in text
        segments_tensors (obj): Torch tensor size [n_tokens]
            with segment ids for each token in text
        model (obj): Embedding model to generate embeddings
            from token and segment ids
    
    Returns:
        list: List of list of floats of size
            [n_tokens, n_embedding_dimensions]
            containing embeddings for each token
    
    """
    
    # Gradient calculation id disabled
    # Model is in inference mode
    with torch.no_grad():
        outputs = model(tokens_tensor, segments_tensors)
        # Removing the first hidden state
        # The first state is the input state
        hidden_states = outputs[2][1:]

    # Getting embeddings from the final BERT layer
    token_embeddings = hidden_states[-1]
    # Collapsing the tensor into 1-dimension
    token_embeddings = torch.squeeze(token_embeddings, dim=0)
    # Converting torchtensors to lists
    list_token_embeddings = [token_embed.tolist() for token_embed in token_embeddings]

    return list_token_embeddings

In [49]:
def get_embedding_distances(texts, keyword):
    target_word_embeddings = []

    for text in texts:
        tokenized_text, tokens_tensor, segments_tensors = bert_text_preparation(text, tokenizer)
        list_token_embeddings = get_bert_embeddings(tokens_tensor, segments_tensors, model)

        # Find the position of keyword in list of tokens
        word_index = tokenized_text.index(keyword)
        # Get the embedding for bank
        word_embedding = list_token_embeddings[word_index]

        target_word_embeddings.append(word_embedding)
        
    list_of_distances = []
    
    for text1, embed1 in zip(texts, target_word_embeddings):
        for text2, embed2 in zip(texts, target_word_embeddings):
            cos_dist = 1 - cosine(embed1, embed2)
            list_of_distances.append([text1, text2, cos_dist])
            
    return list_of_distances
    
    

In [57]:
texts = ["bank",
         "The river bank was flooded.",
         "The bank vault was robust.",
         "He had to bank on her for support.",
         "The bank was out of money.",
         "The bank teller was a man."]

In [58]:
texts2 = ["It is a term used to describe this concept.",
         "I am currently in my second term of university."]

In [59]:
get_embedding_distances(texts, 'bank')

[['bank', 'bank', 1.0],
 ['bank', 'The river bank was flooded.', 0.33806328331314606],
 ['bank', 'The bank vault was robust.', 0.4940982832447629],
 ['bank', 'He had to bank on her for support.', 0.2561400022657283],
 ['bank', 'The bank was out of money.', 0.4699417027149666],
 ['bank', 'The bank teller was a man.', 0.4660202688464421],
 ['The river bank was flooded.', 'bank', 0.33806328331314606],
 ['The river bank was flooded.', 'The river bank was flooded.', 1.0],
 ['The river bank was flooded.',
  'The bank vault was robust.',
  0.5233254855624423],
 ['The river bank was flooded.',
  'He had to bank on her for support.',
  0.3315835872442887],
 ['The river bank was flooded.',
  'The bank was out of money.',
  0.5121607211837422],
 ['The river bank was flooded.',
  'The bank teller was a man.',
  0.5192740469056842],
 ['The bank vault was robust.', 'bank', 0.4940982832447629],
 ['The bank vault was robust.',
  'The river bank was flooded.',
  0.5233254855624423],
 ['The bank vault w