In [1]:
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from variables import *
import lightgbm as lgb
import pandas as pd
import shap

  from .autonotebook import tqdm as notebook_tqdm


# __Functions__

In [2]:
def excludeOutcomes(df, Outcomes):
    baseOutcomes = df.drop(columns=COLUMNS_TO_EXCLUDE_BY_OUTCOME[Outcomes]).copy()
    baseOutcomes = baseOutcomes.drop(columns=['onda', 'dataadm', 'direto_cti']).copy()

    baseOutcomes = baseOutcomes.dropna(subset=[Outcomes]).copy()

    baseOutcomes[Outcomes] = baseOutcomes[Outcomes].astype(int).copy()

    return baseOutcomes

def shapPlot(df, target_col, classifier, wave, outcome, sociodemographicVariables):
    X = df.drop(columns=[target_col])
    y = df[target_col]

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)

    clf = lgb.LGBMClassifier(random_state=42, verbose=-1)
    clf.fit(X_train, y_train)
    explainer = shap.TreeExplainer(clf)
    shapVals = explainer(X_test)

    saveDir = ''
    if sociodemographicVariables == True:
        shapVals.feature_names = [VARIABLES_TO_RENAME_SOCIODEMOGRAPHIC.get(feat, feat) for feat in shapVals.feature_names]
        saveDir = f"graph/sociodemographic/shap_{classifier}_{wave}_{outcome}_sociodemographic.png"
    else:
        shapVals.feature_names = [VARIABLES_TO_RENAME_NO_SOCIODEMOGRAPHIC.get(feat, feat) for feat in shapVals.feature_names]
        saveDir = f"graph/noSociodemographic/shap_{classifier}_{wave}_{outcome}_no_sociodemographic.png"


    plt.figure(figsize=(12,8))

    ax = plt.gca()
    shap.plots.beeswarm(shapVals, show=False)   

    plt.tight_layout()

    plt.savefig(saveDir, format='png') 
    plt.close() 
    plt.show()


# __SHAP plot__

In [3]:
waves = ['Onda 2 e 3']
outcomes = ['intercorrencia_3_5_6_13_16']
classifier = "LightGBM"

## __with sociodemographic variables__

In [4]:
base_covid = pd.read_parquet("datasets/banco_completo_REGISTRO_COVID_28_08_processado_cardiopatia_sociodemographic.parquet")

nPacientes = base_covid.shape[0]
nVariaveis = base_covid.shape[1]
print(f"Number of patients: {nPacientes}\nNumber of variables: {nVariaveis}")

Number of patients: 16957
Number of variables: 68


In [5]:
for wave in waves:
    for outcome in outcomes:
        if wave in ['Onda 2 e 3']:
            df_wave = base_covid[base_covid['onda'].isin(['Onda 2', 'Onda 3'])].copy()

        df_wave_outcome = excludeOutcomes(df_wave, outcome)

        shapPlot(df_wave_outcome, outcome, classifier, wave, outcome, True)


## __without sociodemographic variables__

In [6]:
base_covid = pd.read_parquet("datasets/banco_completo_REGISTRO_COVID_28_08_processado_cardiopatia_no_sociodemographic.parquet")

nPacientes = base_covid.shape[0]
nVariaveis = base_covid.shape[1]
print(f"Number of patients: {nPacientes}\nNumber of variables: {nVariaveis}")

Number of patients: 16957
Number of variables: 61


In [7]:
for wave in waves:
    for outcome in outcomes:
        if wave in ['Onda 2 e 3']:
            df_wave = base_covid[base_covid['onda'].isin(['Onda 2', 'Onda 3'])].copy()

        df_wave_outcome = excludeOutcomes(df_wave, outcome)

        shapPlot(df_wave_outcome, outcome, classifier, wave, outcome, False)
