In [55]:
from transformers import AutoTokenizer, AutoModel
import torch

Load pre-trained SciBERT model and tokenizer

In [56]:
tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased")

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Define input text

In [57]:
text = "Ibuprofen and aspirin are commonly used pain relievers."
print(text)

Ibuprofen and aspirin are commonly used pain relievers.


Tokenize text and convert to tensor

In [58]:
input_ids = torch.tensor([tokenizer.encode(text, add_special_tokens=True)])
print(input_ids)

tensor([[  102,  6749,   251, 16639,   117,   137, 15858,   220,  4531,   501,
          2675, 18464,  1559,   205,   103]])


Run text through SciBERT model

In [59]:
with torch.no_grad():
    outputs = model(input_ids)

Extract contextualized word embeddings

In [60]:
last_hidden_states = outputs[0]
contextualized_embeddings = last_hidden_states[0]

Print contextualized embeddings for each token

In [61]:
for i, token in enumerate(tokenizer.tokenize(text)):
    print("{}: {}".format(token, contextualized_embeddings[i].tolist()))

ib: [-1.3792442083358765, -0.6788857579231262, -1.38141667842865, 0.3280603885650635, 0.08109791576862335, -0.7253592014312744, 0.48939192295074463, -0.38629329204559326, 0.48549965023994446, 0.35323214530944824, -0.14118917286396027, -0.3727581799030304, -0.1135111153125763, -0.11998011916875839, -0.1281462162733078, -0.4569450616836548, -2.861097574234009, -0.3100034296512604, 1.0678069591522217, -1.4393365383148193, -0.8098611235618591, 0.20574893057346344, -0.513880729675293, -0.08407843858003616, 0.7659168839454651, 0.76183021068573, 0.07167866826057434, -0.44636425375938416, -0.46413296461105347, 0.7201201319694519, 0.644421398639679, -0.6340055465698242, -0.37077596783638, -0.6436165571212769, 0.12185201048851013, 0.25162070989608765, -0.14171606302261353, -0.9345234036445618, -0.13391265273094177, -0.36855247616767883, 0.2599562108516693, 0.7158434391021729, 0.9598493576049805, -0.2235429286956787, -0.041905906051397324, -0.4176367223262787, 0.543834924697876, 0.010009899735450