In [8]:
import pickle
import string
import numpy as np
import pandas as pd
from tqdm import tqdm
from itertools import combinations, product

import torch
from transformers import BertTokenizer, BertModel

import nltk
import en_core_web_sm
from nltk.metrics.distance import jaccard_distance

import warnings
warnings.simplefilter('ignore')

In [57]:
string.punctuation += '’'
string.punctuation += '–'
nlp = en_core_web_sm.load()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
bert_model = BertModel.from_pretrained("bert-base-cased")
bert_model = bert_model.to(device)

In [58]:
def process(sentence, tokenizer, nlp, return_pt=True):
    doc = nlp(sentence)
    tokens = list(doc)

    chunk2id = {}

    start_chunk = []
    end_chunk = []
    noun_chunks = []
    for chunk in doc.noun_chunks:
        noun_chunks.append(chunk.text)
        start_chunk.append(chunk.start)
        end_chunk.append(chunk.end)

    sentence_mapping = []
    token2id = {}
    mode = 0 # 1 in chunk, 0 not in chunk       
    chunk_id = 0
    for idx, token in enumerate(doc):
        if idx in start_chunk:
            mode = 1
            sentence_mapping.append(noun_chunks[chunk_id])
            if sentence_mapping[-1] not in token2id:
                token2id[sentence_mapping[-1]] = len(token2id)
            chunk_id += 1
        elif idx in end_chunk:
            mode = 0

        if mode == 0:
            sentence_mapping.append(token.text)
            if sentence_mapping[-1] not in token2id:
                token2id[sentence_mapping[-1]] = len(token2id)


    token_ids = []
    tokenid2word_mapping = []

    for token in sentence_mapping:
        subtoken_ids = tokenizer(str(token), add_special_tokens=False)['input_ids']
        tokenid2word_mapping += [ token2id[token] ]*len(subtoken_ids)
        token_ids += subtoken_ids

    tokenizer_name = str(tokenizer.__str__)
    if 'GPT2' in tokenizer_name:
        outputs = {
            'input_ids': token_ids,
            'attention_mask': [1]*(len(token_ids)),
        }

    else:
        outputs = {
            'input_ids': [tokenizer.cls_token_id] + token_ids + [tokenizer.sep_token_id],
            'attention_mask': [1]*(len(token_ids)+2),
            'token_type_ids': [0]*(len(token_ids)+2)
        }

    if return_pt:
        for key, value in outputs.items():
            outputs[key] = torch.from_numpy(np.array(value)).long().unsqueeze(0)
    
    return outputs, tokenid2word_mapping, token2id, noun_chunks, sentence_mapping


def compress_attention(attention, tokenid2word_mapping, operator=np.mean):

    new_index = []
    
    prev = -1
    for idx, row in enumerate(attention):
        token_id = tokenid2word_mapping[idx]
        if token_id != prev:
            new_index.append( [row])
            prev = token_id
        else:
            new_index[-1].append(row)

    new_matrix = []
    for row in new_index:
        new_matrix.append(operator(np.array(row), 0))

    new_matrix = np.array(new_matrix)

    attention = np.array(new_matrix).T

    prev = -1
    new_index=  []
    for idx, row in enumerate(attention):
        token_id = tokenid2word_mapping[idx]
        if token_id != prev:
            new_index.append( [row])
            prev = token_id
        else:
            new_index[-1].append(row)

    
    new_matrix = []
    for row in new_index:
        new_matrix.append(operator(np.array(row), 0))
    
    new_matrix = np.array(new_matrix)
    
    return new_matrix.T


def get_outputs(sentence, tokenizer, encoder, nlp, use_cuda=False):

    tokenizer_name = str(tokenizer.__str__)
    inputs, tokenid2word_mapping, token2id, tokens, sentence_mapping = process(sentence, nlp=nlp, tokenizer=tokenizer, return_pt=True)
    id2token = {value: key for key, value in token2id.items()}
    for key in inputs.keys():
        inputs[key] = inputs[key].to(device)
    outputs = encoder(**inputs, output_attentions=True)
    
    return outputs[2], tokenid2word_mapping, token2id, sentence_mapping


def get_embeddings(sentence):
    rel_pos = ['NN', 'NNP', 'NNS', 'JJR', 'JJS', 'MD', 'POS', 'VB', 'VBG', 'VBD', 'VBN', 'VBP', 'VBZ']
    head_tail_pos = ['NN', 'NNP', 'NNS', 'PRP']

    use_cuda = True    
    att, tokenid2word_mapping, token2id, sentence_mapping = get_outputs(sentence, bert_tokenizer, bert_model, nlp, use_cuda=use_cuda)
    
    new_matr = []
    
    for layer in att:
        for head in layer.squeeze():
            attn = head.cpu()
            attention_matrix = attn.detach().numpy()
            attention_matrix = attention_matrix[1:-1, 1:-1]
            
            merged_attention = compress_attention(attention_matrix, tokenid2word_mapping)
            
            new_matr.append(merged_attention)
    
    new_matr = np.stack(new_matr)
    
    words = [token for token in sentence_mapping if token not in string.punctuation]
    
    nn_words = [word for word in words if nltk.pos_tag([word])[0][1] in head_tail_pos]
    other_words = [word for word in words if nltk.pos_tag([word])[0][1] in rel_pos]
    
    triplets = [triplet for triplet in list(product(nn_words, nn_words, other_words)) 
                if triplet[0] != triplet[1] and triplet[0] != triplet[2] and triplet[1] != triplet[2]]
    
    sent_embeddings = []
    
    for triplet in triplets:
       
        head_ind = sentence_mapping.index(triplet[0])
        tail_ind = sentence_mapping.index(triplet[1])
        rel_ind = sentence_mapping.index(triplet[2])   

        head_rel_emb = new_matr[:, head_ind, rel_ind]
        rel_tail_emb = new_matr[:, rel_ind, tail_ind]

        triplet_emb = np.concatenate((head_rel_emb, rel_tail_emb), axis=0).squeeze()
        sent_embeddings.append((triplet_emb, triplet))
        
    
    return sent_embeddings


def deduplication(pred_list):
    pred_max_conf = {}
    filtered_pred = {}

    for ind, pred in enumerate(pred_list):
        pred_triplet = (pred[1][0], pred[1][1])

        if pred_triplet not in filtered_pred.keys():
            pred_max_conf[pred_triplet] = pred[2]
            filtered_pred[pred_triplet] = pred

        elif pred_triplet in filtered_pred and pred[2] > pred_max_conf[pred_triplet]:
            pred_max_conf[pred_triplet] = pred[2]
            filtered_pred[pred_triplet] = pred
    
    sorted_pred = sorted(list(filtered_pred.values()), key=lambda x: x[2], reverse=True)
    prediction = [el[1] for el in sorted_pred]
    
    return prediction


def get_predictions(sentence, threshold_bin=0.5):
    pred_list = []
    emb_sent = get_embeddings(sentence)
    for emb in emb_sent:
        binary_conf = lr_bin.predict_proba(emb[0].reshape(1, -1))[0][1]
        if binary_conf > threshold_bin:
            predicted_label = list(lr_multi.predict(emb[0].reshape(1, -1)))[0]
            triplet = emb[1]
            pred_list.append((predicted_label, triplet, binary_conf))
    return pred_list


def compare_triplets(targets, predict, dist_thresh=0.4):
    compare_result = []
    for target in targets:
        sub_compare = []
        for target, predict_ in zip(target, predict):
            answer =  False
            dist = jaccard_distance(set(target.lower()), set(predict_.lower()))
            if predict_ in target or dist < dist_thresh or target in predict_:
                answer = True
            sub_compare.append(answer)
        sub_compare = all(sub_compare)
        compare_result.append(sub_compare)
    return any(compare_result)    


def compute_logreg_nm(dataset):
    fp, tp, fn = 0, 0, 0
    tp_predicts_dict = {}
    for row in dataset.itertuples():
        
        try:
            predictions = get_predictions(row.text, threshold_bin=0.7)
        except IndexError:
            continue

        filtered_predictions = deduplication(predictions)
        print(filtered_predictions)
        target_triplets = [target[:3] for target in eval(row.target)]
        print('!!! target',target_triplets)
        tp_predicts = []
        for predict in filtered_predictions:
            
            score_bool = compare_triplets(target_triplets, predict)

            if score_bool:
                tp_predicts.append(predict)
                tp += 1
            else:
                fp += 1
                
        if len(tp_predicts):
            tp_predicts_dict[row.text] = tp_predicts
        
        for target in target_triplets:
            score_bool = compare_triplets(filtered_predictions, target)
            if not score_bool:
                fn += 1

    try:        
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        f1 = 2 * (precision * recall) / (precision + recall)
        return precision, recall, f1, tp_predicts_dict
    
    except ZeroDivisionError:
        return 0, 0, 0, tp_predicts_dict


def compute_csv_default(dataset, sample_size=100, filename='filename'):
    rels, prs, rcls, f1s = [], [], [], []
    sizes, labels, preds = [], [], []

    for rel in tqdm(sorted(lr_multi.classes_, key=lambda x: int(x[1:]))):
        mono_tr_subset = dataset[dataset.rel == rel]

        if not mono_tr_subset.empty:
            if mono_tr_subset.shape[0] > sample_size:
                mono_tr_subset = mono_tr_subset.sample(sample_size)

            label = get_title(rel)
            size = mono_tr_subset.shape[0]
            try:
                precision, recall, f1, pred_dict = compute_logreg_nm(mono_tr_subset)
            except:
                print(f'rel {rel} skipped, check soon')
                continue

            rels.append(rel)
            prs.append(precision)
            rcls.append(recall)
            f1s.append(f1)
            sizes.append(size)
            labels.append(label)
            preds.append(pred_dict)

            scoring_result = pd.DataFrame({'rel': rels,
                                           'label': labels,
                                           'size': sizes, 
                                           'precision': prs, 
                                           'recall': rcls, 
                                           'f1': f1s,
                                           'tps': preds})

            scoring_result.to_csv(f'{filename}.csv', index=False)

def get_title(relation_id):
    return RELATIONS[RELATIONS.relation==relation_id].title.values[0]

In [50]:
get_predictions('hi my name is Sanzhar i am from Almaty')

[('P17', ('hi', 'my name', 'i'), 0.5986786860214062),
 ('P276', ('Sanzhar', 'my name', 'i'), 0.5261964351536904),
 ('P276', ('Almaty', 'my name', 'i'), 0.6590005936830667)]

In [51]:
deduplication(get_predictions('hi my name is Sanzhar i am from Almaty'))

[('Almaty', 'my name', 'i'),
 ('hi', 'my name', 'i'),
 ('Sanzhar', 'my name', 'i')]

In [10]:
with open('logreg_multi.pkl', 'rb') as file:
    lr_multi = pickle.load(file)
    
with open('logreg_bin.pkl', 'rb') as file:
    lr_bin = pickle.load(file)

In [17]:
RELATIONS = pd.read_csv('../data/meta/relations_docred.csv')

In [53]:
dataset = pd.read_csv('../data/train-val-test/valid.csv')

In [59]:
mono_tr_subset = dataset[dataset.rel == 'P19']
precision, recall, f1, pred_dict = compute_logreg_nm(mono_tr_subset)

[('He', 'County Down', 'born'), ('He', 'Margaret Garrett', 'born'), ('He', 'William Kidd', 'born'), ('He', 'the son', 'born')]
!!! target [('He', 'County Down', 'born in')]
[]
!!! target [('Jon Juaristi', 'Bilbao', 'born in')]
[('He', 'a son', 'born'), ('He', 'Haapsalu', 'born'), ('He', 'Field', 'born'), ('He', 'Agneta von Dellwig', 'born'), ('He', 'Marshal Carl Horn', 'born')]
!!! target [('He', 'Haapsalu', 'born in')]
[('Trier', 'Rome', 'moved'), ('He', 'Rome', 'moved'), ('Germany', 'Rome', 'moved'), ('He', 'Germany', 'born'), ('a lawyer', 'Rome', 'moved'), ('He', 'Trier', 'born'), ('born', 'Rome', 'moved'), ('He', 'a lawyer', 'born')]
!!! target [('He', 'Trier', 'born at')]
[('He', 'Milan', 'born'), ('Politecnico di Milano', 'Milan', 'born'), ('he', 'Milan', 'born'), ('studied architecture', 'Milan', 'born')]
!!! target [('He', 'Milan', 'born in')]
[('Pierre-Gilles de Gennes', 'Paris', 'born')]
!!! target [('Pierre-Gilles de Gennes', 'Paris, France', 'born in')]
[('He', 'Chalkis', '

In [60]:
precision, recall, f1

(0.12195121951219512, 0.8333333333333334, 0.21276595744680848)

In [61]:
compute_csv_default(dataset, 100)

  0%|          | 0/61 [00:00<?, ?it/s]

[('Braúnas', 'Brazil', 'a municipality'), ('Minas Gerais', 'Brazil', 'is'), ('the Southeast region', 'Brazil', 'is')]
!!! target [('municipality', 'Brazil', 'state'), ('Minas Gerais', 'Brazil', 'state')]
[]
!!! target [('Mariehamn', 'Finland', 'land')]
[('India', 'Karnataka', 'a village'), ('the southern state', 'Karnataka', 'a village')]
!!! target [('Karnataka', 'India', 'state')]
[('Nebriinae', 'California', 'the US state'), ('Nebria darlingtoni', 'California', 'the US state'), ('ground beetle', 'California', 'the US state'), ('a species', 'California', 'the US state'), ('Nebria darlingtoni', 'ground beetle', 'a species')]
!!! target [('California', 'US', 'state')]


  2%|▏         | 1/61 [00:02<02:58,  2.98s/it]

rel P17 skipped, check soon
rel P19 skipped, check soon
[]
!!! target [('He', 'Erlangen', 'died in')]
[('He', 'Michigan', 'died')]
!!! target [('He', 'Marshall, Michigan', 'died in')]
[('he', 'master', 'an appointment'), ('master', 'preacher', 'became'), ('Kitzingen', 'preacher', 'became'), ('Nuremberg', 'preacher', 'became'), ('the Sebaldus school', 'preacher', 'became'), ('he', 'preacher', 'became'), ('Nuremberg', 'master', 'an appointment'), ('preacher', 'master', 'an appointment'), ('Kitzingen', 'master', 'an appointment'), ('an appointment', 'preacher', 'became'), ('the Sebaldus school', 'master', 'an appointment')]
!!! target [('Sebaldus', 'Nuremberg', 'died in')]
[]
!!! target [('He', 'Graz', 'died in')]
[]
!!! target [('He', 'Tehran', 'died in')]
[]
!!! target [('He', 'Saint-Anselme', 'died in')]
[]
!!! target [('He', 'Rotherham', 'died in')]
[('He', 'Tarn', 'born'), ('He', 'Labarthe', 'born'), ('He', 'et', 'born'), ('He', 'Garonne', 'born'), ('He', 'Montauban', 'born')]
!!! ta

  5%|▍         | 3/61 [00:04<01:09,  1.20s/it]

[]
!!! target [('He', 'Caserta', 'died in')]
rel P20 skipped, check soon


  8%|▊         | 5/61 [00:04<00:34,  1.61it/s]

rel P22 skipped, check soon
rel P25 skipped, check soon


 11%|█▏        | 7/61 [00:04<00:19,  2.73it/s]

rel P26 skipped, check soon
[('Sogn', 'og', 'the county'), ('Fjordane', 'og', 'the county')]
!!! target [('Sogn og Fjordane', 'county', 'is a')]
rel P31 skipped, check soon


 15%|█▍        | 9/61 [00:05<00:28,  1.79it/s]

rel P36 skipped, check soon
[('Laodamas', 'the son', 'kills'), ('Laodamas', 'Eteocles', 'the son')]
!!! target [('Eteocles', 'Laodamas', 'son')]
rel P40 skipped, check soon





IndexError: index 0 is out of bounds for axis 0 with size 0