In [105]:
import pandas as pd
import matplotlib.pyplot as plt
import os
import json
from collections import Counter
import ast
import networkx as nx
import re

In [89]:
base_path = f"{os.path.dirname(os.getcwd())}\\data"

In [90]:
error_df = pd.read_excel(f"{base_path}\\output\\error_analysis\\validation_miss.xlsx")

In [91]:
error_df[["PATHOLOGY", "predicted_diagnosis"]]

Unnamed: 0,PATHOLOGY,predicted_diagnosis
0,Viral pharyngitis,"{'Cluster headache': 1.0, 'Acute laryngitis': ..."
1,Viral pharyngitis,"{'Acute laryngitis': 1.0, 'Acute otitis media'..."
2,Chronic rhinosinusitis,"{'Viral pharyngitis': 1.0, 'Acute otitis media..."
3,Chronic rhinosinusitis,"{'Viral pharyngitis': 0.9990000000000001, 'Acu..."
4,URTI,"{'Chronic rhinosinusitis': 0.998, 'Viral phary..."
...,...,...
555,Viral pharyngitis,"{'Acute otitis media': 1.0, 'Chronic rhinosinu..."
556,Bronchitis,"{'Viral pharyngitis': 1.0, 'Acute otitis media..."
557,URTI,"{'Acute rhinosinusitis': 0.99, 'Viral pharyngi..."
558,URTI,"{'Chronic rhinosinusitis': 0.9971428571428572,..."


In [92]:
error_df["PATHOLOGY"].value_counts().sort_values().plot.barh()
plt.title("Prediction Miss Frequency")
plt.xlabel("Count")
plt.savefig(f'{base_path}\\output\\error_analysis\\pred_miss_freq.jpg', bbox_inches='tight')
plt.clf()

<Figure size 432x288 with 0 Axes>

In [93]:
with open(f"{base_path}\\input\\release_conditions.json") as f:
  disease_dict = json.load(f)
disease_list = list(disease_dict.keys())

In [94]:
pred_miss_dict = {}
for disease in disease_list:
    miss_list = []
    for idx, row in error_df[error_df["PATHOLOGY"]==disease].iterrows():
        miss_list.extend(ast.literal_eval(row["predicted_diagnosis"]).keys())
    if miss_list:
        count = dict(Counter(miss_list))
        pred_miss_dict[disease]={i: round(count[i] / len(miss_list), 3) for i in count}
        # pred_miss_dict[disease]=dict(Counter(miss_list))

In [95]:
pred_miss_dict

{'GERD': {'Acute laryngitis': 0.321,
  'Acute otitis media': 0.321,
  'Viral pharyngitis': 0.321,
  'Cluster headache': 0.012,
  'Possible NSTEMI / STEMI': 0.012,
  'Bronchitis': 0.012},
 'Anemia': {'Stable angina': 0.25,
  'Possible NSTEMI / STEMI': 0.25,
  'SLE': 0.292,
  'Pericarditis': 0.083,
  'Myocarditis': 0.083,
  'Pulmonary embolism': 0.042},
 'Viral pharyngitis': {'Cluster headache': 0.071,
  'Acute laryngitis': 0.096,
  'Possible NSTEMI / STEMI': 0.068,
  'Acute otitis media': 0.315,
  'GERD': 0.056,
  'Acute rhinosinusitis': 0.17,
  'Chronic rhinosinusitis': 0.151,
  'Bronchitis': 0.052,
  'URTI': 0.019,
  'Stable angina': 0.003},
 'Acute laryngitis': {'Acute otitis media': 0.333,
  'Bronchitis': 0.097,
  'Viral pharyngitis': 0.333,
  'GERD': 0.236},
 'Myocarditis': {'Stable angina': 0.333,
  'Pericarditis': 0.273,
  'Possible NSTEMI / STEMI': 0.273,
  'Unstable angina': 0.03,
  'Spontaneous pneumothorax': 0.061,
  'SLE': 0.03},
 'SLE': {'Myocarditis': 0.067,
  'Pericarditi

In [96]:
pred_miss_df = pd.DataFrame({"disease": pred_miss_dict.keys()})
for disease in disease_list:
    pred_miss_df[disease]=[pred_miss_dict[i].get(disease, 0) for i in pred_miss_dict.keys()]
pred_miss_df.set_index('disease', inplace=True)

In [97]:
pred_miss_df

Unnamed: 0_level_0,Spontaneous pneumothorax,Cluster headache,Boerhaave,Spontaneous rib fracture,GERD,HIV (initial infection),Anemia,Viral pharyngitis,Inguinal hernia,Myasthenia gravis,...,Pneumonia,Acute rhinosinusitis,Chronic rhinosinusitis,Bronchiolitis,Pulmonary neoplasm,Possible NSTEMI / STEMI,Sarcoidosis,Pancreatic neoplasm,Acute pulmonary edema,Pericarditis
disease,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
GERD,0.0,0.012,0.0,0,0.0,0,0.0,0.321,0,0,...,0,0.0,0.0,0.0,0.0,0.012,0,0,0,0.0
Anemia,0.0,0.0,0.0,0,0.0,0,0.0,0.0,0,0,...,0,0.0,0.0,0.0,0.0,0.25,0,0,0,0.083
Viral pharyngitis,0.0,0.071,0.0,0,0.056,0,0.0,0.0,0,0,...,0,0.17,0.151,0.0,0.0,0.068,0,0,0,0.0
Acute laryngitis,0.0,0.0,0.0,0,0.236,0,0.0,0.333,0,0,...,0,0.0,0.0,0.0,0.0,0.0,0,0,0,0.0
Myocarditis,0.061,0.0,0.0,0,0.0,0,0.0,0.0,0,0,...,0,0.0,0.0,0.0,0.0,0.273,0,0,0,0.273
SLE,0.0,0.0,0.0,0,0.0,0,0.2,0.0,0,0,...,0,0.0,0.0,0.0,0.0,0.267,0,0,0,0.067
Unstable angina,0.278,0.0,0.0,0,0.0,0,0.028,0.0,0,0,...,0,0.0,0.0,0.0,0.0,0.056,0,0,0,0.0
Stable angina,0.037,0.037,0.0,0,0.0,0,0.037,0.0,0,0,...,0,0.0,0.0,0.0,0.0,0.259,0,0,0,0.222
Acute otitis media,0.0,0.055,0.008,0,0.059,0,0.0,0.308,0,0,...,0,0.19,0.173,0.0,0.0,0.051,0,0,0,0.0
Bronchospasm / acute asthma exacerbation,0.0,0.0,0.0,0,0.0,0,0.0,0.0,0,0,...,0,0.0,0.0,0.0,0.0,0.0,0,0,0,0.0


In [98]:
pred_miss_graph = pred_miss_df.stack()
pred_miss_graph = pred_miss_graph.rename_axis(('Actual', 'Prediction Miss')).reset_index(name='weight')
pred_miss_graph = pred_miss_graph[pred_miss_graph["weight"]>0]
pred_miss_graph

Unnamed: 0,Actual,Prediction Miss,weight
1,GERD,Cluster headache,0.012
7,GERD,Viral pharyngitis,0.321
14,GERD,Acute laryngitis,0.321
31,GERD,Acute otitis media,0.321
34,GERD,Bronchitis,0.012
...,...,...,...
881,Possible NSTEMI / STEMI,Pericarditis,0.133
904,Pericarditis,Myocarditis,0.296
908,Pericarditis,SLE,0.037
911,Pericarditis,Stable angina,0.333


In [99]:
pred_miss_graph.to_csv(f"{base_path}\\output\\error_analysis\\pred_miss_weigths.csv", index=False)

In [100]:
G = nx.from_pandas_edgelist(pred_miss_graph, 'Actual', 'Prediction Miss', edge_attr='weight', create_using=nx.DiGraph())
pos = nx.spring_layout(G, seed=0)
print(nx.info(G))

Name: 
Type: DiGraph
Number of nodes: 29
Number of edges: 124
Average in degree:   4.2759
Average out degree:   4.2759


In [103]:
fig = plt.figure(figsize=(15, 8))
weights = [G[u][v]['weight'] for u,v in G.edges()]
nx.draw(G, pos, edge_color=weights, edge_cmap=plt.cm.Blues, with_labels=True)
plt.title("Prediction Miss", fontsize=20)
plt.tight_layout()
plt.savefig(f'{base_path}\\output\\error_analysis\\pred_miss_network.jpg', bbox_inches='tight')
plt.clf()

  """


<Figure size 1080x576 with 0 Axes>

In [107]:
for disease in pred_miss_dict:
    G = nx.from_pandas_edgelist(pred_miss_graph[pred_miss_graph["Actual"]==disease], 'Actual', 'Prediction Miss', edge_attr='weight', create_using=nx.DiGraph())
    pos = nx.spring_layout(G, seed=0)
    fig = plt.figure(figsize=(15, 8))
    weights = [G[u][v]['weight'] for u,v in G.edges()]
    nx.draw(G, pos, edge_color=weights, edge_cmap=plt.cm.Blues, with_labels=True)
    plt.title(f"Prediction Miss - {disease}", fontsize=20)
    plt.tight_layout()
    img_filename = re.sub('[^a-zA-Z0-9 \n\.]', '', disease).replace(" ", "_")
    plt.savefig(f'{base_path}\\output\\error_analysis\\pred_miss_{img_filename}.jpg', bbox_inches='tight')
    plt.clf()

  


<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>

<Figure size 1080x576 with 0 Axes>