In [7]:
import shap
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
from sentiment import sentiment_analysis
from processing import process_data_to_list

In [8]:
posts = process_data_to_list("data/posts", "content")
comms = process_data_to_list("data/comments", "body")
data = posts+comms

Processing complete. Total posts and comments 28
Processing complete. Total posts and comments 6705


In [9]:
data = data[:10]

In [10]:
# Load pre-trained tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")

def predict_sentiment(text):
    #numpy array handling
    if isinstance(text, np.ndarray):
        text = text.tolist()
    
    # Tokenize and encode text for the model
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    # Predict using the model
    with torch.no_grad():
        outputs = model(**inputs)
    # Return softmax probabilities
    probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
    return probabilities.detach().cpu().numpy()

def shap_word_importance(texts):
    # Initialize SHAP explainer using the prediction function and tokenizer
    explainer = shap.Explainer(predict_sentiment, tokenizer)
    # Calculate SHAP values
    shap_values = explainer(texts)
    
    # Initialize list to store importance scores for each text
    importance_scores = []
    
    # Iterate over SHAP values for each text
    for shap_value in shap_values:
        # Extract SHAP values for each token
        # Note: Each token's SHAP value is stored in shap_value.values, where we sum over the output classes (axis=-1)
        token_importance = np.sum(shap_value.values, axis=-1).tolist()
        # Store the summed importance scores for each word
        importance_scores.append(token_importance)
    
    return importance_scores

In [12]:
# Process each text for sentiment analysis:
for text in data:
    sentiment = predict_sentiment(text)
    print(sentiment)  # This will print the sentiment probabilities for each text.

# Calculate SHAP word importance for the texts
importance_scores = shap_word_importance(data)
for i, scores in enumerate(importance_scores):
    print(f"Text {i+1} word importance scores: {scores}")

[[0.99699867 0.00300132]]
[[0.99887687 0.00112306]]
[[0.989924   0.01007606]]
[[0.9948891  0.00511092]]
[[0.99866736 0.00133267]]
[[0.9975968  0.00240323]]
[[0.9779279  0.02207205]]
[[0.98703206 0.01296791]]
[[0.00459804 0.995402  ]]
[[0.9982503  0.00174962]]


  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer:  10%|██▏                   | 1/10 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer:  30%|████▏         | 3/10 [01:44<03:20, 28.64s/it]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer:  40%|█████▌        | 4/10 [03:22<05:45, 57.66s/it]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer:  50%|███████       | 5/10 [05:15<06:31, 78.26s/it]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer:  60%|████████▍     | 6/10 [07:03<05:55, 88.75s/it]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer:  70%|█████████▊    | 7/10 [08:56<04:50, 96.67s/it]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer:  80%|██████████▍  | 8/10 [10:43<03:20, 100.16s/it]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer:  90%|███████████▋ | 9/10 [12:31<01:42, 102.51s/it]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer: 100%|█████████████| 10/10 [14:03<00:00, 99.12s/it]

  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer: 11it [16:40, 100.10s/it]                          

Text 1 word importance scores: [-1.0186340659856796e-09, 6.111804395914078e-10, 2.5902409106492996e-09, 4.918547347187996e-09, -3.92901711165905e-09, -1.8335413187742233e-09, 1.076841726899147e-09, 1.4551915228366852e-10, -2.5902409106492996e-09, -2.5902409106492996e-09, 4.511093720793724e-09, -8.440110832452774e-10, -2.240994945168495e-09, -2.240994945168495e-09, -2.240994945168495e-09, 1.4842953532934189e-09, 3.2887328416109085e-09, 1.0273652151226997e-08, 1.6007106751203537e-09, 6.7229848355054855e-09, -2.8558133635669947e-09, 8.88879489110983e-10, 8.88879489110983e-10, 3.6501054034767527e-10, -2.4774635676294565e-09, -2.4774635676294565e-09, -1.7207639757543802e-09, -1.7207639757543802e-09, -2.2094657978199983e-09, -4.304941590704825e-09, -9.943808623735784e-11, -6.861228030174971e-09, 2.1989560866630953e-10, -3.5053946906629663e-09, -7.502320735139345e-10, -1.2158933608216738e-09, 1.4616590403304652e-09, 5.575000412705311e-09, -2.1249470958870087e-09, -1.892116452233139e-09, -3.17




In [14]:
len(importance_scores)

10

In [15]:
sentiments = sentiment_analysis(data)

Post 1, sentiment: NEGATIVE, score: 0.99700
Post 2, sentiment: NEGATIVE, score: 0.99888
Post 3, sentiment: NEGATIVE, score: 0.98992
Post 4, sentiment: NEGATIVE, score: 0.99489
Post 5, sentiment: NEGATIVE, score: 0.99867
Post 6, sentiment: NEGATIVE, score: 0.99760
Post 7, sentiment: NEGATIVE, score: 0.97793
Post 8, sentiment: NEGATIVE, score: 0.98703
Post 9, sentiment: POSITIVE, score: 0.99540
Post 10, sentiment: NEGATIVE, score: 0.99825


In [16]:
sentiments

[-0.99699866771698,
 -0.9988768696784973,
 -0.9899240136146545,
 -0.9948890805244446,
 -0.9986673593521118,
 -0.997596800327301,
 -0.9779279232025146,
 -0.9870320558547974,
 0.9954019784927368,
 -0.9982503056526184]

In [17]:
def combine_sentiment_and_importance(sentiments, importance_scores):
    combined_scores = []

    for sentiment, word_scores in zip(sentiments, importance_scores):
        adjusted_scores = [sentiment * score for score in word_scores]
        combined_scores.append(adjusted_scores)

    return combined_scores

In [18]:
final = combine_sentiment_and_importance(sentiments, importance_scores)

In [19]:
final

[[1.0155768066788529e-09,
  -6.093460840073117e-10,
  -2.5824667369833687e-09,
  -4.903785152249318e-09,
  3.91722482576129e-09,
  1.8280382520219352e-09,
  -1.0736097670605016e-09,
  -1.4508240095412184e-10,
  2.5824667369833687e-09,
  2.5824667369833687e-09,
  -4.497554429577777e-09,
  8.414779255339067e-10,
  2.2342689746934763e-09,
  2.2342689746934763e-09,
  2.2342689746934763e-09,
  -1.4798404897320427e-09,
  -3.2788622615631535e-09,
  -1.0242817507361002e-08,
  -1.5959064104953402e-09,
  -6.702806924080429e-09,
  2.847242118724641e-09,
  -8.862116664045999e-10,
  -8.862116664045999e-10,
  -3.639150224292872e-10,
  2.4700278762439243e-09,
  2.4700278762439243e-09,
  1.7155993912824907e-09,
  1.7155993912824907e-09,
  2.2028344567927726e-09,
  4.292021030532127e-09,
  9.913963949897193e-11,
  6.840635204986845e-09,
  -2.19235628877125e-10,
  3.4948738364131526e-09,
  7.479803777719401e-10,
  1.2122440608251299e-09,
  -1.4572721158659533e-09,
  -5.558267983988809e-09,
  2.118569423