In [9]:
from transformers import AutoTokenizer, AutoModel
import numpy as np
import torch
import re

In [3]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
model = AutoModel.from_pretrained('bert-base-cased', output_hidden_states=True).eval()

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/436k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
sent1 = 'Horizontal diplopia  worse in left gaze and with distance fixation  suggests limitation of left lateral rectus movement '
sent2 = 'If the horizontal diplopia had been worse at near e g   while reading  the right medial rectus would be culpable since a near vision task like reading requires convergence and active medial rectus contraction '

In [5]:
tok1 = tokenizer(sent1, return_tensors='pt')
tok2 = tokenizer(sent2, return_tensors='pt')

In [7]:
# diplopia
sent1_idxs = [1]
sent2_idxs = [3]

tok1_ids = [np.where(np.array(tok1.word_ids()) == idx) for idx in sent1_idxs]
tok2_ids = [np.where(np.array(tok2.word_ids()) == idx) for idx in sent2_idxs]

with torch.no_grad():
    out1 = model(**tok1)
    out2 = model(**tok2)

# Only grab the last hidden state
states1 = out1.hidden_states[-1].squeeze()
states2 = out2.hidden_states[-1].squeeze()

# Select the tokens that we're after corresponding to "Diplopia"
embs1 = states1[[tup[0][0] for tup in tok1_ids]]
embs2 = states2[[tup[0][0] for tup in tok2_ids]]

In [8]:
torch.cosine_similarity(embs1.reshape(1,-1), embs2.reshape(1,-1))

tensor([0.9484])

## Utility function

In [38]:
def get_word_embedding(model: AutoModel, sentence: str, word: str, index: int=None):
    clean_sentence = re.sub(' +', ' ', sentence) # remove subsequent spaces
    if index is None:
        for i, _word in enumerate(clean_sentence.split(' ')):
            if _word.lower() == word.lower():
                index = i
                break

    assert index is not None, "Error: word not found in provided sentence."
    tokens = tokenizer(clean_sentence, return_tensors='pt')

    token_ids = [np.where(np.array(tokens.word_ids()) == idx) for idx in [index]]

    with torch.no_grad():
        output = model(**tokens)
    
    # Only grab the last hidden state
    hidden_states = output.hidden_states[-1].squeeze()

    # Select the tokens that we're after corresponding to the word provided
    embedding = hidden_states[[tup[0][0] for tup in token_ids]]
    return embedding

In [44]:
# # test of the above.
# model = AutoModel.from_pretrained('bert-base-cased', output_hidden_states=True).eval()
# get_word_embedding(model, "this is a sentence", "sentence")

In [40]:
def get_similarity(emb1: torch.tensor, emb2: torch.tensor) -> torch.tensor:
    return torch.cosine_similarity(emb1.reshape(1,-1), emb2.reshape(1,-1))

In [43]:
emb1 = get_word_embedding(model, "this is a sentence", "sentence")
emb2 = get_word_embedding(model, "this is yet another sentence", "this")

get_similarity(emb1, emb2)

tensor([0.6752])