In [1]:
import pandas as pd
import json
import ast
import numpy as np
import re
import pickle
import os
from tqdm import tqdm
tqdm.pandas()
import warnings
warnings.filterwarnings("ignore")

In [2]:
base_path = f"{os.path.dirname(os.getcwd())}\\data"

In [3]:
diagnosis_df_valid = pd.read_csv(f"{base_path}\\input\\release_validate_patients")

In [None]:
diagnosis_df_valid.info(verbose=True, null_counts=True)

In [4]:
with open(f"{base_path}\\input\\release_conditions.json") as f:
  disease_dict = json.load(f)
disease_list = list(disease_dict.keys())

In [5]:
with open(f"{base_path}\\input\\release_evidences.json") as f:
  evidences = json.load(f)
evidences_list = []
evidences_dict = {}
for e in evidences.keys():
  # only binary symptoms and no antecedents
  if (not evidences[e]["possible-values"]) and (not evidences[e]["is_antecedent"]):
    evidences_list.append(e)
    evidences_dict[e] = evidences[e]["question_en"]
evidences_dict["AGE"] = "AGE"
evidences_dict["SEX"] = "SEX"
evidences_dict["RANDOM"] = "RANDOM"
feature_columns = ["AGE", "SEX"] + evidences_list + ["RANDOM"]

In [6]:
model_dict = {}
for disease in disease_list:
    disease_filename = re.sub('[^a-zA-Z0-9 \n\.]', '', disease).replace(" ", "_")
    with open(f'{base_path}\\output\\diseases\\{disease_filename}\\{disease_filename}_model.pkl', 'rb') as f:
        model_dict[disease] = pickle.load(f)

In [7]:
def data_proc(df, seed):
    df["binary_evidences"] = df["EVIDENCES"].apply(lambda x: [d for d in ast.literal_eval(x) if "@" not in d])
    for e in evidences_list:
        df[e] = df["binary_evidences"].apply(lambda x: 1 if e in x else 0)
    df["SEX"] = df["SEX"].map({'F': 0, 'M': 1})
    # add RANDOM column for explainability
    np.random.seed(seed)
    df["RANDOM"] = np.random.choice([0, 1], df.shape[0])
    df = df[feature_columns + ["PATHOLOGY"]]
    return df

In [8]:
# 1% is 6min
diagnosis_df_valid = data_proc(diagnosis_df_valid, seed=0)
# sample x% of the validation dataset
# diagnosis_df_valid = diagnosis_df_valid.sample(frac=0.01, random_state=1)
diagnosis_df_valid

Unnamed: 0,AGE,SEX,E_91,E_53,E_159,E_129,E_154,E_155,E_210,E_140,...,E_168,E_180,E_67,E_171,E_111,E_182,E_103,E_23,RANDOM,PATHOLOGY
0,55,0,0,1,0,0,1,0,0,1,...,0,0,0,0,0,0,0,0,0,Anemia
1,10,0,0,1,0,0,0,1,0,0,...,0,0,0,1,1,0,0,0,1,Panic attack
2,68,0,1,1,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,1,Influenza
3,13,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Anemia
4,48,1,0,1,0,0,0,0,1,0,...,0,0,0,0,0,0,0,0,1,Boerhaave
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
132443,27,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,Viral pharyngitis
132444,57,1,0,1,0,0,0,0,0,0,...,0,0,1,0,0,0,0,0,1,Acute pulmonary edema
132445,52,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,GERD
132446,10,1,1,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,Epiglottitis


In [9]:
def predict_diagnosis(x):
    pred_list = []
    for target_disease in disease_list:
        rf_model = model_dict[target_disease]
        prediction = rf_model.predict_proba(x)
        prediction_proba = prediction[0][1]
        if prediction_proba>0:
            pred_list.append({
                "disease": target_disease,
                "probability": prediction_proba
            })
    if pred_list:
        pred_df = pd.DataFrame(pred_list).set_index('disease')
        # return only top 3 - allows ties
        pred_df['rank'] = pred_df['probability'].rank(method='min', ascending=False)
        pred_df = pred_df.sort_values(by="rank")
        pred_df = pred_df[pred_df["rank"]<=3][["probability"]]
        pred_dict = pred_df.to_dict()["probability"]
        return pred_dict
    else:
        return {}

In [10]:
diagnosis_df_valid["predicted_diagnosis"] = diagnosis_df_valid[feature_columns].progress_apply(lambda x : predict_diagnosis([x]), axis=1)

  0%|          | 0/132448 [00:00<?, ?it/s]

100%|██████████| 132448/132448 [7:05:55<00:00,  5.18it/s]  


In [11]:
diagnosis_df_valid

Unnamed: 0,AGE,SEX,E_91,E_53,E_159,E_129,E_154,E_155,E_210,E_140,...,E_180,E_67,E_171,E_111,E_182,E_103,E_23,RANDOM,PATHOLOGY,predicted_diagnosis
0,55,0,0,1,0,0,1,0,0,1,...,0,0,0,0,0,0,0,0,Anemia,"{'Anemia': 1.0, 'SLE': 0.01}"
1,10,0,0,1,0,0,0,1,0,0,...,0,0,1,1,0,0,0,1,Panic attack,"{'Panic attack': 1.0, 'Myocarditis': 0.01, 'SL..."
2,68,0,1,1,0,1,0,0,0,0,...,0,0,0,0,0,0,0,1,Influenza,"{'Influenza': 1.0, 'Pneumonia': 0.01}"
3,13,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,Anemia,"{'Anemia': 1.0, 'Boerhaave': 0.01, 'PSVT': 0.0..."
4,48,1,0,1,0,0,0,0,1,0,...,0,0,0,0,0,0,0,1,Boerhaave,"{'Boerhaave': 1.0, 'Possible NSTEMI / STEMI': ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
132443,27,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,Viral pharyngitis,"{'Acute otitis media': 1.0, 'Viral pharyngitis..."
132444,57,1,0,1,0,0,0,0,0,0,...,0,1,0,0,0,0,0,1,Acute pulmonary edema,"{'Acute pulmonary edema': 1.0, 'Myocarditis': ..."
132445,52,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,GERD,{'GERD': 1.0}
132446,10,1,1,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,Epiglottitis,"{'Epiglottitis': 1.0, 'Croup': 0.1, 'Larygospa..."


In [12]:
diagnosis_df_valid["is_hit"] = diagnosis_df_valid[["PATHOLOGY", "predicted_diagnosis"]].progress_apply(lambda x: x["PATHOLOGY"] in x["predicted_diagnosis"], axis=1)

100%|██████████| 132448/132448 [00:01<00:00, 92721.66it/s]


In [13]:
diagnosis_df_valid

Unnamed: 0,AGE,SEX,E_91,E_53,E_159,E_129,E_154,E_155,E_210,E_140,...,E_67,E_171,E_111,E_182,E_103,E_23,RANDOM,PATHOLOGY,predicted_diagnosis,is_hit
0,55,0,0,1,0,0,1,0,0,1,...,0,0,0,0,0,0,0,Anemia,"{'Anemia': 1.0, 'SLE': 0.01}",True
1,10,0,0,1,0,0,0,1,0,0,...,0,1,1,0,0,0,1,Panic attack,"{'Panic attack': 1.0, 'Myocarditis': 0.01, 'SL...",True
2,68,0,1,1,0,1,0,0,0,0,...,0,0,0,0,0,0,1,Influenza,"{'Influenza': 1.0, 'Pneumonia': 0.01}",True
3,13,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,Anemia,"{'Anemia': 1.0, 'Boerhaave': 0.01, 'PSVT': 0.0...",True
4,48,1,0,1,0,0,0,0,1,0,...,0,0,0,0,0,0,1,Boerhaave,"{'Boerhaave': 1.0, 'Possible NSTEMI / STEMI': ...",True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
132443,27,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,1,Viral pharyngitis,"{'Acute otitis media': 1.0, 'Viral pharyngitis...",True
132444,57,1,0,1,0,0,0,0,0,0,...,1,0,0,0,0,0,1,Acute pulmonary edema,"{'Acute pulmonary edema': 1.0, 'Myocarditis': ...",True
132445,52,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,GERD,{'GERD': 1.0},True
132446,10,1,1,1,0,0,0,0,0,0,...,0,0,0,0,0,0,1,Epiglottitis,"{'Epiglottitis': 1.0, 'Croup': 0.1, 'Larygospa...",True


In [14]:
acc = diagnosis_df_valid.is_hit.value_counts().to_dict()
acc["hit_rate"] = acc[True]/len(diagnosis_df_valid)
acc

{True: 131888, False: 560, 'hit_rate': 0.9957719255858903}

In [15]:
with open(f"{base_path}\\output\\error_analysis\\validation_metric.json", "w") as outfile: 
    json.dump(acc, outfile, indent=True)

In [16]:
diagnosis_df_valid[diagnosis_df_valid["is_hit"]==False]

Unnamed: 0,AGE,SEX,E_91,E_53,E_159,E_129,E_154,E_155,E_210,E_140,...,E_67,E_171,E_111,E_182,E_103,E_23,RANDOM,PATHOLOGY,predicted_diagnosis,is_hit
135,62,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,1,Viral pharyngitis,"{'Cluster headache': 1.0, 'Acute laryngitis': ...",False
146,57,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,1,Viral pharyngitis,"{'Acute laryngitis': 1.0, 'Acute otitis media'...",False
331,64,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,1,Chronic rhinosinusitis,"{'Viral pharyngitis': 1.0, 'Acute otitis media...",False
450,79,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,Chronic rhinosinusitis,"{'Viral pharyngitis': 0.9990000000000001, 'Acu...",False
475,33,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,URTI,"{'Chronic rhinosinusitis': 0.998, 'Viral phary...",False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
131346,50,1,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,Viral pharyngitis,"{'Acute otitis media': 1.0, 'Chronic rhinosinu...",False
131855,77,1,1,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,Bronchitis,"{'Viral pharyngitis': 1.0, 'Acute otitis media...",False
132081,29,0,1,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,URTI,"{'Acute rhinosinusitis': 0.99, 'Viral pharyngi...",False
132318,58,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,URTI,"{'Chronic rhinosinusitis': 0.9971428571428572,...",False


In [17]:
diagnosis_df_valid[diagnosis_df_valid["is_hit"]==False].to_excel(f"{base_path}\\output\\error_analysis\\validation_miss.xlsx")