In [1]:
%matplotlib widget

import tensorflow as tf
from matplotlib import pyplot
import matplotlib
import matplotlib.cm as colormap
import numpy
import os
import json, pickle
import pandas
from functools import partial, reduce
import importlib
from sklearn import manifold
from scipy import stats

import sys
sys.path.append('../libs')

import flacdb
import prepare_data
import initialize
import data_pipeline
import conv_model
import plot_batch
import generate_report_kfold
import icd_util

pyplot.style.use('dark_background')

In [2]:
%%time

model_ids = [1469795, 1469816, 1470209]
checkpoint_index = 2
Y_dict, X_dict, Z = [], [], []

for fold_index, model_id in enumerate(model_ids):
    H, x, y, p, metadata, priors = generate_report_kfold.generate_predictions(
        model_id = model_id,
        fold_index = fold_index,
        checkpoint_index = checkpoint_index,
        example_count_log2 = 14
    )
    X_dict.append(x)
    Y_dict.append(y)
    Z.append(numpy.log(p / (1 - p)))

X = [x['signals'] for x in X_dict]
Y = [y['diagnosis'] for y in Y_dict]

found hypes ../hypes/1469795_20200512-210303.json 
found weights /scr1/checkpoints/1469795_20200512-210303_00384.ckpt
loading predictions
found hypes ../hypes/1469816_20200512-213718.json 
found weights /scr1/checkpoints/1469816_20200512-213718_00384.ckpt
loading predictions
found hypes ../hypes/1470209_20200513-050523.json 
found weights /scr1/checkpoints/1470209_20200513-050523_00384.ckpt
loading predictions
CPU times: user 16 s, sys: 5.66 s, total: 21.7 s
Wall time: 21.6 s


In [3]:
with open('posterior.pkl', 'rb') as f:
    Q, Pi = pickle.load(f)

In [4]:
organs = {
    'brain': [
        '191', '198.3', '348.1', '348.31', '348.4', '348.5', '349.82', 
        '430-438', '430', '431', '432.1', '434', '437.3', '850-854', '852'
    ],
    'liver': [
        '070', '155', '155.0', '570', '571', '571.1', '571.2', '572', '572.2',
        '572.3', '572.4', '574'
    ],
    'lung': ['162', '480-488', '507', '511', '518.0', '518.81', '997.3'],
    'kidney': ['403', '580-589', '584', '585'],
    'heart': [
        '396', '397.0', '410-414', '410', '410.7', '414.0', '416', '424.0', 
        '424.1', '425', '426', '427', '427.1', '427.31', '427.32', '427.41', 
        '427.5', '428', '428.0', '428.2', '428.3', '785.51', '997.1'
    ],
    'sepsis': ['038', '785.52', '995.9', '995.92']
}

colors = {
    'brain': 'green',
    'liver': 'gold',
    'lung': 'blue', 
    'kidney': 'magenta',
    'heart': 'red',
    'sepsis': 'peru'
}

def lookup_color(code):
    for k in organs:
        if code in organs[k]:
            return colors[k]
    else:
        return 'white'

group_names = icd_util.load_group_strings()

def get_name(code):
    if code not in group_names:
        return code.replace('_', ' ').title()
    name = code + ': ' + group_names[code]
    name = name.replace('/', '_')
    for character in "',()[]":
        name = name.replace(character, '')
    return name

names = [get_name(i) for i in priors.index]

In [7]:
z = manifold.TSNE(perplexity=25, n_iter=1000).fit_transform(numpy.vstack(Z).T)
pyplot.close(7)
fig, ax = pyplot.subplots(num=7)
c = [lookup_color(i) for i in priors.index]
ax.scatter(z[:, 0], z[:, 1], c=c, s=64, linewidths=1, edgecolors='black')
for i in ax.spines.values():
    i.set_visible(False)
pyplot.xticks([])
pyplot.yticks([])
texts = [
    pyplot.text(z[i, 0], z[i, 1], ' ' * 4 + txt, ha='left', va='center', fontsize=5) 
    for i, txt in enumerate(names)
]
fig.canvas.layout.width = '1200px'
fig.canvas.layout.height = '800px'

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [8]:
sigs = numpy.vstack(X)
risk = numpy.log2(numpy.vstack(Q) / numpy.vstack(Pi))

In [17]:
j = next(j for j, name in enumerate(names) if 'atrial fib' in name.lower())
names[j]

'427.31: Atrial Fibrillation'

In [22]:
print(risk[I[0], j], risk[I[-1], j])

-4.978618938049022 1.6255879248358038


In [82]:
Z_stack = numpy.vstack(Z)

In [131]:
i1 = numpy.vstack(Z).sum(axis=1)[0]

In [223]:
Z_stack = numpy.vstack(Z)
M = metadata.reset_index()[['subject_id', 'rec_id']].drop_duplicates()
M = M.set_index('rec_id', verify_integrity=True)
subject_ids = numpy.hstack([M.loc[i['rec_id']].values[:, 0] for i in Y_dict])
# i1 = numpy.argsort(Z_stack.max(axis=1))[4]

def plot_condition(substr, fig_num, i1=0, i2=-1):
    
    j0 = next(j for j, name in enumerate(names) if substr in name.lower())
    # I = numpy.argsort(risk[:, j0])
    I = numpy.argsort(Z_stack[:, j0])

    J = [H['input_sigs'].index(j) for j in H['input_sigs_validation']]
    x1, x2 = sigs[I[i1], :, J], sigs[I[i2], :, J]

    pyplot.close(fig_num)
    fig, axes = pyplot.subplots(nrows=len(J), num=fig_num)
    axes[0].set_title(names[j0])
    axes[-1].set_xlabel('Subject {}'.format(subject_ids[I[i2]]))

    for j, axis in enumerate(axes):
        for spine in axis.spines.values():
            spine.set_visible(False)
        axis.set_xticks([])
        axis.plot(x1[j], c=[0, 1, 0, 0.2], linewidth=0.5)
        axis.plot(x2[j], c=[1, 0, 0, 0.6], linewidth=0.5)
        axis.set_ylabel(H['input_sigs_validation'][j])
        axis.yaxis.tick_right()
        axis.tick_params(axis='y', colors='gray')

    pyplot.tight_layout(pad=1)
    pyplot.subplots_adjust(hspace=0.1)

In [213]:
plot_condition('atrial fib', fig_num=0, i1=5, i2=-2)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [214]:
plot_condition('age at least', fig_num=1, i1=0, i2=-1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [215]:
plot_condition('cardiogenic', fig_num=2, i1=0, i2=-1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [216]:
plot_condition('cerebral an', fig_num=3, i1=3, i2=-4)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [217]:
plot_condition('neoplasm of brain', fig_num=4, i1=0, i2=-1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [218]:
plot_condition('424.1', fig_num=5, i1=0, i2=-1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [225]:
plot_condition('hepatorenal', fig_num=6, i1=0, i2=-1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [226]:
plot_condition('alcoholic hepatitis', fig_num=7, i1=0, i2=-1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [227]:
plot_condition('alcoholic cir', fig_num=8, i1=0, i2=-5)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [232]:
plot_condition('hepatic coma', fig_num=9, i1=0, i2=-9)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [235]:
plot_condition('cardiac arr', fig_num=10, i1=0, i2=-4)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [237]:
plot_condition('died', fig_num=11, i1=0, i2=-1)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
plot_condition('died', fig_num=11, i1=0, i2=-1)