In [44]:
import json
from sentence_transformers import SentenceTransformer
import pandas as pd

import json
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizer

from sklearn.neighbors import NearestNeighbors


In [14]:
datafile = "../../data/ent_extraction/bi.json"

In [15]:
import os
os.getcwd()

'/Users/yiyichen/Documents/experiments/Creole-NLU-NLG-Suite/wikipedia/ZS_BERT/model'

In [67]:
with open(datafile) as f:
    bi = json.load(f)

In [69]:
bi[0]

{'tokens': ['Kakao',
  'tri',
  'kam',
  'long',
  'Andes',
  'bigfala',
  'hil',
  ',',
  'Amazon',
  'reva',
  '.'],
 'edgeSet': [{'left': [2], 'right': [4]},
  {'left': [2], 'right': [8]},
  {'left': [4], 'right': [8]}]}

In [17]:
data = bi

# load data and properties embeds

In [19]:
prop_list_path='../resources/property_list.html'
sentence_embedder='bert-base-nli-mean-tokens'

prop_list = pd.read_html(prop_list_path, flavor="html5lib")[0]


In [20]:
prop_list.dropna(subset="description", inplace=True)

In [21]:
len(prop_list)

7174

In [22]:
encoder = SentenceTransformer(sentence_embedder)

In [23]:
property2id = {prop:idx for idx, prop in enumerate(prop_list["ID"].tolist())}
id2property = {idx:prop for idx, prop in enumerate(prop_list["ID"].tolist())}

In [24]:
sentence_embeddings = encoder.encode(prop_list.description.to_list())

In [26]:
pid2vec = {}
for pid, embedding in zip(prop_list.ID, sentence_embeddings):
    pid2vec[pid] = embedding.astype('float32')

In [27]:
pid2vec["P6"].shape

(768,)

In [29]:
def mark_wiki_entity(edge, sent_len):
    e1 = edge['left']
    e2 = edge['right']
    marked_e1 = np.array([0] * sent_len)
    marked_e2 = np.array([0] * sent_len)
    marked_e1[e1] += 1
    marked_e2[e2] += 1
    return torch.tensor(marked_e1, dtype=torch.long), torch.tensor(marked_e2, dtype=torch.long)


In [30]:
class WikiDataset(Dataset):
    def __init__(self, data, tokenizer="bert-base-multilingual-cased"):
        self.data = data
        self.len = len(self.data)
        self.tokenizer = BertTokenizer.from_pretrained(
            tokenizer, do_lower_case=False)

    def __getitem__(self, idx):
        g = self.data[idx]
        sentence = " ".join(g["tokens"])
        tokens = self.tokenizer.tokenize(sentence)
        tokens_ids = self.tokenizer.convert_tokens_to_ids(["[CLS]"] + tokens + ["[SEP]"])
        tokens_tensor = torch.tensor(tokens_ids)
        segments_tensor = torch.tensor([0] * len(tokens_ids),
                                       dtype=torch.long)
        edge = g["edgeSet"][0]
        marked_e1, marked_e2 = mark_wiki_entity(edge, len(tokens_ids))

        return (tokens_tensor, segments_tensor, marked_e1, marked_e2)

    def __len__(self):
        return self.len

In [31]:
def create_mini_batch(samples):
    tokens_tensors = [s[0] for s in samples]
    segments_tensors = [s[1] for s in samples]
    marked_e1 = [s[2] for s in samples]
    marked_e2 = [s[3] for s in samples]

    tokens_tensors = pad_sequence(tokens_tensors,
                                  batch_first=True)
    segments_tensors = pad_sequence(segments_tensors,
                                    batch_first=True)
    marked_e1 = pad_sequence(marked_e1,
                             batch_first=True)
    marked_e2 = pad_sequence(marked_e2,
                             batch_first=True)
    masks_tensors = torch.zeros(tokens_tensors.shape,
                                dtype=torch.long)
    masks_tensors = masks_tensors.masked_fill(
        tokens_tensors != 0, 1)
    
    return tokens_tensors, segments_tensors, marked_e1, marked_e2, masks_tensors


In [32]:
dataset =  WikiDataset(data)

In [33]:
dataloader = DataLoader(dataset, batch_size=16, collate_fn=create_mini_batch)

In [34]:
dataloader

<torch.utils.data.dataloader.DataLoader at 0x28970b490>

# loading model

In [5]:
from transformers import BertModel, BertConfig, BertPreTrainedModel, BertTokenizer
from model import ZSBert
import torch

In [3]:
model_path = "best_f1_0.7081677743338072_wiki_epoch_4_m_5_alpha_0.4_gamma_7.5"

In [35]:
# torch.load() requires model module in the same folder 

In [7]:
model = torch.load(model_path,map_location=torch.device('cpu'))

In [36]:
model.eval()

ZSBert(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    

In [45]:
attr = list(pid2vec.values())
attr = np.array(attr)

In [52]:
pid2vec.keys()

dict_keys(['P6', 'P10', 'P14', 'P15', 'P16', 'P17', 'P18', 'P19', 'P20', 'P21', 'P22', 'P25', 'P26', 'P27', 'P30', 'P31', 'P35', 'P36', 'P37', 'P38', 'P39', 'P40', 'P41', 'P47', 'P50', 'P51', 'P53', 'P54', 'P57', 'P58', 'P59', 'P61', 'P65', 'P66', 'P69', 'P78', 'P81', 'P84', 'P85', 'P86', 'P87', 'P88', 'P91', 'P92', 'P94', 'P97', 'P98', 'P101', 'P102', 'P103', 'P105', 'P106', 'P108', 'P109', 'P110', 'P111', 'P112', 'P113', 'P114', 'P115', 'P117', 'P118', 'P119', 'P121', 'P122', 'P123', 'P126', 'P127', 'P128', 'P129', 'P131', 'P134', 'P135', 'P136', 'P137', 'P138', 'P140', 'P141', 'P143', 'P144', 'P149', 'P150', 'P154', 'P155', 'P156', 'P157', 'P158', 'P159', 'P161', 'P162', 'P163', 'P166', 'P167', 'P169', 'P170', 'P171', 'P172', 'P175', 'P176', 'P177', 'P178', 'P179', 'P180', 'P181', 'P183', 'P184', 'P185', 'P186', 'P189', 'P190', 'P193', 'P194', 'P195', 'P196', 'P197', 'P199', 'P200', 'P201', 'P205', 'P206', 'P207', 'P208', 'P209', 'P210', 'P212', 'P213', 'P214', 'P215', 'P217', 'P218

In [54]:
id2property

{0: 'P6',
 1: 'P10',
 2: 'P14',
 3: 'P15',
 4: 'P16',
 5: 'P17',
 6: 'P18',
 7: 'P19',
 8: 'P20',
 9: 'P21',
 10: 'P22',
 11: 'P25',
 12: 'P26',
 13: 'P27',
 14: 'P30',
 15: 'P31',
 16: 'P35',
 17: 'P36',
 18: 'P37',
 19: 'P38',
 20: 'P39',
 21: 'P40',
 22: 'P41',
 23: 'P47',
 24: 'P50',
 25: 'P51',
 26: 'P53',
 27: 'P54',
 28: 'P57',
 29: 'P58',
 30: 'P59',
 31: 'P61',
 32: 'P65',
 33: 'P66',
 34: 'P69',
 35: 'P78',
 36: 'P81',
 37: 'P84',
 38: 'P85',
 39: 'P86',
 40: 'P87',
 41: 'P88',
 42: 'P91',
 43: 'P92',
 44: 'P94',
 45: 'P97',
 46: 'P98',
 47: 'P101',
 48: 'P102',
 49: 'P103',
 50: 'P105',
 51: 'P106',
 52: 'P108',
 53: 'P109',
 54: 'P110',
 55: 'P111',
 56: 'P112',
 57: 'P113',
 58: 'P114',
 59: 'P115',
 60: 'P117',
 61: 'P118',
 62: 'P119',
 63: 'P121',
 64: 'P122',
 65: 'P123',
 66: 'P126',
 67: 'P127',
 68: 'P128',
 69: 'P129',
 70: 'P131',
 71: 'P134',
 72: 'P135',
 73: 'P136',
 74: 'P137',
 75: 'P138',
 76: 'P140',
 77: 'P141',
 78: 'P143',
 79: 'P144',
 80: 'P149',
 81: 

In [61]:
prop_list.reset_index(inplace=True)

In [46]:
attr.shape

(7174, 768)

In [63]:
data_len =0
preds = []
for data in dataloader:
    tokens_tensors, segments_tensors, marked_e1, marked_e2, masks_tensors = [t for t in data if t is not None]
    with torch.no_grad():
        outputs, out_relation_emb = model(input_ids=tokens_tensors, 
                                        token_type_ids=segments_tensors,
                                        e1_mask=marked_e1,
                                        e2_mask=marked_e2,
                                        attention_mask=masks_tensors)
        logits = outputs[0]
        print(out_relation_emb.shape)
        
        tree = NearestNeighbors(n_neighbors=1, algorithm='ball_tree', metric=lambda a, b: -(a@b))
        tree.fit(attr)
        predictions = tree.kneighbors(out_relation_emb, 1, return_distance=False).flatten() 
        print(predictions)
        data_len+=len(predictions)
        preds+=[id2property[i] for i in predictions]

        
        

torch.Size([16, 768])
[1453 5079 1056 1453 3688 6666 3660 5605 2087  277 3996  277 2087 2087
 6415 6666]
torch.Size([16, 768])
[ 277 3688 1365 4930  277 2827 6415 3688 4930 5332 2827 2087 2087 2087
 2087 2087]
torch.Size([16, 768])
[2087  652 2773 2087 6817  360 1174 3996 2087 3660 5570 2743 1453 2087
 2087 1453]
torch.Size([16, 768])
[3688 3688 4792 4930 3660 6415  652  969  277 2087 3688  360  652  277
 1365 2087]
torch.Size([16, 768])
[1056 2827  578 6415 1056 3996 2087 3688  360 2743 3996 3996  360 3996
 6740 4792]
torch.Size([16, 768])
[ 360 2087  360  652  360  277 3688 3688 2087 3660 3688 6415  277  277
  277  277]
torch.Size([16, 768])
[ 360  277 2087  360 5327 4792  360 4792 1151   75  360 4383  277  969
 6415  652]
torch.Size([16, 768])
[2087 2087 3996 5603 2087 2827 6571 6415  652 3996 1745 2087 4930 6740
  360 1453]
torch.Size([16, 768])
[ 360 3783 1453  360  652 3688 3688 2827  360 3688 4792 6415 3688 3688
 2743  652]
torch.Size([16, 768])
[2743 1453  360 2087 4930 3996 54

In [64]:
preds

['P1855',
 'P5713',
 'P1362',
 'P1855',
 'P4254',
 'P7418',
 'P4224',
 'P6288',
 'P2559',
 'P465',
 'P4568',
 'P465',
 'P2559',
 'P2559',
 'P7149',
 'P7418',
 'P465',
 'P4254',
 'P1750',
 'P5554',
 'P465',
 'P3365',
 'P7149',
 'P4254',
 'P5554',
 'P5997',
 'P3365',
 'P2559',
 'P2559',
 'P2559',
 'P2559',
 'P2559',
 'P2559',
 'P886',
 'P3303',
 'P2559',
 'P7575',
 'P560',
 'P1534',
 'P4568',
 'P2559',
 'P4224',
 'P6251',
 'P3270',
 'P1855',
 'P2559',
 'P2559',
 'P1855',
 'P4254',
 'P4254',
 'P5401',
 'P5554',
 'P4224',
 'P7149',
 'P886',
 'P1269',
 'P465',
 'P2559',
 'P4254',
 'P560',
 'P886',
 'P465',
 'P1750',
 'P2559',
 'P1362',
 'P3365',
 'P805',
 'P7149',
 'P1362',
 'P4568',
 'P2559',
 'P4254',
 'P560',
 'P3270',
 'P4568',
 'P4568',
 'P560',
 'P4568',
 'P7497',
 'P5401',
 'P560',
 'P2559',
 'P560',
 'P886',
 'P560',
 'P465',
 'P4254',
 'P4254',
 'P2559',
 'P4224',
 'P4254',
 'P7149',
 'P465',
 'P465',
 'P465',
 'P465',
 'P560',
 'P465',
 'P2559',
 'P560',
 'P5992',
 'P5401',
 'P560

In [50]:
data_len

176

In [51]:
pid2vec.keys()

dict_keys(['P6', 'P10', 'P14', 'P15', 'P16', 'P17', 'P18', 'P19', 'P20', 'P21', 'P22', 'P25', 'P26', 'P27', 'P30', 'P31', 'P35', 'P36', 'P37', 'P38', 'P39', 'P40', 'P41', 'P47', 'P50', 'P51', 'P53', 'P54', 'P57', 'P58', 'P59', 'P61', 'P65', 'P66', 'P69', 'P78', 'P81', 'P84', 'P85', 'P86', 'P87', 'P88', 'P91', 'P92', 'P94', 'P97', 'P98', 'P101', 'P102', 'P103', 'P105', 'P106', 'P108', 'P109', 'P110', 'P111', 'P112', 'P113', 'P114', 'P115', 'P117', 'P118', 'P119', 'P121', 'P122', 'P123', 'P126', 'P127', 'P128', 'P129', 'P131', 'P134', 'P135', 'P136', 'P137', 'P138', 'P140', 'P141', 'P143', 'P144', 'P149', 'P150', 'P154', 'P155', 'P156', 'P157', 'P158', 'P159', 'P161', 'P162', 'P163', 'P166', 'P167', 'P169', 'P170', 'P171', 'P172', 'P175', 'P176', 'P177', 'P178', 'P179', 'P180', 'P181', 'P183', 'P184', 'P185', 'P186', 'P189', 'P190', 'P193', 'P194', 'P195', 'P196', 'P197', 'P199', 'P200', 'P201', 'P205', 'P206', 'P207', 'P208', 'P209', 'P210', 'P212', 'P213', 'P214', 'P215', 'P217', 'P218

In [None]:
def evaluate(preds, y_attr, y_label, idxmap, num_train_y, dist_func='inner'):
    assert dist_func in ['inner', 'euclidian', 'cosine']
    if dist_func == 'inner':
        tree = NearestNeighbors(n_neighbors=1, algorithm='ball_tree', metric=lambda a, b: -(a@b))
    elif dist_func == 'euclidian':
        tree = NearestNeighbors(n_neighbors=1)
    elif dist_func == 'cosine':
        tree = NearestNeighbors(n_neighbors=1, algorithm='ball_tree', metric=lambda a, b: -((a@b) / (( (a@a) **.5) * ( (b@b) ** .5) )))
    tree.fit(y_attr)
    predictions = tree.kneighbors(preds, 1, return_distance=False).flatten() + num_train_y
    p_macro, r_macro, f_macro = compute_macro_PRF(predictions, y_label)
    return p_macro, r_macro, f_macro

In [None]:
def extract_relation_emb(model, testloader):
    out_relation_embs = None
    model.eval()
    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    for data in testloader:
        tokens_tensors, segments_tensors, marked_e1, marked_e2, \
        masks_tensors, relation_emb = [t.to(device) for t in data if t is not None]

        with torch.no_grad():
            outputs, out_relation_emb = model(input_ids=tokens_tensors, 
                                        token_type_ids=segments_tensors,
                                        e1_mask=marked_e1,
                                        e2_mask=marked_e2,
                                        attention_mask=masks_tensors)
            logits = outputs[0]
        if out_relation_embs is None:
            out_relation_embs = out_relation_emb
        else:
            out_relation_embs = torch.cat((out_relation_embs, out_relation_emb))    
    return out_relation_embs