## ESNLI alBERT multiclass
In this notebook we examine the performance of interpretability techniques in the ESNLI dataset using alBERT on token level 

In [1]:
import sys
sys.path.append('../')

import numpy as np
import pandas as pd
import re
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, recall_score, average_precision_score
from dataset import Dataset
from myModel import MyModel, MyDataset
from myExplainers import MyExplainer
from myEvaluation import MyEvaluation
from sklearn.preprocessing import maxabs_scale
import pickle
from tqdm import tqdm
import datetime
import csv
import warnings
import torch
import tensorflow as tf
from scipy.special import softmax
from helper import print_results, print_results_ap

Loading model and dataset, defining transformer model, and if rationales are available in the dataset

In [2]:
data_path = '../datasets/esnli_multiclass.pickle'
model_path = 'Trained Models/'
save_path = 'Results/esnli_multiclass/'

In [3]:
model_name = 'albert'
dataset_name='esnli_albert_multiclass'
existing_rationales = True

Load MyModel, and the subsequent tokenizer

In [5]:
task = 'single_label'
sentence_level = False
labels = 3

model = MyModel(model_path, dataset_name, model_name, task, labels, cased=False, attention=True)
model_no_attention = MyModel(model_path, dataset_name, model_name, task, labels, cased=False, attention=False)
max_sequence_len = model.tokenizer.max_len_single_sentence
tokenizer = model.tokenizer

import torch
print(torch.cuda.is_available())
model.trainer.model.to('cuda')

True


AlbertForSequenceClassification(
  (albert): AlbertModel(
    (embeddings): AlbertEmbeddings(
      (word_embeddings): Embedding(30000, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0, inplace=False)
    )
    (encoder): AlbertTransformer(
      (embedding_hidden_mapping_in): Linear(in_features=128, out_features=768, bias=True)
      (albert_layer_groups): ModuleList(
        (0): AlbertLayerGroup(
          (albert_layers): ModuleList(
            (0): AlbertLayer(
              (full_layer_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (attention): AlbertAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768,

In [6]:
esnli = Dataset(path = data_path)
dataset, label_names = esnli.load_esnli_multiclass()

In [7]:
train_texts, test_texts, test_rationales, validation_texts, train_labels, test_labels, validation_labels = dataset

In [8]:
for i in range(len(test_rationales)):
    if (test_rationales[i][0] == []) & ((test_rationales[i][1] == [])):
        test_rationales[i][0] = [0] * len(test_rationales[i][2])
        test_rationales[i][1] = [0] * len(test_rationales[i][2])
        test_rationales[i][2] = list(test_rationales[i][2])
    
    elif (test_rationales[i][0] == []) & (test_rationales[i][2] == []):
        test_rationales[i][0] = [0] * len(test_rationales[i][1])
        test_rationales[i][1] = list(test_rationales[i][1])
        test_rationales[i][2] = [0] * len(test_rationales[i][1])
    else:
        test_rationales[i][0] = list(test_rationales[i][0])
        test_rationales[i][1] = [0] * len(test_rationales[i][0])
        test_rationales[i][2] = [0] * len(test_rationales[i][0])

  
  import sys


In [9]:
test_test_rationales = test_rationales

In [10]:
new_rationale = []
len_test = len(test_labels) # 2000
num_labels = len(np.unique(test_labels)) #3

for i in range(len_test):
    rationale = []
    test_t = test_texts[i].split(' ')
    for j in range(num_labels):
        label_rational = []
        for k in range(len(test_t)):
            for r in tokenizer.tokenize(test_t[k]):
                #if r == '.':
                #    print(r)
                #print(r)
                rationall = 1 if test_test_rationales[i][j][k] > 0 else 0
                label_rational.append(rationall)
        rationale.append(label_rational)
    new_rationale.append(rationale)

Then, we measure the performance of the model using f1 score (both macro)

In [11]:
predictions = []
for test_text in test_texts:
    outputs = model.my_predict(test_text)
    predictions.append(outputs[0])

pred_labels = []
for prediction in predictions:
    pred_labels.append(np.argmax(softmax(prediction)))

accuracy_score(test_labels, pred_labels), f1_score(test_labels, pred_labels, average='macro'), f1_score(test_labels, pred_labels, average='micro')

2000it [02:27, 13.77it/s]            

(0.7915, 0.7922547738739563, 0.7915)

2000it [02:40, 13.77it/s]

In [12]:
my_explainers = MyExplainer(label_names, model_no_attention) #model 2

my_evaluators = MyEvaluation(label_names, model_no_attention.my_predict, sentence_level = False, task = 'multi-class', evaluation_level_all = True, tokenizer=tokenizer) #model 2
my_evaluatorsP = MyEvaluation(label_names, model_no_attention.my_predict, sentence_level = False, task = 'multi-class', evaluation_level_all = False, tokenizer=tokenizer) #model 2
evaluation =  {'F':my_evaluators.faithfulness, 'FTP': my_evaluators.faithful_truthfulness_penalty, 
          'NZW': my_evaluators.nzw, 'AUPRC': my_evaluators.auprc}
evaluationP = {'F':my_evaluatorsP.faithfulness, 'FTP': my_evaluatorsP.faithful_truthfulness_penalty, 
          'NZW': my_evaluatorsP.nzw, 'AUPRC': my_evaluators.auprc}

In [21]:
import time
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    
    now = datetime.datetime.now()
    file_name = save_path + 'ESNLI_ALBERT_IG_'+str(now.day) + '_' + str(now.month) + '_' + str(now.year)
    metrics = {'F':[], 'FTP':[], 'NZW':[], 'AUPRC' : []}
    metricsP = {'F':[], 'FTP':[], 'NZW':[], 'AUPRC' : []}
    time_r = []
    my_explainers.neighbours = 2000
    techniques = [my_explainers.ig] #my_explainers.lime 
    for ind in tqdm(range(len(test_texts))):
        torch.cuda.empty_cache() 
        test_rational = new_rationale[ind]
        instance = test_texts[ind]
        my_evaluators.clear_states()
        my_evaluatorsP.clear_states()
        prediction, _, _ = model_no_attention.my_predict(instance)
        enc = model.tokenizer([instance,instance], truncation=True, padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
    
        interpretations = []
        kk = 0
        for technique in techniques:
            ts = time.time()
            temp = technique(instance, prediction, tokens, mask, _, _)
            interpretations.append([np.array(i)/np.max(abs(np.array(i))) for i in temp])
            time_r.append(time.time()-ts)
            kk = kk + 1
        for metric in metrics.keys():
            evaluated = []
            for interpretation in interpretations:
                evaluated.append(evaluation[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
            metrics[metric].append(evaluated)
        my_evaluatorsP.saved_state = my_evaluators.saved_state.copy()
        my_evaluators.clear_states()
        for metric in metrics.keys():
            evaluatedP = []
            for interpretation in interpretations:
                evaluatedP.append(evaluationP[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
            metricsP[metric].append(evaluatedP)
with open(file_name+'(A).pickle', 'wb') as handle:
    pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+'(P).pickle', 'wb') as handle:
    pickle.dump(metricsP, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+'_TIME.pickle', 'wb') as handle:
    pickle.dump(time_r, handle, protocol=pickle.HIGHEST_PROTOCOL)
time_r = np.array(time_r)
# time_r.mean()
# time_r.mean(axis=1)

100%|██████████| 2000/2000 [53:14<00:00,  1.60s/it] 


We present the results for IG

In [22]:
# Can be fixed
print(time_r)

[1.08529377 0.92850542 0.93050051 ... 0.7705605  0.7933774  0.76258564]


In [23]:
print_results(file_name+'(A)', [' IG '], metrics, label_names) #[' LIME', ' IG  ']

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


F
 IG   0.035590000450611115 | 0.05574 0.11505 -0.06403
FTP
 IG   0.14647 | 0.16321 0.15109 0.12512
NZW
 IG   1.0 | 1.0 1.0 1.0
AUPRC
 IG   0.33724 | 0.40574 0.27728 0.32871


In [24]:
print_results(file_name+'(P)', [' IG '], metricsP, label_names) #[' LIME', ' IG  ']

F
 IG   0.21689 | 0.2714 0.33006 0.04923
FTP
 IG   0.25859 | 0.33529 0.3828 0.05769
NZW
 IG   1.0 | 1.0 1.0 1.0
AUPRC
 IG   0.33724 | 0.40574 0.27728 0.32871


Then, we perform the experiments for the different attention setups!

In [10]:
conf = []
for ci in ['Mean', 'Multi'] + list(range(12)):
    for ce in ['Mean'] + list(range(12)):
        for cp in ['From', 'To', 'MeanColumns', 'MaxColumns']: # Matrix: From, To, MeanColumns, MeanRows, MaxColumns, MaxRows
            for cl in [False]: # Selection: True: select layers per head, False: do not
                conf.append([ci, ce, cp, cl])
len(conf)

728

In [14]:
import time 
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    
    now = datetime.datetime.now()
    
    file_name = save_path + 'ESNLI_ALBERT_ATTENTION_'+str(now.day) + '_' + str(now.month) + '_' + str(now.year)
    metrics = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC' : []}
    metricsP = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC' : []}
    time_r = []
    time_b = []
    time_b2 = []
    for con in conf:
        time_r.append([])
    for ind in tqdm(range(len(test_texts))):
        torch.cuda.empty_cache() 
        test_rational = new_rationale[ind]
        instance = test_texts[ind]
        my_evaluators.clear_states()
        my_evaluatorsP.clear_states()
        my_explainers.save_states = {}
        prediction, attention, _ = model.my_predict(instance)
        enc = model.tokenizer([instance,instance], truncation=True, padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
        
        interpretations = []
        kk = 0
        for con in conf:
            ts = time.time()
            my_explainers.config = con
            temp = my_explainers.my_attention(instance, prediction, tokens, mask, attention, _)
            interpretations.append([maxabs_scale(i) for i in temp])
            time_r[kk].append(time.time()-ts)
            kk = kk + 1
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluation[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b.append(k)
            metrics[metric].append(evaluated)
        my_evaluatorsP.saved_state = my_evaluators.saved_state.copy()
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluationP[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b2.append(k)
            metricsP[metric].append(evaluated)
with open(file_name+' (A).pickle', 'wb') as handle:
    pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+' (P).pickle', 'wb') as handle:
    pickle.dump(metricsP, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+'_TIME.pickle', 'wb') as handle:
    pickle.dump(time_r, handle, protocol=pickle.HIGHEST_PROTOCOL)
time_r = np.array(time_r)
time_r.mean(axis=1).min(),time_r.mean(axis=1).max(), time_r.mean(axis=1).mean(), time_r.sum(axis=1).mean(), np.mean(time_b), np.mean(time_b2)

100%|██████████| 2000/2000 [3:03:27<00:00,  5.50s/it]  


We present the results of the different attention setups

In [None]:
print_results(file_name+' (A)', conf, metrics, label_names)

FTP
['Mean', 'Mean', 'From', False]  -0.0 | 0.00752 0.38622 -0.39374
['Mean', 'Mean', 'To', False]  -0.0 | -0.0859 0.45486 -0.36897
['Mean', 'Mean', 'MeanColumns', False]  -0.0 | -0.00342 0.27532 -0.2719


  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


['Mean', 'Mean', 'MaxColumns', False]  -0.0 | -0.0741 0.26272 -0.18862
['Mean', 0, 'From', False]  -0.0 | 0.00487 0.38923 -0.3941
['Mean', 0, 'To', False]  -0.0 | -0.07279 0.44507 -0.37228
['Mean', 0, 'MeanColumns', False]  -0.0 | 0.00025 0.2827 -0.28295
['Mean', 0, 'MaxColumns', False]  -0.0 | -0.01209 0.29677 -0.28467
['Mean', 1, 'From', False]  -0.0 | -0.01339 0.34741 -0.33402
['Mean', 1, 'To', False]  -0.0 | -0.08235 0.47149 -0.38915
['Mean', 1, 'MeanColumns', False]  -0.0 | -0.02953 0.29809 -0.26856
['Mean', 1, 'MaxColumns', False]  -0.0 | -0.03337 0.35032 -0.31695
['Mean', 2, 'From', False]  -0.0 | -0.01813 0.452 -0.43387
['Mean', 2, 'To', False]  -0.0 | -0.12118 0.30957 -0.18839
['Mean', 2, 'MeanColumns', False]  -0.0 | 0.00474 0.3034 -0.30814
['Mean', 2, 'MaxColumns', False]  -0.0 | -0.0356 0.32019 -0.28459
['Mean', 3, 'From', False]  -0.0 | -0.06908 0.4045 -0.33542
['Mean', 3, 'To', False]  -0.0 | -0.05691 0.30327 -0.24636
['Mean', 3, 'MeanColumns', False]  -0.0 | -0.08019 0.2

In [None]:
print_results(file_name+' (P)', conf, metricsP, label_names)

FTP
['Mean', 'Mean', 'From', False]  0.61184 | 0.48452 1.23549 0.11552
['Mean', 'Mean', 'To', False]  0.58342 | 0.25213 1.37984 0.11828
['Mean', 'Mean', 'MeanColumns', False]  0.48123 | 0.40943 0.89639 0.13786
['Mean', 'Mean', 'MaxColumns', False]  0.35433 | 0.14947 0.7954 0.11811
['Mean', 0, 'From', False]  0.62487 | 0.50023 1.25195 0.12242
['Mean', 0, 'To', False]  0.61307 | 0.3088 1.38414 0.14629
['Mean', 0, 'MeanColumns', False]  0.51373 | 0.47208 0.92844 0.14069
['Mean', 0, 'MaxColumns', False]  0.50891 | 0.4306 0.95975 0.13637
['Mean', 1, 'From', False]  0.5209 | 0.3888 1.07462 0.09927
['Mean', 1, 'To', False]  0.59609 | 0.2641 1.4189 0.10527
['Mean', 1, 'MeanColumns', False]  0.43376 | 0.28295 0.91764 0.1007
['Mean', 1, 'MaxColumns', False]  0.50954 | 0.31529 1.09809 0.11524
['Mean', 2, 'From', False]  0.62759 | 0.41837 1.38343 0.08098
['Mean', 2, 'To', False]  0.38655 | 0.12897 0.89142 0.13925
['Mean', 2, 'MeanColumns', False]  0.53346 | 0.47002 0.99128 0.13908
['Mean', 2, 'Max

We calculate the best attention setup using Optimus variations (we do not use the Optimus implementation at this step)

In [15]:
print_results_ap(metrics, label_names, conf)

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


Baseline: -7.374841293206202e-10  and NZW: 1.0 and AUPRC: 0.6020347240644445
Max Across: -2.528158978461666e-10  and NZW: 0.9998916666666666 and AUPRC: 0.5092589421414682
Per Label Per Instance: 0.2646716823268673  and NZW:  0.9988386072449177 and AUPRC: 0.31096573110917985
Per Instance: 4.874575742127204e-08  and NZW:  0.998990825035562 and AUPRC: 0.2596425935642796


In [16]:
print_results_ap(metricsP, label_names, conf)

Baseline: 0.6118434526199982  and NZW: 1.0 and AUPRC: 0.6020347240644445
Max Across: 0.6635175797975238  and NZW: 1.0 and AUPRC: 0.6037004655488016
Per Label Per Instance: 0.8633359336237785  and NZW:  0.997529437930103 and AUPRC: 0.563042594072341


  out=out, **kwargs)


Per Instance: 0.8633359336237785  and NZW:  0.997529437930103 and AUPRC: 0.4381637140602424


We repeat the process with Attention Scores with negative values (A*), thus by skipping the Softmax function. In the attention setups, we exclude the multiplication option in heads and layers, as a few combinations reach +/-inf

In [17]:
conf = []
for ci in ['Mean'] + list(range(12)):
    for ce in ['Mean'] + list(range(12)):
        for cp in ['From', 'To', 'MeanColumns', 'MaxColumns']: # Matrix: From, To, MeanColumns, MeanRows, MaxColumns, MaxRows
            for cl in [False]: # Selection: True: select layers per head, False: do not
                conf.append([ci, ce, cp, cl])
len(conf)

676

In [18]:
import time 
import math
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    
    now = datetime.datetime.now()
    
    file_name = save_path + 'ESNLI_ALBERT_A_ATTENTION_NO_SOFTMAX_'+str(now.day) + '_' + str(now.month) + '_' + str(now.year)
    metrics = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC' : []}
    metricsP = {'FTP':[], 'F':[], 'NZW':[], 'AUPRC' : []}
    time_r = []
    time_b = []
    time_b2 = []
    for con in conf:
        time_r.append([])
    for ind in tqdm(range(len(test_texts))):
        torch.cuda.empty_cache() 
        test_rational = new_rationale[ind]
        instance = test_texts[ind]
        my_evaluators.clear_states()
        my_evaluatorsP.clear_states()
        my_explainers.save_states = {}
        prediction, _, hidden_states = model.my_predict(instance)
        enc = model.tokenizer([instance,instance], truncation=True, padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
        
        attention = []
        for la in range(12):
            our_new_layer = []
            bob =  model.trainer.model.base_model.encoder.albert_layer_groups[0].albert_layers[0].attention
            has = hidden_states[la]
            aaa = bob.key(torch.tensor(has).to('cuda'))
            bbb = bob.query(torch.tensor(has).to('cuda'))
            for he in range(12):
                attention_scores = torch.matmul(bbb[:,he*64:(he+1)*64], aaa[:,he*64:(he+1)*64].transpose(-1, -2))
                attention_scores = attention_scores / math.sqrt(64)
                our_new_layer.append(attention_scores.cpu().detach().numpy())
            attention.append(our_new_layer)
        attention = np.array(attention)
        
        interpretations = []
        kk = 0
        for con in conf:
            ts = time.time()
            my_explainers.config = con
            temp = my_explainers.my_attention(instance, prediction, tokens, mask, attention, _)
            interpretations.append([maxabs_scale(i) for i in temp])
            time_r.append(time.time()-ts)
            kk = kk + 1
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluation[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b.append(k)
            metrics[metric].append(evaluated)
        my_evaluatorsP.saved_state = my_evaluators.saved_state.copy()
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                tt = time.time()
                evaluated.append(evaluationP[metric](interpretation, _, instance, prediction, tokens, _, _, test_rational))
                k = k + (time.time()-tt)
            if metric == 'FTP':
                time_b2.append(k)
            metricsP[metric].append(evaluated)        
with open(file_name+' (A).pickle', 'wb') as handle:
    pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+' (P).pickle', 'wb') as handle:
    pickle.dump(metricsP, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(file_name+'_TIME.pickle', 'wb') as handle:
    pickle.dump(time_r, handle, protocol=pickle.HIGHEST_PROTOCOL)
time_r = np.array(time_r)
time_r.mean(axis=1).min(),time_r.mean(axis=1).max(), time_r.mean(axis=1).mean(), time_r.sum(axis=1).mean(), np.mean(time_b), np.mean(time_b2)

100%|██████████| 2000/2000 [3:08:48<00:00,  5.66s/it]  


We present the results for the different attention setups

In [None]:
print_results(file_name+' (A)', conf, metrics, label_names)

FTP
['Mean', 'Mean', 'From', False]  -0.0 | -0.05917 0.27948 -0.22032
['Mean', 'Mean', 'To', False]  -0.0 | -0.09121 0.44959 -0.35838
['Mean', 'Mean', 'MeanColumns', False]  -0.0 | -0.06727 0.24761 -0.18034
['Mean', 'Mean', 'MaxColumns', False]  -0.0 | -0.06686 0.25434 -0.18748


  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


['Mean', 0, 'From', False]  0.0 | 0.0638 -0.27731 0.21351
['Mean', 0, 'To', False]  -0.0 | -0.07447 0.45224 -0.37777
['Mean', 0, 'MeanColumns', False]  -0.0 | -0.0705 0.27981 -0.2093
['Mean', 0, 'MaxColumns', False]  -0.0 | -0.06695 0.31228 -0.24533
['Mean', 1, 'From', False]  -0.0 | -0.06799 0.31561 -0.24762
['Mean', 1, 'To', False]  -0.0 | -0.0762 0.47127 -0.39507
['Mean', 1, 'MeanColumns', False]  -0.0 | -0.07058 0.28918 -0.2186
['Mean', 1, 'MaxColumns', False]  -0.0 | -0.06369 0.23973 -0.17604
['Mean', 2, 'From', False]  -0.0 | -0.07921 0.43442 -0.35521
['Mean', 2, 'To', False]  -0.0 | -0.07856 0.36256 -0.284
['Mean', 2, 'MeanColumns', False]  -0.0 | -0.07681 0.23676 -0.15996
['Mean', 2, 'MaxColumns', False]  -0.0 | -0.09018 0.26926 -0.17908
['Mean', 3, 'From', False]  0.0 | 0.02072 -0.08275 0.06203
['Mean', 3, 'To', False]  -0.0 | -0.06593 0.24501 -0.17908
['Mean', 3, 'MeanColumns', False]  -0.0 | -0.05362 0.16102 -0.1074
['Mean', 3, 'MaxColumns', False]  -0.0 | -0.06013 0.22153 -

In [None]:
print_results(file_name+' (P)', conf, metricsP, label_names)

FTP
['Mean', 'Mean', 'From', False]  0.32483 | 0.0968 0.8258 0.05188
['Mean', 'Mean', 'To', False]  0.55051 | 0.21268 1.33886 0.1
['Mean', 'Mean', 'MeanColumns', False]  0.30008 | 0.08951 0.72948 0.08125
['Mean', 'Mean', 'MaxColumns', False]  0.32143 | 0.1127 0.75724 0.09434
['Mean', 0, 'From', False]  -0.38242 | -0.19614 -0.84183 -0.10929
['Mean', 0, 'To', False]  0.58499 | 0.27866 1.36593 0.11037
['Mean', 0, 'MeanColumns', False]  0.36409 | 0.1511 0.84048 0.10068
['Mean', 0, 'MaxColumns', False]  0.39923 | 0.17056 0.93089 0.09624
['Mean', 1, 'From', False]  0.3728 | 0.12118 0.92427 0.07294
['Mean', 1, 'To', False]  0.57879 | 0.24426 1.40602 0.0861
['Mean', 1, 'MeanColumns', False]  0.34082 | 0.0979 0.84795 0.07662
['Mean', 1, 'MaxColumns', False]  0.28942 | 0.08622 0.70497 0.07706
['Mean', 2, 'From', False]  0.51546 | 0.18724 1.28281 0.07633
['Mean', 2, 'To', False]  0.44596 | 0.16121 1.08019 0.09646
['Mean', 2, 'MeanColumns', False]  0.29761 | 0.09937 0.6941 0.09936
['Mean', 2, 'Max

We calculate the best attention setup using Optimus variations (we do not use the Optimus implementation script at this step)

In [19]:
print_results_ap(metrics, label_names, conf)

  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)


Baseline: 2.269624895430885e-09  and NZW: 1.0 and AUPRC: 0.530555284711656
Max Across: 3.784789111745586e-09  and NZW: 1.0 and AUPRC: 0.2379534509245289
Per Label Per Instance: 0.6161406672643608  and NZW:  1.0 and AUPRC: 0.3132952088979517
Per Instance: 1.1065884920544167e-07  and NZW:  1.0 and AUPRC: 0.282757079369867


In [20]:
print_results_ap(metricsP, label_names, conf)

Baseline: -0.22705476316599915  and NZW: 1.0 and AUPRC: 0.530555284711656
Max Across: 0.5849874408966941  and NZW: 1.0 and AUPRC: 0.4234155651651186
Per Label Per Instance: 0.8908513292088815  and NZW:  1.0 and AUPRC: 0.5077940669674452


  out=out, **kwargs)


Per Instance: 0.8908513292088815  and NZW:  1.0 and AUPRC: 0.43849421662844446
