In [2]:
from lime.lime_text import LimeTextExplainer
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
import shap
import numpy as np
from src.shap_consistency_change import calculate_consistency_change_shap
from src.lime_consistency_change import calculate_consistency_change_lime
random_state = 0

In [3]:
# https://www.kaggle.com/datasets/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews
data = pd.read_csv('../data/IMDB Dataset.csv')
X, y = data["review"], data["sentiment"] == "positive"
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=random_state)

## Example positive review

In [4]:
print(X[0])

One of the other reviewers has mentioned that after watching just 1 Oz episode you'll be hooked. They are right, as this is exactly what happened with me.<br /><br />The first thing that struck me about Oz was its brutality and unflinching scenes of violence, which set in right from the word GO. Trust me, this is not a show for the faint hearted or timid. This show pulls no punches with regards to drugs, sex or violence. Its is hardcore, in the classic use of the word.<br /><br />It is called OZ as that is the nickname given to the Oswald Maximum Security State Penitentary. It focuses mainly on Emerald City, an experimental section of the prison where all the cells have glass fronts and face inwards, so privacy is not high on the agenda. Em City is home to many..Aryans, Muslims, gangstas, Latinos, Christians, Italians, Irish and more....so scuffles, death stares, dodgy dealings and shady agreements are never far away.<br /><br />I would say the main appeal of the show is due to the fac

In [5]:
pipeline = Pipeline([
    ('tfidf', TfidfVectorizer(stop_words='english', max_df=0.7)),  # Convert text to TF-IDF features
    ('clf', LogisticRegression(solver='liblinear', random_state=random_state))  # Logistic Regression classifier
])

In [6]:
pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")
print(classification_report(y_test, y_pred))

Accuracy: 0.89
              precision    recall  f1-score   support

       False       0.90      0.88      0.89      5035
        True       0.88      0.90      0.89      4965

    accuracy                           0.89     10000
   macro avg       0.89      0.89      0.89     10000
weighted avg       0.89      0.89      0.89     10000


## Define explanation dataset

In [7]:
X_explain = X_test[:1000]

## Calculate shap explanations

In [8]:
# takes about 2.5 minutes
masker = shap.maskers.Text()
# Calculate SHAP values for an explanation samples
explainer = shap.Explainer(lambda s: pipeline.predict_proba(s).T[1], masker, seed=0, silent=True)
shap_values = explainer(X_explain)

## Calculate shap explanations consistency change and mean consistency change for all dataset

In [9]:
# takes about 1.4 minute
consistency_change_scoring_function =  lambda s: pipeline.predict_proba([s]).T[1]
consistency_change_shap = [calculate_consistency_change_shap(consistency_change_scoring_function, x, shap_values[i]) 
                          for (i,x) in enumerate(X_explain)]
mean_consistency_change_shap = np.mean(np.mean(consistency_change_shap,axis=1))
mean_consistency_change_shap

np.float64(0.053704888976680284)

## Calculate lime explanations

In [10]:
# takes about 2.5 minute
class_names = ["0", "1"]
lime_explainer = LimeTextExplainer(class_names=class_names)
lime_explanations =  [lime_explainer.explain_instance(x, pipeline.predict_proba, top_labels=2,
                                                       num_features=300, num_samples=1000) for i,x in enumerate(X_explain)]

## Calculate lime explanations consistency change and mean consistency change for all dataset

In [11]:
consistency_change_lime = [calculate_consistency_change_lime(consistency_change_scoring_function, x, lime_explanations[i]) 
                           for i,x in enumerate(X_explain)]
mean_consistency_change_lime = np.mean(np.mean(consistency_change_lime,axis=1))
mean_consistency_change_lime 

np.float64(0.12069535457373529)