In [2]:
import pandas as pd
import json
import matplotlib.pyplot as plt
import ast
import numpy as np
import pickle
from treeinterpreter import treeinterpreter as ti
import re
import os
from constants import base_path
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
tqdm.pandas()

In [4]:
# load validation data
diagnosis_df_valid = pd.read_csv(f"{base_path}\\input\\release_validate_patients")

# get evidence codes and english texts
with open(f"{base_path}\\input\\release_evidences.json") as f:
  evidences = json.load(f)
evidences_list = []
evidences_dict = {}
for e in evidences.keys():
  # only binary symptoms and no antecedents
  if (not evidences[e]["possible-values"]) and (not evidences[e]["is_antecedent"]):
    evidences_list.append(e)
    evidences_dict[e] = evidences[e]["question_en"]
evidences_dict["AGE"] = "AGE"
evidences_dict["SEX"] = "SEX"
feature_columns = ["AGE", "SEX"] + evidences_list

In [5]:
with open(f"{base_path}\\input\\release_conditions.json") as f:
  disease_dict = json.load(f)
disease_list = list(disease_dict.keys())

In [6]:
def data_proc(df):
    df["binary_evidences"] = df["EVIDENCES"].apply(lambda x: [d for d in ast.literal_eval(x) if "@" not in d])
    for e in evidences_list:
        df[e] = df["binary_evidences"].apply(lambda x: 1 if e in x else 0)
    df["SEX"] = df["SEX"].map({'F': 0, 'M': 1})
    df = df[feature_columns + ["PATHOLOGY"]]
    return df

In [7]:
diagnosis_df_valid = data_proc(diagnosis_df_valid)
diagnosis_df_valid

Unnamed: 0,AGE,SEX,E_91,E_53,E_159,E_129,E_154,E_155,E_210,E_140,...,E_193,E_168,E_180,E_67,E_171,E_111,E_182,E_103,E_23,PATHOLOGY
0,55,0,0,1,0,0,1,0,0,1,...,0,0,0,0,0,0,0,0,0,Anemia
1,10,0,0,1,0,0,0,1,0,0,...,0,0,0,0,1,1,0,0,0,Panic attack
2,68,0,1,1,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Influenza
3,13,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Anemia
4,48,1,0,1,0,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,Boerhaave
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
132443,27,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Viral pharyngitis
132444,57,1,0,1,0,0,0,0,0,0,...,0,0,0,1,0,0,0,0,0,Acute pulmonary edema
132445,52,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,GERD
132446,10,1,1,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Epiglottitis


## Random Forest

In [8]:
model_dict = {}
for disease in disease_list:
    disease_filename = re.sub('[^a-zA-Z0-9 \n\.]', '', disease).replace(" ", "_")
    with open(f'{base_path}\\output\\diseases\\{disease_filename}\\{disease_filename}_model.pkl', 'rb') as f:
        model_dict[disease] = pickle.load(f)

In [9]:
def sample_patient(df, n_sample, seed=0):
    np.random.seed(seed)
    sample_idx = np.random.choice(df.index, size=n_sample, replace=False)
    # sample_idx = [9040] #when u want a specific patient
    sample_df = df.loc[sample_idx]
    return sample_df, sample_idx


In [10]:
def pred_explain(x, idx):
    # create output path per patient
    img_path = f'{base_path}\\output\\patients\\patient{idx}'
    if not os.path.exists(img_path):
        os.makedirs(img_path)
    
    # predict
    pred_list = []
    for target_disease in disease_list:
        rf_model = model_dict[target_disease]
        prediction = rf_model.predict_proba(x)
        prediction_proba = prediction[0][1]
        if prediction_proba>0:
            pred_list.append({
                "disease": target_disease,
                "probability": prediction_proba
            })
    
    # return all predictions, no rank filter
    # but explain top 1 only - rank allows ties
    if pred_list:
        pred_df = pd.DataFrame(pred_list).set_index('disease')
        pred_df['rank'] = pred_df['probability'].rank(method='min', ascending=False)
        pred_df = pred_df.sort_values(by="rank")
        pred_df = pred_df.sort_values(by="probability", ascending=False)
        pred_dict_all = pred_df.to_dict()["probability"]
        pred_df = pred_df[pred_df["rank"]<=1][["probability"]]
        pred_dict = pred_df.to_dict()["probability"]
        # pred_dict = pred_dict_all
        diagnosis_prediction = {
            "patient_id": idx,
            "diagnosis_prediction": pred_dict_all
        }
        with open(f"{img_path}\\diagnosis_prediction.json", "w") as outfile: 
            json.dump(diagnosis_prediction, outfile, indent=True)
        for target_disease in pred_dict:
            rf_model = model_dict[target_disease]
            prediction, bias, contributions = ti.predict(rf_model, x)
            contributions_values = contributions[0][:,1]
            contributions_df = pd.DataFrame({"contributions_values": contributions_values, "contributions_abs_values": abs(contributions_values)})
            symptoms_en = x.columns.map(evidences_dict)
            symptoms_values = [str(x[f].values[0]) for f in x.columns]
            symptoms_df = pd.DataFrame({"symptoms_en": symptoms_en, "symptoms_values": symptoms_values})
            contributions_df.index  = symptoms_df["symptoms_en"] + "=" + symptoms_df["symptoms_values"]
            contributions_df = contributions_df.sort_values(by="contributions_abs_values", ascending=False).head(10).sort_values(by="contributions_abs_values")
            contributions_df["contributions_values"].plot.barh()
            plt.xlabel("Symptom Importance Score")
            plt.title(f"Probability of {target_disease}: {pred_dict[target_disease]:.3f}")
            plt.figtext(.01, .99, 'Symptoms with bars pointing to the right support a positive diagnosis.\nSymptoms with bars pointing to the left do not support a positive diagnosis.')
            img_filename = re.sub('[^a-zA-Z0-9 \n\.]', '', target_disease).replace(" ", "_")
            plt.savefig(f"{img_path}\\{img_filename}.jpg", bbox_inches='tight')
            plt.clf()
    return pred_dict_all

In [11]:
# Select random patients
sample_df, sample_idx = sample_patient(diagnosis_df_valid, 100, seed=0)
sample_df

Unnamed: 0,AGE,SEX,E_91,E_53,E_159,E_129,E_154,E_155,E_210,E_140,...,E_193,E_168,E_180,E_67,E_171,E_111,E_182,E_103,E_23,PATHOLOGY
122890,56,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Pericarditis
5880,24,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Myocarditis
52220,59,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Sarcoidosis
38102,68,0,1,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Viral pharyngitis
19860,39,0,0,1,0,0,0,1,0,0,...,0,0,0,0,0,0,0,0,0,PSVT
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
120475,66,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Sarcoidosis
104198,26,1,0,0,0,0,0,0,0,0,...,1,1,1,0,0,0,0,0,0,Acute dystonic reactions
70938,43,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,URTI
3508,1,0,1,1,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Pneumonia


In [12]:
sample_df["predicted_diagnosis"] = sample_df[feature_columns].progress_apply(lambda x : pred_explain(sample_df[feature_columns].loc[[x.name]], x.name), axis=1)
sample_df

100%|██████████| 100/100 [02:24<00:00,  1.45s/it]


Unnamed: 0,AGE,SEX,E_91,E_53,E_159,E_129,E_154,E_155,E_210,E_140,...,E_168,E_180,E_67,E_171,E_111,E_182,E_103,E_23,PATHOLOGY,predicted_diagnosis
122890,56,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,Pericarditis,{'Pericarditis': 1.0}
5880,24,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,Myocarditis,"{'Myocarditis': 1.0, 'Spontaneous pneumothorax..."
52220,59,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,Sarcoidosis,"{'Sarcoidosis': 1.0, 'Chagas': 0.18, 'Stable a..."
38102,68,0,1,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,Viral pharyngitis,"{'Viral pharyngitis': 1.0, 'Acute otitis media..."
19860,39,0,0,1,0,0,0,1,0,0,...,0,0,0,0,0,0,0,0,PSVT,"{'PSVT': 1.0, 'Atrial fibrillation': 0.02, 'Pe..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
120475,66,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,Sarcoidosis,"{'Sarcoidosis': 1.0, 'Pulmonary neoplasm': 0.0..."
104198,26,1,0,0,0,0,0,0,0,0,...,1,1,0,0,0,0,0,0,Acute dystonic reactions,{'Acute dystonic reactions': 1.0}
70938,43,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,URTI,"{'URTI': 1.0, 'Influenza': 0.09, 'Ebola': 0.03..."
3508,1,0,1,1,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,Pneumonia,"{'Pneumonia': 1.0, 'Bronchiolitis': 0.17, 'Cro..."


<Figure size 432x288 with 0 Axes>

In [13]:
sample_df.loc[[1774]]

Unnamed: 0,AGE,SEX,E_91,E_53,E_159,E_129,E_154,E_155,E_210,E_140,...,E_168,E_180,E_67,E_171,E_111,E_182,E_103,E_23,PATHOLOGY,predicted_diagnosis
1774,38,1,0,1,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,Influenza,"{'Influenza': 1.0, 'Pneumonia': 0.04, 'HIV (in..."


## Logistic Regression

In [14]:
# Open LogisticRegression Model
model_dict = {}
for disease in disease_list:
    disease_filename = re.sub('[^a-zA-Z0-9 \n\.]', '', disease).replace(" ", "_")
    with open(f'{base_path}\\output\\diseases\\{disease_filename}\\{disease_filename}_logreg_model.pkl', 'rb') as f:
        model_dict[disease] = pickle.load(f)

In [26]:
def pred_explain(x, idx):
    # create output path per patient
    img_path = f'{base_path}\\output\\patients\\patient{idx}'
    if not os.path.exists(img_path):
        os.makedirs(img_path)
    
    # predict
    pred_list = []
    highest_prob = 0
    highest_prob_disease = ""
    for target_disease in disease_list:
        logreg_model = model_dict[target_disease]
        prediction = logreg_model.predict_proba(x)
        prediction_proba = prediction[0][1]
        if prediction_proba > 0:
            pred_list.append({
                "disease": target_disease,
                "probability": prediction_proba
            })
            if prediction_proba > highest_prob:
                highest_prob = prediction_proba
                highest_prob_disease = target_disease
    
    # return top 3 predictions
    if pred_list:
        pred_df = pd.DataFrame(pred_list).set_index('disease')
        pred_df = pred_df.sort_values(by="probability", ascending=False).head(3)
        pred_dict_all = pred_df.to_dict()["probability"]
        pred_dict = pred_dict_all.copy()
        diagnosis_prediction = {
            "patient_id": idx,
            "diagnosis_prediction": pred_dict_all
        }
        with open(f"{img_path}\\diagnosis_logreg_prediction.json", "w") as outfile: 
            json.dump(diagnosis_prediction, outfile, indent=True)
        for target_disease in pred_dict:
            logreg_model = model_dict[target_disease]
            prediction = logreg_model.predict_proba(x)
            prediction_proba = prediction[0][1]
            if prediction_proba > 0:
                pred_dict[target_disease] = prediction_proba
                if target_disease == highest_prob_disease:
                    symptoms_values = [x[f].values[0] for f in x.columns]
                    model_coeffs = logreg_model.coef_[0]
                    # to get ftr importance, multiply ftr coeff with ftr value
                    contributions_values = [model_coeffs[i]*symptoms_values[i] for i in range(len(symptoms_values))]
                    contributions_df = pd.DataFrame({"contributions_values": contributions_values, "contributions_abs_values": [abs(i) for i in contributions_values]})
                    symptoms_en = x.columns.map(evidences_dict)
                    symptoms_values = [str(f) for f in symptoms_values]
                    symptoms_df = pd.DataFrame({"symptoms_en": symptoms_en, "symptoms_values": symptoms_values})
                    contributions_df.index = symptoms_df["symptoms_en"] + "=" + symptoms_df["symptoms_values"]
                    contributions_df = contributions_df.sort_values(by="contributions_abs_values", ascending=False).head(10).sort_values(by="contributions_abs_values")
                    contributions_df = contributions_df[contributions_df["contributions_values"]!=0]
                    contributions_df["contributions_values"].plot.barh()
                    plt.xlabel("Symptom Importance Score")
                    plt.title(f"Probability of {target_disease}: {pred_dict[target_disease]:.3f}")
                    plt.figtext(.01, .99, 'Symptoms with bars pointing to the right support a positive diagnosis.\nSymptoms with bars pointing to the left do not support a positive diagnosis.')
                    img_filename = re.sub('[^a-zA-Z0-9 \n\.]', '', target_disease).replace(" ", "_")
                    plt.savefig(f"{img_path}\\{img_filename}_logreg.jpg", bbox_inches='tight')
                    plt.clf()
        return pred_dict_all

In [25]:
sample_df["predicted_diagnosis"] = sample_df[feature_columns].progress_apply(lambda x : pred_explain(sample_df[feature_columns].loc[[x.name]], x.name), axis=1)
sample_df

100%|██████████| 100/100 [00:23<00:00,  4.25it/s]


Unnamed: 0,AGE,SEX,E_91,E_53,E_159,E_129,E_154,E_155,E_210,E_140,...,E_168,E_180,E_67,E_171,E_111,E_182,E_103,E_23,PATHOLOGY,predicted_diagnosis
122890,56,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,Pericarditis,"{'Pericarditis': 0.9990761652821807, 'Spontane..."
5880,24,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,Myocarditis,"{'Myocarditis': 0.9998757226678112, 'Spontaneo..."
52220,59,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,Sarcoidosis,"{'Sarcoidosis': 0.9960149219519636, 'Chagas': ..."
38102,68,0,1,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,Viral pharyngitis,"{'Viral pharyngitis': 0.9992683354575539, 'Acu..."
19860,39,0,0,1,0,0,0,1,0,0,...,0,0,0,0,0,0,0,0,PSVT,"{'PSVT': 0.9939807395272716, 'Myocarditis': 0...."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
120475,66,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,Sarcoidosis,"{'Sarcoidosis': 0.999999119850657, 'Chagas': 0..."
104198,26,1,0,0,0,0,0,0,0,0,...,1,1,0,0,0,0,0,0,Acute dystonic reactions,{'Acute dystonic reactions': 0.999999999949320...
70938,43,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,URTI,"{'URTI': 0.9999916376187971, 'Influenza': 0.07..."
3508,1,0,1,1,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,Pneumonia,"{'Pneumonia': 0.9999999995334292, 'Bronchiolit..."


<Figure size 432x288 with 0 Axes>

In [23]:
sample_df.loc[[1774]]

Unnamed: 0,AGE,SEX,E_91,E_53,E_159,E_129,E_154,E_155,E_210,E_140,...,E_168,E_180,E_67,E_171,E_111,E_182,E_103,E_23,PATHOLOGY,predicted_diagnosis
1774,38,1,0,1,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,Influenza,"{'Influenza': 0.9994008842096015, 'Ebola': 0.0..."
