#https://github.com/BramVanroy/bert-for-inference/blob/master/introduction-to-bert.ipynb 

In [23]:
import torch
from transformers import BertModel, BertTokenizer

In [24]:
# Initialize the tokenizer with a pretrained model
model_name = 'bert-base-german-cased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name, output_hidden_states=True)

In [25]:
# Convert the string "granola bars" to tokenized vocabulary IDs
granola_ids = tokenizer.encode('Ich bin ein Minion')
# Print the IDs
print('granola_ids', granola_ids)
# Convert the IDs to the actual vocabulary item
# Notice how the subword unit (suffix) starts with "##" to indicate 
# that it is part of the previous string
print('granola_tokens', tokenizer.convert_ids_to_tokens(granola_ids))

granola_ids [3, 1671, 4058, 39, 14156, 23, 4]
granola_tokens ['[CLS]', 'Ich', 'bin', 'ein', 'Mini', '##on', '[SEP]']


In [26]:
# Convert the list of IDs to a tensor of IDs 
granola_ids = torch.LongTensor(granola_ids)
# Print the IDs
print('granola_ids', granola_ids)

granola_ids tensor([    3,  1671,  4058,    39, 14156,    23,     4])


In [27]:
# Set the device to GPU (cuda) if available, otherwise stick with CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = model.to(device)
granola_ids = granola_ids.to(device)

model.eval()
print(device)

cpu


In [28]:
print(granola_ids.size())
# unsqueeze IDs to get batch size of 1 as added dimension
granola_ids = granola_ids.unsqueeze(0)
print(granola_ids.size())

print(type(granola_ids))
with torch.no_grad():
    out = model(input_ids=granola_ids)

# the output is a tuple
print(type(out))
# the tuple contains three elements as explained above)
print(len(out))
# we only want the hidden_states
hidden_states = out[2]
print(len(hidden_states))
print(hidden_states.shape)


torch.Size([7])
torch.Size([1, 7])
<class 'torch.Tensor'>
<class 'tuple'>
3
13


AttributeError: 'tuple' object has no attribute 'shape'

In [18]:
print(model)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30000, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [19]:
sentence_embedding = torch.mean(hidden_states[-1], dim=1).squeeze()
print(sentence_embedding)
print(sentence_embedding.size())

tensor([-1.7320e-01,  4.0907e-02,  4.5438e-01, -4.2997e-02,  1.8002e-01,
         5.4788e-01, -4.8228e-02, -4.2350e-01, -8.4088e-01,  7.3957e-01,
        -3.6593e-01,  3.5973e-01,  8.3447e-02, -3.2965e-01,  8.1233e-01,
         8.9030e-02,  5.6575e-01, -5.2960e-01,  4.3412e-01,  6.6688e-02,
        -8.7260e-02,  3.8725e-01, -4.9636e-01, -3.0553e-01, -9.4124e-01,
         5.3892e-01,  4.1416e-02, -3.7828e-01, -4.5504e-01,  2.9940e-01,
         7.2828e-04,  2.3506e-02, -3.7628e-02, -3.0512e-01, -8.8771e-03,
        -4.5097e-01, -3.8860e-01, -7.2443e-01,  5.2398e-01,  1.8769e-02,
        -5.7824e-02,  7.2235e-01, -7.1317e-01, -1.7501e-01,  1.5490e-01,
         4.5173e-02, -3.7515e-01, -1.8814e-01, -6.6256e-01, -1.8524e-02,
        -3.7364e-01, -4.2808e-01,  7.0594e-01, -1.9603e-01,  3.3559e-02,
         1.5797e-01, -4.0003e-02,  5.2094e-01, -5.2013e-01,  1.0825e-01,
         3.2536e-01,  7.7148e-01, -1.2824e-01, -1.2871e-01,  2.3662e-01,
        -3.7407e-01, -1.7652e-02, -2.4922e-01, -2.2

In [20]:
# get last four layers
last_four_layers = [hidden_states[i] for i in (-1, -2, -3, -4)]
# cast layers to a tuple and concatenate over the last dimension
cat_hidden_states = torch.cat(tuple(last_four_layers), dim=-1)
print(cat_hidden_states.size())

# take the mean of the concatenated vector over the token dimension
cat_sentence_embedding = torch.mean(cat_hidden_states, dim=1).squeeze()
print(cat_hidden_states.shape)
print(cat_sentence_embedding)
print(cat_sentence_embedding.size())

torch.Size([1, 7, 3072])
torch.Size([1, 7, 3072])
tensor([-0.1732,  0.0409,  0.4544,  ..., -0.5578, -0.2156, -0.0674])
torch.Size([3072])
