In [6]:
import os
import torch
import shap
import numpy as np
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification

# Paths and model setup
MODEL_PATH = os.path.abspath(os.path.join(os.path.dirname("__file__"), '../../results/cadec-absa/cadec-absa-model'))
MODEL_NAME = "distilbert-base-uncased"
MAX_LENGTH = 256

device = torch.device('cpu')
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH)
model.to(device)
model.eval()
tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_NAME)

def predict_proba(texts):
    if isinstance(texts, np.ndarray):
        texts = texts.tolist()
    if isinstance(texts, str):
        texts = [texts]
    if not isinstance(texts, list):
        texts = list(texts)
    texts = [str(t) for t in texts]
    inputs = tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()
    return probs

sample_texts = [
    "This medicine helped me a lot with my pain.",
    "I had severe side effects and had to stop taking it.",
    "No improvement after two weeks of use.",
    "Much better than my previous medication."
]

explainer = shap.Explainer(predict_proba, shap.maskers.Text(tokenizer))
shap_values = explainer(sample_texts)

In [7]:
shap.plots.text(shap_values[0])

In [8]:
shap.plots.text(shap_values[1])

In [9]:
shap.plots.text(shap_values[2])

In [10]:
shap.plots.text(shap_values[3])