In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
import numpy as np
import pickle
# context = notebook, import tqdm
from tqdm import tqdm


import sys
sys.path.append('../')
from src import shap

import warnings
warnings.filterwarnings('ignore')



In [2]:

folder = '../data/processed/'

datasets = ['dv2', 'dv3', 'dv2_humanized']
df = pd.read_csv(folder + 'updated_moral_turing.csv')
# rename identification to belief
df = df.rename(columns={'identification': 'belief'})
data = {}

for ds in tqdm(datasets):

    for col_name in ['source', 'belief', 'agree2']:
        rs = 12345  # Random state

        target_names = ['0', '1']
        df_ = df[df['dataset'] == ds].copy()

        # if col_name == 'agree' or col_name == 'agree2':
            # df_['binary_label'] = np.where(df_[col_name] == 1, 1, 0)
        # elif col_name == 'label':
            # df_['binary_label'] = np.where(df_['source'] == 'AI', 1, 0)
        # else:
            # df_['binary_label'] = np.where(df_[col_name] == 'AI', 1, 0)
        if col_name == 'agree2':
            df_['binary_label'] = (df_[col_name] == 1).astype(int)
        else:
            df_['binary_label'] = (df_[col_name] == 'AI').astype(int)

        df_1 = df_[df_['binary_label'] == 1]
        df_0 = df_[df_['binary_label'] != 1]

        # Downsample
        if len(df_1) > len(df_0):
            df_1_downsampled = df_1.sample(n=len(df_0), random_state=rs)
            df_balanced = pd.concat([df_1_downsampled, df_0])
        else:
            df_0_downsampled = df_0.sample(n=len(df_1), random_state=rs)
            df_balanced = pd.concat([df_0_downsampled, df_1])

        df2 = df_balanced.sample(frac=1, random_state=rs).reset_index(drop=True)
        X = df2['response']
        y = df2['binary_label']

        # Clean up stop words
        vectorizer = TfidfVectorizer(
            stop_words='english',
            # Only include words with letters (exclude numbers)
            token_pattern=r'\b[a-zA-Z][a-zA-Z]+\b',
            min_df=3,
            max_features=1000
        )
        X_tfidf = vectorizer.fit_transform(X)

        X_train, X_test, y_train, y_test = train_test_split(
            X_tfidf, y, test_size=0.2, random_state=rs)

        # Train a classifier
        X_train_dense = X_train.toarray()
        X_test_dense = X_test.toarray()

        clf = RandomForestClassifier(n_estimators=100, random_state=rs)
        clf.fit(X_train_dense, y_train)

        y_pred = clf.predict(X_test_dense)
        report = classification_report(
            y_test, y_pred, target_names=target_names)
        e3 = shap.TreeExplainer(clf, X_train_dense)

        sv3 = e3.shap_values(X_test_dense)
        
        if ds not in data:
            data[ds] = {}
        
        if col_name not in data[ds]:
            data[ds][col_name] = {}

        data[ds][col_name]['shap'] = sv3
        data[ds][col_name]['feature_names'] = vectorizer.get_feature_names_out()
        data[ds][col_name]['dense'] = X_test_dense

# save the data
with open(folder + 'shap_data.pkl', 'wb') as f:
    pickle.dump(data, f)



In [None]:
import seaborn as sns
# import shap again
import pickle
from src import shap
import matplotlib.pyplot as plt
import sys
sys.path.append('../')
from src import shap

sns.set_context('talk')
sns.set_palette('viridis')

# best palettes 
# palette = ['dark:blue', 'dark:orange']
# get cmap of viridis
cmap_green = sns.color_palette('dark:#5f9e6e', as_cmap=True)
cmap_blue = sns.color_palette('dark:#c44e52', as_cmap=True)
cmap_orange = sns.color_palette('dark:#854e9e', as_cmap=True)

# Plot SHAP values for each class
#for i, class_name in enumerate(['1']):
# load data if exists
folder = '../data/processed/'
try:
    with open(folder + 'shap_data.pkl', 'rb') as f:
        data = pickle.load(f)
except FileNotFoundError:
    print("File not found")
    print('Please run the previous cell to generate the data')
    raise Exception


plt.figure(figsize=(100, 100), dpi=300)
count = 0

sources = ['dv2', 'dv3', 'dv2_humanized']

for col_name in ['source', 'belief', 'agree2']:
    for source, cmap in zip(sources, [cmap_blue, cmap_green, cmap_orange]):
        count += 1
        print(f"Source: {source}, column: {col_name}")
        sv3 = data[source][col_name]
        i = 0
        class_name = '1'
        
        shap_values = data[source][col_name]['shap']
        X_test_dense = data[source][col_name]['dense']
        feature_names = data[source][col_name]['feature_names']


        print(f"SHAP values for class '{class_name}':")

        # shap.summary_plot(sv3[:, :, i], features=X_test_dense, feature_names=vectorizer.get_feature_names_out())
        plt.subplot(3, 3, count)
        
        # color_bar = count % 3 == 0
        color_bar = False

        ax = shap.summary_plot(shap_values=shap_values[:, :, 1], features=X_test_dense,
                           feature_names=feature_names,
                            max_display=10, class_names=[f'Not {source}', source], plot_size=(12, 12), color=cmap, cmap=cmap,
                             color_bar=color_bar, alpha=1, show=False) #get viridis colors
        if col_name == 'source' or col_name == 'belief':
            ax.set_xlabel(f"importance in predicting AI \n ({col_name})")
        else:
            ax.set_xlabel(f"importance in predicting \nagreement")
            
        xticks = ax.get_xticklabels()
        # round to 1 decimal
        try:
            xticks = [x.get_text()[:len(x.get_text())-1]  if len(x.get_text().replace("-", "")) > 4 else x.get_text() for x in xticks]
            # print(xticks)
            xticks = ["0" if x in ("0.00", "0.0") else x for x in xticks]
            ax.set_xticklabels(xticks)
        except ValueError:
            pass

        if count in [1, 2, 3]:
            ax.set_title(f"{source.replace('_', ' ')}")

        # set color bar label
        if color_bar:
            try:
                ax.collections[0].colorbar.set_label('Word frequency')
            # set larger bar width
                ax.collections[0].colorbar.ax.set_aspect(10)
            except AttributeError:
                pass

plt.tight_layout()

# reduce white space between subplots
plt.subplots_adjust(wspace=0.55, hspace=0.4)
            # 

SyntaxError: 'return' outside function (697329217.py, line 33)

In [7]:
data['dv2'].keys()

dict_keys(['source', 'belief', 'agree2'])