In [10]:
import torch
from transformers import AutoTokenizer, AutoModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained("google/bert_uncased_L-2_H-128_A-2")
model = AutoModel.from_pretrained("google/bert_uncased_L-2_H-128_A-2")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Some weights of the model checkpoint at google/bert_uncased_L-2_H-128_A-2 were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
aliases = ['united states', 'united states of america', 'u.s.', 'u.s.a.', 'u.s', 'u.s.a', 'us', 'usa', 'america', 'american', 'americans', 'u.s']
inputs = tokenizer(aliases, return_tensors="pt", padding=True, truncation=True).to(device)
model.to(device)
model.eval()
with torch.no_grad():
    outputs = model(**inputs)

In [8]:
print(inputs['input_ids'].shape)

torch.Size([12, 8])


In [6]:
print(outputs.keys())

odict_keys(['last_hidden_state', 'pooler_output'])


In [7]:
print(outputs.pooler_output.shape) 

torch.Size([12, 128])


In [10]:
avg_embedding = torch.mean(outputs.pooler_output, dim=0)
print(avg_embedding.shape)

torch.Size([128])


In [1]:
import os
import joblib

In [33]:
import os

data_dir = '../data/wikidata5m'
embeddings_file = 'wikidata5m_entity_embeddings_dim-128.pkl'
embeddings_path = os.path.join(data_dir, embeddings_file)

with open(embeddings_path, 'rb') as f:
    embeddings = joblib.load(f)

print(len(embeddings))

1000


In [34]:
import random

test_file = 'test_entity_embeddings.pkl'
test_path = os.path.join(data_dir, test_file)

test_embeddings = {}

for i in range(1000000):
    random_id = f'Q{random.randint(10000, 99999)}'
    random_embedding = torch.randn(128).numpy()
    test_embeddings[random_id] = random_embedding
    
with open(test_path, 'wb') as f:
    joblib.dump(test_embeddings, f)
    
with open(test_path, 'rb') as f:
    test_embeddings = joblib.load(f)
    
print(len(test_embeddings))

89999


In [28]:
embeddings.keys()

dict_keys(['Q80129', 'Q62717', 'Q89477', 'Q33873', 'Q33283', 'Q12394', 'Q83107', 'Q16496', 'Q56465', 'Q16842', 'Q13360', 'Q21133', 'Q43655', 'Q21908', 'Q63173', 'Q72893', 'Q25096', 'Q13182', 'Q33351', 'Q20846', 'Q71230', 'Q21698', 'Q83054', 'Q66889', 'Q68237', 'Q78536', 'Q38968', 'Q13034', 'Q86933', 'Q44324', 'Q43854', 'Q90975', 'Q78853', 'Q67804', 'Q26657', 'Q30770', 'Q43651', 'Q52066', 'Q91095', 'Q56973', 'Q21980', 'Q51099', 'Q63531', 'Q96604', 'Q19031', 'Q62803', 'Q42345', 'Q32355', 'Q32040', 'Q48553', 'Q68251', 'Q21368', 'Q87847', 'Q35668', 'Q71464', 'Q73089', 'Q89883', 'Q99034', 'Q26686', 'Q57152', 'Q24350', 'Q86976', 'Q77266', 'Q14910', 'Q71418', 'Q34166', 'Q82037', 'Q96269', 'Q91890', 'Q70164', 'Q71324', 'Q55527', 'Q26072', 'Q67763', 'Q48190', 'Q35970', 'Q51429', 'Q52600', 'Q98990', 'Q98612', 'Q44446', 'Q91904', 'Q94404', 'Q61414', 'Q56670', 'Q46896', 'Q43256', 'Q26904', 'Q58043', 'Q22098', 'Q42368', 'Q48539', 'Q18502', 'Q39334', 'Q47821', 'Q38676', 'Q22439', 'Q97844', 'Q72914',

In [30]:
embeddings['Q80129']

array([ 0.20112823, -1.450661  ,  1.1963005 , -0.6425147 , -1.6208289 ,
        1.5463616 , -1.4418917 , -0.4097485 , -0.8300499 , -1.948085  ,
        0.2312974 ,  1.0275522 , -0.22805353, -0.7385629 , -0.8833031 ,
        0.16548386,  0.2086623 , -0.4788725 ,  0.29602957, -0.2173062 ,
        0.23286653,  0.06512939,  0.5301847 ,  0.07341131, -1.9372576 ,
       -0.11398292, -0.577861  , -0.8262021 , -0.8494098 ,  1.4643928 ,
       -0.43943974,  0.04750703,  0.16258834,  2.1225817 ,  1.0598441 ,
        0.15834662, -0.76255405, -1.3511376 ,  0.17397967, -0.37576386,
       -0.74411047, -1.870567  ,  0.9663696 ,  0.804181  , -0.6139993 ,
        0.8406251 ,  0.6854169 , -0.396833  , -2.3775034 , -1.0109088 ,
        0.65014076,  0.07670992,  2.0412107 ,  1.0906005 , -0.46683833,
       -0.43985477, -0.820184  , -0.32026583,  0.8586788 ,  2.0500412 ,
        0.19995806, -0.5528527 , -0.91998243,  1.3814806 ,  0.01663729,
        1.2966276 , -1.3215276 , -0.76077044, -1.6607566 ,  1.57