In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from datasets import load_metric, Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from einops import rearrange
from helper import VizHelper
import matplotlib.pyplot as plt

from custom_bert import BertForSequenceClassification

In [8]:
max_seq_length = 128

# AMI18

In [9]:
model_name = "./bert-base-cased_ami18/"
tokenizer_name = "bert-base-cased"

In [10]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def preprocess_text(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_seq_length)

In [11]:
train = pd.read_csv("data/miso_train.tsv", sep="\t")
validation = pd.read_csv("data/miso_dev.tsv", sep="\t")
test = pd.read_csv("data/miso_test.tsv", sep="\t")

raw_datasets = DatasetDict(
    train=Dataset.from_pandas(train),
    validation=Dataset.from_pandas(validation),
    test=Dataset.from_pandas(test)
)
raw_datasets = raw_datasets.rename_column("misogynous", "label")
proc_datasets = raw_datasets.map(preprocess_text, batched=True, remove_columns=raw_datasets["train"].features)
proc_datasets.set_format("pt")

  0%|          | 0/4 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [12]:
model = AutoModelForSequenceClassification.from_pretrained(model_name).eval()
effective_model = BertForSequenceClassification.from_pretrained(model_name).eval()

*** Calling custom BertForSequenceClassification ***


In [13]:
exp = VizHelper(model, tokenizer, raw_datasets["test"], proc_datasets["test"])

Attention

In [None]:
exp.show_attention(idx=21, head=3, layer=10)

In [None]:
exp.show_effective_attention(idx=21, head=3, layer=10)

In [None]:
idx, head, layer = 21, 7, 2
fig = exp.compare_attentions(idx, head, layer)
fig.savefig(f"plots/comp_attentions_{idx}_{head}_{layer}.png", bbox_inches='tight')

In [None]:
idx, head, layer = 21, 1, 9
fig = exp.compare_attentions(idx, head, layer)
fig.savefig(f"plots/comp_attentions_{idx}_{head}_{layer}.png", bbox_inches='tight')

Classification

In [None]:
exp.classify(21)

In [None]:
results = exp.compute_table(21)

In [None]:
results.to_excel("table.xlsx")

In [None]:
idx, head, layer = 21, 1, 2
fig = exp.compare_attentions(idx, head, layer, effective_model=effective_model, fontsize=18)
fig.savefig(f"plots/comp_attentions_{idx}_{head}_{layer}.png", bbox_inches='tight')

In [None]:
grad, embeds = exp.get_gradient(21)

Gradients

In [None]:
exp.get_gradient(21)

Final table

In [None]:
table = exp.compute_table(21)

SHAP

In [23]:
shap = exp.get_kernel_shap(21)



Kernel Shap attribution:   0%|          | 0/200 [00:00<?, ?it/s][A[A

Kernel Shap attribution:   2%|▎         | 5/200 [00:00<00:20,  9.46it/s][A[A

Kernel Shap attribution:   5%|▌         | 10/200 [00:01<00:20,  9.44it/s][A[A

Kernel Shap attribution:   8%|▊         | 15/200 [00:01<00:20,  9.03it/s][A[A

Kernel Shap attribution:  10%|█         | 20/200 [00:02<00:20,  8.78it/s][A[A

Kernel Shap attribution:  12%|█▎        | 25/200 [00:02<00:19,  9.11it/s][A[A

Kernel Shap attribution:  15%|█▌        | 30/200 [00:03<00:19,  8.79it/s][A[A

Kernel Shap attribution:  18%|█▊        | 35/200 [00:03<00:18,  9.12it/s][A[A

Kernel Shap attribution:  20%|██        | 40/200 [00:04<00:17,  9.03it/s][A[A

Kernel Shap attribution:  22%|██▎       | 45/200 [00:04<00:16,  9.25it/s][A[A

Kernel Shap attribution:  25%|██▌       | 50/200 [00:05<00:15,  9.44it/s][A[A

Kernel Shap attribution:  28%|██▊       | 55/200 [00:05<00:15,  9.53it/s][A[A

Kernel Shap attribution:  30%|███   

In [29]:
shap[0, :, 0].shape

torch.Size([128])

In [35]:
shap[0, 120, :]

tensor([0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298,
        0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298,
        0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298,
        0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298,
        0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298,
        0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298,
        0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298,
        0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298,
        0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298,
        0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298,
        0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298,
        0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298, 0.0298,
        0.0298, 0.0298, 0.0298, 0.0298, 