In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../..")

In [3]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from interpreto.attributions.methods.occlusion import OcclusionExplainer
from interpreto.attributions.perturbations.base import GranularityLevel
from interpreto.commons.model_wrapping.classification_inference_wrapper import ClassificationInferenceWrapper

from interpreto.visualizations.attributions.classification_highlight import SingleClassAttributionVisualization

In [4]:
model_name = "textattack/bert-base-uncased-imdb"
test_sentences = ["Best movie ever", "Worst movie ever verylongword"]
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
inference_wrapper = ClassificationInferenceWrapper(model=model, batch_size=4)
exp = OcclusionExplainer(
    tokenizer=tokenizer,
    inference_wrapper=inference_wrapper,
    granularity_level=GranularityLevel.WORD,
)
explaination = exp.explain(test_sentences)

for elem in explaination:
    print(elem.attributions, elem.elements)

tensor([[2.5734, 0.6498, 0.5065]], grad_fn=<SqueezeBackward1>) ['best', 'movie', 'ever']
tensor([[4.1983, 1.0871, 0.7224, 0.4840]], grad_fn=<SqueezeBackward1>) ['worst', 'movie', 'ever', 'verylongword']


In [5]:
# Reorganize the attributions as (l, c)
print("Before:")
for elem in explaination:
    print(elem.attributions.shape)
for elem in explaination:
    elem.attributions = elem.attributions.T
print("After:")
for elem in explaination:
    print(elem.attributions.shape)

Before:
torch.Size([1, 3])
torch.Size([1, 4])
After:
torch.Size([3, 1])
torch.Size([4, 1])


In [6]:
viz = SingleClassAttributionVisualization(
    attribution_output_list = explaination,
    css=".common-word-style {font-size: 1.5em}"
)
viz.display()