In [6]:
import pandas as pd
import json
import matplotlib.pyplot as plt
import ast
import numpy as np
import pickle
from treeinterpreter import treeinterpreter as ti
import re
import os
from constants import base_path, model_list
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
tqdm.pandas()

In [7]:
# load validation data
diagnosis_df_valid = pd.read_csv(f"{base_path}\\input\\release_validate_patients")

# get evidence codes and english texts
with open(f"{base_path}\\input\\release_evidences.json") as f:
  evidences = json.load(f)
evidences_list = []
evidences_dict = {}
for e in evidences.keys():
  # only binary symptoms and antecedents
  if (not evidences[e]["possible-values"]):
    evidences_list.append(e)
    evidences_dict[e] = evidences[e]["question_en"]
evidences_dict["AGE"] = "AGE"
evidences_dict["SEX"] = "SEX"
feature_columns = ["AGE", "SEX"] + evidences_list

In [8]:
with open(f"{base_path}\\input\\release_conditions.json") as f:
  disease_dict = json.load(f)
disease_list = list(disease_dict.keys())

In [9]:
def data_proc(df):
    df["binary_evidences"] = df["EVIDENCES"].apply(lambda x: [d for d in ast.literal_eval(x) if "@" not in d])
    for e in evidences_list:
        df[e] = df["binary_evidences"].apply(lambda x: 1 if e in x else 0)
    df["SEX"] = df["SEX"].map({'F': 0, 'M': 1})
    df = df[feature_columns + ["PATHOLOGY"]]
    return df

In [10]:
diagnosis_df_valid = data_proc(diagnosis_df_valid)
diagnosis_df_valid

Unnamed: 0,AGE,SEX,E_91,E_53,E_159,E_129,E_154,E_155,E_210,E_140,...,E_199,E_121,E_120,E_142,E_195,E_183,E_224,E_223,E_5,PATHOLOGY
0,55,0,0,1,0,0,1,0,0,1,...,0,0,0,0,0,0,0,0,0,Anemia
1,10,0,0,1,0,0,0,1,0,0,...,0,0,0,0,0,0,0,0,0,Panic attack
2,68,0,1,1,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Influenza
3,13,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Anemia
4,48,1,0,1,0,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,Boerhaave
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
132443,27,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Viral pharyngitis
132444,57,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,Acute pulmonary edema
132445,52,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,GERD
132446,10,1,1,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Epiglottitis


In [11]:
def sample_patient(df, n_sample, seed=0):
    np.random.seed(seed)
    sample_idx = np.random.choice(df.index, size=n_sample, replace=False)
    # sample_idx = [9040] #when u want a specific patient
    sample_df = df.loc[sample_idx]
    return sample_df, sample_idx

In [12]:
# Select random patients
sample_df, sample_idx = sample_patient(diagnosis_df_valid, 100, seed=0)
sample_df

Unnamed: 0,AGE,SEX,E_91,E_53,E_159,E_129,E_154,E_155,E_210,E_140,...,E_199,E_121,E_120,E_142,E_195,E_183,E_224,E_223,E_5,PATHOLOGY
122890,56,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Pericarditis
5880,24,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Myocarditis
52220,59,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Sarcoidosis
38102,68,0,1,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Viral pharyngitis
19860,39,0,0,1,0,0,0,1,0,0,...,0,0,0,0,0,0,0,0,0,PSVT
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
120475,66,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Sarcoidosis
104198,26,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Acute dystonic reactions
70938,43,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,URTI
3508,1,0,1,1,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Pneumonia


## Tree-based models

In [34]:
def pred_explain_tree_model(x, idx, model_dict, model_name):
    # create output path per patient
    img_path = f'{base_path}\\output\\patients\\patient{idx}\\{model_name}'
    if not os.path.exists(img_path):
        os.makedirs(img_path)
    
    # predict
    pred_list = []
    for target_disease in disease_list:
        clf_model = model_dict[target_disease]
        prediction = clf_model.predict_proba(x)
        prediction_proba = prediction[0][1]
        if prediction_proba>0:
            pred_list.append({
                "disease": target_disease,
                "probability": prediction_proba
            })
    
    # return all predictions, no rank filter
    # but explain top 1 only - rank allows ties
    if pred_list:
        pred_df = pd.DataFrame(pred_list).set_index('disease')
        pred_df['rank'] = pred_df['probability'].rank(method='min', ascending=False)
        pred_df = pred_df.sort_values(by="rank")
        pred_df = pred_df.sort_values(by="probability", ascending=False)
        pred_dict_all = pred_df.to_dict()["probability"]
        pred_df = pred_df[pred_df["rank"]<=1][["probability"]]
        pred_dict = pred_df.to_dict()["probability"]
        # pred_dict = pred_dict_all
        diagnosis_prediction = {
            "patient_id": idx,
            "diagnosis_prediction": pred_dict_all
        }
        with open(f"{img_path}\\diagnosis_prediction.json", "w") as outfile: 
            json.dump(diagnosis_prediction, outfile, indent=True)
        for target_disease in pred_dict:
            clf_model = model_dict[target_disease]
            prediction, _, contributions = ti.predict(clf_model, x)
            contributions_values = contributions[0][:,1]
            contributions_df = pd.DataFrame({"contributions_values": contributions_values, "contributions_abs_values": abs(contributions_values)})
            symptoms_en = x.columns.map(evidences_dict)
            symptoms_values = [x[f].values[0] for f in x.columns]
            symptoms_df = pd.DataFrame({"symptoms_en": symptoms_en, "symptoms_values": symptoms_values})
            contributions_df.index  = symptoms_df["symptoms_en"] + "=" + symptoms_df["symptoms_values"].astype(str)
            contributions_df["symptoms_values"] = symptoms_values
            contributions_df = contributions_df[contributions_df["contributions_values"]>0]
            contributions_df = contributions_df[contributions_df["symptoms_values"]>0]
            contributions_df = contributions_df.sort_values(by="contributions_abs_values", ascending=False).head(10).sort_values(by="contributions_abs_values")
            contributions_df["contributions_values"].plot.barh()
            plt.xlabel("Symptom Importance Score")
            plt.title(f"Probability of {target_disease}: {pred_dict[target_disease]:.3f}\n({model_name})")
            plt.figtext(.01, .99, 'Symptoms with higher importance score support a positive diagnosis.')
            img_filename = re.sub('[^a-zA-Z0-9 \n\.]', '', target_disease).replace(" ", "_")
            plt.savefig(f"{img_path}\\{img_filename}.jpg", bbox_inches='tight')
            plt.clf()
    return pred_dict_all

In [35]:
for model_name in model_list["tree-based"]:
    if model_name!="gradient_boost":
        print(f"Explaining {model_name}...")
        model_dict = {}
        for disease in disease_list:
            disease_filename = re.sub('[^a-zA-Z0-9 \n\.]', '', disease).replace(" ", "_")
            with open(f'{base_path}\\output\\diseases\\{disease_filename}\\{model_name}\\{disease_filename}_model.pkl', 'rb') as f:
                model_dict[disease] = pickle.load(f)
        sample_df["predicted_diagnosis"] = sample_df[feature_columns].progress_apply(lambda x : pred_explain_tree_model(sample_df[feature_columns].loc[[x.name]], x.name, model_dict, model_name), axis=1)

Explaining decision_tree...


  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:24<00:00,  4.13it/s]


Explaining random_forest...


100%|██████████| 100/100 [01:16<00:00,  1.31it/s]


<Figure size 640x480 with 0 Axes>

## Logistic Regression

In [13]:
# Open LogisticRegression Model
model_dict = {}
feature_importance = {}
for disease in disease_list:
    disease_filename = re.sub('[^a-zA-Z0-9 \n\.]', '', disease).replace(" ", "_")
    with open(f'{base_path}\\output\\diseases\\{disease_filename}\\logistic_regression\\{disease_filename}_model.pkl', 'rb') as f:
        model_dict[disease] = pickle.load(f)
    with open(f'{base_path}\\output\\diseases\\{disease_filename}\\logistic_regression\\feature_importance.json', 'rb') as f:
        feature_importance[disease] = json.load(f)

In [14]:
def pred_explain(x, idx):
    # create output path per patient
    img_path = f'{base_path}\\output\\patients\\patient{idx}\\logistic_regression'
    if not os.path.exists(img_path):
        os.makedirs(img_path)
    
    # predict
    pred_list = []
    highest_prob = 0
    highest_prob_disease = ""
    for target_disease in disease_list:
        logreg_model = model_dict[target_disease]
        prediction = logreg_model.predict_proba(x)
        prediction_proba = prediction[0][1]
        if prediction_proba > 0:
            pred_list.append({
                "disease": target_disease,
                "probability": prediction_proba
            })
            if prediction_proba > highest_prob:
                highest_prob = prediction_proba
                highest_prob_disease = target_disease
    
    # return top 3 predictions
    if pred_list:
        pred_df = pd.DataFrame(pred_list).set_index('disease')
        pred_df = pred_df.sort_values(by="probability", ascending=False).head(3)
        pred_dict_all = pred_df.to_dict()["probability"]
        pred_dict = pred_dict_all.copy()
        diagnosis_prediction = {
            "patient_id": idx,
            "diagnosis_prediction": pred_dict_all
        }
        with open(f"{img_path}\\diagnosis_logreg_prediction.json", "w") as outfile: 
            json.dump(diagnosis_prediction, outfile, indent=True)
        for target_disease in pred_dict:
            logreg_model = model_dict[target_disease]
            prediction = logreg_model.predict_proba(x)
            prediction_proba = prediction[0][1]
            if prediction_proba > 0:
                pred_dict[target_disease] = prediction_proba
                if target_disease == highest_prob_disease:
                    symptoms_values = [x[f].values[0] for f in x.columns]
                    # model_coeffs = logreg_model.coef_[0]
                    # get standardized coeffs
                    model_coeffs = [feature_importance[target_disease][evidences_dict[f]] for f in x.columns]
                    # to get ftr importance, multiply ftr coeff with ftr value
                    contributions_values = [model_coeffs[i]*symptoms_values[i] for i in range(len(symptoms_values))]
                    contributions_df = pd.DataFrame({"contributions_values": contributions_values, "contributions_abs_values": [abs(i) for i in contributions_values]})
                    symptoms_en = x.columns.map(evidences_dict)
                    symptoms_values = [str(f) for f in symptoms_values]
                    symptoms_df = pd.DataFrame({"symptoms_en": symptoms_en, "symptoms_values": symptoms_values})
                    contributions_df.index = symptoms_df["symptoms_en"] + "=" + symptoms_df["symptoms_values"]
                    contributions_df = contributions_df.sort_values(by="contributions_abs_values", ascending=False).head(10).sort_values(by="contributions_abs_values")
                    contributions_df = contributions_df[contributions_df["contributions_values"]!=0]
                    contributions_df["contributions_values"].plot.barh()
                    plt.xlabel("Symptom Importance Score")
                    plt.title(f"Probability of {target_disease}: {pred_dict[target_disease]:.3f}\n(logistic_regression)")
                    plt.figtext(.01, .99, 'Symptoms with higher importance score support a positive diagnosis.')
                    img_filename = re.sub('[^a-zA-Z0-9 \n\.]', '', target_disease).replace(" ", "_")
                    plt.savefig(f"{img_path}\\{img_filename}.jpg", bbox_inches='tight')
                    plt.clf()
        return pred_dict_all

In [15]:
sample_df["predicted_diagnosis"] = sample_df[feature_columns].progress_apply(lambda x : pred_explain(sample_df[feature_columns].loc[[x.name]], x.name), axis=1)
sample_df

100%|██████████| 100/100 [00:46<00:00,  2.16it/s]


Unnamed: 0,AGE,SEX,E_91,E_53,E_159,E_129,E_154,E_155,E_210,E_140,...,E_121,E_120,E_142,E_195,E_183,E_224,E_223,E_5,PATHOLOGY,predicted_diagnosis
122890,56,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,Pericarditis,"{'Pericarditis': 0.9999869558820509, 'Spontane..."
5880,24,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,Myocarditis,"{'Myocarditis': 0.9999570158410973, 'Inguinal ..."
52220,59,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,Sarcoidosis,"{'Sarcoidosis': 0.99770742368118, 'Chagas': 0...."
38102,68,0,1,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,Viral pharyngitis,"{'Viral pharyngitis': 0.99999912062668, 'Chaga..."
19860,39,0,0,1,0,0,0,1,0,0,...,0,0,0,0,0,0,0,0,PSVT,"{'PSVT': 0.9996853763973577, 'Myocarditis': 0...."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
120475,66,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,Sarcoidosis,"{'Sarcoidosis': 0.9999457437869111, 'Chagas': ..."
104198,26,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,Acute dystonic reactions,{'Acute dystonic reactions': 0.999999998804219...
70938,43,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,URTI,"{'URTI': 0.9999999263117864, 'Ebola': 0.002426..."
3508,1,0,1,1,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,Pneumonia,"{'Pneumonia': 0.9999999999390268, 'Bronchiolit..."


<Figure size 640x480 with 0 Axes>