In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from datasets import load_metric, Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from einops import rearrange
from helper import VizHelper
import matplotlib.pyplot as plt

from custom_bert import BertForSequenceClassification

comet_ml is installed but `COMET_API_KEY` is not set.


In [4]:
max_seq_length = 128

In [11]:
SOC_DIR = "./contextualizing-hate-speech-models-with-explanations/"

In [13]:
import sys
if SOC_DIR not in sys.path:
    sys.path.append(SOC_DIR)

# AMI18

In [5]:
model_name = "./bert-base-cased_ami18/"
tokenizer_name = "bert-base-cased"

In [6]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def preprocess_text(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_seq_length)

In [7]:
train = pd.read_csv("data/miso_train.tsv", sep="\t")
validation = pd.read_csv("data/miso_dev.tsv", sep="\t")
test = pd.read_csv("data/miso_test.tsv", sep="\t")

raw_datasets = DatasetDict(
    train=Dataset.from_pandas(train),
    validation=Dataset.from_pandas(validation),
    test=Dataset.from_pandas(test)
)
raw_datasets = raw_datasets.rename_column("misogynous", "label")
proc_datasets = raw_datasets.map(preprocess_text, batched=True, remove_columns=raw_datasets["train"].features)
proc_datasets.set_format("pt")

  0%|          | 0/4 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [8]:
model = AutoModelForSequenceClassification.from_pretrained(model_name).eval()
effective_model = BertForSequenceClassification.from_pretrained(model_name).eval()

*** Calling custom BertForSequenceClassification ***


In [9]:
exp = VizHelper(model, tokenizer, raw_datasets["test"], proc_datasets["test"])

Attention

In [None]:
exp.show_attention(idx=21, head=3, layer=10)

In [None]:
exp.show_effective_attention(idx=21, head=3, layer=10)

In [None]:
idx, head, layer = 21, 7, 2
fig = exp.compare_attentions(idx, head, layer)
fig.savefig(f"plots/comp_attentions_{idx}_{head}_{layer}.png", bbox_inches='tight')

In [None]:
idx, head, layer = 21, 1, 9
fig = exp.compare_attentions(idx, head, layer)
fig.savefig(f"plots/comp_attentions_{idx}_{head}_{layer}.png", bbox_inches='tight')

Classification

In [None]:
exp.classify(21)

In [None]:
results = exp.compute_table(21)

In [None]:
results.to_excel("table.xlsx")

In [None]:
idx, head, layer = 21, 1, 2
fig = exp.compare_attentions(idx, head, layer, effective_model=effective_model, fontsize=18)
fig.savefig(f"plots/comp_attentions_{idx}_{head}_{layer}.png", bbox_inches='tight')

In [None]:
grad, embeds = exp.get_gradient(21)

Gradients

In [None]:
exp.get_gradient(21)

Final table

In [52]:
%%capture
table = exp.compute_table(21)

In [53]:
table

tokens,[CLS],i,miss,my,stupid,pretty,s,##tan,##k,dumb,whore,s.1,##kan,##k.1,trick,bitch,ass,friends,[SEP]
G,0.047945,0.013699,-0.061644,-0.10274,-0.082192,0.013699,0.006849,-0.027397,-0.030822,-0.020548,0.171233,0.023973,-0.054795,-0.054795,-0.089041,0.0,0.041096,0.006849,0.150685
GxI,0.004382,-0.055667,-0.033224,-0.034306,-0.013459,0.003522,0.011157,0.0192,0.024633,-0.000851,-0.003492,0.014394,0.083351,0.061475,-0.075811,-0.113489,-0.201125,-0.139274,0.107186
IntegratedGradients,-0.02481,-0.04167,0.079333,-0.055123,-0.076697,0.03804,-0.047656,-0.126776,-0.033091,-0.112115,0.030842,-0.045686,-0.031556,-0.034987,-0.08623,-0.090326,-0.030227,0.002293,-0.012541
KernelSHAP,0.099068,0.082946,-0.047284,-0.036151,0.059429,0.043475,0.019928,-0.00811,-0.058207,0.053965,0.110713,-0.097884,0.087608,0.093683,0.017766,-0.035715,0.034638,0.007814,0.344866
SOC,0.05681,-0.00285,0.001385,0.00559,0.046292,0.083087,0.015206,-0.022462,-0.002206,0.016573,0.16212,0.026735,0.011044,-0.002935,0.008057,0.320324,0.103327,-0.034178,0.078819


In [54]:
table.sum(axis=1)

G                     -0.047945
GxI                   -0.341400
IntegratedGradients   -0.698984
KernelSHAP             0.772549
SOC                    0.870739
dtype: float64

In [55]:
table.to_excel(f"table_21.xlsx")

SHAP

In [None]:
shap = exp.get_kernel_shap(21)

In [67]:
ig = exp.get_integrated_gradients(21)

In [68]:
ig

tensor([-0.0248, -0.0417,  0.0793, -0.0551, -0.0767,  0.0380, -0.0477, -0.1268,
        -0.0331, -0.1121,  0.0308, -0.0457, -0.0316, -0.0350, -0.0862, -0.0903,
        -0.0302,  0.0023, -0.0125], dtype=torch.float64,
       grad_fn=<DivBackward0>)

In [44]:
exp.get_soc(21)

> [0;32m/home/dauin_user/gattanasio/benchmarking-xai-private/contextualizing-hate-speech-models-with-explanations/hiex/soc_api.py[0m(91)[0;36mword_level_explanation_bert[0;34m()[0m
[0;32m     89 [0;31m[0;34m[0m[0m
[0m[0;32m     90 [0;31m        [0mbreakpoint[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 91 [0;31m        [0mself[0m[0;34m.[0m[0moutput_file[0m[0;34m.[0m[0mwrite[0m[0;34m([0m[0ms[0m [0;34m+[0m [0;34m'\n'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     92 [0;31m[0;34m[0m[0m
[0m[0;32m     93 [0;31m    [0;32mdef[0m [0mhierarchical_explanation_bert[0m[0;34m([0m[0mself[0m[0;34m,[0m [0minput_ids[0m[0;34m,[0m [0minput_mask[0m[0;34m,[0m [0msegment_ids[0m[0;34m,[0m [0mlabel[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  inp


array([  101,   178,  5529,  1139,  4736,  2785,   188,  5108,  1377,
       14908, 20239,   188,  8752,  1377,  7959,  7979,  3919,  2053,
         102,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0]

ipdb>  s


--Return--
None
> [0;32m/home/dauin_user/gattanasio/benchmarking-xai-private/contextualizing-hate-speech-models-with-explanations/hiex/soc_api.py[0m(91)[0;36mword_level_explanation_bert[0;34m()[0m
[0;32m     89 [0;31m[0;34m[0m[0m
[0m[0;32m     90 [0;31m        [0mbreakpoint[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 91 [0;31m        [0mself[0m[0;34m.[0m[0moutput_file[0m[0;34m.[0m[0mwrite[0m[0;34m([0m[0ms[0m [0;34m+[0m [0;34m'\n'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     92 [0;31m[0;34m[0m[0m
[0m[0;32m     93 [0;31m    [0;32mdef[0m [0mhierarchical_explanation_bert[0m[0;34m([0m[0mself[0m[0;34m,[0m [0minput_ids[0m[0;34m,[0m [0minput_mask[0m[0;34m,[0m [0msegment_ids[0m[0;34m,[0m [0mlabel[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  print(s


*** SyntaxError: unexpected EOF while parsing


ipdb>  print(s)


i 0.000642	 miss -0.043734	 my 0.076152	 stupid 0.214686	 pretty 0.332197	 s 0.267287	 ##tan -0.088214	 ##k 0.006056	 dumb 0.152184	 whore 0.883939	 s 0.206264	 ##kan 0.023679	 ##k 0.035208	 trick 0.035117	 bitch 1.342909	 ass 0.341583	 friends -0.251832	


ipdb>  scores


[0.0006419807905331254, -0.043733783066272736, 0.07615158706903458, 0.21468639373779297, 0.3321974277496338, 0.26728692650794983, -0.08821425586938858, 0.006056016776710749, 0.1521843671798706, 0.8839393854141235, 0.206263929605484, 0.023678705096244812, 0.035207539796829224, 0.03511699661612511, 1.3429094552993774, 0.34158265590667725, -0.2518318295478821]


ipdb>  inp_length


19


ipdb>  self.tokenizer.batch_decode(inp[:inp_length])


['[CLS]', 'i', 'miss', 'my', 'stupid', 'pretty', 's', '##tan', '##k', 'dumb', 'whore', 's', '##kan', '##k', 'trick', 'bitch', 'ass', 'friends', '[SEP]']


ipdb>  self.tokenizer.decode(inp[:inp_length])


'[CLS] i miss my stupid pretty stank dumb whore skank trick bitch ass friends [SEP]'


ipdb>  len(scores)


17


ipdb>  exit


BdbQuit: 

In [21]:
exp.model.device

device(type='cpu')

In [17]:
item = exp._get_item(21)

In [18]:
item.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])