## Test Interpretability techniques in HateXplain with BERT

In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import average_precision_score, f1_score, accuracy_score
from scipy.special import softmax
from dataset import Dataset
from myModel import MyModel
from myExplainers import MyExplainer
from myEvaluation import MyEvaluation
from sklearn.preprocessing import maxabs_scale
import pickle
from tqdm import tqdm
import time
import torch
import datetime
import csv
import warnings

In [None]:
data_path = '/models/'
model_path = '/models/'
save_path = '/results/HX/'

In [None]:
model_name = 'bert'
existing_rationales = True

In [None]:
task = 'single_label'
labels = 2
model = MyModel(model_path, 'bert_hx', model_name, task, labels, False)
max_sequence_len = model.tokenizer.max_len_single_sentence
tokenizer = model.tokenizer

In [None]:
hx = Dataset(path=data_path)
x, y, label_names, rationales = hx.load_hatexplain(tokenizer)

In [None]:
indices = np.arange(len(y))
train_texts, test_texts, train_labels, test_labels, _, test_indexes = train_test_split(
    x, y, indices, stratify=y, test_size=.2, random_state=42)
if existing_rationales:
    test_rationales = [rationales[x] for x in test_indexes]

size = (0.1 * len(y)) / len(train_labels)
train_texts, validation_texts, train_labels, validation_labels = train_test_split(
    list(train_texts),
    train_labels,
    stratify=train_labels,
    test_size=size,
    random_state=42)


In [None]:
test_test_rationales = []
for test_rational in test_rationales:
    test_test_rationales.append([0, test_rational])


In [None]:
predictions = []
for test_text in test_texts:
    outputs = model.my_predict(test_text)
    predictions.append(outputs[0])


In [None]:
pred_labels = []
for prediction in predictions:
    pred_labels.append(np.argmax(softmax(prediction)))


def average_precision_wrapper(y, y_pred, view):
    return average_precision_score(y, y_pred.toarray(), average=view)


print(average_precision_score(test_labels, pred_labels, average='macro'),
      accuracy_score(test_labels, pred_labels),
      f1_score(test_labels, pred_labels, average='macro'),
      f1_score(test_labels, pred_labels, average='binary'))


In [None]:
my_explainers = MyExplainer(label_names, model)
my_evaluators = MyEvaluation(label_names, model.my_predict, False)

In [None]:
def print_results(name, techniques, metrics):
    with open(name + '.csv', 'w', encoding='UTF8') as f:
        writer = csv.writer(f)
        for metric in metrics.keys():
            print(metric)
            temp_metric = np.array(metrics[metric])
            for i in range(len(techniques)):
                label_score = []
                for label in range(len(label_names)):
                    tempo = [
                        k for k in temp_metric[:, i, label]
                        if str(k) != str(np.average([]))
                    ]
                    if len(tempo) == 0:
                        tempo.append(0)
                    label_score.append(np.array(tempo))
                temp_mean = []
                for k in label_score:
                    temp_mean.append(k.mean())
                temp_mean = np.array(temp_mean).mean()
                writer.writerow(
                    [techniques[i], metric, temp_mean] +
                    [label_score[o].mean() for o in range(len(label_names))])
                print(
                    techniques[i], ' {} | {}'.format(
                        round(temp_mean, 5), ' '.join([
                            str(round(label_score[o].mean(), 5))
                            for o in range(len(label_names))
                        ])))


## Evaluating LIME and IG

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)

    now = datetime.datetime.now()

    file_name = save_path + 'hx_BERT'+str(now.day) + '_' + str(now.month) + '_' + str(now.year)
    metrics = {'F':[], 'FTP':[], 'NZW':[], 'AUPRC':[]}
    evaluation = {
        'F': my_evaluators.faithfulness,
        'FTP': my_evaluators.faithful_truthfulness_penalty,
        'NZW': my_evaluators.nzw,
        'AUPRC': my_evaluators.auprc
    }

    techniques = [my_explainers.lime, my_explainers.ig]
    for ind in tqdm(range(len(test_texts))):
        torch.cuda.empty_cache()
        test_rational = test_test_rationales[ind]
        instance = test_texts[ind]
        my_evaluators.clear_states()
        prediction, attention, hidden_states = model.my_predict(instance)
        enc = model.tokenizer([instance, instance],
                              truncation=True,
                              padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
        interpretations = []
        for technique in techniques:
            temp = technique(instance, prediction, tokens, mask, attention,
                             hidden_states)
            interpretations.append([maxabs_scale(i) for i in temp])
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                evaluated.append(evaluation[metric](interpretation, _, instance, prediction,
                                                    tokens, hidden_states, _, test_rational))
            k = k + 1
            metrics[metric].append(evaluated)

        with open(file_name + '.pickle', 'wb') as handle:
            pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
print_results(file_name, [' LIME', ' IG  '], metrics)

# Testing the time responne!!!

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)

    times = [[], [], []]
    techniques = [my_explainers.lime, my_explainers.ig]
    for ind in tqdm(range(0, 50)):
        instance = test_texts[ind]
        prediction, attention, hidden_states = model.my_predict(instance)
        enc = model.tokenizer([instance, instance],
                              truncation=True,
                              padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
        times[0].append(len(tokens))
        xind = 1
        for technique in techniques:
            ts = time.time()
            temp = technique(instance, prediction, tokens, mask, attention,
                             hidden_states)
            te = time.time() - ts
            times[xind].append(te)
            xind += 1
times

In [None]:
np.array(times[2]).mean()

## Attention

In [None]:
conf = []
for ci in ['Mean', 'Multi', 'Sum'] + list(range(12)):  # Layers: Mean, Multi, Sum, First, Last
    for ce in ['Mean', 'Sum'] + list(range(12)):  #True every token, False only cls
        for cp in ['From', 'To', 'MeanColumns', 'MeanRows', 'MaxColumns', 'MaxRows']:  # Matrix: From, To, MeanColumns, MeanRows, MaxColumns, MaxRows
            for cl in [False]:  # Selection: True: select layers per head, False: do not
                conf.append([ci, ce, cp, cl])
for ci in ['Mean', 'Multi', 'Sum']:
    for ce in ['']:
        for cp in ['']:
            for cl in [True]:
                conf.append([ci, ce, cp, cl])

1263

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)

    now = datetime.datetime.now()

    file_name = save_path + 'hx_BERT_Attention_'+str(now.day) + '_' + str(now.month) + '_' + str(now.year)
    metrics = {'F':[],  'FTP':[], 'NZW':[], 'AUPRC':[]}
    evaluation = {
        'F': my_evaluators.faithfulness,
        'FTP': my_evaluators.faithful_truthfulness_penalty,
        'NZW': my_evaluators.nzw,
        'AUPRC': my_evaluators.auprc
    }

    for ind in tqdm(range(len(test_texts))):
        torch.cuda.empty_cache()
        test_rational = test_test_rationales[ind]
        instance = test_texts[ind]
        my_evaluators.clear_states()
        prediction, attention, hidden_states = model.my_predict(instance)
        enc = model.tokenizer([instance, instance],
                              truncation=True,
                              padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
        interpretations = []
        for con in conf:
            my_explainers.config = con
            temp = my_explainers.my_attention(instance, prediction, tokens,
                                              mask, attention, hidden_states)
            interpretations.append([maxabs_scale(i) for i in temp])
        for metric in metrics.keys():
            evaluated = []
            k = 0
            for interpretation in interpretations:
                evaluated.append(evaluation[metric](interpretation, _,
                                                    instance, prediction,
                                                    tokens, hidden_states, _,
                                                    test_rational))
            k = k + 1
            metrics[metric].append(evaluated)

        with open(file_name + '.pickle', 'wb') as handle:
            pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
print_results(file_name, conf, metrics)  #New my explain 2

# Calculate time resposne

In [None]:
import time
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    times = [[]]  #Number of confs + 1!

    for bob in range(len(conf)):
        times.append([])
    techniques = [my_explainers.lime, my_explainers.ig]
    for ind in tqdm(range(0, 50)):
        instance = test_texts[ind]
        prediction, attention, hidden_states = model.my_predict(instance)
        enc = model.tokenizer([instance, instance],
                              truncation=True,
                              padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
        times[0].append(len(tokens))
        ts = time.time()
        for con in conf:
            my_explainers.config = con
            temp = my_explainers.my_attention(instance, prediction, tokens,
                                              mask, attention, hidden_states)
        te = time.time() - ts
        times[1].append(te)


In [None]:
np.array(times[1]).mean()

0.07553523540496826