In [None]:
import shap
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch.nn.functional as F

In [None]:
# Load model and tokenizer
model_name = "mbeukman/xlm-roberta-base-finetuned-ner-swahili"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
model.eval()

In [None]:
# Wrapper for SHAP
def shap_ner_wrapper(texts):
    all_probs = []
    for text in texts:
        inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits.squeeze(0)  # [seq_len, num_labels]
            probs = F.softmax(logits, dim=-1)  # [seq_len, num_labels]
            # For SHAP, simplify: Take max prob per token (as a proxy)
            token_probs = torch.max(probs, dim=-1).values.numpy()
        all_probs.append(token_probs)
    return np.array(all_probs)

In [None]:
# Use SHAP's TextExplainer
explainer = shap.Explainer(shap_ner_wrapper, tokenizer)

# Example text
sample_text = "በ8420 ብር የተሰራ የእጄታ ወንበር ይዘዙ Call 8420 for order."
shap_values = explainer([sample_text])

# Initialize JavaScript visualization (important for Colab)
shap.initjs()

# Display SHAP text explanation
shap.plots.text(shap_values[0])