## HATEXPLAIN alBERT multiclass
In this notebook we examine the performance of interpretability techniques in the HateXplain dataset using alBERT on token level 

In [1]:
import sys
sys.path.append('../')

import numpy as np
import pandas as pd
import re
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, recall_score, average_precision_score
from dataset import Dataset
from myModel import MyModel, MyDataset
from myExplainers import MyExplainer
from myEvaluation import MyEvaluation
from sklearn.preprocessing import maxabs_scale
import pickle
from tqdm import tqdm
import datetime
import csv
import warnings
import torch
import tensorflow as tf
from scipy.special import softmax
from helper import print_results, print_results_ap

Loading model and dataset, defining transformer model, and if rationales are available in the dataset

In [2]:
data_path = '../datasets/hatexplain.json'
model_path = 'Trained Models/'
save_path = 'Results/hx_multiclass/'

In [3]:
model_name = 'albert'
dataset_name='hx_albert_v2_multiclass'
existing_rationales = True

Load MyModel, and the subsequent tokenizer

In [4]:
task = 'single_label'
sentence_level = False
labels = 3

model = MyModel(model_path, dataset_name, model_name, task, labels, cased=False, attention=True)
model_no_attention = MyModel(model_path, dataset_name, model_name, task, labels, cased=False, attention=False)
max_sequence_len = model.tokenizer.max_len_single_sentence
tokenizer = model.tokenizer

import torch
print(torch.cuda.is_available())
model.trainer.model.to('cuda')

True


AlbertForSequenceClassification(
  (albert): AlbertModel(
    (embeddings): AlbertEmbeddings(
      (word_embeddings): Embedding(30000, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0, inplace=False)
    )
    (encoder): AlbertTransformer(
      (embedding_hidden_mapping_in): Linear(in_features=128, out_features=768, bias=True)
      (albert_layer_groups): ModuleList(
        (0): AlbertLayerGroup(
          (albert_layers): ModuleList(
            (0): AlbertLayer(
              (full_layer_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (attention): AlbertAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768,

In [5]:
hx = Dataset(path='../') #data_path
x, y, label_names, rationales = hx.load_hatexplain_multiclass(tokenizer)

In [6]:
indices = np.arange(len(y))
train_texts, test_texts, train_labels, test_labels, _, test_indexes = train_test_split(
    x, y, indices, stratify=y, train_size=8000, test_size=2000, random_state=42)
if existing_rationales:
    test_rationales = [rationales[x] for x in test_indexes]

# size = (0.1 * len(y)) / len(train_labels)
train_texts, validation_texts, train_labels, validation_labels = train_test_split(
    list(train_texts),
    train_labels,
    stratify=train_labels,
    test_size=1000,
    random_state=42)

In [7]:
for i, label in enumerate(test_labels):
    
    if label == 0:
        token_length = len(tokenizer.tokenize(test_texts[i]))
        test_rationales[i] = [[0] * token_length,
                              [0] * token_length,
                              [0] * token_length]
    elif label == 1:
        test_rationales[i] = [[0] * len(test_rationales[i]), 
                            test_rationales[i], 
                            [0] * len(test_rationales[i])]
    else:
        test_rationales[i] = [[0] * len(test_rationales[i]),  
                            [0] * len(test_rationales[i]),
                            test_rationales[i]]

In [9]:
test_test_rationales = test_rationales

Then, we measure the performance of the model using f1 score (both macro)

In [10]:
predictions = []
for test_text in test_texts:
    outputs = model.my_predict(test_text)
    predictions.append(outputs[0])

pred_labels = []
for prediction in predictions:
    pred_labels.append(np.argmax(softmax(prediction)))

accuracy_score(test_labels, pred_labels), f1_score(test_labels, pred_labels, average='macro'), f1_score(test_labels, pred_labels, average='micro')

1999it [02:01, 19.59it/s]            

(0.636, 0.6333142297681306, 0.636)

2000it [02:20, 19.59it/s]

In [11]:
my_explainers = MyExplainer(label_names, model_no_attention) #model 2

my_evaluators = MyEvaluation(label_names, model_no_attention.my_predict, sentence_level = False, task = 'multi-class', evaluation_level_all = True, tokenizer=tokenizer) #model 2
my_evaluatorsP = MyEvaluation(label_names, model_no_attention.my_predict, sentence_level = False, task = 'multi-class', evaluation_level_all = False, tokenizer=tokenizer) #model 2
evaluation =  {'F':my_evaluators.faithfulness, 'FTP': my_evaluators.faithful_truthfulness_penalty, 
          'NZW': my_evaluators.nzw, 'AUPRC': my_evaluators.auprc}
evaluationP = {'F':my_evaluatorsP.faithfulness, 'FTP': my_evaluatorsP.faithful_truthfulness_penalty, 
          'NZW': my_evaluatorsP.nzw, 'AUPRC': my_evaluators.auprc}

In [12]:
new_rationale = test_rationales
len_test = len(test_labels) # 2000
num_labels = len(np.unique(test_labels)) #3

In [13]:
import time
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    
    now = datetime.datetime.now()
    file_name = save_path + 'HX_ALBERT_IG_'+str(now.day) + '_' + str(now.month) + '_' + str(now.year)
    metrics = {'F':[], 'FTP':[], 'NZW':[], 'AUPRC' : []}
    metricsP = {'F':[], 'FTP':[], 'NZW':[], 'AUPRC' : []}
    time_r = []
    my_explainers.neighbours = 2000
    techniques = [my_explainers.ig] #my_explainers.lime 
    for ind in tqdm(range(len(test_texts))):
        torch.cuda.empty_cache() 
        test_rational = new_rationale[ind]
        instance = test_texts[ind]
        my_evaluators.clear_states()
        my_evaluatorsP.clear_states()
        # model_no_attention.predict xwris attention + hidden states
        prediction, _, _ = model_no_attention.my_predict(instance)
        enc = model.tokenizer([instance,instance], truncation=True, padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
    
        interpretations = []
        kk = 0
        for technique in techniques:
            ts = time.time()
            temp = technique(instance, prediction, tokens, mask, _, _)
            interpretations.append([np.array(i)/np.max(abs(np.array(i))) for i in temp])
            time_r.append(time.time()-ts)
            kk = kk + 1
        for metric in metrics.keys():
            evaluated = []
            for interpretation in interpretations:
                evaluated.append(evaluation[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
            metrics[metric].append(evaluated)
        my_evaluatorsP.saved_state = my_evaluators.saved_state.copy()
        my_evaluators.clear_states()
        for metric in metrics.keys():
            evaluatedP = []
            for interpretation in interpretations:
                evaluatedP.append(evaluationP[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
            metricsP[metric].append(evaluatedP)
        with open(file_name+'(A).pickle', 'wb') as handle:
            pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)
        with open(file_name+'(P).pickle', 'wb') as handle:
            pickle.dump(metricsP, handle, protocol=pickle.HIGHEST_PROTOCOL)
        with open(file_name+'_TIME.pickle', 'wb') as handle:
            pickle.dump(time_r, handle, protocol=pickle.HIGHEST_PROTOCOL)
time_r = np.array(time_r)
# time_r.mean()
# time_r.mean(axis=1)

100%|██████████| 2000/2000 [49:33<00:00,  1.49s/it]


We present the results for IG

In [14]:
print(time_r)

[0.91900158 0.74653006 0.77351475 ... 0.96453094 0.82851529 0.89100122]


In [15]:
print_results(file_name+'(A)', [' IG '], metrics, label_names) #[' LIME', ' IG  ']

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


F
 IG   0.07083000242710114 | 0.02915 0.137 0.04634
FTP
 IG   0.17387 | 0.21717 0.15893 0.1455
NZW
 IG   1.0 | 1.0 1.0 1.0
AUPRC
 IG   0.46402 | 0.0 0.72849 0.66357


In [16]:
print_results(file_name+'(P)', [' IG '], metricsP, label_names) #[' LIME', ' IG  ']

F
 IG   0.259 | 0.11648 0.42517 0.23535
FTP
 IG   0.31372 | 0.19325 0.45277 0.29513
NZW
 IG   1.0 | 1.0 1.0 1.0
AUPRC
 IG   0.46402 | 0.0 0.72849 0.66357


Then, we perform the experiments for the different attention setups!

In [17]:
conf = []
for ci in ['Mean', 'Multi'] + list(range(12)):
    for ce in ['Mean'] + list(range(12)):
        for cp in ['From', 'To', 'MeanColumns', 'MaxColumns']: # Matrix: From, To, MeanColumns, MeanRows, MaxColumns, MaxRows
            for cl in [False]: # Selection: True: select layers per head, False: do not
                conf.append([ci, ce, cp, cl])
len(conf)

728

In [18]:
import time 
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    
    now = datetime.datetime.now()
    
    file_name = save_path + 'HX_ALBERT_ATTENTION_'+str(now.day) + '_' + str(now.month) + '_' + str(now.year)
    metrics = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC' : []}
    metricsP = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC' : []}
    time_r = []
    time_b = []
    time_b2 = []
    torch.cuda.empty_cache()
    for con in conf:
        time_r.append([])
    for ind in tqdm(range(len(test_texts))):
        test_rational = new_rationale[ind]
        instance = test_texts[ind]
        my_evaluators.clear_states()
        my_evaluatorsP.clear_states()
        my_explainers.save_states = {}
        prediction, attention, _ = model.my_predict(instance)
        enc = model.tokenizer([instance,instance], truncation=True, padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
        
        interpretations = []
        kk = 0
        for con in conf:
            ts = time.time()
            my_explainers.config = con
            temp = my_explainers.my_attention(instance, prediction, tokens, mask, attention, _)
            interpretations.append([maxabs_scale(i) for i in temp])
            time_r[kk].append(time.time()-ts)
            kk = kk + 1
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluation[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b.append(k)
            metrics[metric].append(evaluated)
        my_evaluatorsP.saved_state = my_evaluators.saved_state.copy()
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluationP[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b2.append(k)
            metricsP[metric].append(evaluated)
with open(file_name+' (A).pickle', 'wb') as handle: 
    pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+' (P).pickle', 'wb') as handle:
    pickle.dump(metricsP, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+'_TIME.pickle', 'wb') as handle:
    pickle.dump(time_r, handle, protocol=pickle.HIGHEST_PROTOCOL)
time_r = np.array(time_r)
# time_r.mean(axis=1).min(),time_r.mean(axis=1).max(), time_r.mean(axis=1).mean(), time_r.sum(axis=1).mean(), np.mean(time_b), np.mean(time_b2)

100%|██████████| 2000/2000 [3:32:53<00:00,  6.39s/it]  


We present the results of the different attention setups

In [20]:
print_results(file_name+' (A)', conf, metrics, label_names)

FTP
['Mean', 'Mean', 'From', False]  -0.0 | -0.15487 0.23264 -0.07777
['Mean', 'Mean', 'To', False]  -0.0 | -0.08796 0.15098 -0.06301
['Mean', 'Mean', 'MeanColumns', False]  -0.0 | -0.11421 0.19663 -0.08242


  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


['Mean', 'Mean', 'MaxColumns', False]  -0.0 | -0.11706 0.20273 -0.08567
['Mean', 0, 'From', False]  -0.0 | -0.16355 0.23832 -0.07477
['Mean', 0, 'To', False]  -0.0 | -0.07256 0.14812 -0.07557
['Mean', 0, 'MeanColumns', False]  -0.0 | -0.12548 0.21731 -0.09183
['Mean', 0, 'MaxColumns', False]  -0.0 | -0.14917 0.23039 -0.08122
['Mean', 1, 'From', False]  -0.0 | -0.15287 0.23141 -0.07853
['Mean', 1, 'To', False]  -0.0 | -0.07862 0.1653 -0.08668
['Mean', 1, 'MeanColumns', False]  -0.0 | -0.12387 0.20828 -0.08441
['Mean', 1, 'MaxColumns', False]  -0.0 | -0.13361 0.22169 -0.08808
['Mean', 2, 'From', False]  -0.0 | -0.10189 0.1656 -0.06371
['Mean', 2, 'To', False]  -0.0 | -0.07089 0.13961 -0.06872
['Mean', 2, 'MeanColumns', False]  -0.0 | -0.09928 0.175 -0.07571
['Mean', 2, 'MaxColumns', False]  -0.0 | -0.10192 0.18671 -0.0848
['Mean', 3, 'From', False]  -0.0 | -0.11106 0.21048 -0.09942
['Mean', 3, 'To', False]  -0.0 | -0.09142 0.16399 -0.07256
['Mean', 3, 'MeanColumns', False]  -0.0 | -0.107

In [21]:
print_results(file_name+' (P)', conf, metricsP, label_names)

FTP
['Mean', 'Mean', 'From', False]  0.4076 | 0.12941 0.7759 0.31748
['Mean', 'Mean', 'To', False]  0.27776 | 0.14083 0.47028 0.22217
['Mean', 'Mean', 'MeanColumns', False]  0.36941 | 0.1684 0.66257 0.27726
['Mean', 'Mean', 'MaxColumns', False]  0.37926 | 0.16367 0.69581 0.27829
['Mean', 0, 'From', False]  0.40702 | 0.12277 0.7799 0.3184
['Mean', 0, 'To', False]  0.28158 | 0.16597 0.4701 0.20867
['Mean', 0, 'MeanColumns', False]  0.39382 | 0.16885 0.73626 0.27636
['Mean', 0, 'MaxColumns', False]  0.40865 | 0.14821 0.76646 0.31127
['Mean', 1, 'From', False]  0.40859 | 0.13729 0.76895 0.31952
['Mean', 1, 'To', False]  0.3312 | 0.20628 0.55294 0.23439
['Mean', 1, 'MeanColumns', False]  0.3665 | 0.14216 0.70092 0.25643
['Mean', 1, 'MaxColumns', False]  0.38016 | 0.13724 0.74366 0.25958
['Mean', 2, 'From', False]  0.33013 | 0.15597 0.54945 0.28496
['Mean', 2, 'To', False]  0.27366 | 0.15354 0.45636 0.21108
['Mean', 2, 'MeanColumns', False]  0.313 | 0.14014 0.56265 0.23623
['Mean', 2, 'MaxCo

We calculate the best attention setup using Optimus variations (we do not use the Optimus implementation at this step)

In [22]:
print_results_ap(metrics, label_names, conf)

Baseline: -4.3892760352252464e-09  and NZW: 1.0 and AUPRC: 0.4083045762248556
Max Across: -2.521474055013536e-09  and NZW: 1.0 and AUPRC: 0.34122078959494023


  out=out, **kwargs)


Per Label Per Instance: 0.08330903337846222  and NZW:  0.9999005202080832 and AUPRC: 0.42350639999973555
Per Instance: 4.255653588784251e-08  and NZW:  0.9999716758457907 and AUPRC: 0.33833027765027035


In [23]:
print_results_ap(metricsP, label_names, conf)

Baseline: 0.40759752983780134  and NZW: 1.0 and AUPRC: 0.4083045762248556
Max Across: 0.43254128431480554  and NZW: 1.0 and AUPRC: 0.42165235805649565
Per Label Per Instance: 0.11553557486842579  and NZW:  0.9996187363834422 and AUPRC: 0.42140278488519045
Per Instance: 0.5616676684156298  and NZW:  0.9996187363834422 and AUPRC: 0.4131630979029362


We repeat the process with Attention Scores with negative values (A*), thus by skipping the Softmax function. In the attention setups, we exclude the multiplication option in heads and layers, as a few combinations reach +/-inf

In [24]:
conf = []
for ci in ['Mean'] + list(range(12)):
    for ce in ['Mean'] + list(range(12)):
        for cp in ['From', 'To', 'MeanColumns', 'MaxColumns']: # Matrix: From, To, MeanColumns, MeanRows, MaxColumns, MaxRows
            for cl in [False]: # Selection: True: select layers per head, False: do not
                conf.append([ci, ce, cp, cl])
len(conf)

676

In [25]:
import time 
import math
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    
    now = datetime.datetime.now()
    
    file_name = save_path + 'HX_ALBERT_A_ATTENTION_NO_SOFTMAX_'+str(now.day) + '_' + str(now.month) + '_' + str(now.year)
    metrics = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC' : []}
    metricsP = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC' : []}
    time_r = []
    time_b = []
    time_b2 = []
    for con in conf:
        time_r.append([])
    for ind in tqdm(range(len(test_texts))):
        torch.cuda.empty_cache() 
        test_rational = new_rationale[ind]
        instance = test_texts[ind]
        my_evaluators.clear_states()
        my_evaluatorsP.clear_states()
        my_explainers.save_states = {}
        prediction, _, hidden_states = model.my_predict(instance)
        enc = model.tokenizer([instance,instance], truncation=True, padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
        
        attention = []
        # If in theory we had more layer groups we also needed to iterate over that and handle it
        for la in range(12):
            our_new_layer = []
            bob =  model.trainer.model.base_model.encoder.albert_layer_groups[0].albert_layers[0].attention
            has = hidden_states[la]
            aaa = bob.key(torch.tensor(has).to('cuda'))
            bbb = bob.query(torch.tensor(has).to('cuda'))
            for he in range(12):
                attention_scores = torch.matmul(bbb[:,he*64:(he+1)*64], aaa[:,he*64:(he+1)*64].transpose(-1, -2))
                attention_scores = attention_scores / math.sqrt(64)
                our_new_layer.append(attention_scores.cpu().detach().numpy())
            attention.append(our_new_layer)
        attention = np.array(attention)
        
        interpretations = []
        kk = 0
        for con in conf:
            ts = time.time()
            my_explainers.config = con
            temp = my_explainers.my_attention(instance, prediction, tokens, mask, attention, _)
            interpretations.append([maxabs_scale(i) for i in temp])
            time_r.append(time.time()-ts)
            kk = kk + 1
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluation[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b.append(k)
            metrics[metric].append(evaluated)
        my_evaluatorsP.saved_state = my_evaluators.saved_state.copy()
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluationP[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b2.append(k)
            metricsP[metric].append(evaluated)        
with open(file_name+' (A).pickle', 'wb') as handle:
    pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+' (P).pickle', 'wb') as handle:
    pickle.dump(metricsP, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+'_TIME.pickle', 'wb') as handle:
    pickle.dump(time_r, handle, protocol=pickle.HIGHEST_PROTOCOL)
time_r = np.array(time_r)
# time_r.mean(axis=1).min(),time_r.mean(axis=1).max(), time_r.mean(axis=1).mean(), time_r.sum(axis=1).mean(), np.mean(time_b), np.mean(time_b2)

100%|██████████| 2000/2000 [3:13:24<00:00,  5.80s/it]  


We present the results for the different attention setups

In [27]:
print_results(file_name+' (A)', conf, metrics, label_names)

  avg = a.mean(axis)


FTP
['Mean', 'Mean', 'From', False]  0.0 | 0.0284 -0.08269 0.05429
['Mean', 'Mean', 'To', False]  0.0 | 0.19108 -0.212 0.02092
['Mean', 'Mean', 'MeanColumns', False]  0.0 | 0.01946 -0.06486 0.0454
['Mean', 'Mean', 'MaxColumns', False]  -0.0 | -0.10685 0.18509 -0.07824


  ret = ret.dtype.type(ret / rcount)


['Mean', 0, 'From', False]  0.0 | 0.02716 -0.0682 0.04104
['Mean', 0, 'To', False]  0.0 | 0.09303 -0.16969 0.07667
['Mean', 0, 'MeanColumns', False]  0.0 | 0.05255 -0.11391 0.06136
['Mean', 0, 'MaxColumns', False]  0.0 | -0.03572 0.03656 -0.00083
['Mean', 1, 'From', False]  0.0 | -0.10427 0.10454 -0.00027
['Mean', 1, 'To', False]  -0.0 | 0.13462 -0.16223 0.02761
['Mean', 1, 'MeanColumns', False]  0.0 | -0.05365 0.05528 -0.00163
['Mean', 1, 'MaxColumns', False]  -0.0 | -0.1411 0.23119 -0.09009
['Mean', 2, 'From', False]  0.0 | 0.0489 -0.12307 0.07417
['Mean', 2, 'To', False]  0.0 | 0.18847 -0.25147 0.063
['Mean', 2, 'MeanColumns', False]  0.0 | 0.04131 -0.08113 0.03983
['Mean', 2, 'MaxColumns', False]  -0.0 | -0.09091 0.16612 -0.0752
['Mean', 3, 'From', False]  0.0 | 0.036 -0.12353 0.08754
['Mean', 3, 'To', False]  0.0 | 0.11539 -0.17417 0.05878
['Mean', 3, 'MeanColumns', False]  0.0 | 0.06806 -0.12795 0.05989
['Mean', 3, 'MaxColumns', False]  -0.0 | -0.10297 0.20118 -0.09822
['Mean', 4

In [28]:
print_results(file_name+' (P)', conf, metricsP, label_names)

FTP
['Mean', 'Mean', 'From', False]  -0.16203 | -0.12313 -0.24457 -0.1184
['Mean', 'Mean', 'To', False]  -0.2552 | 0.14658 -0.70106 -0.21112
['Mean', 'Mean', 'MeanColumns', False]  -0.12008 | -0.08443 -0.21037 -0.06545
['Mean', 'Mean', 'MaxColumns', False]  0.3133 | 0.12533 0.60371 0.21085
['Mean', 0, 'From', False]  -0.15296 | -0.12754 -0.18476 -0.14658
['Mean', 0, 'To', False]  -0.29575 | -0.12265 -0.56453 -0.20007
['Mean', 0, 'MeanColumns', False]  -0.20156 | -0.10669 -0.36284 -0.13515
['Mean', 0, 'MaxColumns', False]  0.0161 | -0.07508 0.15402 -0.03064
['Mean', 1, 'From', False]  0.13371 | -0.09069 0.38885 0.10297
['Mean', 1, 'To', False]  -0.11776 | 0.21686 -0.57762 0.00748
['Mean', 1, 'MeanColumns', False]  0.08316 | -0.03177 0.2155 0.06575
['Mean', 1, 'MaxColumns', False]  0.38585 | 0.12473 0.77125 0.26156
['Mean', 2, 'From', False]  -0.21101 | -0.12173 -0.3919 -0.11939
['Mean', 2, 'To', False]  -0.23395 | 0.12049 -0.77688 -0.04545
['Mean', 2, 'MeanColumns', False]  -0.10217 | -

We calculate the best attention setup using Optimus variations (we do not use the Optimus implementation script at this step)

In [29]:
print_results_ap(metrics, label_names, conf)

Baseline: 5.223943069787573e-09  and NZW: 1.0 and AUPRC: 0.4033626966683332
Max Across: 6.1304735315959036e-09  and NZW: 1.0 and AUPRC: 0.34079738846005264


  out=out, **kwargs)


Per Label Per Instance: 0.20570542822141635  and NZW:  1.0 and AUPRC: 0.40901653960225454
Per Instance: 9.995030304185726e-08  and NZW:  1.0 and AUPRC: 0.342002154061314


In [30]:
print_results_ap(metricsP, label_names, conf)

Baseline: -0.16203262902746185  and NZW: 1.0 and AUPRC: 0.4033626966683332
Max Across: 0.4132143707592697  and NZW: 1.0 and AUPRC: 0.4790332744680912
Per Label Per Instance: 0.1258216243330492  and NZW:  1.0 and AUPRC: 0.41837965774019664
Per Instance: 0.6082511899874095  and NZW:  1.0 and AUPRC: 0.4123446416980466
