# Experiment with BERT first

## Get embeddings from target word in a sentence

In [9]:
from transformers import BertTokenizer, BertModel
import torch

In [21]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

In [115]:
target_word = "puppeteer"
test_sent = "They stole all the money in the bank."
test_sent = "The puppeteer show and the puppeteer came late."

In [116]:
input_ids = torch.tensor(tokenizer.encode(test_sent, add_special_tokens=True)).unsqueeze(0)  # Batch size 1
print(tokenizer.convert_ids_to_tokens(input_ids[0]))
print(input_ids)
outputs = model(input_ids)
last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

['[CLS]', 'the', 'puppet', '##eer', 'show', 'and', 'the', 'puppet', '##eer', 'came', 'late', '.', '[SEP]']
tensor([[  101,  1996, 13997, 11510,  2265,  1998,  1996, 13997, 11510,  2234,
          2397,  1012,   102]])


In [122]:
target_ids = tokenizer.encode(target_word, add_special_tokens=False)
# USE REPRESENTATION OF FIRST SUB-TOKEN AS EMBEDDING (Following Devlin et al 2019)
inputs_list = input_ids[0].tolist()
target_len = len(target_ids)
matches = [target_ids == inputs_list[i:i+target_len] for i in range(len(inputs_list) - target_len + 1)]
matches_id = [j for j, val in enumerate(matches) if val]
print(matches_id) 

[2, 7]


In [123]:
# It's possible to have the same word in the sentence twice
for match in matches_id:
    print(outputs[0][0][match])

tensor([ 4.3486e-01, -1.4866e-02,  7.8103e-01,  4.4384e-01,  5.2585e-01,
         1.1567e-01, -1.4227e-01, -9.4531e-02, -4.3484e-01, -2.5195e-01,
         1.2881e-01, -9.9183e-01,  6.5990e-01,  9.5503e-01, -1.0791e+00,
        -2.1023e-01, -6.6997e-01,  2.1093e-01, -4.4637e-01,  7.1933e-01,
        -2.9007e-01, -6.6554e-01, -4.7834e-01,  2.0128e-01,  1.1098e-01,
        -1.0652e-01,  1.6658e-01,  3.4497e-01,  7.8254e-02,  2.0630e-01,
        -4.2258e-02,  3.2207e-01,  3.7641e-01, -3.2900e-02, -4.0481e-01,
         7.6611e-02, -7.8441e-02,  2.0920e-01, -5.8172e-01,  2.2798e-01,
        -2.4296e-01, -2.2732e-01,  5.5855e-01, -5.3747e-02, -1.4880e-01,
        -2.9582e-01,  8.3967e-01, -4.6689e-01,  5.5735e-01,  1.2176e-01,
        -6.8455e-01,  4.1413e-01, -5.4875e-01, -2.6541e-01,  8.5599e-02,
         4.1755e-01,  6.4972e-01, -6.8411e-01,  3.3127e-01, -6.9191e-02,
        -2.8705e-01,  9.1383e-01,  1.4944e-01, -5.7320e-01, -5.7987e-01,
        -1.2310e-01,  1.2783e-01,  1.4330e+00, -1.0