In [1]:
%matplotlib widget

import tensorflow as tf
from matplotlib import pyplot
import matplotlib
import matplotlib.cm as colormap
import numpy
import os
import json, pickle
import pandas
from functools import partial, reduce
import importlib
from sklearn import manifold
from scipy import stats

import sys
sys.path.append('../libs')

import flacdb
import prepare_data
import initialize
import data_pipeline
import conv_model
import plot_batch
import generate_report_kfold
import icd_util

# pyplot.style.use('dark_background')

In [2]:
%%time

model_ids = [1469795, 1469816, 1470209]
checkpoint_index = 2
Y_dict, Z = [], []

for fold_index, model_id in enumerate(model_ids):
    H, x, y, p, metadata, priors = generate_report_kfold.generate_predictions(
        model_id = model_id,
        fold_index = fold_index,
        checkpoint_index = checkpoint_index,
        example_count_log2 = 14
    )
    Y_dict.append(y)
    Z.append(numpy.log(p / (1 - p)))

Y = [y['diagnosis'] for y in Y_dict]

found hypes ../hypes/1469795_20200512-210303.json 
found weights /scr1/checkpoints/1469795_20200512-210303_00384.ckpt
loading predictions
found hypes ../hypes/1469816_20200512-213718.json 
found weights /scr1/checkpoints/1469816_20200512-213718_00384.ckpt
loading predictions
found hypes ../hypes/1470209_20200513-050523.json 
found weights /scr1/checkpoints/1470209_20200513-050523_00384.ckpt
loading predictions
CPU times: user 15.8 s, sys: 5.22 s, total: 21 s
Wall time: 21 s


In [9]:
with open('posterior.pkl', 'rb') as f:
    Q, Pi = pickle.load(f)

In [4]:
organs = {
    'brain': [
        '191', '198.3', '348.1', '348.31', '348.4', '348.5', '349.82', 
        '430-438', '430', '431', '432.1', '434', '437.3', '850-854', '852'
    ],
    'liver': [
        '070', '155', '155.0', '570', '571', '571.1', '571.2', '572', '572.2',
        '572.3', '572.4', '574'
    ],
    'lung': ['162', '480-488', '507', '511', '518.0', '518.81', '997.3'],
    'kidney': ['403', '580-589', '584', '585'],
    'heart': [
        '396', '397.0', '410-414', '410', '410.7', '414.0', '416', '424.0', 
        '424.1', '425', '426', '427', '427.1', '427.31', '427.32', '427.41', 
        '427.5', '428', '428.0', '428.2', '428.3', '785.51', '997.1'
    ],
    'sepsis': ['038', '785.52', '995.9', '995.92']
}

colors = {
    'brain': 'green',
    'liver': 'gold',
    'lung': 'blue', 
    'kidney': 'magenta',
    'heart': 'red',
    'sepsis': 'peru'
}

def lookup_color(code):
    for k in organs:
        if code in organs[k]:
            return colors[k]
    else:
        return 'white'

group_names = icd_util.load_group_strings()

def get_name(code):
    if code not in group_names:
        return code.replace('_', ' ').title()
    name = code + ': ' + group_names[code]
    name = name.replace('/', '_')
    for character in "',()[]":
        name = name.replace(character, '')
    return name

names = [get_name(i) for i in priors.index]

In [3]:
numpy.vstack(Z).shape

NameError: name 'Z' is not defined

In [5]:
pyplot.style.use('default')

In [6]:
z = manifold.TSNE(perplexity=25, n_iter=1000).fit_transform(numpy.vstack(Z).T)
pyplot.close(7)
fig, ax = pyplot.subplots(num=7)
c = [lookup_color(i) for i in priors.index]
ax.scatter(z[:, 0], z[:, 1], c=c, s=64, linewidths=1, edgecolors='black')
for i in ax.spines.values():
    i.set_visible(False)
pyplot.xticks([])
pyplot.yticks([])
texts = [
    pyplot.text(z[i, 0], z[i, 1], ' ' * 4 + txt, ha='left', va='center', fontsize=5) 
    for i, txt in enumerate(names)
]
fig.canvas.layout.width = '1200px'
fig.canvas.layout.height = '800px'

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
risk = numpy.log2(numpy.vstack(Q) / numpy.vstack(Pi))

In [64]:
pyplot.close(0)
fig, ax = pyplot.subplots(num=0)
fig.canvas.layout.width = '1200px'
fig.canvas.layout.height = '800px'
i = numpy.random.randint(len(risk))
# p = numpy.vstack(Q)[i]
y = numpy.vstack(Y)[i]
Y_diag = numpy.hstack([y['admission_diagnosis'] for y in Y_dict])
y0 = Y_diag[i]

c = [
    [1, 0, 0, round(min(risk[i, j], 2)) / 2] if risk[i, j] > 1 
    else [0, 1, 0, round(min(abs(risk[i, j]), 4)) / 4] if risk[i, j] < -1 
    else [0, 0, 0, 0] for j in range(risk.shape[1])
]

edges = ['black' if i == 1 else 'white' for i in y]
pyplot.scatter(z[:, 0], z[:, 1], c=c, s=80, linewidths=2, edgecolors=edges)
ax = fig.gca()
for spine in ax.spines.values():
    spine.set_visible(False)
pyplot.xticks([])
pyplot.yticks([])
texts = [
    pyplot.text(z[j, 0], z[j, 1], ' ' * 4 + names[j], ha='left', va='center', fontsize=6, alpha=1 if numpy.abs(risk[i, j]) > 1 or y[j] == 1 else 0) 
    for j in range(risk.shape[1])
]
pyplot.title('Admitted for: {}'.format(y0.decode()))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0.5, 1.0, 'Admitted for: Tracheal Esophageal Fistula')

In [20]:
risk.shape

(43557, 90)

In [21]:
len(p)

13576