## HATEXPLAIN BERT multiclass
In this notebook we examine the performance of interpretability techniques in the HateExplain dataset using BERT on token level 

In [1]:
import sys
sys.path.append('../')

import numpy as np
import pandas as pd
import re
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, recall_score, average_precision_score
from dataset import Dataset
from myModel import MyModel, MyDataset
from myExplainers import MyExplainer
from myEvaluation import MyEvaluation
from sklearn.preprocessing import maxabs_scale
import pickle
from tqdm import tqdm
import datetime
import csv
import warnings
import torch
import tensorflow as tf
from scipy.special import softmax
from helper import print_results, print_results_ap

Loading model and dataset, defining transformer model, and if rationales are available in the dataset

In [2]:
data_path = '../datasets/hatexplain.json'
model_path = 'Trained Models/'
save_path = 'Results/hx_multiclass/'

In [3]:
model_name = 'bert'
dataset_name='hx_bert_uncased_multiclass/'
existing_rationales = True

Load MyModel, and the subsequent tokenizer

In [4]:
task = 'single_label'
sentence_level = False
labels = 3

model = MyModel(model_path, dataset_name, model_name, task, labels, cased=False, attention=True)
model_no_attention = MyModel(model_path, dataset_name, model_name, task, labels, cased=False, attention=False)
max_sequence_len = model.tokenizer.max_len_single_sentence
tokenizer = model.tokenizer

import torch
print(torch.cuda.is_available())
model.trainer.model.to('cuda')

True


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [5]:
hx = Dataset(path='../') #data_path
x, y, label_names, rationales = hx.load_hatexplain_multiclass(tokenizer)

In [7]:
indices = np.arange(len(y))
train_texts, test_texts, train_labels, test_labels, _, test_indexes = train_test_split(
    x, y, indices, stratify=y, train_size=8000, test_size=2000, random_state=42)
if existing_rationales:
    test_rationales = [rationales[x] for x in test_indexes]

# size = (0.1 * len(y)) / len(train_labels)
train_texts, validation_texts, train_labels, validation_labels = train_test_split(
    list(train_texts),
    train_labels,
    stratify=train_labels,
    test_size=1000,
    random_state=42)

In [8]:
for i, label in enumerate(test_labels):
    
    if label == 0:
        test_rationales[i] = [[0], [0], [0]]
    elif label == 1:
        test_rationales[i] = [[0] * len(test_rationales[i]), 
                            test_rationales[i], 
                            [0] * len(test_rationales[i])]
    else:
        test_rationales[i] = [[0] * len(test_rationales[i]),  
                            [0] * len(test_rationales[i]),
                            test_rationales[i]]

In [9]:
test_test_rationales = test_rationales

Then, we measure the performance of the model using average precision score and f1 score (both macro)

In [10]:
predictions = []
for test_text in test_texts:
    outputs = model.my_predict(test_text)
    predictions.append(outputs[0])

pred_labels = []
for prediction in predictions:
    pred_labels.append(np.argmax(softmax(prediction)))

accuracy_score(test_labels, pred_labels), f1_score(test_labels, pred_labels, average='macro'), f1_score(test_labels, pred_labels, average='micro')

1999it [02:10, 14.92it/s]            

(0.6725, 0.6258645986637291, 0.6725)

2000it [02:30, 14.92it/s]

In [11]:
my_explainers = MyExplainer(label_names, model_no_attention) #model 2

my_evaluators = MyEvaluation(label_names, model_no_attention.my_predict, sentence_level = False, task = 'multi-class', evaluation_level_all = True, tokenizer=tokenizer) #model 2
my_evaluatorsP = MyEvaluation(label_names, model_no_attention.my_predict, sentence_level = False, task = 'multi-class', evaluation_level_all = False, tokenizer=tokenizer) #model 2
evaluation =  {'F':my_evaluators.faithfulness, 'FTP': my_evaluators.faithful_truthfulness_penalty, 
          'NZW': my_evaluators.nzw, 'AUPRC': my_evaluators.auprc}
evaluationP = {'F':my_evaluatorsP.faithfulness, 'FTP': my_evaluatorsP.faithful_truthfulness_penalty, 
          'NZW': my_evaluatorsP.nzw, 'AUPRC': my_evaluators.auprc}

In [12]:
new_rationale = test_rationales
len_test = len(test_labels) # 2000
num_labels = len(np.unique(test_labels)) #3

In [13]:
import time
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    
    now = datetime.datetime.now()
    file_name = save_path + 'HX_BERT_IG_'+str(now.day) + '_' + str(now.month) + '_' + str(now.year)
    metrics = {'F':[], 'FTP':[], 'NZW':[], 'AUPRC' : []}
    metricsP = {'F':[], 'FTP':[], 'NZW':[], 'AUPRC' : []}
    time_r = []
    my_explainers.neighbours = 2000
    techniques = [my_explainers.ig] #my_explainers.lime 
    for ind in tqdm(range(len(test_texts))):
        torch.cuda.empty_cache() 
        test_rational = new_rationale[ind]
        instance = test_texts[ind]
        my_evaluators.clear_states()
        my_evaluatorsP.clear_states()
        # model_no_attention.predict xwris attention + hidden states
        prediction, _, _ = model_no_attention.my_predict(instance)
        enc = model.tokenizer([instance,instance], truncation=True, padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
    
        interpretations = []
        kk = 0
        for technique in techniques:
            ts = time.time()
            temp = technique(instance, prediction, tokens, mask, _, _)
            interpretations.append([np.array(i)/np.max(abs(np.array(i))) for i in temp])
            time_r.append(time.time()-ts) #time_r[kk]  ??
            kk = kk + 1
        for metric in metrics.keys():
            evaluated = []
            for interpretation in interpretations:
                evaluated.append(evaluation[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
            metrics[metric].append(evaluated)
        my_evaluatorsP.saved_state = my_evaluators.saved_state.copy()
        my_evaluators.clear_states()
        for metric in metrics.keys():
            evaluatedP = []
            for interpretation in interpretations:
                evaluatedP.append(evaluationP[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
            metricsP[metric].append(evaluatedP)
with open(file_name+'(A).pickle', 'wb') as handle:
    pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+'(P).pickle', 'wb') as handle:
    pickle.dump(metricsP, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+'_TIME.pickle', 'wb') as handle:
    pickle.dump(time_r, handle, protocol=pickle.HIGHEST_PROTOCOL)
time_r = np.array(time_r)
# time_r.mean()
# time_r.mean(axis=1)

100%|██████████| 2000/2000 [51:03<00:00,  1.53s/it] 


We present the results for IG

In [14]:
print(time_r)

[0.773772   0.69116282 0.86951041 ... 0.77419639 0.64850807 0.73752379]


In [15]:
print_results(file_name+'(A)', [' IG '], metrics, label_names) #[' LIME', ' IG  ']

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


F
 IG   0.09623999893665314 | 0.0052 0.19664 0.08689
FTP
 IG   0.27415 | 0.38641 0.30177 0.13428
NZW
 IG   1.0 | 1.0 1.0 1.0
AUPRC
 IG   0.50764 | 0.0 0.81061 0.71232


In [16]:
print_results(file_name+'(P)', [' IG '], metricsP, label_names) #[' LIME', ' IG  ']

F
 IG   0.304 | 4e-05 0.4299 0.48206
FTP
 IG   0.47586 | 0.13969 0.64981 0.63808
NZW
 IG   1.0 | 1.0 1.0 1.0
AUPRC
 IG   0.50764 | 0.0 0.81061 0.71232


Then, we perform the experiments for the different attention setups!

In [17]:
conf = []
for ci in ['Mean', 'Multi'] + list(range(12)):
    for ce in ['Mean'] + list(range(12)):
        for cp in ['From', 'To', 'MeanColumns', 'MaxColumns']: # Matrix: From, To, MeanColumns, MeanRows, MaxColumns, MaxRows
            for cl in [False]: # Selection: True: select layers per head, False: do not
                conf.append([ci, ce, cp, cl])
len(conf)

728

In [18]:
import time 
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    
    now = datetime.datetime.now()
    
    file_name = save_path + 'HX_BERT_ATTENTION_'+str(now.day) + '_' + str(now.month) + '_' + str(now.year)
    metrics = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC' : []}
    metricsP = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC' : []}
    time_r = []
    time_b = []
    time_b2 = []
    for con in conf:
        time_r.append([])
    for ind in tqdm(range(len(test_texts))):
        torch.cuda.empty_cache() 
        test_rational = new_rationale[ind]
        instance = test_texts[ind]
        my_evaluators.clear_states()
        my_evaluatorsP.clear_states()
        my_explainers.save_states = {}
        prediction, attention, _ = model.my_predict(instance)
        enc = model.tokenizer([instance,instance], truncation=True, padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
        
        interpretations = []
        kk = 0
        for con in conf:
            ts = time.time()
            my_explainers.config = con
            temp = my_explainers.my_attention(instance, prediction, tokens, mask, attention, _)
            interpretations.append([maxabs_scale(i) for i in temp])
            time_r[kk].append(time.time()-ts) #again kk
            kk = kk + 1
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluation[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b.append(k)
            metrics[metric].append(evaluated)
        my_evaluatorsP.saved_state = my_evaluators.saved_state.copy()
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluationP[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b2.append(k)
            metricsP[metric].append(evaluated)
with open(file_name+' (A).pickle', 'wb') as handle:
    pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+' (P).pickle', 'wb') as handle:
    pickle.dump(metricsP, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+'_TIME.pickle', 'wb') as handle:
    pickle.dump(time_r, handle, protocol=pickle.HIGHEST_PROTOCOL)
time_r = np.array(time_r)
time_r.mean(axis=1).min(),time_r.mean(axis=1).max(), time_r.mean(axis=1).mean(), time_r.sum(axis=1).mean(), np.mean(time_b), np.mean(time_b2)

  0%|          | 0/2000 [00:00<?, ?it/s]

100%|██████████| 2000/2000 [3:05:42<00:00,  5.57s/it]  


We present the results of the different attention setups

In [21]:
print_results(file_name+' (A)', conf, metrics, label_names)

  avg = a.mean(axis)


FTP
['Mean', 'Mean', 'From', False]  0.0 | -0.3609 0.27715 0.08374
['Mean', 'Mean', 'To', False]  -0.0 | -0.23994 0.16208 0.07785


  ret = ret.dtype.type(ret / rcount)


['Mean', 'Mean', 'MeanColumns', False]  -0.0 | -0.35734 0.27143 0.08591
['Mean', 'Mean', 'MaxColumns', False]  -0.0 | -0.26826 0.19127 0.07698
['Mean', 0, 'From', False]  0.0 | -0.32251 0.24699 0.07552
['Mean', 0, 'To', False]  -0.0 | -0.16106 0.10165 0.05941
['Mean', 0, 'MeanColumns', False]  -0.0 | -0.27892 0.18423 0.09469
['Mean', 0, 'MaxColumns', False]  -0.0 | -0.23592 0.14746 0.08846
['Mean', 1, 'From', False]  0.0 | -0.31345 0.25648 0.05697
['Mean', 1, 'To', False]  0.0 | -0.1984 0.13414 0.06426
['Mean', 1, 'MeanColumns', False]  -0.0 | -0.3288 0.2661 0.0627
['Mean', 1, 'MaxColumns', False]  0.0 | -0.27367 0.20194 0.07173
['Mean', 2, 'From', False]  0.0 | -0.257 0.20435 0.05265
['Mean', 2, 'To', False]  -0.0 | -0.14399 0.08757 0.05642
['Mean', 2, 'MeanColumns', False]  0.0 | -0.23928 0.20314 0.03614
['Mean', 2, 'MaxColumns', False]  0.0 | -0.23268 0.18895 0.04373
['Mean', 3, 'From', False]  0.0 | -0.27259 0.23125 0.04134
['Mean', 3, 'To', False]  -0.0 | -0.22576 0.16647 0.05929


In [22]:
print_results(file_name+' (P)', conf, metricsP, label_names)

FTP
['Mean', 'Mean', 'From', False]  0.33694 | -0.13424 0.59635 0.5487
['Mean', 'Mean', 'To', False]  0.21073 | -0.11909 0.32062 0.43066
['Mean', 'Mean', 'MeanColumns', False]  0.31954 | -0.14475 0.58004 0.52333
['Mean', 'Mean', 'MaxColumns', False]  0.20957 | -0.13614 0.38863 0.37621
['Mean', 0, 'From', False]  0.30615 | -0.11978 0.53128 0.50694
['Mean', 0, 'To', False]  0.10569 | -0.1133 0.17433 0.25604
['Mean', 0, 'MeanColumns', False]  0.27023 | -0.13068 0.38262 0.55876
['Mean', 0, 'MaxColumns', False]  0.22862 | -0.11886 0.29953 0.50519
['Mean', 1, 'From', False]  0.25213 | -0.12249 0.54881 0.33007
['Mean', 1, 'To', False]  0.16124 | -0.10441 0.26026 0.32785
['Mean', 1, 'MeanColumns', False]  0.26361 | -0.13269 0.5727 0.35082
['Mean', 1, 'MaxColumns', False]  0.23394 | -0.12686 0.42532 0.40337
['Mean', 2, 'From', False]  0.19932 | -0.11294 0.43396 0.27695
['Mean', 2, 'To', False]  0.08209 | -0.10888 0.14009 0.21505
['Mean', 2, 'MeanColumns', False]  0.17128 | -0.10391 0.44221 0.17

We calculate the best attention setup using Optimus variations (we do not use the Optimus implementation at this step)

In [23]:
print_results_ap(metrics, label_names, conf)

Baseline: 4.4072027508927175e-10  and NZW: 1.0 and AUPRC: 0.48775859839588137
Max Across: 1.6000400402763997e-09  and NZW: 1.0 and AUPRC: 0.3762103795365783


  out=out, **kwargs)


Per Label Per Instance: 0.07480946701102213  and NZW:  0.9999912280701754 and AUPRC: 0.5059339454361295
Per Instance: 5.4451271211860185e-08  and NZW:  1.0 and AUPRC: 0.33726484833868703


In [24]:
print_results_ap(metricsP, label_names, conf)

Baseline: 0.33693715549556363  and NZW: 1.0 and AUPRC: 0.48775859839588137
Max Across: 0.37126288687236036  and NZW: 1.0 and AUPRC: 0.5406796832893724
Per Label Per Instance: 0.08487613434769825  and NZW:  1.0 and AUPRC: 0.5110871592142282
Per Instance: 0.458248650373068  and NZW:  1.0 and AUPRC: 0.500125421188902


We repeat the process with Attention Scores with negative values (A*), thus by skipping the Softmax function. In the attention setups, we exclude the multiplication option in heads and layers, as a few combinations reach +/-inf

In [25]:
conf = []
for ci in ['Mean'] + list(range(12)):
    for ce in ['Mean'] + list(range(12)):
        for cp in ['From', 'To', 'MeanColumns', 'MaxColumns']: # Matrix: From, To, MeanColumns, MeanRows, MaxColumns, MaxRows
            for cl in [False]: # Selection: True: select layers per head, False: do not
                conf.append([ci, ce, cp, cl])
len(conf)

676

In [26]:
import time 
import math
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    
    now = datetime.datetime.now()
    
    file_name = save_path + 'HX_BERT_A_ATTENTION_NO_SOFTMAX_'+str(now.day) + '_' + str(now.month) + '_' + str(now.year)
    metrics = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC': []}
    metricsP = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC': []}
    time_r = []
    time_b = []
    time_b2 = []
    for con in conf:
        time_r.append([])
    for ind in tqdm(range(len(test_texts))):
        torch.cuda.empty_cache() 
        test_rational = new_rationale[ind]
        instance = test_texts[ind]
        my_evaluators.clear_states()
        my_evaluatorsP.clear_states()
        my_explainers.save_states = {}
        prediction, _, hidden_states = model.my_predict(instance)
        enc = model.tokenizer([instance,instance], truncation=True, padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
        
        attention = []
        for la in range(12):
            our_new_layer = []
            bob = model.trainer.model.base_model.encoder.layer[la].attention
            has = hidden_states[la]
            aaa = bob.self.key(torch.tensor(has).to('cuda'))
            bbb = bob.self.query(torch.tensor(has).to('cuda'))
            for he in range(12):
                attention_scores = torch.matmul(bbb[:,he*64:(he+1)*64], aaa[:,he*64:(he+1)*64].transpose(-1, -2))
                attention_scores = attention_scores / math.sqrt(64)
                our_new_layer.append(attention_scores.cpu().detach().numpy())
            attention.append(our_new_layer)
        attention = np.array(attention)
        
        interpretations = []
        kk = 0
        for con in conf:
            ts = time.time()
            my_explainers.config = con
            temp = my_explainers.my_attention(instance, prediction, tokens, mask, attention, _)
            interpretations.append([maxabs_scale(i) for i in temp])
            time_r[kk].append(time.time()-ts) #kk ?
            kk = kk + 1
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluation[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b.append(k)
            metrics[metric].append(evaluated)
        my_evaluatorsP.saved_state = my_evaluators.saved_state.copy()
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluationP[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b2.append(k)
            metricsP[metric].append(evaluated)        
with open(file_name+' (A).pickle', 'wb') as handle:
    pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+' (P).pickle', 'wb') as handle:
    pickle.dump(metricsP, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+'_TIME.pickle', 'wb') as handle:
    pickle.dump(time_r, handle, protocol=pickle.HIGHEST_PROTOCOL)
time_r = np.array(time_r)
time_r.mean(axis=1).min(),time_r.mean(axis=1).max(), time_r.mean(axis=1).mean(), time_r.sum(axis=1).mean(), np.mean(time_b), np.mean(time_b2)

100%|██████████| 2000/2000 [2:50:54<00:00,  5.13s/it]  


We present the results for the different attention setups

In [28]:
print_results(file_name+' (A)', conf, metrics, label_names)

  avg = a.mean(axis)


FTP
['Mean', 'Mean', 'From', False]  0.0 | 0.04507 -0.01472 -0.03036
['Mean', 'Mean', 'To', False]  -0.0 | -0.27309 0.20973 0.06335
['Mean', 'Mean', 'MeanColumns', False]  0.0 | -0.05002 0.05078 -0.00076


  ret = ret.dtype.type(ret / rcount)


['Mean', 'Mean', 'MaxColumns', False]  -0.0 | -0.21917 0.14933 0.06984
['Mean', 0, 'From', False]  0.0 | 0.05331 -0.03442 -0.01889
['Mean', 0, 'To', False]  0.0 | -0.23369 0.17662 0.05707
['Mean', 0, 'MeanColumns', False]  0.0 | -0.0913 0.06307 0.02823
['Mean', 0, 'MaxColumns', False]  -0.0 | -0.20379 0.12365 0.08014
['Mean', 1, 'From', False]  -0.0 | 0.06693 -0.02704 -0.03989
['Mean', 1, 'To', False]  0.0 | -0.27819 0.20944 0.06874
['Mean', 1, 'MeanColumns', False]  0.0 | -0.08326 0.08638 -0.00312
['Mean', 1, 'MaxColumns', False]  0.0 | -0.26582 0.18951 0.07632
['Mean', 2, 'From', False]  0.0 | 0.05896 -0.02375 -0.0352
['Mean', 2, 'To', False]  -0.0 | -0.17512 0.11622 0.0589
['Mean', 2, 'MeanColumns', False]  0.0 | 0.09543 -0.04716 -0.04827
['Mean', 2, 'MaxColumns', False]  0.0 | -0.15077 0.10308 0.04769
['Mean', 3, 'From', False]  0.0 | 0.09649 -0.04432 -0.05217
['Mean', 3, 'To', False]  -0.0 | -0.24404 0.18654 0.0575
['Mean', 3, 'MeanColumns', False]  -0.0 | 0.08043 -0.03131 -0.0491

In [29]:
print_results(file_name+' (P)', conf, metricsP, label_names)

FTP
['Mean', 'Mean', 'From', False]  0.03161 | 0.08477 0.02501 -0.01496
['Mean', 'Mean', 'To', False]  0.22429 | -0.12513 0.43505 0.36295
['Mean', 'Mean', 'MeanColumns', False]  0.11536 | 0.03782 0.16979 0.13846
['Mean', 'Mean', 'MaxColumns', False]  0.1579 | -0.12691 0.29124 0.30938
['Mean', 0, 'From', False]  0.01642 | 0.06704 -0.03665 0.01887
['Mean', 0, 'To', False]  0.16538 | -0.12021 0.35821 0.25813
['Mean', 0, 'MeanColumns', False]  0.15071 | -0.00364 0.17404 0.28172
['Mean', 0, 'MaxColumns', False]  0.17345 | -0.12016 0.23071 0.40979
['Mean', 1, 'From', False]  -0.00285 | 0.08973 -0.00296 -0.09531
['Mean', 1, 'To', False]  0.25308 | -0.11732 0.43923 0.43733
['Mean', 1, 'MeanColumns', False]  0.14333 | 0.03551 0.25581 0.13867
['Mean', 1, 'MaxColumns', False]  0.23149 | -0.12601 0.39724 0.42323
['Mean', 2, 'From', False]  -0.003 | 0.07966 0.00266 -0.09131
['Mean', 2, 'To', False]  0.11134 | -0.11361 0.20532 0.24232
['Mean', 2, 'MeanColumns', False]  -0.05345 | 0.08394 -0.05204 -0

We calculate the best attention setup using Optimus variations (we do not use the Optimus implementation script at this step)

In [30]:
print_results_ap(metrics, label_names, conf)

Baseline: 3.395101214054286e-10  and NZW: 1.0 and AUPRC: 0.4760584705564292
Max Across: 3.49152621456165e-09  and NZW: 1.0 and AUPRC: 0.37683445127322307


  out=out, **kwargs)


Per Label Per Instance: 0.1627142405610311  and NZW:  1.0 and AUPRC: 0.49439289072731957
Per Instance: 1.1529691161700495e-07  and NZW:  1.0 and AUPRC: 0.3449200520197541


In [31]:
print_results_ap(metricsP, label_names, conf)

Baseline: 0.0316075957298731  and NZW: 1.0 and AUPRC: 0.4760584705564292
Max Across: 0.37459280317939286  and NZW: 1.0 and AUPRC: 0.5429974825673622
Per Label Per Instance: 0.10619502229250351  and NZW:  1.0 and AUPRC: 0.5018245017412385
Per Instance: 0.530070352900451  and NZW:  1.0 and AUPRC: 0.48945109674639836
