In [1]:
import os
import string
import pickle
from collections import OrderedDict

import numpy as np
from tqdm import tqdm
from itertools import combinations, product

import nltk
from nltk.tokenize import sent_tokenize
from nltk.metrics.distance import jaccard_distance

import torch

string.punctuation += '’'
string.punctuation += '–'
string.punctuation += '”'


def get_predictions(sentence, attentions_types, use_bert, use_lmms, lr_bin, lr_multi, tokenizer, encoder, nlp, threshold_bin=0.5):
    pred_list = []
    emb_sent = get_embeddings(sentence, attentions_types, use_bert, use_lmms, tokenizer, encoder, nlp)
    for emb in tqdm(emb_sent, total=len(emb_sent), leave=False):
        binary_conf = lr_bin.predict_proba(emb[0].reshape(1, -1))[0][1]

        if binary_conf > threshold_bin:
            predicted_label = list(lr_multi.predict(emb[0].reshape(1, -1)))[0]
            triplet = emb[1]
            pred_list.append((predicted_label, triplet, binary_conf))
    return pred_list

def process(sentence, tokenizer, nlp, return_pt=True):
    doc = nlp(sentence)
    tokens = list(doc)

    chunk2id = {}

    start_chunk = []
    end_chunk = []
    noun_chunks = []
    for chunk in doc.noun_chunks:
        noun_chunks.append(chunk.text)
        start_chunk.append(chunk.start)
        end_chunk.append(chunk.end)

    sentence_mapping = []
    token2id = {}
    mode = 0 # 1 in chunk, 0 not in chunk       
    chunk_id = 0
    for idx, token in enumerate(doc):
        if idx in start_chunk:
            mode = 1
            sentence_mapping.append(noun_chunks[chunk_id])
            if sentence_mapping[-1] not in token2id:
                token2id[sentence_mapping[-1]] = len(token2id)
            chunk_id += 1
        elif idx in end_chunk:
            mode = 0

        if mode == 0:
            sentence_mapping.append(token.text)
            if sentence_mapping[-1] not in token2id:
                token2id[sentence_mapping[-1]] = len(token2id)


    token_ids = []
    tokenid2word_mapping = []

    for token in sentence_mapping:
        subtoken_ids = tokenizer(str(token), add_special_tokens=False)['input_ids']
        tokenid2word_mapping += [ token2id[token] ]*len(subtoken_ids)
        token_ids += subtoken_ids

    tokenizer_name = str(tokenizer.__str__)

    outputs = {
        'input_ids': [tokenizer.cls_token_id] + token_ids + [tokenizer.sep_token_id],
        'attention_mask': [1]*(len(token_ids)+2),
        'token_type_ids': [0]*(len(token_ids)+2)
    }

    if return_pt:
        for key, value in outputs.items():
            outputs[key] = torch.from_numpy(np.array(value)).long().unsqueeze(0)
    
    return outputs, tokenid2word_mapping, token2id, noun_chunks, sentence_mapping


def deduplication(pred_list):
    pred_max_conf = {}
    filtered_pred = {}

    for ind, pred in enumerate(pred_list):
        pred_triplet = (pred[1][0], pred[1][1])

        if pred_triplet not in filtered_pred.keys():
            pred_max_conf[pred_triplet] = pred[2]
            filtered_pred[pred_triplet] = pred

        elif pred_triplet in filtered_pred and pred[2] > pred_max_conf[pred_triplet]:
            pred_max_conf[pred_triplet] = pred[2]
            filtered_pred[pred_triplet] = pred
    
    sorted_pred = sorted(list(filtered_pred.values()), key=lambda x: x[2], reverse=True)
    prediction = [el[1] for el in sorted_pred]
    predicted_labels = [el[0] for el in sorted_pred]
    
    return prediction, predicted_labels


def compare_triplets(targets, predict, dist_thresh=0.2):
    compare_result = []
    for target in targets:
        sub_compare = []
        for target, predict_ in zip(target, predict):
            answer =  False
            dist = jaccard_distance(set(target.lower()), set(predict_.lower()))
            if predict_ in target or dist < dist_thresh or target in predict_:
                answer = True
            sub_compare.append(answer)
        sub_compare = all(sub_compare)
        compare_result.append(sub_compare)
    return any(compare_result)


def compute_metrics(df, lr_bin, lr_multi, best_attentions, tokenizer, encoder, nlp):
    fp, tp, fn, rel_pred_count = 0, 0, 0, 0
    tp_predicts_dict, fp_predicts_dict = OrderedDict({}), OrderedDict({})
    preds_dict = OrderedDict({})
    prs, recs, f1s = [], [], []
    
    for row in tqdm(df.itertuples(), total=df.shape[0]):
        
        tp_local, fp_local, fn_local = 0, 0, 0
        collected_predictions = []
        cutted_text = sent_tokenize(row.text)
        for cut_text in cutted_text:
            try:
                predictions = get_predictions(cut_text, best_attentions, False, False, lr_bin, lr_multi, tokenizer=tokenizer, encoder=encoder, nlp=nlp, threshold_bin=0.7)
            except:
                predictions = []
            filtered_predictions, _ = deduplication(predictions)
            if len(filtered_predictions):
                collected_predictions.extend(filtered_predictions)
                
        preds_dict[row.text] = list(set(collected_predictions))
        targets = eval(row.target)
        target_triplets = [target[:3] for target in targets]
        tp_predicts = []
        fp_predicts = []

        for predict in collected_predictions:
            score_bool = compare_triplets(target_triplets, predict)

            if score_bool:
                tp_predicts.append(predict)
                tp += 1
                tp_local += 1
            else:
                fp_predicts.append(predict)
                fp += 1
                fp_local += 1

        tp_predicts_dict[row.text] = tp_predicts

        fp_predicts_dict[row.text] = fp_predicts

        for target in target_triplets:
            score_bool = compare_triplets(filtered_predictions, target)
            if not score_bool:
                fn += 1
                fn_local += 1
        try:        
            precision_local = tp_local / (tp_local + fp_local)
            recall_local = tp_local / (tp_local + fn_local)
            f1_local = 2 * (precision_local * recall_local) / (precision_local + recall_local)
        except ZeroDivisionError:
            pass
        
        prs.append(precision_local)
        recs.append(recall_local)
        f1s.append(f1_local)

    assert len(prs) == len(recs) == len(f1s) == len(tp_predicts_dict) == len(fp_predicts_dict) == len(preds_dict)
    
    df['precision'] = prs
    df['recall'] = recs
    df['f1'] = f1s
    df['tps'] = list(tp_predicts_dict.values())
    df['fps'] = list(fp_predicts_dict.values())
    df['preds'] = list(preds_dict.values())
    
    df.to_csv('../data/meta/trex_data_long_parsed.csv', index=False)
    
    try:        
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        f1 = 2 * (precision * recall) / (precision + recall)
        return precision, recall, f1
    
    except ZeroDivisionError:
        return 0, 0, 0

def get_vectorname(attentions_types, use_bert, use_lmms):
    attentions_to_be_used = ['h-r', 'r-t', 'h-t', 'r-h', 't-r', 't-h'] 
    attentions_to_use = tuple([att for i, att in enumerate(attentions_to_be_used) if attentions_types[i] == 1])
    name = '_'.join(attentions_to_use)
    
    if use_bert:
        name += '_bert'
    
    if use_lmms:
        name += '_lmms'
        
    return name



def load_lr_models(vector_names):
    with open(f'./logreg_models/lr_multi_{vector_names}.pkl', 'rb') as file:
        lr_multi = pickle.load(file)

    with open(f'./logreg_models/lr_bin_{vector_names}.pkl', 'rb') as file:
        lr_bin = pickle.load(file)
    
    return lr_bin, lr_multi


def compress_attention(attention, tokenid2word_mapping, operator=np.mean):

    new_index = []
    
    prev = -1
    for idx, row in enumerate(attention):
        token_id = tokenid2word_mapping[idx]
        if token_id != prev:
            new_index.append( [row])
            prev = token_id
        else:
            new_index[-1].append(row)

    new_matrix = []
    for row in new_index:
        new_matrix.append(operator(np.array(row), 0))

    new_matrix = np.array(new_matrix)

    attention = np.array(new_matrix).T

    prev = -1
    new_index=  []
    for idx, row in enumerate(attention):
        token_id = tokenid2word_mapping[idx]
        if token_id != prev:
            new_index.append( [row])
            prev = token_id
        else:
            new_index[-1].append(row)

    
    new_matrix = []
    for row in new_index:
        new_matrix.append(operator(np.array(row), 0))
    
    new_matrix = np.array(new_matrix)
    
    return new_matrix.T

def get_outputs(sentence, tokenizer, encoder, nlp, use_cuda=True):

    tokenizer_name = str(tokenizer.__str__)
    inputs, tokenid2word_mapping, token2id, tokens, sentence_mapping = process(sentence, nlp=nlp, tokenizer=tokenizer, return_pt=True)
    id2token = {value: key for key, value in token2id.items()}
    for key in inputs.keys():
        inputs[key] = inputs[key].cuda()
    outputs = encoder(**inputs, output_attentions=True)
    
    return outputs[2], tokenid2word_mapping, token2id, sentence_mapping


def get_embeddings(sentence, attentions_types, use_bert, use_lmms, tokenizer, encoder, nlp):
    rel_pos = ['NN', 'NNP', 'NNS', 'JJR', 'JJS', 'MD', 'POS', 'VB', 'VBG', 'VBD', 'VBN', 'VBP', 'VBZ']
    head_tail_pos = ['NN', 'NNP', 'NNS', 'PRP']

    use_cuda = True    
    att, tokenid2word_mapping, token2id, sentence_mapping = get_outputs(sentence, tokenizer, encoder, nlp, use_cuda=use_cuda)
    
    new_matr = []
    
    for layer in att:
        for head in layer.squeeze():
            attn = head.cpu()
            attention_matrix = attn.detach().numpy()
            attention_matrix = attention_matrix[1:-1, 1:-1]
            
            merged_attention = compress_attention(attention_matrix, tokenid2word_mapping)
            
            new_matr.append(merged_attention)
    
    new_matr = np.stack(new_matr)
    
    words = [token for token in sentence_mapping if token not in string.punctuation]
    
    nn_words = [word for word in words if nltk.pos_tag([word])[0][1] in head_tail_pos]
    other_words = [word for word in words if nltk.pos_tag([word])[0][1] in rel_pos]
    
    triplets = [triplet for triplet in list(product(nn_words, nn_words, other_words)) 
                if triplet[0] != triplet[1] and triplet[0] != triplet[2] and triplet[1] != triplet[2]]
    
    
    sent_embeddings = []
    
    for triplet in triplets:
        head_ind = sentence_mapping.index(triplet[0])
        tail_ind = sentence_mapping.index(triplet[1])
        rel_ind = sentence_mapping.index(triplet[2])   

        head_rel_emb = new_matr[:, head_ind, rel_ind]
        rel_tail_emb = new_matr[:, rel_ind, tail_ind]
        head_tail_emb = new_matr[:, head_ind, tail_ind]
        rel_head_emb = new_matr[:, rel_ind, head_ind]
        tail_rel_emb = new_matr[:, tail_ind, rel_ind]
        tail_head_emb = new_matr[:, tail_ind, head_ind]
        
        attentions_to_be_used = [head_rel_emb, rel_tail_emb, head_tail_emb, rel_head_emb, tail_rel_emb, tail_head_emb] 
        attentions_to_use = tuple([att for i, att in enumerate(attentions_to_be_used) if attentions_types[i] == 1])

        triplet_emb = np.concatenate(attentions_to_use, axis=0).squeeze()
        sentence = ' '.join(sentence_mapping)
        sent_embeddings.append((triplet_emb, triplet))
        
    return sent_embeddings

In [2]:
import pandas as pd

import torch
from transformers import BertTokenizer, BertModel
from nltk.tokenize import sent_tokenize

import en_core_web_sm

from utils_inference import (get_vectorname, load_lr_models, compute_metrics)

import warnings

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [3]:
warnings.filterwarnings("ignore", message=r"Passing", category=FutureWarning)

nlp = en_core_web_sm.load()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

encoder    = BertModel.from_pretrained("bert-base-cased")
encoder    = encoder.to(device)
encoder    = encoder.eval()
tokenizer  = BertTokenizer.from_pretrained('bert-base-cased')

df = pd.read_csv('../data/meta/trex_data_long.csv')



In [4]:
best_attentions = [1, 1, 1, 1, 1, 0]
vectorname = get_vectorname(best_attentions, False, False)
lr_bin, lr_multi = load_lr_models(vectorname)

prec, rec, f1 = compute_metrics(df, lr_bin, lr_multi, best_attentions, tokenizer, encoder, nlp)

print(prec, rec, f1)
print()


