In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import json
from collections import Counter
import ast
import networkx as nx
import re
from constants import base_path

## Random Forest

In [3]:
error_df = pd.read_csv(f"{base_path}\\output\\error_analysis\\validation_df_all_patients.csv")
error_df = error_df[error_df["is_matched"]==False]

In [4]:
error_df = error_df[["PATHOLOGY", "predicted_diagnosis"]]
error_df

Unnamed: 0,PATHOLOGY,predicted_diagnosis
5,Bronchospasm / acute asthma exacerbation,"['Bronchiectasis', 'Tuberculosis', 'Bronchospa..."
12,SLE,"['Inguinal hernia', 'SLE']"
22,Acute rhinosinusitis,['Chronic rhinosinusitis']
23,Acute otitis media,['Croup']
29,URTI,['Chronic rhinosinusitis']
...,...,...
132411,Viral pharyngitis,['Chronic rhinosinusitis']
132417,Chronic rhinosinusitis,"['Acute rhinosinusitis', 'Chronic rhinosinusit..."
132427,Viral pharyngitis,['Acute laryngitis']
132443,Viral pharyngitis,['Acute otitis media']


In [5]:
error_df["PATHOLOGY"].value_counts().sort_values().plot.barh(figsize=(6, 8))
plt.title("Prediction Error Frequency")
plt.xlabel("Count")
plt.savefig(f'{base_path}\\output\\error_analysis\\pred_miss_freq.jpg', bbox_inches='tight')
plt.clf()

<Figure size 432x576 with 0 Axes>

In [6]:
with open(f"{base_path}\\input\\release_conditions.json") as f:
  disease_dict = json.load(f)
disease_list = list(disease_dict.keys())

In [7]:
pred_miss_dict = {}
for disease in disease_list:
    miss_list = []
    for idx, row in error_df[error_df["PATHOLOGY"]==disease].iterrows():
        miss_list.extend(ast.literal_eval(row["predicted_diagnosis"]))
    miss_list = [i for i in miss_list if i!=disease]
    if miss_list:
        count = dict(Counter(miss_list))
        pred_miss_dict[disease]={i: round(count[i] / len(miss_list), 3) for i in count}
        # pred_miss_dict[disease]=dict(Counter(miss_list))

In [8]:
pred_miss_dict

{'Spontaneous pneumothorax': {'Unstable angina': 0.075,
  'Pericarditis': 0.776,
  'Stable angina': 0.092,
  'Pulmonary embolism': 0.057},
 'Cluster headache': {'Acute otitis media': 0.682,
  'Viral pharyngitis': 0.268,
  'Possible NSTEMI / STEMI': 0.05},
 'Boerhaave': {'Possible NSTEMI / STEMI': 0.947,
  'GERD': 0.026,
  'Unstable angina': 0.026},
 'GERD': {'Acute laryngitis': 0.206,
  'Acute otitis media': 0.317,
  'Viral pharyngitis': 0.238,
  'Cluster headache': 0.016,
  'Pericarditis': 0.159,
  'Boerhaave': 0.032,
  'Anemia': 0.032},
 'HIV (initial infection)': {'Influenza': 1.0},
 'Anemia': {'SLE': 0.043,
  'Stable angina': 0.348,
  'Possible NSTEMI / STEMI': 0.304,
  'Pericarditis': 0.043,
  'Myocarditis': 0.087,
  'PSVT': 0.174},
 'Viral pharyngitis': {'Acute otitis media': 0.49,
  'Chronic rhinosinusitis': 0.072,
  'Acute laryngitis': 0.116,
  'Cluster headache': 0.213,
  'GERD': 0.01,
  'Acute rhinosinusitis': 0.09,
  'URTI': 0.005,
  'Possible NSTEMI / STEMI': 0.001,
  'Bron

In [9]:
pred_miss_df = pd.DataFrame({"disease": pred_miss_dict.keys()})
for disease in disease_list:
    pred_miss_df[disease]=[pred_miss_dict[i].get(disease, 0) for i in pred_miss_dict.keys()]
pred_miss_df.set_index('disease', inplace=True)

In [10]:
pred_miss_df

Unnamed: 0_level_0,Spontaneous pneumothorax,Cluster headache,Boerhaave,Spontaneous rib fracture,GERD,HIV (initial infection),Anemia,Viral pharyngitis,Inguinal hernia,Myasthenia gravis,...,Pneumonia,Acute rhinosinusitis,Chronic rhinosinusitis,Bronchiolitis,Pulmonary neoplasm,Possible NSTEMI / STEMI,Sarcoidosis,Pancreatic neoplasm,Acute pulmonary edema,Pericarditis
disease,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Spontaneous pneumothorax,0.0,0.0,0.0,0,0.0,0.0,0.0,0.0,0.0,0,...,0,0.0,0.0,0,0.0,0.0,0,0.0,0.0,0.776
Cluster headache,0.0,0.0,0.0,0,0.0,0.0,0.0,0.268,0.0,0,...,0,0.0,0.0,0,0.0,0.05,0,0.0,0.0,0.0
Boerhaave,0.0,0.0,0.0,0,0.026,0.0,0.0,0.0,0.0,0,...,0,0.0,0.0,0,0.0,0.947,0,0.0,0.0,0.0
GERD,0.0,0.016,0.032,0,0.0,0.0,0.032,0.238,0.0,0,...,0,0.0,0.0,0,0.0,0.0,0,0.0,0.0,0.159
HIV (initial infection),0.0,0.0,0.0,0,0.0,0.0,0.0,0.0,0.0,0,...,0,0.0,0.0,0,0.0,0.0,0,0.0,0.0,0.0
Anemia,0.0,0.0,0.0,0,0.0,0.0,0.0,0.0,0.0,0,...,0,0.0,0.0,0,0.0,0.304,0,0.0,0.0,0.043
Viral pharyngitis,0.0,0.213,0.0,0,0.01,0.0,0.0,0.0,0.0,0,...,0,0.09,0.072,0,0.0,0.001,0,0.0,0.0,0.0
Inguinal hernia,0.0,0.0,0.0,0,0.0,0.0,0.0,0.0,0.0,0,...,0,0.0,0.0,0,0.0,0.0,0,0.109,0.0,0.0
Anaphylaxis,0.0,0.0,0.0,0,0.0,0.0,0.0,0.0,0.0,0,...,0,0.0,0.0,0,0.0,0.0,0,0.0,0.0,0.0
Acute laryngitis,0.0,0.0,0.0,0,0.015,0.0,0.0,0.364,0.0,0,...,0,0.0,0.0,0,0.0,0.0,0,0.0,0.0,0.0


In [11]:
pred_miss_graph = pred_miss_df.stack()
pred_miss_graph = pred_miss_graph.rename_axis(('Actual', 'Prediction Miss')).reset_index(name='weight')
pred_miss_graph = pred_miss_graph[pred_miss_graph["weight"]>0]
pred_miss_graph = pred_miss_graph.sort_values(['Actual','weight'], ascending=False)
pred_miss_graph

Unnamed: 0,Actual,Prediction Miss,weight
325,Viral pharyngitis,Acute otitis media,0.490
295,Viral pharyngitis,Cluster headache,0.213
308,Viral pharyngitis,Acute laryngitis,0.116
334,Viral pharyngitis,Acute rhinosinusitis,0.090
335,Viral pharyngitis,Chronic rhinosinusitis,0.072
...,...,...,...
448,Acute laryngitis,Viral pharyngitis,0.364
445,Acute laryngitis,GERD,0.015
475,Acute laryngitis,Bronchitis,0.005
1209,Acute COPD exacerbation / infection,Bronchospasm / acute asthma exacerbation,0.737


In [12]:
pred_miss_graph.to_csv(f"{base_path}\\output\\error_analysis\\pred_miss_weigths.csv", index=False)

In [13]:
G = nx.from_pandas_edgelist(pred_miss_graph, 'Actual', 'Prediction Miss', edge_attr='weight', create_using=nx.DiGraph())
pos = nx.spring_layout(G, seed=0)
print(nx.info(G))

Name: 
Type: DiGraph
Number of nodes: 39
Number of edges: 166
Average in degree:   4.2564
Average out degree:   4.2564


In [14]:
fig = plt.figure(figsize=(15, 8))
weights = [G[u][v]['weight'] for u,v in G.edges()]
nx.draw(G, pos, edge_color=weights, edge_cmap=plt.cm.Blues, with_labels=True, arrowsize=20)
plt.title("Prediction Miss", fontsize=20)
plt.tight_layout()
plt.savefig(f'{base_path}\\output\\error_analysis\\pred_miss_network.jpg', bbox_inches='tight')
plt.clf()

  """


<Figure size 1080x576 with 0 Axes>

In [15]:
for disease in pred_miss_dict:
    G = nx.from_pandas_edgelist(pred_miss_graph[(pred_miss_graph["Actual"]==disease) | (pred_miss_graph["Prediction Miss"]==disease) ], 'Actual', 'Prediction Miss', edge_attr='weight', create_using=nx.DiGraph())
    pos = nx.spring_layout(G, seed=0)
    fig = plt.figure(figsize=(15, 8))
    weights = [G[u][v]['weight'] for u,v in G.edges()]
    nx.draw(G, pos, edge_color=weights, edge_cmap=plt.cm.Blues, with_labels=True, arrowsize=20)
    plt.title(f"Prediction Miss - {disease}", fontsize=20)
    plt.tight_layout()
    img_filename = re.sub('[^a-zA-Z0-9 \n\.]', '', disease).replace(" ", "_")
    plt.savefig(f'{base_path}\\output\\error_analysis\\pred_miss_{img_filename}.jpg', bbox_inches='tight')
    plt.clf()

  
  after removing the cwd from sys.path.


<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

## Logistic Regression

In [16]:
error_df = pd.read_csv(f"{base_path}\\output\\error_analysis\\validation_logreg_df_all_patients.csv")
error_df = error_df[error_df["is_matched"]==False]

In [17]:
error_df = error_df[["PATHOLOGY", "predicted_diagnosis"]]
error_df

Unnamed: 0,PATHOLOGY,predicted_diagnosis
22,Acute rhinosinusitis,['Chronic rhinosinusitis']
23,Acute otitis media,['Allergic sinusitis']
29,URTI,['Viral pharyngitis']
58,Acute otitis media,['Viral pharyngitis']
66,Influenza,['URTI']
...,...,...
132389,Bronchiectasis,['Bronchospasm / acute asthma exacerbation']
132399,Tuberculosis,['Bronchiectasis']
132400,Acute otitis media,['Viral pharyngitis']
132427,Viral pharyngitis,['Acute laryngitis']


In [18]:
error_df["PATHOLOGY"].value_counts().sort_values().plot.barh(figsize=(6, 8))
plt.title("Prediction Error Frequency")
plt.xlabel("Count")
plt.savefig(f'{base_path}\\output\\error_analysis\\pred_miss_freq_logreg.jpg', bbox_inches='tight')
plt.clf()

<Figure size 432x576 with 0 Axes>

In [19]:
pred_miss_dict = {}
for disease in disease_list:
    miss_list = []
    for idx, row in error_df[error_df["PATHOLOGY"]==disease].iterrows():
        miss_list.extend(ast.literal_eval(row["predicted_diagnosis"]))
    miss_list = [i for i in miss_list if i!=disease]
    if miss_list:
        count = dict(Counter(miss_list))
        pred_miss_dict[disease]={i: round(count[i] / len(miss_list), 3) for i in count}
        # pred_miss_dict[disease]=dict(Counter(miss_list))

In [20]:
pred_miss_dict

{'Spontaneous pneumothorax': {'Pericarditis': 0.899, 'Stable angina': 0.101},
 'Boerhaave': {'Possible NSTEMI / STEMI': 1.0},
 'GERD': {'Viral pharyngitis': 0.667,
  'Cluster headache': 0.018,
  'Pericarditis': 0.175,
  'Tuberculosis': 0.035,
  'Boerhaave': 0.053,
  'Anemia': 0.053},
 'HIV (initial infection)': {'Pancreatic neoplasm': 0.294, 'Influenza': 0.706},
 'Anemia': {'Stable angina': 0.933, 'PSVT': 0.067},
 'Viral pharyngitis': {'Acute otitis media': 0.066,
  'Cluster headache': 0.643,
  'Acute laryngitis': 0.291},
 'Inguinal hernia': {'SLE': 1.0},
 'Anaphylaxis': {'SLE': 1.0},
 'Epiglottitis': {'Acute laryngitis': 1.0},
 'Acute laryngitis': {'Acute otitis media': 0.713, 'Viral pharyngitis': 0.287},
 'Croup': {'Larygospasm': 1.0},
 'PSVT': {'Pericarditis': 0.759,
  'Bronchospasm / acute asthma exacerbation': 0.037,
  'Atrial fibrillation': 0.093,
  'Anemia': 0.111},
 'Atrial fibrillation': {'PSVT': 0.995,
  'Bronchospasm / acute asthma exacerbation': 0.005},
 'Bronchiectasis': {

In [21]:
pred_miss_df = pd.DataFrame({"disease": pred_miss_dict.keys()})
for disease in disease_list:
    pred_miss_df[disease]=[pred_miss_dict[i].get(disease, 0) for i in pred_miss_dict.keys()]
pred_miss_df.set_index('disease', inplace=True)

In [22]:
pred_miss_df

Unnamed: 0_level_0,Spontaneous pneumothorax,Cluster headache,Boerhaave,Spontaneous rib fracture,GERD,HIV (initial infection),Anemia,Viral pharyngitis,Inguinal hernia,Myasthenia gravis,...,Pneumonia,Acute rhinosinusitis,Chronic rhinosinusitis,Bronchiolitis,Pulmonary neoplasm,Possible NSTEMI / STEMI,Sarcoidosis,Pancreatic neoplasm,Acute pulmonary edema,Pericarditis
disease,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Spontaneous pneumothorax,0.0,0.0,0.0,0,0,0.0,0.0,0.0,0,0,...,0,0.0,0.0,0.0,0.0,0.0,0,0.0,0,0.899
Boerhaave,0.0,0.0,0.0,0,0,0.0,0.0,0.0,0,0,...,0,0.0,0.0,0.0,0.0,1.0,0,0.0,0,0.0
GERD,0.0,0.018,0.053,0,0,0.0,0.053,0.667,0,0,...,0,0.0,0.0,0.0,0.0,0.0,0,0.0,0,0.175
HIV (initial infection),0.0,0.0,0.0,0,0,0.0,0.0,0.0,0,0,...,0,0.0,0.0,0.0,0.0,0.0,0,0.294,0,0.0
Anemia,0.0,0.0,0.0,0,0,0.0,0.0,0.0,0,0,...,0,0.0,0.0,0.0,0.0,0.0,0,0.0,0,0.0
Viral pharyngitis,0.0,0.643,0.0,0,0,0.0,0.0,0.0,0,0,...,0,0.0,0.0,0.0,0.0,0.0,0,0.0,0,0.0
Inguinal hernia,0.0,0.0,0.0,0,0,0.0,0.0,0.0,0,0,...,0,0.0,0.0,0.0,0.0,0.0,0,0.0,0,0.0
Anaphylaxis,0.0,0.0,0.0,0,0,0.0,0.0,0.0,0,0,...,0,0.0,0.0,0.0,0.0,0.0,0,0.0,0,0.0
Epiglottitis,0.0,0.0,0.0,0,0,0.0,0.0,0.0,0,0,...,0,0.0,0.0,0.0,0.0,0.0,0,0.0,0,0.0
Acute laryngitis,0.0,0.0,0.0,0,0,0.0,0.0,0.287,0,0,...,0,0.0,0.0,0.0,0.0,0.0,0,0.0,0,0.0


In [23]:
pred_miss_graph = pred_miss_df.stack()
pred_miss_graph = pred_miss_graph.rename_axis(('Actual', 'Prediction Miss')).reset_index(name='weight')
pred_miss_graph = pred_miss_graph[pred_miss_graph["weight"]>0]
pred_miss_graph = pred_miss_graph.sort_values(['Actual','weight'], ascending=False)
pred_miss_graph

Unnamed: 0,Actual,Prediction Miss,weight
246,Viral pharyngitis,Cluster headache,0.643
259,Viral pharyngitis,Acute laryngitis,0.291
276,Viral pharyngitis,Acute otitis media,0.066
911,Unstable angina,Stable angina,0.397
926,Unstable angina,Possible NSTEMI / STEMI,0.362
...,...,...,...
472,Acute laryngitis,Acute otitis media,0.713
448,Acute laryngitis,Viral pharyngitis,0.287
768,Acute dystonic reactions,Bronchospasm / acute asthma exacerbation,1.000
1258,Acute COPD exacerbation / infection,Bronchospasm / acute asthma exacerbation,0.733


In [24]:
pred_miss_graph.to_csv(f"{base_path}\\output\\error_analysis\\pred_miss_weigths_logreg.csv", index=False)

In [25]:
G = nx.from_pandas_edgelist(pred_miss_graph, 'Actual', 'Prediction Miss', edge_attr='weight', create_using=nx.DiGraph())
pos = nx.spring_layout(G, seed=0)
print(nx.info(G))

Name: 
Type: DiGraph
Number of nodes: 43
Number of edges: 110
Average in degree:   2.5581
Average out degree:   2.5581


In [26]:
fig = plt.figure(figsize=(15, 8))
weights = [G[u][v]['weight'] for u,v in G.edges()]
nx.draw(G, pos, edge_color=weights, edge_cmap=plt.cm.Blues, with_labels=True, arrowsize=20)
plt.title("Prediction Miss", fontsize=20)
plt.tight_layout()
plt.savefig(f'{base_path}\\output\\error_analysis\\pred_miss_network_logreg.jpg', bbox_inches='tight')
plt.clf()

  """


<Figure size 1080x576 with 0 Axes>

In [27]:
for disease in pred_miss_dict:
    G = nx.from_pandas_edgelist(pred_miss_graph[(pred_miss_graph["Actual"]==disease) | (pred_miss_graph["Prediction Miss"]==disease) ], 'Actual', 'Prediction Miss', edge_attr='weight', create_using=nx.DiGraph())
    pos = nx.spring_layout(G, seed=0)
    fig = plt.figure(figsize=(15, 8))
    weights = [G[u][v]['weight'] for u,v in G.edges()]
    nx.draw(G, pos, edge_color=weights, edge_cmap=plt.cm.Blues, with_labels=True, arrowsize=20)
    plt.title(f"Prediction Miss - {disease}", fontsize=20)
    plt.tight_layout()
    img_filename = re.sub('[^a-zA-Z0-9 \n\.]', '', disease).replace(" ", "_")
    plt.savefig(f'{base_path}\\output\\error_analysis\\pred_miss_{img_filename}_logreg.jpg', bbox_inches='tight')
    plt.clf()

  
  after removing the cwd from sys.path.


<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>