## BERT Model - Post Training Setup

Use BERT and run it for the sentence completion problem, state-of-the-art model. Then remove the last layer and modify it to return embeddings for the input instead of a single output: the word to complete the sentence.

In [2]:
from transformers import BertTokenizer, BertModel
import pandas as pd
import numpy as np
import nltk
import torch

In [3]:
model = BertModel.from_pretrained('bert-base-uncased',
                                  output_hidden_states = True,)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [10]:
texts = ["The river bank was flooded.",
         "The bank vault was robust.",
         "He had to bank on her for support.",
         "The bank was out of money.",
         "The bank teller was a man."]

In [5]:
def bert_text_preparation(text, tokenizer):
    marked_text = "[CLS] " + text + " [SEP]"
    tokenized_text = tokenizer.tokenize(marked_text)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    segments_ids = [1]*len(indexed_tokens)

    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])

    return tokenized_text, tokens_tensor, segments_tensors

In [6]:
def get_bert_embeddings(tokens_tensor, segments_tensors, model):
    with torch.no_grad():
        outputs = model(tokens_tensor, segments_tensors)
        hidden_states = outputs[2][1:]

    token_embeddings = hidden_states[-1]
    token_embeddings = torch.squeeze(token_embeddings, dim=0)
    list_token_embeddings = [token_embed.tolist() for token_embed in token_embeddings]
    return list_token_embeddings

In [7]:
target_word_embeddings = []

for text in texts:
    tokenized_text, tokens_tensor, segments_tensors = bert_text_preparation(text, tokenizer)
    list_token_embeddings = get_bert_embeddings(tokens_tensor, segments_tensors, model)
    word_index = tokenized_text.index('bank')
    word_embedding = list_token_embeddings[word_index]

    target_word_embeddings.append(word_embedding)

In [12]:
for text in texts:
    tokenized_text, tokens_tensor, segments_tensors = bert_text_preparation(text, tokenizer)
    list_token_embeddings = get_bert_embeddings(tokens_tensor, segments_tensors, model)
    word_index = tokenized_text.index('bank')
    word_embedding = list_token_embeddings[word_index]
    print('Index: ', word_index)
    print(word_embedding[:5])
    print(len(word_embedding))
    print("="*100)

Index:  3
[-0.13987627625465393, -0.4429723620414734, 0.12858182191848755, -0.03672254830598831, 0.14131620526313782]
768
Index:  2
[0.6443396806716919, -0.7919419407844543, 0.02263902686536312, 0.0009034294635057449, 0.6554418206214905]
768
Index:  4
[-0.057423096150159836, -0.6665879487991333, -0.3368636965751648, -0.42104634642601013, 0.9046264886856079]
768
Index:  2
[0.7917051911354065, -0.5113241076469421, -0.030860211700201035, -0.06428904831409454, 1.1648433208465576]
768
Index:  2
[0.7716076374053955, -0.403091162443161, -0.29921817779541016, 0.05266457796096802, 0.361716628074646]
768


In [22]:
word_and_embeddings_dict = {'words': [], 'embeds': []}
for text in texts:
    tokenized_text, tokens_tensor, segments_tensors = bert_text_preparation(text, tokenizer)
    list_token_embeddings = get_bert_embeddings(tokens_tensor, segments_tensors, model)
    for tok in tokenized_text:
        word_index = tokenized_text.index(tok)
        word_embedding = list_token_embeddings[word_index]
        word_and_embeddings_dict['words'].append(tok)
        word_and_embeddings_dict['embeds'].append(word_embedding)

In [25]:
for word, embed in zip(word_and_embeddings_dict['words'], word_and_embeddings_dict['embeds']):
    print(word, embed[:5])

[CLS] [-0.37874600291252136, -0.0703376978635788, -0.36608144640922546, -0.10915069282054901, -0.24549156427383423]
the [-0.06094054505228996, -0.39486756920814514, -0.5337704420089722, 0.17299635708332062, 0.2710832357406616]
river [0.2703665792942047, 0.2709972560405731, -0.3092411458492279, 0.1405070573091507, 0.43297654390335083]
bank [-0.13987627625465393, -0.4429723620414734, 0.12858182191848755, -0.03672254830598831, 0.14131620526313782]
was [-0.3832331895828247, -0.6323017477989197, -0.4131511449813843, 0.024052325636148453, 0.2380470335483551]
flooded [0.32402217388153076, -0.5755363702774048, 0.16925807297229767, -0.15234218537807465, 0.8776127099990845]
. [0.5574973225593567, 0.1583925187587738, -0.570341944694519, 0.4329517185688019, -0.20079463720321655]
[SEP] [0.3168503940105438, 0.41758719086647034, -0.5019322633743286, 0.38555893301963806, 0.3209969103336334]
[CLS] [-0.3363654613494873, -0.3163451850414276, -0.0863780677318573, 0.0032097669318318367, -0.1854727864265442