In [1]:
import tensorflow as tf
from matplotlib import pyplot
import matplotlib
import matplotlib.cm as colormap
import numpy
import os
import json, pickle
import pandas
from functools import partial, reduce
import importlib
from sklearn import manifold

import sys
sys.path.append('../libs')

import flacdb
import prepare_data
import initialize
import data_pipeline
import conv_model
import plot_batch
import generate_report_kfold
import icd_util

In [2]:
H, X0, Y0, P0, metadata, priors = generate_report_kfold.generate_predictions(
    model_id = 1469795,
    fold_index = 0,
    checkpoint_index = 2,
    example_count_log2 = 14
)

found hypes ../hypes/1469795_20200512-210303.json 
found weights /scr1/checkpoints/1469795_20200512-210303_00384.ckpt
loading predictions


In [3]:
H, X1, Y1, P1, metadata, priors = generate_report_kfold.generate_predictions(
    model_id = 1469816,
    fold_index = 1,
    checkpoint_index = 2,
    example_count_log2 = 14
)

found hypes ../hypes/1469816_20200512-213718.json 
found weights /scr1/checkpoints/1469816_20200512-213718_00384.ckpt
loading predictions


In [4]:
H, X2, Y2, P2, metadata, priors = generate_report_kfold.generate_predictions(
    model_id = 1470209,
    fold_index = 2,
    checkpoint_index = 2,
    example_count_log2 = 14
)

found hypes ../hypes/1470209_20200513-050523.json 
found weights /scr1/checkpoints/1470209_20200513-050523_00384.ckpt
loading predictions


In [5]:
P = numpy.vstack([P0, P1, P2])
X = {k: numpy.concatenate([X0[k], X1[k], X2[k]], axis=0) for k in X0}
Y = {k: numpy.concatenate([Y0[k], Y1[k], Y2[k]], axis=0) for k in Y0}

In [6]:
%%time

%matplotlib widget

plotter = generate_report_kfold.generate_plotter(H, X, Y, P, metadata, priors)
plotter()

(43557, 90) predictions shape


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

interactive(children=(SelectionSlider(continuous_update=False, description='Condition:', layout=Layout(width='…

CPU times: user 6.65 s, sys: 14.9 s, total: 21.6 s
Wall time: 21 s


<function generate_report_kfold.get_plotter.<locals>.update(code, threshold, example_index, log_scale)>

In [7]:
organs = {
    'brain': [
        '191', '198.3', '348.1', '348.31', '348.4', '348.5', '349.82', 
        '430-438', '430', '431', '432.1', '434', '437.3', '850-854', '852'
    ],
    'liver': [
        '070', '155', '155.0', '570', '571', '571.1', '571.2', '572', '572.2',
        '572.3', '572.4', '574'
    ],
    'lung': ['162', '480-488', '507', '511', '518.0', '518.81', '997.3'],
    'kidney': ['403', '580-589', '584', '585'],
    'heart': [
        '396', '397.0', '410-414', '410', '410.7', '414.0', '416', '424.0', 
        '424.1', '425', '426', '427', '427.1', '427.31', '427.32', '427.41', 
        '427.5', '428', '428.0', '428.2', '428.3', '785.51', '997.1'
    ],
    'sepsis': ['038', '785.52', '995.9', '995.92']
}

colors = {
    'brain': 'green',
    'liver': 'gold',
    'lung': 'blue', 
    'kidney': 'magenta',
    'heart': 'red',
    'sepsis': 'peru'
}

def lookup_color(code):
    for k in organs:
        if code in organs[k]:
            return colors[k]
    else:
        return 'white'

In [8]:
group_names = icd_util.load_group_strings()

def get_name(code):
    if code not in group_names:
        return code.replace('_', ' ').title()
    name = code + ': ' + group_names[code]
    name = name.replace('/', '_')
    for character in "',()[]":
        name = name.replace(character, '')
    return name

names = [get_name(i) for i in priors.index]

In [9]:
z = manifold.TSNE(perplexity=25, n_iter=1000).fit_transform(P.T)
fig = pyplot.figure(7)
fig.clear()
c = [lookup_color(i) for i in priors.index]
pyplot.scatter(z[:, 0], z[:, 1], c=c, s=64, linewidths=1, edgecolors='black')
ax = fig.gca()
for i in ax.spines.values():
    i.set_visible(False)
pyplot.xticks([])
pyplot.yticks([])
texts = [
    pyplot.text(z[i, 0], z[i, 1], ' ' * 4 + txt, ha='left', va='center', fontsize=6) 
    for i, txt in enumerate(names)
]
fig.canvas.layout.width = '100%'
fig.canvas.layout.height = '1200px'
fig.canvas

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(layout=Layout(height='1200px', width='100%'), toolbar=Toolbar(toolitems=[('Home', 'Reset original view'…

In [41]:
P.shape

(43557, 90)

In [12]:
fig = pyplot.figure(6)
fig.clear()
fig.canvas.layout.width = '100%'
fig.canvas.layout.height = '1000px'
i = numpy.random.randint(len(P))
p = P[i]
y = Y['diagnosis'][i]
y0 = Y['admission_diagnosis'][i]

norm = matplotlib.colors.Normalize(vmin=0, vmax=1, clip=True)
mapper = colormap.ScalarMappable(norm=norm, cmap=colormap.Reds)
c = mapper.to_rgba(p)
c[:, -1] = p

edges = ['white' if i == 1 else 'black' for i in y]
# edges = ['white' if i != 1 else 'black' for i in y]
pyplot.scatter(z[:, 0], z[:, 1], c=c, s=80, linewidths=2, edgecolors=edges)
ax = fig.gca()
for spine in ax.spines.values():
    spine.set_visible(False)
pyplot.xticks([])
pyplot.yticks([])
texts = [
    pyplot.text(z[j, 0], z[j, 1], ' ' * 4 + names[j], ha='left', va='center', fontsize=6, alpha=p[j]) 
    for j in range(len(p))
]
pyplot.title('Admitted for: {}'.format(y0.decode()))
fig.canvas

Canvas(layout=Layout(height='1000px', width='100%'), toolbar=Toolbar(toolitems=[('Home', 'Reset original view'…