In [36]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [37]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from datasets import load_metric, Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from einops import rearrange
from helper import VizHelper
import matplotlib.pyplot as plt

from custom_bert import BertForSequenceClassification

In [38]:
max_seq_length = 128

# AMI18

In [9]:
model_name = "./bert-base-cased_ami18/"
tokenizer_name = "bert-base-cased"

In [10]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def preprocess_text(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_seq_length)

In [11]:
train = pd.read_csv("data/miso_train.tsv", sep="\t")
validation = pd.read_csv("data/miso_dev.tsv", sep="\t")
test = pd.read_csv("data/miso_test.tsv", sep="\t")

raw_datasets = DatasetDict(
    train=Dataset.from_pandas(train),
    validation=Dataset.from_pandas(validation),
    test=Dataset.from_pandas(test)
)
raw_datasets = raw_datasets.rename_column("misogynous", "label")
proc_datasets = raw_datasets.map(preprocess_text, batched=True, remove_columns=raw_datasets["train"].features)
proc_datasets.set_format("pt")

  0%|          | 0/4 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [12]:
model = AutoModelForSequenceClassification.from_pretrained(model_name).eval()
effective_model = BertForSequenceClassification.from_pretrained(model_name).eval()

*** Calling custom BertForSequenceClassification ***


In [57]:
exp = VizHelper(model, tokenizer, raw_datasets["test"], proc_datasets["test"])

Attention

In [None]:
exp.show_attention(idx=21, head=3, layer=10)

In [None]:
exp.show_effective_attention(idx=21, head=3, layer=10)

In [None]:
idx, head, layer = 21, 7, 2
fig = exp.compare_attentions(idx, head, layer)
fig.savefig(f"plots/comp_attentions_{idx}_{head}_{layer}.png", bbox_inches='tight')

In [None]:
idx, head, layer = 21, 1, 9
fig = exp.compare_attentions(idx, head, layer)
fig.savefig(f"plots/comp_attentions_{idx}_{head}_{layer}.png", bbox_inches='tight')

Classification

In [None]:
exp.classify(21)

In [None]:
results = exp.compute_table(21)

In [None]:
results.to_excel("table.xlsx")

In [None]:
idx, head, layer = 21, 1, 2
fig = exp.compare_attentions(idx, head, layer, effective_model=effective_model, fontsize=18)
fig.savefig(f"plots/comp_attentions_{idx}_{head}_{layer}.png", bbox_inches='tight')

In [None]:
grad, embeds = exp.get_gradient(21)

Gradients

In [None]:
exp.get_gradient(21)

Final table

In [74]:
%%capture
table = exp.compute_table(21)

In [75]:
table

tokens,[CLS],i,miss,my,stupid,pretty,s,##tan,##k,dumb,whore,s.1,##kan,##k.1,trick,bitch,ass,friends,[SEP]
G,-0.046693,-0.027237,-0.027237,-0.007782,0.015564,0.0,-0.019455,0.085603,0.070039,-0.015564,0.023346,0.027237,-0.077821,-0.035019,-0.171206,0.0,-0.108949,0.210117,-0.031128
GxI,0.004382,-0.055667,-0.033224,-0.034306,-0.013459,0.003522,0.011157,0.0192,0.024633,-0.000851,-0.003492,0.014394,0.083351,0.061475,-0.075812,-0.11349,-0.201125,-0.139274,0.107186
IntegratedGradients,-0.02481,-0.04167,0.079333,-0.055123,-0.076697,0.03804,-0.047656,-0.126776,-0.033091,-0.112115,0.030842,-0.045686,-0.031556,-0.034987,-0.08623,-0.090326,-0.030227,0.002293,-0.012541
KernelSHAP,0.030795,-0.030487,-0.103036,-0.044664,0.067001,0.018243,0.049915,0.048692,0.043373,0.033584,0.111427,0.072649,0.066315,0.040431,0.027878,0.169449,0.009331,0.078815,0.275172


In [77]:
table.to_excel(f"table_21.xlsx")

SHAP

In [None]:
shap = exp.get_kernel_shap(21)

In [67]:
ig = exp.get_integrated_gradients(21)

In [68]:
ig

tensor([-0.0248, -0.0417,  0.0793, -0.0551, -0.0767,  0.0380, -0.0477, -0.1268,
        -0.0331, -0.1121,  0.0308, -0.0457, -0.0316, -0.0350, -0.0862, -0.0903,
        -0.0302,  0.0023, -0.0125], dtype=torch.float64,
       grad_fn=<DivBackward0>)