## LLMs

In [None]:
#%pip install shap %pip install xgboost
import pandas as pd
import shap
import matplotlib.pyplot as plt
import openpyxl
import numpy as np

In [None]:
excel_path= r"C:\Users\LEGION\Documents\GIT\Tehran_COVID_Cohort\DO_NOT_PUBLISH\LLM_X_test_dataset.xlsx"
ground_truth_col = 'Inhospital Mortalit(TRUE)'
model_prediction_col_list = ['Mixtral-8x7B-Instruct-v0.1 ', 	'Llama-3-70B',	'Mistral-7B-Instruct',	'-Llama-3-8B',	'gpt-4o-2024-05-13_outcome',	'gpt-3.5-turbo-0125_outcome',	'gpt-4-turbo-2024-04-09_outcome',	'gpt-4-0613_outcome']
columns_to_drop= ['patient medical hidtory']

In [None]:
import pandas as pd
import shap
import matplotlib.pyplot as plt
import openpyxl
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from sklearn.preprocessing import StandardScaler
						

# Load the dataset
excel_path = r"C:\Users\LEGION\Documents\GIT\Tehran_COVID_Cohort\DO_NOT_PUBLISH\LLM_X_test_dataset.xlsx"
data = pd.read_excel(excel_path)

# Specify the relevant columns
ground_truth_col = 'Inhospital Mortalit(TRUE)'
model_prediction_col_list = ['Mixtral-8x7B', 'Llama-3-70B', 'Mistral-7B', 'Llama3-8B', 'GPT4o', 'GPT3.5', 'GPT4T', 'GPT4']
columns_to_drop = ['patient medical hidtory']

# Drop unnecessary columns
data = data.drop(columns=columns_to_drop)

feature_columns = [col for col in data.columns if col not in model_prediction_col_list + [ground_truth_col]]
scaler = StandardScaler()
for col in feature_columns:
    unique_values = data[col].unique()
    if not set(unique_values).issubset({0, 1}):
        data[[col]] = scaler.fit_transform(data[[col]])
        
        
all_shap_stats_df = pd.DataFrame()
for model_col in model_prediction_col_list:
    # Extract model predictions and input features
    X = data[feature_columns]
    y = data[model_col]

    # Train a surrogate model
    surrogate_model = XGBClassifier(random_state=42)  # Added random_state for reproducibility
    surrogate_model.fit(X, y)


    explainer = shap.Explainer(surrogate_model)
    shap_values = explainer(X)  # Use .shap_values(X) for tree-based explainers if using older SHAP versions

    # Extract feature names and their SHAP values (absolute values)
    shap_values_df = pd.DataFrame(shap_values.values, columns=feature_columns).abs()

    # Calculate the mean and standard deviation of SHAP values for each feature
    shap_values_stats = shap_values_df.describe().loc[['mean', 'std']]
    shap_values_stats = shap_values_stats.rename(index={'mean': f'{model_col}__mean', 'std': f'{model_col}__std'})

    all_shap_stats_df = pd.concat([all_shap_stats_df, shap_values_stats])
    
    # Visualize SHAP values
    fig_waterfall, ax_waterfall = plt.subplots()
    shap.plots.bar(shap_values.abs.mean(0), max_display=10)
    fig_waterfall.savefig(f'LLM___{model_col}_global.pdf', dpi=300, bbox_inches='tight')
    plt.close(fig_waterfall)  

    # Visualize SHAP values
    fig_violin, ax_violin = plt.subplots()
    shap.plots.violin(shap_values, max_display=10, plot_type='layered_violin', plot_size=(7,7))
    fig_violin.savefig(f'LLM___{model_col}_perinput.pdf', dpi=300, bbox_inches='tight')
    plt.close(fig_violin)  

    

In [None]:
all_shap_stats_df

In [None]:
def transform_SHAP_scores_to_impact_percentage(df):
    # Identifying rows ending with '__mean' and '__std'
    mean_rows = df.filter(regex='__mean$', axis=0)
    std_rows = df.filter(regex='__std$', axis=0)
    
    # Calculating the average for each column
    mean_average = mean_rows.mean(axis=0)
    std_average = std_rows.mean(axis=0)
    
    # Adding the new rows to the dataframe
    df.loc['ALL_average_mean'] = mean_average
    df.loc['ALL_average_std'] = std_average
    
    
    df = df.T
    mean_cols = df.filter(regex='_mean$')
    std_cols = df.filter(regex='_std$')    
    
    # Initialize dictionaries to store transformed means and stds
    impact_scores = {}
    impact_stds = {}    
    for mean_col, std_col in zip(mean_cols, std_cols):
        total_impact = df[mean_col].abs().sum()
        # Calculate the relative impact scores
        impact_scores[mean_col] = (df[mean_col].abs() / total_impact) * 100
        
        # Calculate the transformed standard deviations for impact scores
        
        impact_stds[std_col] = df[std_col] * (impact_scores[mean_col] / df[mean_col].abs())
    # Add new columns to the dataframe
    for col, scores in impact_scores.items():
        impact_col_name = col.replace('_mean', '_impact_score')
        df[impact_col_name] = scores
    
    for col, stds in impact_stds.items():
        impact_std_name = col.replace('_std', '_impact_score_std')
        df[impact_std_name] = stds        
        
    # Sort columns based on root feature name, ignoring specific suffixes
    sorted_columns = sorted(df.columns, key=lambda x: x.split('_')[0] + x.split('_')[-1])
    df = df[sorted_columns]        
    return df

LLM_shap_stat = transform_SHAP_scores_to_impact_percentage(all_shap_stats_df)
LLM_shap_stat
#CML_shap_stat.to_excel(r"C:\Users\LEGION\Documents\GIT\Tehran_COVID_Cohort\DO_NOT_PUBLISH\CML_shap_stat.xlsx")


In [None]:
def plot_top_10_means(df, mean_column_name, std_column_name, naming_dict={}, plot_title=None, save_plt_name=None):
    # Extract the means and standard deviations
    means = df[mean_column_name]
    stds = df[std_column_name]
    
    # Find the top 10 rows with the highest means
    top_10_indices = means.nlargest(10).index
    top_10_means = means.loc[top_10_indices]
    top_10_stds = stds.loc[top_10_indices]
    
    # If a naming dictionary is provided, rename the indices
    top_10_labels = []
    for index in top_10_indices:
        if index in naming_dict:
            top_10_labels.append(naming_dict[index])
        else:
            print(f"{index} not in naming_dict")
            top_10_labels.append(index)
    
    # Convert the standard deviations into a format compatible with the error bar function:
    error = [np.zeros(len(top_10_stds)), top_10_stds] 
    fig, ax = plt.subplots()
    bars = ax.barh(top_10_labels, top_10_means, xerr=error, capsize=0, color='#FF1F5b', ecolor='#FF1F5b')
    
    # Annotate each bar with the mean value
    for bar, mean in zip(bars, top_10_means):
        ax.text(bar.get_width(), bar.get_y() + bar.get_height()/2, f'{mean:.2f}%', 
                va='center', ha='right', color='white', fontsize=10)
    
    ax.set_xlabel('Mean Impact Score and SD')
    if plot_title:
        ax.set_title(plot_title)

    plt.xticks(rotation=45)  # Rotate x-axis labels if necessary
    plt.tight_layout()  # Adjust layout to make room for rotated x-axis labels
    
    if save_plt_name:
        plt.savefig(save_plt_name, dpi=450, bbox_inches='tight')
        
    plt.show()


    
 
naming_dic = {'Age' :'Demographic - Age',
              'CR' : 'Lab - Cr',
              'sodium': 'Lab - Na',
              'MCV': 'Lab - MCV',
              'Hemoglobin':'Lab - Hb',
              'alkaline phosphatase': 'Lab - ALP',
              ' Lymphocyte count': 'Lab - Lymphocyte',
              ' Neutrophils percentage': 'Lab - Neutrophil',
              'diastolic Blood pressure': 'VS - Diastolic BP',
              'O2 saturation without supply': 'VS - O2 Saturation',
              'PT' : 'Lab - PT',
              'potassium' : 'Lab - K',
              'ESR' : 'Lab - ESR',
              } 
       
plot_top_10_means(LLM_shap_stat, 'ALL_average_impact_score', 'ALL_average_impact_score_std',naming_dict=naming_dic, 
                  save_plt_name="MAIN__LLM_MeanImpact.svg"
                  )


In [None]:
def plot_top_10_means(df, mean_column_name, std_column_name, naming_dict={}, plot_title=None, save_plt_name=None):
    # Extract the means and standard deviations
    means = df[mean_column_name]
    stds = df[std_column_name]
    
    # Find the top 10 rows with the highest means
    top_10_indices = means.nlargest(10).index
    top_10_means = means.loc[top_10_indices]
    top_10_stds = stds.loc[top_10_indices]
    
    # If a naming dictionary is provided, rename the indices
    top_10_labels = []
    for index in top_10_indices:
        if index in naming_dict:
            top_10_labels.append(naming_dict[index])
        else:
            print(f"{index} not in naming_dict")
            top_10_labels.append(index)
    
    # Convert the standard deviations into a format compatible with the error bar function:
    error = [np.zeros(len(top_10_stds)), top_10_stds] 
    fig, ax = plt.subplots()
    bars = ax.barh(top_10_labels, top_10_means, xerr=error, capsize=0, color='#F2Acca', ecolor='#F2Acca')
    
    # Annotate each bar with the mean value
    for bar, mean in zip(bars, top_10_means):
        ax.text(bar.get_width(), bar.get_y() + bar.get_height()/2, f'{mean:.2f}%', 
                va='center', ha='right', color='white', fontsize=10)
    
    ax.set_xlabel('Mean Impact Score and SD')
    if plot_title:
        ax.set_title(plot_title)

    plt.xticks(rotation=45)  # Rotate x-axis labels if necessary
    plt.tight_layout()  # Adjust layout to make room for rotated x-axis labels
    
    if save_plt_name:
        plt.savefig(save_plt_name, dpi=450, bbox_inches='tight')
        
    plt.show()


    
 
naming_dic = {'Age' :'Demographic - Age',
              'CR' : 'Lab - Cr',
              'sodium': 'Lab - Na',
              'MCV': 'Lab - MCV',
              'Hemoglobin':'Lab - Hb',
              'alkaline phosphatase': 'Lab - ALP',
              ' Lymphocyte count': 'Lab - Lymphocyte',
              ' Neutrophils percentage': 'Lab - Neutrophil',
              'diastolic Blood pressure': 'VS - Diastolic BP',
              'O2 saturation without supply': 'VS - O2 Saturation',

              } 
       
plot_top_10_means(LLM_shap_stat, 'GPT4__impact_score', 'GPT4__impact_score_std',naming_dict=naming_dic, 
                  save_plt_name="MAIN__LLM_GPT4Impact.svg"
                  )


In [None]:
LLM_shap_stat.columns

In [None]:
def plot_top_10_means(df, mean_column_name, std_column_name, naming_dict={}, plot_title=None, save_plt_name=None):
    # Extract the means and standard deviations
    means = df[mean_column_name]
    stds = df[std_column_name]
    
    # Find the top 10 rows with the highest means
    top_10_indices = means.nlargest(10).index
    top_10_means = means.loc[top_10_indices]
    top_10_stds = stds.loc[top_10_indices]
    
    # If a naming dictionary is provided, rename the indices
    top_10_labels = []
    for index in top_10_indices:
        if index in naming_dict:
            top_10_labels.append(naming_dict[index])
        else:
            print(f"{index} not in naming_dict")
            top_10_labels.append(index)
    
    # Convert the standard deviations into a format compatible with the error bar function:
    error = [np.zeros(len(top_10_stds)), top_10_stds] 
    fig, ax = plt.subplots()
    bars = ax.barh(top_10_labels, top_10_means, xerr=error, capsize=0, color='#e3a857', ecolor='#e3a857')
    
    # Annotate each bar with the mean value
    for bar, mean in zip(bars, top_10_means):
        ax.text(bar.get_width(), bar.get_y() + bar.get_height()/2, f'{mean:.2f}%', 
                va='center', ha='right', color='white', fontsize=10)
    
    ax.set_xlabel('Mean Impact Score and SD')
    if plot_title:
        ax.set_title(plot_title)

    plt.xticks(rotation=45)  # Rotate x-axis labels if necessary
    plt.tight_layout()  # Adjust layout to make room for rotated x-axis labels
    
    if save_plt_name:
        plt.savefig(save_plt_name, dpi=450, bbox_inches='tight')
        
    plt.show()


    
 
naming_dic = {'Age' :'Demographic - Age',
              'CR' : 'Lab - Cr',
              'sodium': 'Lab - Na',
              'MCV': 'Lab - MCV',
              'Hemoglobin':'Lab - Hb',
              'alkaline phosphatase': 'Lab - ALP',
              ' Lymphocyte count': 'Lab - Lymphocyte',
              ' Neutrophils percentage': 'Lab - Neutrophil',
              'diastolic Blood pressure': 'VS - Diastolic BP',
              'O2 saturation without supply': 'VS - O2 Saturation',
              'PT':'Lab - PT',
              'potassium':'Lab - K',
              'PTT' : 'Lab - PTT',
              'Temperature' : 'VS - Temp'

              } 
       
plot_top_10_means(LLM_shap_stat, 'Mistral-7B__impact_score', 'Mistral-7B__impact_score_std',naming_dict=naming_dic, 
                  save_plt_name="MAIN__LLM_mistralImpact.svg"
                  )


## CMLS

In [None]:
import pandas as pd
import shap
import matplotlib.pyplot as plt
import openpyxl
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier

# Load the dataset
excel_path = r"C:\Users\LEGION\Documents\GIT\Tehran_COVID_Cohort\DO_NOT_PUBLISH\CML_X_test_dataset.xlsx"
data = pd.read_excel(excel_path)

# Specify the relevant columns
ground_truth_col = 'y_true'
model_prediction_col_list = ['logistic regression_y_predicted', 'SVM_y_predicted', 'Decision tree_y_predicted', 'knn_y_predicted', 'Random forest_y_predicted', 'XGboost_y_predicted', 'neural net_y_predicted']
columns_to_drop = ['logistic regression_y_predicted2',	'SVM_y_predicted2',	'Decision tree_y_predicted2',	'knn_y_predicted2',	'Random forest_y_predicted2',	'XGboost_y_predicted2',	'neural net_y_predicted2']

# Drop unnecessary columns
data = data.drop(columns=columns_to_drop)

feature_columns = [col for col in data.columns if col not in model_prediction_col_list + [ground_truth_col]]

all_shap_stats_df = pd.DataFrame()
for model_col in model_prediction_col_list:
    # Extract model predictions and input features
    X = data[feature_columns]
    y = data[model_col]

    # Train a surrogate model
    surrogate_model = XGBClassifier(random_state=42)  # Added random_state for reproducibility
    surrogate_model.fit(X, y)


    explainer = shap.Explainer(surrogate_model)
    shap_values = explainer(X)  # Use .shap_values(X) for tree-based explainers if using older SHAP versions

    # Extract feature names and their SHAP values (absolute values)
    shap_values_df = pd.DataFrame(shap_values.values, columns=feature_columns).abs()

    # Calculate the mean and standard deviation of SHAP values for each feature
    shap_values_stats = shap_values_df.describe().loc[['mean', 'std']]
    shap_values_stats = shap_values_stats.rename(index={'mean': f'{model_col}__mean', 'std': f'{model_col}__std'})

    all_shap_stats_df = pd.concat([all_shap_stats_df, shap_values_stats])
    
    # Visualize SHAP values
    fig_waterfall, ax_waterfall = plt.subplots()
    shap.plots.bar(shap_values.abs.mean(0), max_display=10)
    fig_waterfall.savefig(f'CML___{model_col}_global.pdf', dpi=300, bbox_inches='tight')
    plt.close(fig_waterfall)  

    # Visualize SHAP values
    fig_violin, ax_violin = plt.subplots()
    shap.plots.violin(shap_values, max_display=10, plot_type='layered_violin', plot_size=(7,7))
    fig_violin.savefig(f'CML___{model_col}_perinput.pdf', dpi=300, bbox_inches='tight')
    plt.close(fig_violin)  

    

In [None]:
def transform_SHAP_scores_to_impact_percentage(df):
    # Identifying rows ending with '__mean' and '__std'
    mean_rows = df.filter(regex='__mean$', axis=0)
    std_rows = df.filter(regex='__std$', axis=0)
    
    # Calculating the average for each column
    mean_average = mean_rows.mean(axis=0)
    std_average = std_rows.mean(axis=0)
    
    # Adding the new rows to the dataframe
    df.loc['ALL_average_mean'] = mean_average
    df.loc['ALL_average_std'] = std_average
    
    
    df = df.T
    mean_cols = df.filter(regex='_mean$')
    std_cols = df.filter(regex='_std$')    
    
    # Initialize dictionaries to store transformed means and stds
    impact_scores = {}
    impact_stds = {}    
    for mean_col, std_col in zip(mean_cols, std_cols):
        total_impact = df[mean_col].abs().sum()
        # Calculate the relative impact scores
        impact_scores[mean_col] = (df[mean_col].abs() / total_impact) * 100
        
        # Calculate the transformed standard deviations for impact scores
        
        impact_stds[std_col] = df[std_col] * (impact_scores[mean_col] / df[mean_col].abs())
    # Add new columns to the dataframe
    for col, scores in impact_scores.items():
        impact_col_name = col.replace('_mean', '_impact_score')
        df[impact_col_name] = scores
    
    for col, stds in impact_stds.items():
        impact_std_name = col.replace('_std', '_impact_score_std')
        df[impact_std_name] = stds        
        
    # Sort columns based on root feature name, ignoring specific suffixes
    sorted_columns = sorted(df.columns, key=lambda x: x.split('_')[0] + x.split('_')[-1])
    df = df[sorted_columns]        
    return df

CML_shap_stat = transform_SHAP_scores_to_impact_percentage(all_shap_stats_df)
CML_shap_stat
#CML_shap_stat.to_excel(r"C:\Users\LEGION\Documents\GIT\Tehran_COVID_Cohort\DO_NOT_PUBLISH\CML_shap_stat.xlsx")


In [None]:
def plot_top_10_means(df,  mean_column_name, std_column_name, naming_dict={}, plot_title=None, save_plt_name=None):
    # Extract the means and standard deviations
    means = df[mean_column_name]
    stds = df[std_column_name]
    
    # Find the top 10 rows with the highest means
    top_10_indices = means.nlargest(10).index
    top_10_means = means.loc[top_10_indices]
    top_10_stds = stds.loc[top_10_indices]
    
    # If a naming dictionary is provided, rename the indices
    top_10_labels = []
    for index in top_10_indices:
        if index in naming_dict:
            top_10_labels.append(naming_dict[index])
        else:
            print(f"{index} not in naming_dict")
            top_10_labels.append(index)
    
    # Convert the standard deviations into a format compatible with the error bar function:
    error = [np.zeros(len(top_10_stds)), top_10_stds] 
    fig, ax = plt.subplots()
    bars = ax.barh(top_10_labels, top_10_means, xerr=error, capsize=0, color='#009ADE', ecolor='#009ADE')
    
    # Annotate each bar with the mean value
    for bar, mean in zip(bars, top_10_means):
        ax.text(bar.get_width(), bar.get_y() + bar.get_height()/2, f'{mean:.2f}%', 
                va='center', ha='right', color='white', fontsize=10)
    
    ax.set_xlabel('Mean Impact Score and SD')
    if plot_title:
        ax.set_title(plot_title)

    plt.xticks(rotation=45)  # Rotate x-axis labels if necessary
    plt.tight_layout()  # Adjust layout to make room for rotated x-axis labels
    
    if save_plt_name:
        plt.savefig(save_plt_name, dpi=450, bbox_inches='tight')
        
    plt.show()


    
 
naming_dic = {'Demographic_Age': "Demographic - Age" , 'VS_O2satwithoutsupp':"VS - O2 Saturation", 'Symptom_LOC':"Symptom - LOC", 'LAB_LYMPHH_1': "Lab - Lymphocyte",'Symptom_Dyspnea': 'Symptom - Dyspnea', 'Symptom_Mylagia': 'Symptom - Dyspnea', 'Demographic_Gender': 'Demographic - Gender', 'LAB_CR_1' : 'Lab - Crr', 'LAB_PLT_1': 'Lab - PLT', 'VS_RR': 'VS - RR'} 
       
#plot_top_10_means(CML_shap_stat, 'ALL_impact_score', 'ALL_impact_score_std',naming_dict=naming_dic, save_plt_name="MAIN__CML_MeanImpact.svg")


In [None]:
def plot_top_10_means(df, mean_column_name, std_column_name, naming_dict={}, plot_title=None, save_plt_name=None):
    # Extract the means and standard deviations
    means = df[mean_column_name]
    stds = df[std_column_name]
    
    # Find the top 10 rows with the highest means
    top_10_indices = means.nlargest(10).index
    top_10_means = means.loc[top_10_indices]
    top_10_stds = stds.loc[top_10_indices]
    
    # If a naming dictionary is provided, rename the indices
    top_10_labels = []
    for index in top_10_indices:
        if index in naming_dict:
            top_10_labels.append(naming_dict[index])
        else:
            print(f"{index} not in naming_dict")
            top_10_labels.append(index)
    
    # Convert the standard deviations into a format compatible with the error bar function:
    error = [np.zeros(len(top_10_stds)), top_10_stds] 
    fig, ax = plt.subplots()
    bars = ax.barh(top_10_labels, top_10_means, xerr=error, capsize=0, color='#00deb3', ecolor='#00deb3')
    
    # Annotate each bar with the mean value
    for bar, mean in zip(bars, top_10_means):
        ax.text(bar.get_width(), bar.get_y() + bar.get_height()/2, f'{mean:.2f}%', 
                va='center', ha='right', color='white', fontsize=10)
    
    ax.set_xlabel('Mean Impact Score and SD')
    if plot_title:
        ax.set_title(plot_title)

    plt.xticks(rotation=45)  # Rotate x-axis labels if necessary
    plt.tight_layout()  # Adjust layout to make room for rotated x-axis labels
    if save_plt_name:
        plt.savefig(save_plt_name, dpi=300, bbox_inches='tight')
    plt.show()

    
 
naming_dic = {'Demographic_Age': "Demographic - Age" , 'VS_O2satwithoutsupp':"VS - O2 Saturation", 'Symptom_LOC':"Symptom - LOC", 'LAB_LYMPHH_1': "Lab - Lymphocyte",'Symptom_Dyspnea': 'Symptom - Dyspnea', 'Symptom_Mylagia': 'Symptom - Dyspnea', 'Demographic_Gender': 'Demographic - Gender', 'LAB_CR_1' : 'Lab - Crr', 'LAB_PLT_1': 'Lab - PLT', 'VS_RR': 'VS - RR', 'Symptom_Caugh': 'Symptom - Cough', 'LAB_NA_First':' Lab - Na' } 
       
plot_top_10_means(CML_shap_stat, 'XGboost_y_predicted__mean', 'XGboost_y_predicted__std',naming_dict=naming_dic, 
                  save_plt_name="MAIN__CML_XGBoostImpac.svg"
                  )


## Fine-tuned LLM

In [None]:
import pandas as pd
import shap
import matplotlib.pyplot as plt
import openpyxl
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier

# Load the dataset
excel_path = r"C:\Users\LEGION\Documents\GIT\Tehran_COVID_Cohort\DO_NOT_PUBLISH\finetuned_X_ex_features.xlsx"
data = pd.read_excel(excel_path)

# Specify the relevant columns
ground_truth_col = 'y_true'
model_prediction_col_list = ['Fine-Tuned Mistral-7B']
columns_to_drop = []

# Drop unnecessary columns
data = data.drop(columns=columns_to_drop)

feature_columns = [col for col in data.columns if col not in model_prediction_col_list + [ground_truth_col]]

all_shap_stats_df = pd.DataFrame()
for model_col in model_prediction_col_list:
    # Extract model predictions and input features
    X = data[feature_columns]
    y = data[model_col]

    # Train a surrogate model
    surrogate_model = XGBClassifier(random_state=42)  # Added random_state for reproducibility
    surrogate_model.fit(X, y)


    explainer = shap.Explainer(surrogate_model)
    shap_values = explainer(X)  # Use .shap_values(X) for tree-based explainers if using older SHAP versions

    # Extract feature names and their SHAP values (absolute values)
    shap_values_df = pd.DataFrame(shap_values.values, columns=feature_columns).abs()

    # Calculate the mean and standard deviation of SHAP values for each feature
    shap_values_stats = shap_values_df.describe().loc[['mean', 'std']]
    shap_values_stats = shap_values_stats.rename(index={'mean': f'{model_col}__mean', 'std': f'{model_col}__std'})

    all_shap_stats_df = pd.concat([all_shap_stats_df, shap_values_stats])
    
    # Visualize SHAP values
    fig_waterfall, ax_waterfall = plt.subplots()
    shap.plots.bar(shap_values.abs.mean(0), max_display=10)
    fig_waterfall.savefig(f'CML___{model_col}_global.pdf', dpi=300, bbox_inches='tight')
    plt.close(fig_waterfall)  

    # Visualize SHAP values
    fig_violin, ax_violin = plt.subplots()
    shap.plots.violin(shap_values, max_display=10, plot_type='layered_violin', plot_size=(7,7))
    fig_violin.savefig(f'CML___{model_col}_perinput.pdf', dpi=300, bbox_inches='tight')
    plt.close(fig_violin)  

    

In [None]:
def transform_SHAP_scores_to_impact_percentage(df):
    # Identifying rows ending with '__mean' and '__std'
    mean_rows = df.filter(regex='__mean$', axis=0)
    std_rows = df.filter(regex='__std$', axis=0)
    
    # Calculating the average for each column
    mean_average = mean_rows.mean(axis=0)
    std_average = std_rows.mean(axis=0)
    
    # Adding the new rows to the dataframe
    df.loc['ALL_average_mean'] = mean_average
    df.loc['ALL_average_std'] = std_average
    
    
    df = df.T
    mean_cols = df.filter(regex='_mean$')
    std_cols = df.filter(regex='_std$')    
    
    # Initialize dictionaries to store transformed means and stds
    impact_scores = {}
    impact_stds = {}    
    for mean_col, std_col in zip(mean_cols, std_cols):
        total_impact = df[mean_col].abs().sum()
        # Calculate the relative impact scores
        impact_scores[mean_col] = (df[mean_col].abs() / total_impact) * 100
        
        # Calculate the transformed standard deviations for impact scores
        
        impact_stds[std_col] = df[std_col] * (impact_scores[mean_col] / df[mean_col].abs())
    # Add new columns to the dataframe
    for col, scores in impact_scores.items():
        impact_col_name = col.replace('_mean', '_impact_score')
        df[impact_col_name] = scores
    
    for col, stds in impact_stds.items():
        impact_std_name = col.replace('_std', '_impact_score_std')
        df[impact_std_name] = stds        
        
    # Sort columns based on root feature name, ignoring specific suffixes
    sorted_columns = sorted(df.columns, key=lambda x: x.split('_')[0] + x.split('_')[-1])
    df = df[sorted_columns]        
    return df

fine_LLM = transform_SHAP_scores_to_impact_percentage(all_shap_stats_df)
fine_LLM
#CML_shap_stat.to_excel(r"C:\Users\LEGION\Documents\GIT\Tehran_COVID_Cohort\DO_NOT_PUBLISH\CML_shap_stat.xlsx")


In [None]:
def plot_top_10_means(df, mean_column_name, std_column_name, naming_dict={}, plot_title=None, save_plt_name=None):
    # Extract the means and standard deviations
    means = df[mean_column_name]
    stds = df[std_column_name]
    
    # Find the top 10 rows with the highest means
    top_10_indices = means.nlargest(10).index
    top_10_means = means.loc[top_10_indices]
    top_10_stds = stds.loc[top_10_indices]
    
    # If a naming dictionary is provided, rename the indices
    top_10_labels = []
    for index in top_10_indices:
        if index in naming_dict:
            top_10_labels.append(naming_dict[index])
        else:
            print(f"{index} not in naming_dict")
            top_10_labels.append(index)
    
    # Convert the standard deviations into a format compatible with the error bar function:
    error = [np.zeros(len(top_10_stds)), top_10_stds] 
    fig, ax = plt.subplots()
    bars = ax.barh(top_10_labels, top_10_means, xerr=error, capsize=0, color='#28b925', ecolor='#28b925')
    
    # Annotate each bar with the mean value
    for bar, mean in zip(bars, top_10_means):
        ax.text(bar.get_width(), bar.get_y() + bar.get_height()/2, f'{mean:.2f}%', 
                va='center', ha='right', color='white', fontsize=10)
    
    ax.set_xlabel('Mean Impact Score and SD')
    if plot_title:
        ax.set_title(plot_title)

    plt.xticks(rotation=45)  # Rotate x-axis labels if necessary
    plt.tight_layout()  # Adjust layout to make room for rotated x-axis labels
    if save_plt_name:
        plt.savefig(save_plt_name, dpi=300, bbox_inches='tight')
    plt.show()

    
 
naming_dic = {'Age' :'Demographic - Age',
              'CR' : 'Lab - Cr',
              'sodium': 'Lab - Na',
              'MCV': 'Lab - MCV',
              'Hemoglobin':'Lab - Hb',
              'alkaline phosphatase': 'Lab - ALP',
              ' Lymphocyte count': 'Lab - Lymphocyte',
              ' Neutrophils percentage': 'Lab - Neutrophil',
              'diastolic Blood pressure': 'VS - Diastolic BP',
              'O2 saturation without supply': 'VS - O2 Saturation',
              'PT':'Lab - PT',
              'potassium':'Lab - K',
              'PTT' : 'Lab - PTT',
              'Temperature' : 'VS - Temp',
              'loss of consciousness': 'Symptomp - LOC',
              'Systolic Blood pressure': 'VS- Systolic BP',
              'pulse rate': 'VS - Pule Rate',
              'Dyspnea': 'Symptom - Dyspnea'

              } 
       
plot_top_10_means(fine_LLM, 'Fine-Tuned Mistral-7B__impact_score', 'Fine-Tuned Mistral-7B__impact_score_std',naming_dict=naming_dic, 
                  save_plt_name="MAIN__fineLLM_FinemistralImpac.svg"
                  )
