In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import math
import warnings

In [None]:
warnings.filterwarnings("ignore")

In [None]:
my_path ='~/mounts/research/husdatalake/disease/scripts/Preleukemia/oona_git'

In [None]:
disease = 'MDS'

### Run SHAP/SHAP_final_model.R & Load SHAP values calculated with R

In [None]:
shap = pd.read_csv(my_path + '/results/final_model/SHAP/' + disease + '_final_model_shap_scores.csv')
values = pd.read_csv(my_path + '/results/final_model/SHAP/' + disease + '_final_model_feature_values.csv')
xtrain = pd.read_csv(my_path + '/results/final_model/SHAP/' + disease + '_x_train.csv')

# Add % column
shap['perc'] = 100 * shap['mean_shap_score'] / shap['mean_shap_score'].sum()


In [None]:
shap['risk'] = np.nan

for i in range(len(shap)):
     
    column = shap['names'].loc[i]
    pos_indices = values[column][values[column] > 0].index
    pos_mean = xtrain[column].loc[pos_indices].mean()
    
    neg_indices = values[column][values[column] < 0].index
    neg_mean = xtrain[column].loc[neg_indices].mean()
    
    if pos_mean > abs(neg_mean):
        shap['risk'].loc[i] = '+'
    else:
        shap['risk'].loc[i] = '-'

shap['names_clean'] = shap['names']

### Manually rename columns

In [None]:
shap['names_clean'].loc[:9] = ['Age', 'B-NEUT -1y', 'B-EOS', 'E-RDW', 'E-MCV', 'E-RDW -5y', 'B-NEUT', 'B-HB Δ5y', 'B-PLT', 'B-PLT Δ5y']

In [None]:
# How many features to show
n=10

categories = list(shap['names_clean'].values[:n])   
values =  list(shap['perc'].values[:n])

In [None]:
colors = shap.loc[:n]['risk'].replace({'-' : 'cornflowerblue', '+': 'indianred'}).to_list()

In [None]:
# Create the figure
fig = plt.figure(figsize=(4,6))

bars = plt.barh(categories, values, color=colors)
# Add values on top of the bars for the first subplot
for bar in bars:
    width = bar.get_width()
    plt.text(
        width, 
        bar.get_y() + bar.get_height() / 2, 
        str(round(width,1)) + ' %', 
        ha='left', 
        va='center', 
        color='black', 
        fontsize=10)
plt.box(False)
plt.xlim([0,85])
plt.xticks([])
plt.tick_params(axis='y', colors='black')
plt.gca().invert_yaxis()
plt.title('MDS', loc='left', fontsize=12)
plt.tight_layout()
fig.savefig('results/final_model/plots/' + disease + '_shap_scores.png')