# Tutorial: Zero-Shot Text Classification

In this short tutorial, we show how to use *ferret* to use and evaluate different explainability approaches in the task of Zero-Shot Text Classification.

We will use `MoritzLaurer/mDeBERTa-v3-base-mnli-xnli` as model checkpoint.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from ferret import (
    Benchmark,
    GradientExplainer,
    IntegratedGradientExplainer,
    LIMEExplainer,
    SHAPExplainer,
)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_name = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)

tokenizer_config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████| 1.26k/1.26k [00:00<00:00, 2.55MB/s]
spm.model: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 4.31M/4.31M [00:00<00:00, 34.0MB/s]
tokenizer.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 16.3M/16.3M [00:00<00:00, 87.5MB/s]
added_tokens.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 23.0/23.0 [00:00<00:00, 50.6kB/s]
special_tokens_map.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 286/286 [00:00<00:00, 579kB/s]
config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 1.07k/1.07k [00:00<00:00, 2.53MB/s]
model.safetensors: 100%|████████████████████████████████████████████████████

In [4]:
ig = IntegratedGradientExplainer(model, tokenizer, multiply_by_inputs=True)
g = GradientExplainer(model, tokenizer, multiply_by_inputs=True)
l = LIMEExplainer(model, tokenizer)

No helper provided. Using default 'text-classification' helper.


In [5]:
bench = Benchmark(
    model, tokenizer, task_name="zero-shot-text-classification", explainers=[ig, g, l]
)

Overriding helper for explainer <ferret.explainers.gradient.IntegratedGradientExplainer object at 0x7fb9acb88be0>
Overriding helper for explainer <ferret.explainers.gradient.GradientExplainer object at 0x7fb9acb89690>
Overriding helper for explainer <ferret.explainers.lime.LIMEExplainer object at 0x7fb9acb88f70>


In [8]:
sequence_to_classify = (
    "Amanda ha cucinato la più buona torta pecan che abbia mai provato!"
)
candidate_labels = ["politics", "economy", "bakery", "science", "informatics"]
sample = (sequence_to_classify, candidate_labels)

In [9]:
sample

('Amanda ha cucinato la più buona torta pecan che abbia mai provato!',
 ['politics', 'economy', 'bakery', 'science', 'informatics'])

When scoring with a zero-shot classifier based on a NLI model, we need to specify the `options` available. You can specify an arbitrary number of options.

In [10]:
# get the prediction from our model
bench.score(sample[0], options=candidate_labels, return_probs=True)

{'politics': 0.15809638798236847,
 'economy': 0.15844039618968964,
 'bakery': 0.3655945956707001,
 'science': 0.158456489443779,
 'informatics': 0.15941213071346283}

Since we know the model uses NLI to perform the classification task, we can know explain the `entailment` class for the most likely option `bakery`.

In [11]:
exp = bench.explain(sample[0], target="entailment", target_option="bakery")

Explainer:  33%|███████████████████████████████████▋                                                                       | 1/3 [00:00<00:00,  2.41it/s]
Batch:   0%|                                                                                                                     | 0/113 [00:00<?, ?it/s][A
Batch:   4%|████▊                                                                                                        | 5/113 [00:00<00:02, 44.52it/s][A
Batch:   9%|█████████▌                                                                                                  | 10/113 [00:00<00:02, 44.77it/s][A
Batch:  13%|██████████████▎                                                                                             | 15/113 [00:00<00:02, 45.19it/s][A
Batch:  18%|███████████████████                                                                                         | 20/113 [00:00<00:02, 45.51it/s][A
Batch:  22%|███████████████████████▉                         

In [12]:
# show explanations
bench.show_table(exp)

Unnamed: 0,▁_0,Amanda,▁ha,▁_1,cucina,to_0,▁la,▁p,iù,▁buon,a_0,▁tort,a_1,▁pe,can,▁che,▁_2,abbia,▁mai,▁prova,to_1,!,[SEP],▁This,▁is,▁_3,baker,y
Integrated Gradient (x Input),0.12,0.05,-0.03,-0.02,-0.01,0.0,0.01,-0.0,-0.01,-0.01,-0.01,0.0,0.01,-0.02,-0.01,-0.02,0.13,-0.04,0.01,-0.0,0.0,-0.04,-0.01,-0.18,-0.06,0.0,0.12,-0.05
Gradient (x Input),-0.04,-0.03,-0.02,-0.01,0.03,0.01,0.0,0.01,0.03,0.03,-0.0,-0.01,-0.0,-0.11,0.04,0.01,-0.03,0.01,0.02,-0.01,0.0,-0.01,-0.11,-0.02,-0.0,0.12,0.16,0.02
LIME,-0.01,-0.0,0.03,0.03,-0.02,0.04,0.04,-0.03,0.02,-0.01,-0.03,0.21,0.11,0.03,0.04,-0.01,0.0,0.01,-0.02,-0.02,-0.01,0.02,0.06,-0.0,-0.1,-0.04,0.01,-0.05


In [13]:
# evaluate explanations and show faithfulness metrics
bench.show_evaluation_table(bench.evaluate_explanations(exp, target="entailment"))

                                                                                                                                                         

Unnamed: 0_level_0,aopc_compr,aopc_suff,taucorr_loo
Explainer,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
Integrated Gradient (x Input),0.36,0.63,0.05
Gradient (x Input),0.59,0.79,-0.19
LIME,0.72,0.7,0.43
