In [None]:
import scipy as sp

import torch
import shap
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
import transformers
from transformers import AutoTokenizer
import matplotlib.pyplot as plt

from utilities import *

classes = ["anger", "fear", "joy", "sadness", "surprise"]

def create_shap_pngs(number_of_examples, model_name, model_path, dataset_path):
    # Load dataset
    dataset, _ = load_and_split_dataset(dataset_path, 1.0)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = torch.load(f"./results/{model_path}/best_model.pth")
    model.eval()

    # Convert dataset to a list of strings
    texts = dataset["text"][:number_of_examples]
    texts = texts.tolist() if not isinstance(texts, list) else texts

    print("Texts to explain:", texts)

    # Define a custom prediction wrapper
    class CustomPipelineWrapper:
        def __init__(self, model, tokenizer):
            self.model = model
            self.tokenizer = tokenizer

        def __call__(self, texts):
            if isinstance(texts, np.ndarray):
                texts = texts.tolist()
            if isinstance(texts, str):
                texts = [texts]
            elif not isinstance(texts, list):
                raise ValueError("Input must be a str, List[str], or numpy.ndarray.")

            tokens = self.tokenizer(
                texts, padding=True, truncation=True, return_tensors="pt"
            )

            with torch.no_grad():
                outputs = self.model(**tokens)

                logits = outputs

            outputs = torch.sigmoid(logits).cpu().numpy()
            return outputs

    pred = CustomPipelineWrapper(model, tokenizer)

    # Provide a masker for text data
    masker = shap.maskers.Text(tokenizer)

    # Initialize SHAP explainer
    explainer = shap.Explainer(pred, masker,output_names=classes)

    # Select a single text for SHAP analysis
    single_text = texts
    print("Single text for explanation:", single_text)

    # Compute SHAP values
    shap_values = explainer(single_text)
    print("SHAP values:", shap_values)

    # Extract base value and tokenized features
    base_value = shap_values.base_values[0]
    feature_names = shap_values.data[0]

    print("Feature names (tokens):", feature_names)
    print("SHAP values shape:", shap_values.values.shape)

    # Ensure dimensions align
    #if len(feature_names) != shap_values.values.shape[1]:
    #    raise ValueError("Mismatch between feature names and SHAP values!")

    # Replace NaN or None in SHAP values with 0
    #shap_values.values = np.nan_to_num(shap_values.values)

    # Create and save force plot
    # force_plot = shap.plots.force(base_value, shap_values.values[0], feature_names)
    # shap.save_html("shap_interactive.html", force_plot)

    # Create and save static text visualization as PNG
    shap.plots.text(shap_values)
    plt.savefig("shap_plot.png", dpi=300, bbox_inches="tight")
    # Save the SHAP text plot as an HTML file
    with open("shap_text_plot.html", "w") as f:
        f.write(shap.plots.text(shap_values, display=False))


# Call the function
create_shap_pngs(
    4,
    "roberta-base",
    "classification/roberta-base_2024-12-16_19-13-00",
    "data/public_data/train/track_a/eng.csv"
)





                        id                                               text  \
0  eng_train_track_a_00001                                 but not very happy   
1  eng_train_track_a_00002  well shes not gon na last the whole song like ...   
2  eng_train_track_a_00003  she sat at her papas recliner sofa only to mov...   
3  eng_train_track_a_00004                      yes the oklahoma city bombing   
4  eng_train_track_a_00005                        they were dancing to bolero   

   Anger  Fear  Joy  Sadness  Surprise  
0      0     0    1        1         0  
1      0     0    1        0         0  
2      0     0    0        0         0  
3      1     1    0        1         1  
4      0     0    1        0         0  
Texts to explain: ['but not very happy', 'well shes not gon na last the whole song like that so since im behind her and the audience cant see below my torso pretty much i use my hand to push down on the lid and support her weight', 'she sat at her papas recliner sofa

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer:  50%|█████     | 2/4 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer: 5it [00:45, 15.11s/it]                       


SHAP values: .values =
array([array([[ 2.88709998e-08,  0.00000000e+00, -3.35276127e-08,
                3.35276127e-08, -3.72529030e-09],
              [-4.68167476e-04,  3.83670330e-02, -2.67782919e-02,
                7.77946040e-03,  2.24027000e-02],
              [ 8.18606000e-02,  1.67158045e-01, -2.38014657e-01,
                1.30067507e-01,  8.46481668e-02],
              [ 7.17438245e-02,  1.87192965e-01, -2.13623162e-01,
                1.31243758e-01,  5.68680158e-02],
              [ 8.02131915e-02,  2.65044458e-02,  5.53238187e-02,
                2.28040135e-01, -1.39640989e-01],
              [ 2.46800482e-08, -1.49011612e-08,  2.04890966e-08,
                2.98023224e-08,  2.04890966e-08]])               ,
       array([[-1.88179564e-03, -2.82176537e-03,  9.91513021e-04,
               -3.37548275e-03, -3.21683311e-03],
              [-1.13096571e-03,  4.12322814e-03, -8.83670431e-03,
               -2.49414658e-03,  1.21632754e-03],
              [ 5.98781160e-04, 

<Figure size 640x480 with 0 Axes>