In [1]:
%matplotlib widget

import tensorflow as tf
from matplotlib import pyplot
import matplotlib
import matplotlib.cm as colormap
import numpy
import os
import json, pickle
import pandas
from functools import partial, reduce
import importlib
from sklearn import manifold
from scipy import stats

import sys
sys.path.append('../libs')

import flacdb
import prepare_data
import initialize
import data_pipeline
import conv_model
import plot_batch
import generate_report_kfold
import icd_util

pyplot.style.use('dark_background')

In [2]:
group_names = icd_util.load_group_strings()
metadata = pandas.read_hdf('/scr-ssd/mimic/metadata.hdf')

In [3]:
%%time

model_ids = [1469795, 1469816, 1470209]
checkpoint_index = 2
Y_dict, X_dict, Z = [], [], []

for fold_index, model_id in enumerate(model_ids):
    H, x, y, p, metadata, priors = generate_report_kfold.generate_predictions(
        model_id = model_id,
        fold_index = fold_index,
        checkpoint_index = checkpoint_index,
        example_count_log2 = 14
    )
    X_dict.append(x)
    Y_dict.append(y)
    Z.append(numpy.log(p / (1 - p)))

X = [x['signals'] for x in X_dict]
Y = [y['diagnosis'] for y in Y_dict]
Y = numpy.vstack(Y)

found hypes ../hypes/1469795_20200512-210303.json 
found weights /scr1/checkpoints/1469795_20200512-210303_00384.ckpt
loading predictions
found hypes ../hypes/1469816_20200512-213718.json 
found weights /scr1/checkpoints/1469816_20200512-213718_00384.ckpt
loading predictions
found hypes ../hypes/1470209_20200513-050523.json 
found weights /scr1/checkpoints/1470209_20200513-050523_00384.ckpt
loading predictions
CPU times: user 15.9 s, sys: 5.28 s, total: 21.2 s
Wall time: 21.2 s


In [4]:
M = metadata.reset_index()[['subject_id', 'rec_id', 'admission_diagnosis']].drop_duplicates()
M = M.set_index('rec_id', verify_integrity=True)
admit_fors_ = numpy.hstack([
    M.loc[i['rec_id'], 'admission_diagnosis'] for i in Y_dict
])
admit_fors_.shape

(43557,)

In [5]:
import string

admit_fors = [[ 
    (j4 if 'S/P' not in j3 and 'R/O' not in j3 else j3).replace('?', '').strip() 
    for j2 in i.split(';') 
    for j3 in j2.split('\\')
    for j4 in j3.split('/')
] for i in admit_fors_.astype('str') ]
admit_fors = [[' '.join(j.split()) for j in i] for i in admit_fors]
C = string.ascii_lowercase + ' '
admit_fors = [
    [''.join(c for c in j if c.lower() in C).title() for j in i] 
    for i in admit_fors
]
admit_fors = [[j.replace('Acute ', '') for j in i] for i in admit_fors]
len(admit_fors)

43557

In [6]:
groups = open('unique_problems.txt').read().split('\n\n')
groups = [i.split('\n') for i in groups]
groups = [[j.split() for j in i] for i in groups]
groups = [[(int(j[0]), ' '.join(j[1:])) for j in i if len(j) > 1] for i in groups]
groups_ = []
for group in groups:
    tot = 0
    for count, problem in group:
        tot += count
    if tot > 75:
        groups_.append(group)

groups = [[problem for count, problem in group] for group in groups_]
len(groups)

41

In [7]:
labels = []
for admit_for in admit_fors:
    label = []
    for group_index, group in enumerate(groups):
        if any(i == j for i in admit_for for j in group):
            label.append(group_index)
    labels.append(label)

In [8]:
[sum(len(j) == i for j in labels) for i in range(5)]

[16866, 22588, 3708, 342, 53]

In [9]:
admit_and_diagnosed_pos = numpy.zeros((len(groups), Y.shape[1]))
admit_and_diagnosed = numpy.zeros((len(groups), Y.shape[1]))
for i, label in enumerate(labels):
    if len(label) > 0:
        I = numpy.array(label)
        admit_and_diagnosed_pos[I] += Y[i] == 1
        admit_and_diagnosed[I] += Y[i] != 0

P_ = admit_and_diagnosed_pos / numpy.maximum(admit_and_diagnosed, 1)
P_[admit_and_diagnosed < 100] = numpy.nan

In [10]:
disease_group_names = icd_util.load_group_strings()

def get_name(code):
    if code not in disease_group_names:
        return code.replace('_', ' ').title()
    name = code + ': ' + disease_group_names[code]
    name = name.replace('/', '_')
    for character in "',()[]":
        name = name.replace(character, '')
    return name

disease_names = [get_name(i) for i in priors.index]

In [212]:
for i in range(Y.shape[1]):
    j = numpy.nanargmax(P_[:, i])
    print(disease_names[i])
    print(groups[j][0] + '\n')

038: Septicemia
Sepsis

041: Bacterial Infection In Conditions Classified Elsewhere And Of Unspecified Site
Sp Motor Vehicle Accident

070: Viral Hepatitis
Copd Exacerbation

140-239: Neoplasms 
Brain Mass

155: Malignant Neoplasm Of Liver And Intrahepatic Bile Ducts
Fever

155.0: Liver Primary
Fever

162: Malignant Neoplasm Of Trachea Bronchus And Lung
Brain Mass

191: Malignant Neoplasm Of Brain
Brain Mass

198.3: Brain And Spinal Cord
Brain Mass

250: Diabetes Mellitus
Renal Failure

250.4: Diabetes With Renal Manifestations
Renal Failure

250.6: Diabetes With Neurological Manifestations
Shortness Of Breath

276.2: Acidosis
Shortness Of Breath

303: Alcohol Dependence Syndrome
Sp Motor Vehicle Accident

305: Nondependent Abuse Of Drugs
Sp Motor Vehicle Accident

305.0: Alcohol Abuse
Sp Motor Vehicle Accident

348.1: Anoxic Brain Damage
Cardiac Arrest

348.31: Metabolic Encephalopathy
Altered Mental Status

348.4: Compression Of Brain
Brain Mass

348.5: Cerebral Edema
Brain Mass

349

In [197]:
count = 0
for i, label in enumerate(labels):
    if len(label) > 2:
        print(admit_fors_[i].upper())
        print(admit_fors[i], '\n')
        count += 1
        if count > 10:
            break

CHEST PAIN;CORONARY ARTERY DISEASE\CATH/STENT PLACEMENT
['Chest Pain', 'Coronary Artery Disease', 'Cath', 'Stent Placement'] 

SEPSIS;RESPIRATORY FAILURE;ACUTE RENAL FAILURE
['Sepsis', 'Respiratory Failure', 'Renal Failure'] 

NON-ST SEGMENT ELEVATION MYOCARDIAL INFARCTION;ATRIAL FIBRILLATION;HYPOTENSION\CARDIAC CATH
['Nonst Segment Elevation Myocardial Infarction', 'Atrial Fibrillation', 'Hypotension', 'Cardiac Cath'] 

AORTIC STENOSIS;CONGESTIVE HEART FAILURE\CARDIAC CATHETERIZATION
['Aortic Stenosis', 'Congestive Heart Failure', 'Cardiac Catheterization'] 

NON-ST SEGMENT ELEVATION MYOCARDIAL INFARCTION;ATRIAL FIBRILLATION;HYPOTENSION\CARDIAC CATH
['Nonst Segment Elevation Myocardial Infarction', 'Atrial Fibrillation', 'Hypotension', 'Cardiac Cath'] 

AORTIC STENOSIS;CONGESTIVE HEART FAILURE\CARDIAC CATHETERIZATION
['Aortic Stenosis', 'Congestive Heart Failure', 'Cardiac Catheterization'] 

CHEST PAIN;CORONARY ARTERY DISEASE\CATH/STENT PLACEMENT
['Chest Pain', 'Coronary Artery Disea

In [213]:
for i, j in zip(*(P_ == 1).nonzero()):
    print(disease_names[j], '<->', groups[i][0])

577: Diseases Of Pancreas <-> Pancreatitis
427: Cardiac Dysrhythmias <-> Syncope
427.31: Atrial Fibrillation <-> Syncope
427: Cardiac Dysrhythmias <-> Cardiac Arrest
427: Cardiac Dysrhythmias <-> Ventricular Tachycardia
427.1: Paroxysmal Ventricular Tachycardia <-> Ventricular Tachycardia
437.3: Cerebral Aneurysm Nonruptured <-> Brain Aneurysm


In [245]:
with open('posterior.pkl', 'rb') as f:
    Q, Pi = pickle.load(f)
Q = numpy.vstack(Q)

In [360]:
Q.min()

0.0004999344237148762

In [361]:
diff2_ = (1-Q) * numpy.log((1-Q) / numpy.maximum(1-P, 1e-9))

  """Entry point for launching an IPython kernel.
  """Entry point for launching an IPython kernel.


In [365]:
P = numpy.nan * numpy.ones(Y.shape)
for i, label in enumerate(labels):
#     if len(label) > 0:
#         p_ = P_[numpy.array(label)]
#         P[i] = numpy.nanmax(p_, axis=0)
    if len(label) == 1:
        P[i] = P_[label[0]]
        
P[P == 0] = numpy.nan
diff1_ = Q * numpy.log(Q / P)
diff2_ = (1-Q) * numpy.log(numpy.maximum((1-Q) / numpy.maximum(1-P, 1e-9), 1e-9))
diff_ = diff1_ + diff2_
diff = numpy.nansum(diff_, axis=1)
diff /= numpy.maximum((~numpy.isnan(diff_)).sum(axis=1), 1)
diff[numpy.isnan(diff_).all(axis=1)] = -numpy.inf
(diff > -numpy.inf).mean()
I = numpy.argsort(diff)[::-1]
diff[I[0]]

1.5341167338766872

In [275]:
organs = {
    'brain': [
        '191', '198.3', '348.1', '348.31', '348.4', '348.5', '349.82', 
        '430-438', '430', '431', '432.1', '434', '437.3', '850-854', '852'
    ],
    'liver': [
        '070', '155', '155.0', '570', '571', '571.1', '571.2', '572', '572.2',
        '572.3', '572.4', '574'
    ],
    'lung': ['162', '480-488', '507', '511', '518.0', '518.81', '997.3'],
    'kidney': ['403', '580-589', '584', '585'],
    'heart': [
        '396', '397.0', '410-414', '410', '410.7', '414.0', '416', '424.0', 
        '424.1', '425', '426', '427', '427.1', '427.31', '427.32', '427.41', 
        '427.5', '428', '428.0', '428.2', '428.3', '785.51', '997.1'
    ],
    'sepsis': ['038', '785.52', '995.9', '995.92']
}

colors = {
    'brain': 'green',
    'liver': 'gold',
    'lung': 'blue', 
    'kidney': 'magenta',
    'heart': 'red',
    'sepsis': 'peru'
}

def lookup_color(code):
    for k in organs:
        if code in organs[k]:
            return colors[k]
    else:
        return 'white'

In [366]:
z = manifold.TSNE(perplexity=25, n_iter=1000).fit_transform(numpy.vstack(Z).T)
pyplot.close(7)
fig, ax = pyplot.subplots(num=7)
c = [lookup_color(i) for i in priors.index]
ax.scatter(z[:, 0], z[:, 1], c=c, s=64, linewidths=1, edgecolors='black')
for i in ax.spines.values():
    i.set_visible(False)
pyplot.xticks([])
pyplot.yticks([])
texts = [
    pyplot.text(z[i, 0], z[i, 1], ' ' * 4 + txt, ha='left', va='center', fontsize=5) 
    for i, txt in enumerate(disease_names)
]
fig.canvas.layout.width = '1200px'
fig.canvas.layout.height = '800px'

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [278]:
risk = numpy.log2(Q / numpy.vstack(Pi))

In [None]:
diff[I[1100]]

In [384]:
5 +5 

10

In [383]:
pyplot.close(0)
fig, ax = pyplot.subplots(num=0)
fig.canvas.layout.width = '1200px'
fig.canvas.layout.height = '800px'
# i = numpy.random.randint(len(risk))
i = I[500]
# p = numpy.vstack(Q)[i]
y = numpy.vstack(Y)[i]
Y_diag = numpy.hstack([y['admission_diagnosis'] for y in Y_dict])
y0 = Y_diag[i]

c = [
    [1, 0, 0, round(min(risk[i, j], 2)) / 2] if risk[i, j] > 1 
    else [0, 1, 0, round(min(abs(risk[i, j]), 4)) / 4] if risk[i, j] < -1 
    else [0, 0, 0, 0] for j in range(len(y))
]

edges = ['white' if i == 1 else 'black' for i in y]
# edges = ['white' if i != 1 else 'black' for i in y]
pyplot.scatter(z[:, 0], z[:, 1], c=c, s=80, linewidths=2, edgecolors=edges)
ax = fig.gca()
for spine in ax.spines.values():
    spine.set_visible(False)
pyplot.xticks([])
pyplot.yticks([])
texts = [
    pyplot.text(z[j, 0], z[j, 1], ' ' * 4 + disease_names[j], ha='left', va='center', fontsize=6, alpha=1 if numpy.abs(risk[i, j]) > 1 or y[j] == 1 else 0) 
    for j in range(len(y))
]
pyplot.title('Admitted for: {}'.format(y0.decode()))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 1.0, 'Admitted for: Cerebral Aneurysm/Sda')

In [321]:
len(I)

43557