## HATEXPLAIN DistilBERT multiclass
In this notebook we examine the performance of interpretability techniques in the HateXplain dataset using DistilBERT on token level 

In [1]:
import sys
sys.path.append('../')

import numpy as np
import pandas as pd
import re
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, recall_score, average_precision_score
from dataset import Dataset
from myModel import MyModel, MyDataset
from myExplainers import MyExplainer
from myEvaluation import MyEvaluation
from sklearn.preprocessing import maxabs_scale
import pickle
from tqdm import tqdm
import datetime
import csv
import warnings
import torch
import tensorflow as tf
from scipy.special import softmax
from helper import print_results, print_results_ap

Loading model and dataset, defining transformer model, and if rationales are available in the dataset

In [2]:
data_path = '../datasets/hatexplain.json'
model_path = 'Trained Models/'
save_path = 'Results/hx_multiclass/'

In [3]:
model_name = 'distilbert'
dataset_name='hx_distilbert_uncased_multiclass'
existing_rationales = True

Load MyModel, and the subsequent tokenizer

In [4]:
task = 'single_label'
sentence_level = False
labels = 3

model = MyModel(model_path, dataset_name, model_name, task, labels, cased=False, attention=True)
model_no_attention = MyModel(model_path, dataset_name, model_name, task, labels, cased=False, attention=False)
max_sequence_len = model.tokenizer.max_len_single_sentence
tokenizer = model.tokenizer

import torch
print(torch.cuda.is_available())
model.trainer.model.to('cuda')

True


DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
       

In [5]:
hx = Dataset(path='../') #data_path
x, y, label_names, rationales = hx.load_hatexplain_multiclass(tokenizer)

In [6]:
indices = np.arange(len(y))
train_texts, test_texts, train_labels, test_labels, _, test_indexes = train_test_split(
    x, y, indices, stratify=y, train_size=8000, test_size=2000, random_state=42)
if existing_rationales:
    test_rationales = [rationales[x] for x in test_indexes]

# size = (0.1 * len(y)) / len(train_labels)
train_texts, validation_texts, train_labels, validation_labels = train_test_split(
    list(train_texts),
    train_labels,
    stratify=train_labels,
    test_size=1000,
    random_state=42)

In [7]:
for i, label in enumerate(test_labels):
    
    if label == 0:
        token_length = len(tokenizer.tokenize(test_texts[i]))
        test_rationales[i] = [[0] * token_length,
                              [0] * token_length,
                              [0] * token_length]
    elif label == 1:
        test_rationales[i] = [[0] * len(test_rationales[i]), 
                            test_rationales[i], 
                            [0] * len(test_rationales[i])]
    else:
        # print(len(test_rationales[i]))
        test_rationales[i] = [[0] * len(test_rationales[i]),  
                            [0] * len(test_rationales[i]),
                            test_rationales[i]]
        # print(test_rationales[i])

In [9]:
test_test_rationales = test_rationales

Then, we measure the performance of the model using accuracy and f1 score (both macro and micro)

In [11]:
predictions = []
for test_text in test_texts:
    outputs = model.my_predict(test_text)
    predictions.append(outputs[0])

pred_labels = []
for prediction in predictions:
    pred_labels.append(np.argmax(softmax(prediction)))

accuracy_score(test_labels, pred_labels), f1_score(test_labels, pred_labels, average='macro'), f1_score(test_labels, pred_labels, average='micro')

1999it [01:16, 25.78it/s]            

(0.675, 0.6493770303746486, 0.675)

2000it [01:30, 25.78it/s]

In [12]:
my_explainers = MyExplainer(label_names, model_no_attention) #model 2

my_evaluators = MyEvaluation(label_names, model_no_attention.my_predict, sentence_level = False, task = 'multi-class', evaluation_level_all = True, tokenizer=tokenizer) #model 2
my_evaluatorsP = MyEvaluation(label_names, model_no_attention.my_predict, sentence_level = False, task = 'multi-class', evaluation_level_all = False, tokenizer=tokenizer) #model 2
evaluation =  {'F':my_evaluators.faithfulness, 'FTP': my_evaluators.faithful_truthfulness_penalty, 
          'NZW': my_evaluators.nzw, 'AUPRC': my_evaluators.auprc}
evaluationP = {'F':my_evaluatorsP.faithfulness, 'FTP': my_evaluatorsP.faithful_truthfulness_penalty, 
          'NZW': my_evaluatorsP.nzw, 'AUPRC': my_evaluators.auprc}

In [13]:
new_rationale = test_rationales
len_test = len(test_labels) # 2000
num_labels = len(np.unique(test_labels)) #3

In [14]:
import time
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    
    now = datetime.datetime.now()
    file_name = save_path + 'HX_DISTILBERT_IG_'+str(now.day) + '_' + str(now.month) + '_' + str(now.year)
    metrics = {'F':[], 'FTP':[], 'NZW':[], 'AUPRC' : []}
    metricsP = {'F':[], 'FTP':[], 'NZW':[], 'AUPRC' : []}
    time_r = []
    my_explainers.neighbours = 2000
    techniques = [my_explainers.ig] #my_explainers.lime 
    for ind in tqdm(range(len(test_texts))):
        torch.cuda.empty_cache() 
        test_rational = new_rationale[ind]
        instance = test_texts[ind]
        my_evaluators.clear_states()
        my_evaluatorsP.clear_states()
        # model_no_attention.predict xwris attention + hidden states
        prediction, _, _ = model_no_attention.my_predict(instance)
        enc = model.tokenizer([instance,instance], truncation=True, padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
    
        interpretations = []
        kk = 0
        for technique in techniques:
            ts = time.time()
            temp = technique(instance, prediction, tokens, mask, _, _)
            interpretations.append([np.array(i)/np.max(abs(np.array(i))) for i in temp])
            time_r.append(time.time()-ts)
            kk = kk + 1
        for metric in metrics.keys():
            evaluated = []
            for interpretation in interpretations:
                evaluated.append(evaluation[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
            metrics[metric].append(evaluated)
        my_evaluatorsP.saved_state = my_evaluators.saved_state.copy()
        my_evaluators.clear_states()
        for metric in metrics.keys():
            evaluatedP = []
            for interpretation in interpretations:
                evaluatedP.append(evaluationP[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
            metricsP[metric].append(evaluatedP)
with open(file_name+'(A).pickle', 'wb') as handle:
    pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+'(P).pickle', 'wb') as handle:
    pickle.dump(metricsP, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+'_TIME.pickle', 'wb') as handle:
    pickle.dump(time_r, handle, protocol=pickle.HIGHEST_PROTOCOL)
time_r = np.array(time_r)
# time_r.mean()
# time_r.mean(axis=1)

100%|██████████| 2000/2000 [33:53<00:00,  1.02s/it]


We present the results for IG

In [None]:
print(time_r)

In [15]:
print_results(file_name+'(A)', [' IG '], metrics, label_names) #[' LIME', ' IG  ']

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


[0.55500221 0.41477299 0.4761622  ... 0.91950703 0.46599889 0.70622802]
F
 IG   0.10010000318288803 | 0.02309 0.18756 0.08965
FTP
 IG   0.28003 | 0.37066 0.29178 0.17765
NZW
 IG   1.0 | 1.0 1.0 1.0
AUPRC
 IG   0.48061 | 0.0 0.76706 0.67479


In [16]:
print_results(file_name+'(P)', [' IG '], metricsP, label_names) #[' LIME', ' IG  ']

F
 IG   0.29381 | 0.0181 0.46416 0.39915
FTP
 IG   0.46698 | 0.13716 0.73039 0.53339
NZW
 IG   1.0 | 1.0 1.0 1.0
AUPRC
 IG   0.48061 | 0.0 0.76706 0.67479


Then, we perform the experiments for the different attention setups!

In [17]:
conf = []
for ci in ['Mean', 'Multi'] + list(range(6)):
    for ce in ['Mean'] + list(range(12)):
        for cp in ['From', 'To', 'MeanColumns', 'MaxColumns']: # Matrix: From, To, MeanColumns, MeanRows, MaxColumns, MaxRows
            for cl in [False]: # Selection: True: select layers per head, False: do not
                conf.append([ci, ce, cp, cl])
len(conf)

416

In [18]:
# This is saved in the wrong file. ESNLI distilbert instead of HX
import time 
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    
    now = datetime.datetime.now()
    
    file_name = save_path + 'HX_DISTILBERT_ATTENTION_'+str(now.day) + '_' + str(now.month) + '_' + str(now.year)
    metrics = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC' : []}
    metricsP = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC' : []}
    time_r = []
    time_b = []
    time_b2 = []
    for con in conf:
        time_r.append([])
    for ind in tqdm(range(len(test_texts))):
        torch.cuda.empty_cache() 
        test_rational = new_rationale[ind]
        instance = test_texts[ind]
        my_evaluators.clear_states()
        my_evaluatorsP.clear_states()
        my_explainers.save_states = {}
        prediction, attention, _ = model.my_predict(instance)
        enc = model.tokenizer([instance,instance], truncation=True, padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
        
        interpretations = []
        kk = 0
        for con in conf:
            ts = time.time()
            my_explainers.config = con
            temp = my_explainers.my_attention(instance, prediction, tokens, mask, attention, _)
            interpretations.append([maxabs_scale(i) for i in temp])
            time_r[kk].append(time.time()-ts)
            kk = kk + 1
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluation[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b.append(k)
            metrics[metric].append(evaluated)
        my_evaluatorsP.saved_state = my_evaluators.saved_state.copy()
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluationP[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b2.append(k)
            metricsP[metric].append(evaluated)
with open(file_name+' (A).pickle', 'wb') as handle:
    pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+' (P).pickle', 'wb') as handle:
    pickle.dump(metricsP, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+'_TIME.pickle', 'wb') as handle:
    pickle.dump(time_r, handle, protocol=pickle.HIGHEST_PROTOCOL)
time_r = np.array(time_r)
time_r.mean(axis=1).min(),time_r.mean(axis=1).max(), time_r.mean(axis=1).mean(), time_r.sum(axis=1).mean(), np.mean(time_b), np.mean(time_b2)

100%|██████████| 2000/2000 [1:46:15<00:00,  3.19s/it]  


We present the results of the different attention setups

In [20]:
print_results(file_name+' (A)', conf, metrics, label_names)

  avg = a.mean(axis)


FTP
['Mean', 'Mean', 'From', False]  -0.0 | -0.37115 0.29974 0.07141
['Mean', 'Mean', 'To', False]  -0.0 | -0.23552 0.17527 0.06025
['Mean', 'Mean', 'MeanColumns', False]  -0.0 | -0.35437 0.28957 0.0648


  ret = ret.dtype.type(ret / rcount)


['Mean', 'Mean', 'MaxColumns', False]  -0.0 | -0.28597 0.2195 0.06647
['Mean', 0, 'From', False]  -0.0 | -0.32611 0.27155 0.05456
['Mean', 0, 'To', False]  0.0 | -0.17408 0.12549 0.04859
['Mean', 0, 'MeanColumns', False]  -0.0 | -0.32891 0.26874 0.06017
['Mean', 0, 'MaxColumns', False]  -0.0 | -0.32041 0.25259 0.06782
['Mean', 1, 'From', False]  -0.0 | -0.34984 0.29116 0.05869
['Mean', 1, 'To', False]  0.0 | -0.2371 0.17709 0.06001
['Mean', 1, 'MeanColumns', False]  -0.0 | -0.33003 0.27722 0.05281
['Mean', 1, 'MaxColumns', False]  -0.0 | -0.28592 0.22705 0.05887
['Mean', 2, 'From', False]  -0.0 | -0.25088 0.19163 0.05925
['Mean', 2, 'To', False]  -0.0 | -0.20934 0.17247 0.03688
['Mean', 2, 'MeanColumns', False]  -0.0 | -0.22703 0.17183 0.05521
['Mean', 2, 'MaxColumns', False]  -0.0 | -0.20121 0.14808 0.05313
['Mean', 3, 'From', False]  0.0 | -0.30868 0.25881 0.04987
['Mean', 3, 'To', False]  0.0 | -0.18595 0.14143 0.04452
['Mean', 3, 'MeanColumns', False]  -0.0 | -0.21208 0.17026 0.041

In [21]:
print_results(file_name+' (P)', conf, metricsP, label_names)

FTP
['Mean', 'Mean', 'From', False]  0.35721 | -0.13506 0.72258 0.48412
['Mean', 'Mean', 'To', False]  0.23692 | -0.09076 0.40222 0.3993
['Mean', 'Mean', 'MeanColumns', False]  0.3206 | -0.14457 0.69391 0.41246
['Mean', 'Mean', 'MaxColumns', False]  0.2273 | -0.14416 0.49527 0.33079
['Mean', 0, 'From', False]  0.31314 | -0.11466 0.65246 0.40161
['Mean', 0, 'To', False]  0.09225 | -0.12759 0.23626 0.16807
['Mean', 0, 'MeanColumns', False]  0.29159 | -0.13762 0.63292 0.37948
['Mean', 0, 'MaxColumns', False]  0.29528 | -0.13654 0.59458 0.42779
['Mean', 1, 'From', False]  0.32356 | -0.13785 0.69281 0.41571
['Mean', 1, 'To', False]  0.19191 | -0.12388 0.38343 0.31617
['Mean', 1, 'MeanColumns', False]  0.29561 | -0.13721 0.66376 0.36028
['Mean', 1, 'MaxColumns', False]  0.23363 | -0.1409 0.52072 0.32106
['Mean', 2, 'From', False]  0.21116 | -0.1145 0.43023 0.31775
['Mean', 2, 'To', False]  0.1826 | -0.09765 0.40103 0.24443
['Mean', 2, 'MeanColumns', False]  0.19231 | -0.11464 0.38176 0.30982

We calculate the best attention setup using Optimus variations (we do not use the Optimus implementation at this step)

In [22]:
print_results_ap(metrics, label_names, conf)

Baseline: -8.401053834076558e-10  and NZW: 1.0 and AUPRC: 0.4976827811741507
Max Across: 8.502313482890619e-10  and NZW: 1.0 and AUPRC: 0.34405537629712696


  out=out, **kwargs)


Per Label Per Instance: 0.07532674240924414  and NZW:  0.9998822714455601 and AUPRC: 0.5059490927355618
Per Instance: 4.81831971381926e-08  and NZW:  1.0 and AUPRC: 0.33821274794971234


In [23]:
print_results_ap(metricsP, label_names, conf)

Baseline: 0.35721040928648756  and NZW: 1.0 and AUPRC: 0.4976827811741507
Max Across: 0.37928798450153667  and NZW: 0.9951533677955776 and AUPRC: 0.5306406707171761
Per Label Per Instance: 0.0846136868204036  and NZW:  0.9993326634362983 and AUPRC: 0.5124626864090573
Per Instance: 0.45464188690003376  and NZW:  0.9993326634362983 and AUPRC: 0.506091565390738


We repeat the process with Attention Scores with negative values (A*), thus by skipping the Softmax function. In the attention setups, we exclude the multiplication option in heads and layers, as a few combinations reach +/-inf

In [24]:
conf = []
for ci in ['Mean'] + list(range(6)):
    for ce in ['Mean'] + list(range(12)):
        for cp in ['From', 'To', 'MeanColumns', 'MaxColumns']: # Matrix: From, To, MeanColumns, MeanRows, MaxColumns, MaxRows
            for cl in [False]: # Selection: True: select layers per head, False: do not
                conf.append([ci, ce, cp, cl])
len(conf)

364

In [25]:
import time 
import math
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    
    now = datetime.datetime.now()
    
    file_name = save_path + 'HX_DISTILBERT_A_ATTENTION_NO_SOFTMAX_'+str(now.day) + '_' + str(now.month) + '_' + str(now.year)
    metrics = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC' : []}
    metricsP = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC' : []}
    time_r = []
    time_b = []
    time_b2 = []
    for con in conf:
        time_r.append([])
    for ind in tqdm(range(len(test_texts))):
        torch.cuda.empty_cache() 
        test_rational = new_rationale[ind]
        instance = test_texts[ind]
        my_evaluators.clear_states()
        my_evaluatorsP.clear_states()
        my_explainers.save_states = {}
        prediction, _, hidden_states = model.my_predict(instance)
        enc = model.tokenizer([instance,instance], truncation=True, padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
        
        attention = []
        for la in range(6):
            our_new_layer = []
            bob =  model.trainer.model.base_model.transformer.layer[la].attention
            has = hidden_states[la]
            aaa = bob.k_lin(torch.tensor(has).to('cuda'))
            bbb = bob.q_lin(torch.tensor(has).to('cuda'))
            for he in range(12):
                bbb = bbb / math.sqrt(64)
                attention_scores = torch.matmul(bbb[:,he*64:(he+1)*64], aaa[:,he*64:(he+1)*64].transpose(-1, -2))
                our_new_layer.append(attention_scores.cpu().detach().numpy())
            attention.append(our_new_layer)
        attention = np.array(attention)
        
        interpretations = []
        kk = 0
        for con in conf:
            ts = time.time()
            my_explainers.config = con
            temp = my_explainers.my_attention(instance, prediction, tokens, mask, attention, _)
            interpretations.append([maxabs_scale(i) for i in temp])
            time_r[kk].append(time.time()-ts)
            kk = kk + 1
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluation[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b.append(k)
            metrics[metric].append(evaluated)
        my_evaluatorsP.saved_state = my_evaluators.saved_state.copy()
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluationP[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b2.append(k)
            metricsP[metric].append(evaluated)        
with open(file_name+' (A).pickle', 'wb') as handle:
    pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+' (P).pickle', 'wb') as handle:
    pickle.dump(metricsP, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+'_TIME.pickle', 'wb') as handle:
    pickle.dump(time_r, handle, protocol=pickle.HIGHEST_PROTOCOL)
time_r = np.array(time_r)
time_r.mean(axis=1).min(),time_r.mean(axis=1).max(), time_r.mean(axis=1).mean(), time_r.sum(axis=1).mean(), np.mean(time_b), np.mean(time_b2)

100%|██████████| 2000/2000 [1:34:27<00:00,  2.83s/it]


We present the results for the different attention setups

In [28]:
print_results(file_name+' (A)', conf, metrics, label_names)

FTP
['Mean', 'Mean', 'From', False]  -0.0 | -0.08708 0.11729 -0.03021
['Mean', 'Mean', 'To', False]  -0.0 | -0.23664 0.19247 0.04417
['Mean', 'Mean', 'MeanColumns', False]  -0.0 | -0.11547 0.10871 0.00676


  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


['Mean', 'Mean', 'MaxColumns', False]  -0.0 | -0.31737 0.25241 0.06496
['Mean', 0, 'From', False]  -0.0 | -0.08407 0.11601 -0.03194
['Mean', 0, 'To', False]  -0.0 | -0.22435 0.18154 0.0428
['Mean', 0, 'MeanColumns', False]  -0.0 | -0.10424 0.09967 0.00458
['Mean', 0, 'MaxColumns', False]  -0.0 | -0.3168 0.25261 0.06419
['Mean', 1, 'From', False]  -0.0 | -0.11417 0.12016 -0.00598
['Mean', 1, 'To', False]  -0.0 | -0.27253 0.20476 0.06777
['Mean', 1, 'MeanColumns', False]  -0.0 | -0.21245 0.18461 0.02784
['Mean', 1, 'MaxColumns', False]  -0.0 | -0.24905 0.18232 0.06673
['Mean', 2, 'From', False]  -0.0 | -0.02649 0.02592 0.00057
['Mean', 2, 'To', False]  -0.0 | -0.21796 0.16448 0.05349
['Mean', 2, 'MeanColumns', False]  -0.0 | -0.09112 0.07658 0.01454
['Mean', 2, 'MaxColumns', False]  -0.0 | -0.19008 0.13669 0.05339
['Mean', 3, 'From', False]  -0.0 | -0.09599 0.10887 -0.01289
['Mean', 3, 'To', False]  -0.0 | -0.29008 0.22657 0.06351
['Mean', 3, 'MeanColumns', False]  -0.0 | -0.11491 0.0889

In [29]:
print_results(file_name+' (P)', conf, metricsP, label_names)

FTP
['Mean', 'Mean', 'From', False]  0.13418 | 0.03456 0.3781 -0.01011
['Mean', 'Mean', 'To', False]  0.17112 | -0.12617 0.43332 0.20622
['Mean', 'Mean', 'MeanColumns', False]  0.15744 | -8e-05 0.32782 0.14458
['Mean', 'Mean', 'MaxColumns', False]  0.27581 | -0.15191 0.59206 0.38728
['Mean', 0, 'From', False]  0.12844 | 0.0351 0.37506 -0.02483
['Mean', 0, 'To', False]  0.15948 | -0.12157 0.40549 0.19451
['Mean', 0, 'MeanColumns', False]  0.14644 | 0.0045 0.30558 0.12924
['Mean', 0, 'MaxColumns', False]  0.27683 | -0.1513 0.59415 0.38764
['Mean', 1, 'From', False]  0.16734 | 0.01621 0.35965 0.12615
['Mean', 1, 'To', False]  0.21066 | -0.14723 0.45117 0.32803
['Mean', 1, 'MeanColumns', False]  0.23307 | -0.05448 0.49116 0.26253
['Mean', 1, 'MaxColumns', False]  0.17733 | -0.14286 0.38437 0.29046
['Mean', 2, 'From', False]  0.10426 | 0.04669 0.12923 0.13685
['Mean', 2, 'To', False]  0.17736 | -0.11726 0.36104 0.28829
['Mean', 2, 'MeanColumns', False]  0.14346 | -0.00012 0.24289 0.18761
['

We calculate the best attention setup using Optimus variations (we do not use the Optimus implementation script at this step)

In [30]:
print_results_ap(metrics, label_names, conf)

Baseline: -8.006535550294144e-10  and NZW: 1.0 and AUPRC: 0.47289123763381014
Max Across: 2.8669819510338903e-09  and NZW: 0.8900092757597925 and AUPRC: 0.33964938802684025


  out=out, **kwargs)


Per Label Per Instance: 0.16714746077705525  and NZW:  0.8140255774972855 and AUPRC: 0.5078218861606594
Per Instance: 1.068244483799427e-07  and NZW:  0.8952231636938027 and AUPRC: 0.35481340540243783


In [31]:
print_results_ap(metricsP, label_names, conf)

Baseline: 0.13418335331416734  and NZW: 1.0 and AUPRC: 0.47289123763381014
Max Across: 0.3817369953731638  and NZW: 1.0 and AUPRC: 0.5306406707171761
Per Label Per Instance: 0.1064779314296576  and NZW:  0.9245871620370695 and AUPRC: 0.5081212479667969
Per Instance: 0.5346651352112782  and NZW:  0.9245871620370695 and AUPRC: 0.504486176460791
