# XAI on fine-tuned SCIBERT

In [1]:
# run this script to import the variables and settings from ipynb_util.py
%run ipynb_util.py

### Load the data

In [2]:
from datasets import load_dataset

# load the dataset
# note: if you don't have the data in the folder, use the download-data.sh script
data_files = { "train": str(DATA_DIR_PATH / "task1-train.jsonl"), "test": str(DATA_DIR_PATH / "task1-test.jsonl")}
dataset = load_dataset("json", data_files=data_files)
#dataset = dataset["train"].train_test_split(test_size=0.3, stratify_by_column="SDG", seed=SEED)

example = dataset["train"][0]
print("Example instance:\t", example)

labels = set(dataset["train"]["SDG"])
print(dataset["train"])

Example instance:	 {'ID': 'oai:www.zora.uzh.ch:126666', 'TITLE': 'Identifying phrasemes via interlingual association measures - A data-driven approach on dependency-parsed and word-aligned parallel corpora', 'ABSTRACT': 'In corpus linguistics, statistical association measures play a major role in identifying collocations such as ‘play’ and ‘role’ in ‘play a role’.  Those two words that appear considerably more often in the same context than one would expect from a random distribution are collocates.  They typically constitute meaning beyond the bare combination of both words’ semantics.\r\nWe employ the same association measures on interlingual word co-occurrences based on statistical word alignment and combine them with intralingual association measures on syntactical dependency relations in order to identify phrasemes.  Support verb constructions exemplify our approach.  They are characterized by the respective verb contributing little to the semantics of the whole construction, whic

### Load the fine-tuned model

In [3]:
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# load the best model
MODEL_PATH = CHECKPOINT_PATH + "/allenai/scibert_scivocab_uncased-ft-task1/checkpoint-532"

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_PATH,
    num_labels=len(labels),
    id2label={i: label for i, label in enumerate(labels)},
    label2id={label: i for i, label in enumerate(labels)}
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model.eval()

pred = transformers.pipeline(
    "text-classification",
    model=model,
    batch_size=8,
    tokenizer=tokenizer,
    device=0,
    top_k=None,     # equal to return_all_scores=True
)

In [4]:
import pandas as pd
from datasets import Dataset

data = pd.DataFrame(
    {
        "text": [f"{title} {abstract}" for title, abstract in zip(dataset["train"]["TITLE"], dataset["train"]["ABSTRACT"])],
        "label": dataset["train"]["SDG"]
    }
)

# take a sample of the dataset for SHAP
sample_size = 10
sample_data = data.iloc[:sample_size]

# wrap sample data in a Dataset for parallel gpu processing
sample_dataset = Dataset.from_pandas(sample_data)

In [5]:
print(sample_dataset)

Dataset({
    features: ['text', 'label'],
    num_rows: 10
})


In [6]:
# compute average text length of the whole data
average_text_length = data["text"].apply(lambda x: len(x)).mean()
print(average_text_length)

# compute average amount of tokens in the whole data
average_tokens = data["text"].apply(lambda x: len(tokenizer(x)["input_ids"])).mean()
print(average_tokens)

len(tokenizer(data["text"][0])["input_ids"])

1425.3232558139534
264.56976744186045


178

## Shap

In [7]:
import shap
from matplotlib.colors import ListedColormap

# cmap = ListedColormap(sns.color_palette(["green", "red"]).as_hex())
explainer = shap.Explainer(pred, output_names=[f"SDG {i + 1}" for i in range(17)] + ["non-relevant"])
shap_values = explainer(sample_dataset["text"])

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer:  20%|██        | 2/10 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer:  40%|████      | 4/10 [00:25<00:21,  3.55s/it]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer:  50%|█████     | 5/10 [00:35<00:30,  6.19s/it]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer:  60%|██████    | 6/10 [00:41<00:24,  6.14s/it]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer:  70%|███████   | 7/10 [00:52<00:23,  7.93s/it]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer:  80%|████████  | 8/10 [00:58<00:14,  7.41s/it]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer:  90%|█████████ | 9/10 [01:06<00:07,  7.57s/it]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer: 11it [01:15,  8.40s/it]                        


In [8]:
shap_values[0, :, 4].values.sum()

-0.0024012872017920017

In [9]:
print(shap_values[0, :, 1].values == shap_values[0, :, 0].values)
print(shap_values[0, :, 1].base_values == shap_values[0, :, 0].base_values)

[False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False Fa

In [27]:
shap.plots.text(shap_values[0, :, :], xmin=-1, xmax=1)

In [45]:
shap_values[0, :, :].data[:7]
shap_values[0, :, :].values[:7, 0]

array([0.00040599, 0.00040599, 0.00040599, 0.00040599, 0.00040599,
       0.00046725, 0.00046725])

In [11]:
sample_data

Unnamed: 0,text,label
0,Identifying phrasemes via interlingual associa...,0
1,Synthesis of tripeptide derivatized cyclopenta...,0
2,Intelligence test items varying in capacity de...,0
3,Copy number increases of transposable elements...,14
4,Linguistics : An Interdisciplinary Journal of ...,0
5,Bose–Einstein condensation of triplons with a ...,0
6,"Turning a blind eye, but not the other cheek: ...",0
7,"A convenient access to 1,2-diferrocenyl-substi...",0
8,The arguments of utility: Preference reversals...,0
9,Non-Abelian chiral spin liquid on a simple non...,0


In [12]:
pred(data["text"][0])

[[{'label': 0, 'score': 0.9962840676307678},
  {'label': 2, 'score': 0.00040405127219855785},
  {'label': 3, 'score': 0.0003473605611361563},
  {'label': 4, 'score': 0.0002902108244597912},
  {'label': 5, 'score': 0.0002729483530856669},
  {'label': 10, 'score': 0.00026633351808413863},
  {'label': 12, 'score': 0.00024927948834374547},
  {'label': 14, 'score': 0.00022910060943104327},
  {'label': 6, 'score': 0.00019695048104040325},
  {'label': 9, 'score': 0.00019046761735808104},
  {'label': 8, 'score': 0.0001864801743067801},
  {'label': 16, 'score': 0.0001817433803807944},
  {'label': 7, 'score': 0.0001695550890872255},
  {'label': 15, 'score': 0.00016195762145798653},
  {'label': 13, 'score': 0.0001610693143447861},
  {'label': 11, 'score': 0.00015242990048136562},
  {'label': 17, 'score': 0.0001354479609290138},
  {'label': 1, 'score': 0.00012075644917786121}]]

In [13]:
# plot the top words impacting a specific class
shap_values.shape

(10, None, 18)