In [1]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from interpreto.attributions.base import InferenceModes
from interpreto.attributions.methods import (
    IntegratedGradients,
    KernelShap,
    Lime,
    Occlusion,
    Saliency,
    SmoothGrad,
    Sobol,
)
from interpreto.commons.granularity import Granularity
from interpreto.visualizations.attributions.classification_highlight import (
    MultiClassAttributionVisualization,
    SingleClassAttributionVisualization,
)

In [2]:
import pandas as pd

dataset_path = "/data/fanny.jourdan/interpreto_test/labeled_noise_text_dataset.txt"
model_name = "/data/fanny.jourdan/interpreto_test/distilbert_trivial_classifier"
tokenizer_name = "hf-internal-testing/tiny-random-distilbert"

model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
df = pd.read_csv(dataset_path)

In [None]:
list_explainers = [
    Occlusion(model=model, batch_size=4, tokenizer=tokenizer, granularity=Granularity.WORD),
    IntegratedGradients(model=model, batch_size=4, tokenizer=tokenizer, n_interpolations=10),
    SmoothGrad(model=model, batch_size=4, tokenizer=tokenizer, n_interpolations=50, noise_level=0.01),
    Saliency(model=model, batch_size=4, tokenizer=tokenizer),
    Lime(
        model=model,
        batch_size=4,
        tokenizer=tokenizer,
        n_perturbations=20,
        granularity=Granularity.WORD,
        distance_function=Lime.distance_functions.HAMMING,
        inference_mode=InferenceModes.SOFTMAX,
    ),
    Sobol(
        model=model,
        batch_size=4,
        tokenizer=tokenizer,
        n_token_perturbations=16,
        granularity=Granularity.WORD,
        sobol_indices_order=Sobol.sobol_indices_orders.TOTAL_ORDER,
    ),
    KernelShap(
        model=model,
        batch_size=4,
        tokenizer=tokenizer,
        n_perturbations=20,
        granularity=Granularity.WORD,
        inference_mode=InferenceModes.SOFTMAX,
    ),
]

k = 10

for explainer in list_explainers:
    print(f"Explaining with {explainer.__class__.__name__}")
    attribution_outputs = explainer.explain(model_inputs=[df["text"][k]])

    viz = SingleClassAttributionVisualization(
        attribution_output=attribution_outputs[0],
        css=".common-word-style { margin-right: 0.3em }",
    )
    viz.display()

In [None]:
k = 10

for explainer in list_explainers:
    print(f"Explaining with {explainer.__class__.__name__}")
    attribution_outputs = explainer.explain(model_inputs=[df["text"][k]], targets=torch.tensor([[0, 1]]))

    viz = MultiClassAttributionVisualization(
        attribution_output=attribution_outputs[0],
        class_names=["A", "B"],
        css=".common-word-style { margin-right: 0.3em }",
    )
    viz.display()