## ESNLI DistilBERT multiclass
In this notebook we examine the performance of interpretability techniques in the ESNLI dataset using DistilBERT on token level 

In [1]:
import sys
sys.path.append('../')

import numpy as np
import pandas as pd
import re
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, recall_score, average_precision_score
from dataset import Dataset
from myModel import MyModel, MyDataset
from myExplainers import MyExplainer
from myEvaluation import MyEvaluation
from sklearn.preprocessing import maxabs_scale
import pickle
from tqdm import tqdm
import datetime
import csv
import warnings
import torch
import tensorflow as tf
from scipy.special import softmax
from helper import print_results, print_results_ap

Loading model and dataset, defining transformer model, and if rationales are available in the dataset

In [27]:
data_path = '../datasets/esnli_multiclass.pickle'
model_path = 'Trained Models/'
save_path = 'Results/esnli_multiclass/'

In [28]:
model_name = 'bert'
dataset_name='esnli_bert_uncased_multiclass'
existing_rationales = True

Load MyModel, and the subsequent tokenizer

In [29]:
task = 'single_label'
sentence_level = False
labels = 3

model = MyModel(model_path, dataset_name, model_name, task, labels, cased=False, attention=True)
model_no_attention = MyModel(model_path, dataset_name, model_name, task, labels, cased=False, attention=False)
max_sequence_len = model.tokenizer.max_len_single_sentence
tokenizer = model.tokenizer

import torch
torch.cuda.is_available()
model.trainer.model.to('cuda')

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [30]:
esnli = Dataset(path = data_path)
dataset, label_names = esnli.load_esnli_multiclass()

In [6]:
train_texts, test_texts, test_rationales, validation_texts, train_labels, test_labels, validation_labels = dataset

In [7]:
for i in range(len(test_rationales)):
    if (test_rationales[i][0] == []) & ((test_rationales[i][1] == [])):
        test_rationales[i][0] = [0] * len(test_rationales[i][2])
        test_rationales[i][1] = [0] * len(test_rationales[i][2])
        test_rationales[i][2] = list(test_rationales[i][2])
    
    elif (test_rationales[i][0] == []) & (test_rationales[i][2] == []):
        test_rationales[i][0] = [0] * len(test_rationales[i][1])
        test_rationales[i][1] = list(test_rationales[i][1])
        test_rationales[i][2] = [0] * len(test_rationales[i][1])
    else:
        test_rationales[i][0] = list(test_rationales[i][0])
        test_rationales[i][1] = [0] * len(test_rationales[i][0])
        test_rationales[i][2] = [0] * len(test_rationales[i][0])

  
  import sys


In [8]:
test_test_rationales = test_rationales

In [9]:
new_rationale = []
len_test = len(test_labels) # 2000
num_labels = len(np.unique(test_labels)) #3

for i in range(len_test):
    rationale = []
    test_t = test_texts[i].split(' ')
    for j in range(num_labels):
        label_rational = []
        for k in range(len(test_t)):
            for r in tokenizer.tokenize(test_t[k]):
                #if r == '.':
                #    print(r)
                #print(r)
                rationall = 1 if test_test_rationales[i][j][k] > 0 else 0
                label_rational.append(rationall)
        rationale.append(label_rational)
    new_rationale.append(rationale)

Then, we measure the performance of the model using f1 score (both macro)

In [10]:
predictions = []
for test_text in test_texts:
    outputs = model.my_predict(test_text)
    predictions.append(outputs[0])

pred_labels = []
for prediction in predictions:
    pred_labels.append(np.argmax(softmax(prediction)))

accuracy_score(test_labels, pred_labels), f1_score(test_labels, pred_labels, average='macro'), f1_score(test_labels, pred_labels, average='micro')

1998it [01:50, 19.84it/s]            

(0.79, 0.7906131321349256, 0.79)

In [20]:
my_explainers = MyExplainer(label_names, model_no_attention) #model 2

my_evaluators = MyEvaluation(label_names, model_no_attention.my_predict, sentence_level = False, task = 'multi-class', evaluation_level_all = True, tokenizer=tokenizer) #model 2
my_evaluatorsP = MyEvaluation(label_names, model_no_attention.my_predict, sentence_level = False, task = 'multi-class', evaluation_level_all = False, tokenizer=tokenizer) #model 2
evaluation =  {'F':my_evaluators.faithfulness, 'FTP': my_evaluators.faithful_truthfulness_penalty, 
          'NZW': my_evaluators.nzw, 'AUPRC': my_evaluators.auprc}
evaluationP = {'F':my_evaluatorsP.faithfulness, 'FTP': my_evaluatorsP.faithful_truthfulness_penalty, 
          'NZW': my_evaluatorsP.nzw, 'AUPRC': my_evaluators.auprc}

In [13]:
import time
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    
    now = datetime.datetime.now()
    file_name = save_path + 'ESNLI_BERT_IG_'+str(now.day) + '_' + str(now.month) + '_' + str(now.year)
    metrics = {'F':[], 'FTP':[], 'NZW':[], 'AUPRC' : []}
    metricsP = {'F':[], 'FTP':[], 'NZW':[], 'AUPRC' : []}
    time_r = []
    my_explainers.neighbours = 2000
    techniques = [my_explainers.ig] #my_explainers.lime 
    for ind in tqdm(range(len(test_texts))):
        torch.cuda.empty_cache() 
        test_rational = new_rationale[ind]
        instance = test_texts[ind]
        my_evaluators.clear_states()
        my_evaluatorsP.clear_states()
        # model_no_attention.predict xwris attention + hidden states
        prediction, _, _ = model_no_attention.my_predict(instance)
        enc = model.tokenizer([instance,instance], truncation=True, padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
    
        interpretations = []
        kk = 0
        for technique in techniques:
            ts = time.time()
            temp = technique(instance, prediction, tokens, mask, _, _)
            interpretations.append([np.array(i)/np.max(abs(np.array(i))) for i in temp])
            time_r.append(time.time()-ts) #time_r[kk]  ??
            kk = kk + 1
        for metric in metrics.keys():
            evaluated = []
            for interpretation in interpretations:
                evaluated.append(evaluation[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
            metrics[metric].append(evaluated)
        my_evaluatorsP.saved_state = my_evaluators.saved_state.copy()
        my_evaluators.clear_states()
        for metric in metrics.keys():
            evaluatedP = []
            for interpretation in interpretations:
                evaluatedP.append(evaluationP[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
            metricsP[metric].append(evaluatedP)
with open(file_name+'(A).pickle', 'wb') as handle:
    pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+'(P).pickle', 'wb') as handle:
    pickle.dump(metricsP, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+'_TIME.pickle', 'wb') as handle:
    pickle.dump(time_r, handle, protocol=pickle.HIGHEST_PROTOCOL)
time_r = np.array(time_r)
# time_r.mean()
# time_r.mean(axis=1)

100%|██████████| 2000/2000 [45:17<00:00,  1.36s/it]


We present the results for IG

In [14]:
print(time_r)

[0.85099864 0.88100028 0.86551261 ... 0.6155057  0.58299923 0.62200022]


In [15]:
print_results(file_name+'(A)', [' IG '], metrics, label_names) #[' LIME', ' IG  ']

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


F
 IG   0.0778999999165535 | 0.10586 0.17021 -0.04238
FTP
 IG   0.27867 | 0.20669 0.31325 0.31608
NZW
 IG   1.0 | 1.0 1.0 1.0
AUPRC
 IG   0.28952 | 0.39346 0.21631 0.25879


In [16]:
print_results(file_name+'(P)', [' IG '], metricsP, label_names) #[' LIME', ' IG  ']

F
 IG   0.32585 | 0.42847 0.46657 0.08251
FTP
 IG   0.45627 | 0.53232 0.72844 0.10805
NZW
 IG   1.0 | 1.0 1.0 1.0
AUPRC
 IG   0.28952 | 0.39346 0.21631 0.25879


Then, we perform the experiments for the different attention setups!

In [21]:
conf = []
for ci in ['Mean', 'Multi'] + list(range(12)):
    for ce in ['Mean'] + list(range(12)):
        for cp in ['From', 'To', 'MeanColumns', 'MaxColumns']: # Matrix: From, To, MeanColumns, MeanRows, MaxColumns, MaxRows
            for cl in [False]: # Selection: True: select layers per head, False: do not
                conf.append([ci, ce, cp, cl])
len(conf)

728

In [18]:
import time 
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    
    now = datetime.datetime.now()
    
    file_name = save_path + 'ESNLI_BERT_ATTENTION_'+str(now.day) + '_' + str(now.month) + '_' + str(now.year)
    metrics = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC' : []}
    metricsP = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC' : []}
    time_r = []
    time_b = []
    time_b2 = []
    for con in conf:
        time_r.append([])
    for ind in tqdm(range(len(test_texts))):
        torch.cuda.empty_cache() 
        test_rational = new_rationale[ind]
        instance = test_texts[ind]
        my_evaluators.clear_states()
        my_evaluatorsP.clear_states()
        my_explainers.save_states = {}
        prediction, attention, _ = model.my_predict(instance)
        enc = model.tokenizer([instance,instance], truncation=True, padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
        
        interpretations = []
        kk = 0
        for con in conf:
            ts = time.time()
            my_explainers.config = con
            temp = my_explainers.my_attention(instance, prediction, tokens, mask, attention, _)
            interpretations.append([maxabs_scale(i) for i in temp])
            time_r[kk].append(time.time()-ts) #again kk
            kk = kk + 1
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluation[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b.append(k)
            metrics[metric].append(evaluated)
        my_evaluatorsP.saved_state = my_evaluators.saved_state.copy()
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluationP[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b2.append(k)
            metricsP[metric].append(evaluated)
with open(file_name+' (A).pickle', 'wb') as handle:
    pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+' (P).pickle', 'wb') as handle:
    pickle.dump(metricsP, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+'_TIME.pickle', 'wb') as handle:
    pickle.dump(time_r, handle, protocol=pickle.HIGHEST_PROTOCOL)
time_r = np.array(time_r)
# time_r.mean(axis=1).min(),time_r.mean(axis=1).max(), time_r.mean(axis=1).mean(), time_r.sum(axis=1).mean(), np.mean(time_b), np.mean(time_b2)

  0%|          | 0/2000 [00:00<?, ?it/s]

100%|██████████| 2000/2000 [3:01:45<00:00,  5.45s/it]  


In [23]:
time_r.mean(axis=1).min(),time_r.mean(axis=1).max(), time_r.mean(axis=1).mean(), time_r.sum(axis=1).mean(), np.mean(time_b), np.mean(time_b2)

(0.00022001981735229493,
 0.0007000494003295898,
 0.0003511602275974148,
 0.03511602275974148,
 3.3583151388168333,
 0.7960365486145019)

We present the results of the different attention setups

In [20]:
print_results(file_name+' (A)', conf, metrics, label_names)

FTP
['Mean', 'Mean', 'From', False]  0.0 | 0.09273 0.21854 -0.31127
['Mean', 'Mean', 'To', False]  0.0 | 0.07141 0.27918 -0.35059


  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


['Mean', 'Mean', 'MeanColumns', False]  0.0 | 0.03307 0.16009 -0.19316
['Mean', 'Mean', 'MaxColumns', False]  0.0 | 0.07213 0.17864 -0.25077
['Mean', 0, 'From', False]  0.0 | 0.08476 0.20121 -0.28597
['Mean', 0, 'To', False]  0.0 | 0.06563 0.19881 -0.26444
['Mean', 0, 'MeanColumns', False]  0.0 | 0.01178 0.17668 -0.18846
['Mean', 0, 'MaxColumns', False]  0.0 | 0.02545 0.18375 -0.2092
['Mean', 1, 'From', False]  0.0 | 0.07599 0.24167 -0.31766
['Mean', 1, 'To', False]  0.0 | 0.03825 0.24253 -0.28078
['Mean', 1, 'MeanColumns', False]  0.0 | 0.01696 0.16713 -0.18409
['Mean', 1, 'MaxColumns', False]  0.0 | 0.03191 0.1818 -0.21371
['Mean', 2, 'From', False]  0.0 | 0.06574 0.22087 -0.28661
['Mean', 2, 'To', False]  0.0 | 0.05191 0.23862 -0.29054
['Mean', 2, 'MeanColumns', False]  0.0 | 0.02123 0.15908 -0.18031
['Mean', 2, 'MaxColumns', False]  0.0 | 0.02076 0.19565 -0.2164
['Mean', 3, 'From', False]  0.0 | 0.03499 0.22689 -0.26188
['Mean', 3, 'To', False]  0.0 | 0.0318 0.25402 -0.28582
['Mean

In [21]:
print_results(file_name+' (P)', conf, metricsP, label_names)

FTP
['Mean', 'Mean', 'From', False]  0.48841 | 0.53343 0.78862 0.14318
['Mean', 'Mean', 'To', False]  0.48397 | 0.45781 0.89879 0.09531
['Mean', 'Mean', 'MeanColumns', False]  0.30909 | 0.30171 0.5307 0.09487
['Mean', 'Mean', 'MaxColumns', False]  0.43093 | 0.49894 0.65651 0.13733
['Mean', 0, 'From', False]  0.43802 | 0.48288 0.71269 0.1185
['Mean', 0, 'To', False]  0.37951 | 0.39852 0.65521 0.08482
['Mean', 0, 'MeanColumns', False]  0.2846 | 0.2296 0.5451 0.07909
['Mean', 0, 'MaxColumns', False]  0.32209 | 0.30655 0.57587 0.08384
['Mean', 1, 'From', False]  0.49207 | 0.48986 0.84365 0.14272
['Mean', 1, 'To', False]  0.38836 | 0.32436 0.75859 0.08214
['Mean', 1, 'MeanColumns', False]  0.29565 | 0.25024 0.54084 0.09587
['Mean', 1, 'MaxColumns', False]  0.35853 | 0.33772 0.61792 0.11994
['Mean', 2, 'From', False]  0.44418 | 0.43867 0.76338 0.13049
['Mean', 2, 'To', False]  0.4117 | 0.37844 0.76431 0.09236
['Mean', 2, 'MeanColumns', False]  0.27612 | 0.22884 0.51468 0.08483
['Mean', 2, 'M

We calculate the best attention setup using Optimus variations (we do not use the Optimus implementation at this step)

In [22]:
print_results_ap(metrics, label_names, conf)

Baseline: 1.7904045653457008e-09  and NZW: 1.0 and AUPRC: 0.5143698842594536
Max Across: 3.801971182326724e-09  and NZW: 1.0 and AUPRC: 0.21021442803198664
Per Label Per Instance: 0.13814768851869455  and NZW:  0.9999065428036017 and AUPRC: 0.33224915327692833
Per Instance: 6.602431183259465e-08  and NZW:  1.0 and AUPRC: 0.25596580440267364


In [23]:
print_results_ap(metricsP, label_names, conf)

Baseline: 0.48840942841050233  and NZW: 1.0 and AUPRC: 0.5143698842594536
Max Across: 0.6152760762772669  and NZW: 1.0 and AUPRC: 0.6144129195180188
Per Label Per Instance: 0.21391192095043218  and NZW:  0.9997640225831596 and AUPRC: 0.510107054098222


  out=out, **kwargs)


Per Instance: 0.8757410205711826  and NZW:  0.9997640225831596 and AUPRC: 0.44275881644281134


We repeat the process with Attention Scores with negative values (A*), thus by skipping the Softmax function. In the attention setups, we exclude the multiplication option in heads and layers, as a few combinations reach +/-inf

In [24]:
conf = []
for ci in ['Mean'] + list(range(12)):
    for ce in ['Mean'] + list(range(12)):
        for cp in ['From', 'To', 'MeanColumns', 'MaxColumns']: # Matrix: From, To, MeanColumns, MeanRows, MaxColumns, MaxRows
            for cl in [False]: # Selection: True: select layers per head, False: do not
                conf.append([ci, ce, cp, cl])
len(conf)

676

In [25]:
import time 
import math
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    
    now = datetime.datetime.now()
    
    file_name = save_path + 'ESNLI_BERT_A_ATTENTION_NO_SOFTMAX_'+str(now.day) + '_' + str(now.month) + '_' + str(now.year)
    metrics = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC': []}
    metricsP = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC': []}
    time_r = []
    time_b = []
    time_b2 = []
    for con in conf:
        time_r.append([])
    for ind in tqdm(range(len(test_texts))):
        torch.cuda.empty_cache() 
        test_rational = new_rationale[ind]
        instance = test_texts[ind]
        my_evaluators.clear_states()
        my_evaluatorsP.clear_states()
        my_explainers.save_states = {}
        prediction, _, hidden_states = model.my_predict(instance)
        enc = model.tokenizer([instance,instance], truncation=True, padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
        
        attention = []
        for la in range(12):
            our_new_layer = []
            bob = model.trainer.model.base_model.encoder.layer[la].attention
            has = hidden_states[la]
            aaa = bob.self.key(torch.tensor(has).to('cuda'))
            bbb = bob.self.query(torch.tensor(has).to('cuda'))
            for he in range(12):
                attention_scores = torch.matmul(bbb[:,he*64:(he+1)*64], aaa[:,he*64:(he+1)*64].transpose(-1, -2))
                attention_scores = attention_scores / math.sqrt(64)
                our_new_layer.append(attention_scores.cpu().detach().numpy())
            attention.append(our_new_layer)
        attention = np.array(attention)
        
        interpretations = []
        kk = 0
        for con in conf:
            ts = time.time()
            my_explainers.config = con
            temp = my_explainers.my_attention(instance, prediction, tokens, mask, attention, _)
            interpretations.append([maxabs_scale(i) for i in temp])
            time_r[kk].append(time.time()-ts) #kk ?
            kk = kk + 1
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluation[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b.append(k)
            metrics[metric].append(evaluated)
        my_evaluatorsP.saved_state = my_evaluators.saved_state.copy()
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluationP[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b2.append(k)
            metricsP[metric].append(evaluated)        
with open(file_name+' (A).pickle', 'wb') as handle:
    pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+' (P).pickle', 'wb') as handle:
    pickle.dump(metricsP, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+'_TIME.pickle', 'wb') as handle:
    pickle.dump(time_r, handle, protocol=pickle.HIGHEST_PROTOCOL)
time_r = np.array(time_r)
# time_r.mean(axis=1).min(),time_r.mean(axis=1).max(), time_r.mean(axis=1).mean(), time_r.sum(axis=1).mean(), np.mean(time_b), np.mean(time_b2)

100%|██████████| 2000/2000 [3:00:23<00:00,  5.41s/it]  


In [26]:
time_r.mean(axis=1).min(),time_r.mean(axis=1).max(), time_r.mean(axis=1).mean(), time_r.sum(axis=1).mean(), np.mean(time_b), np.mean(time_b2)

(0.00022507190704345703,
 0.0007402300834655762,
 0.00034327915081611046,
 0.034327915081611045,
 3.1141790914535523,
 0.7140869212150573)

We present the results for the different attention setups

In [26]:
try:
    time_r.mean(axis=1).min(),time_r.mean(axis=1).max(), time_r.mean(axis=1).mean(), time_r.sum(axis=1).mean(), np.mean(time_b), np.mean(time_b2)
except:
    print('Failure')

Failure


In [27]:
print_results(file_name+' (A)', conf, metrics, label_names)

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


FTP
['Mean', 'Mean', 'From', False]  -0.0 | 0.04476 -0.07724 0.03248
['Mean', 'Mean', 'To', False]  0.0 | 0.10772 0.28936 -0.39708
['Mean', 'Mean', 'MeanColumns', False]  -0.0 | -0.01874 -0.19929 0.21803
['Mean', 'Mean', 'MaxColumns', False]  0.0 | 0.03961 0.12405 -0.16366
['Mean', 0, 'From', False]  -0.0 | 0.04741 -0.08849 0.04108
['Mean', 0, 'To', False]  0.0 | 0.13327 0.27117 -0.40444
['Mean', 0, 'MeanColumns', False]  -0.0 | -0.01546 -0.16422 0.17968
['Mean', 0, 'MaxColumns', False]  0.0 | 0.02191 0.13542 -0.15733
['Mean', 1, 'From', False]  -0.0 | 0.03536 -0.07534 0.03998
['Mean', 1, 'To', False]  0.0 | 0.08277 0.29052 -0.37329
['Mean', 1, 'MeanColumns', False]  -0.0 | -0.02855 -0.18822 0.21677
['Mean', 1, 'MaxColumns', False]  0.0 | 0.01029 0.13257 -0.14286
['Mean', 2, 'From', False]  -0.0 | 0.03765 -0.05031 0.01267
['Mean', 2, 'To', False]  0.0 | 0.07879 0.28882 -0.36761
['Mean', 2, 'MeanColumns', False]  -0.0 | -0.03376 -0.16892 0.20268
['Mean', 2, 'MaxColumns', False]  0.0 | 0

In [28]:
print_results(file_name+' (P)', conf, metricsP, label_names)

FTP
['Mean', 'Mean', 'From', False]  -0.06021 | 0.04493 -0.19689 -0.02867
['Mean', 'Mean', 'To', False]  0.57683 | 0.61057 0.98583 0.1341
['Mean', 'Mean', 'MeanColumns', False]  -0.28362 | -0.21156 -0.59487 -0.04443
['Mean', 'Mean', 'MaxColumns', False]  0.31752 | 0.3507 0.47142 0.13044
['Mean', 0, 'From', False]  -0.0632 | 0.05619 -0.22374 -0.02205
['Mean', 0, 'To', False]  0.59525 | 0.69356 0.95701 0.13516
['Mean', 0, 'MeanColumns', False]  -0.22661 | -0.16012 -0.49118 -0.02853
['Mean', 0, 'MaxColumns', False]  0.27864 | 0.27332 0.46175 0.10084
['Mean', 1, 'From', False]  -0.06964 | 0.01106 -0.19378 -0.02619
['Mean', 1, 'To', False]  0.52875 | 0.50081 0.96253 0.12293
['Mean', 1, 'MeanColumns', False]  -0.2757 | -0.22532 -0.56672 -0.03506
['Mean', 1, 'MaxColumns', False]  0.28717 | 0.24931 0.48031 0.1319
['Mean', 2, 'From', False]  -0.03243 | 0.04265 -0.11869 -0.02127
['Mean', 2, 'To', False]  0.52718 | 0.49648 0.96106 0.124
['Mean', 2, 'MeanColumns', False]  -0.28679 | -0.29513 -0.51

We calculate the best attention setup using Optimus variations (we do not use the Optimus implementation script at this step)

In [29]:
print_results_ap(metrics, label_names, conf)

Baseline: -2.25035865405084e-09  and NZW: 1.0 and AUPRC: 0.4812618549438754
Max Across: 3.688347590310078e-09  and NZW: 1.0 and AUPRC: 0.23577532897699918
Per Label Per Instance: 0.2982475903316168  and NZW:  1.0 and AUPRC: 0.3194813907163208
Per Instance: 1.3905891512218613e-07  and NZW:  1.0 and AUPRC: 0.2652827538983935


In [30]:
print_results_ap(metricsP, label_names, conf)

Baseline: -0.06020976199586173  and NZW: 1.0 and AUPRC: 0.4812618549438754
Max Across: 0.5952454225928835  and NZW: 1.0 and AUPRC: 0.5899297440472324
Per Label Per Instance: 0.21485771021814767  and NZW:  1.0 and AUPRC: 0.4875491463520363


  out=out, **kwargs)


Per Instance: 0.8791171621491044  and NZW:  1.0 and AUPRC: 0.4399415061840175
