In [1]:
%matplotlib widget

import tensorflow as tf
from matplotlib import pyplot
import matplotlib
import matplotlib.cm as colormap
import numpy
import os
import json, pickle
import pandas
from functools import partial, reduce
import importlib
from sklearn import manifold
from scipy import stats

import sys
sys.path.append('../libs')

import flacdb
import prepare_data
import initialize
import data_pipeline
import conv_model
import plot_batch
import generate_report_kfold
import icd_util

pyplot.style.use('default')
# pyplot.style.use('dark_background')

In [2]:
%%time

model_ids = [1469795, 1469816, 1470209]
checkpoint_index = 2
Y_dict, Z, P = [], [], []

for fold_index, model_id in enumerate(model_ids):
    H, x, y, p, metadata, priors = generate_report_kfold.generate_predictions(
        model_id = model_id,
        fold_index = fold_index,
        checkpoint_index = checkpoint_index,
        example_count_log2 = 14
    )
    Y_dict.append(y)
    P.append(p)
    Z.append(numpy.log(p / (1 - p)))

Y = [y['diagnosis'] for y in Y_dict]

found hypes ../hypes/1469795_20200512-210303.json 
found weights /scr1/checkpoints/1469795_20200512-210303_00384.ckpt
loading predictions
found hypes ../hypes/1469816_20200512-213718.json 
found weights /scr1/checkpoints/1469816_20200512-213718_00384.ckpt
loading predictions
found hypes ../hypes/1470209_20200513-050523.json 
found weights /scr1/checkpoints/1470209_20200513-050523_00384.ckpt
loading predictions
CPU times: user 16 s, sys: 5.34 s, total: 21.3 s
Wall time: 21.3 s


In [3]:
group_names = icd_util.load_group_strings()

def get_name(code):
    if code not in group_names:
        return code.replace('_', ' ').title()
    name = code + ': ' + group_names[code]
    name = name.replace('/', '_')
    for character in "',()[]":
        name = name.replace(character, '')
    return name

names = numpy.array([get_name(i).strip() for i in priors.index])

def gaussian(diff, sig):
    a = sig * numpy.sqrt(2*numpy.pi)
    b = -2 * sig**2
    return numpy.exp(diff**2 / b) / a

from scipy import stats
mu, sig = numpy.random.rand(2) * 5
p1 = stats.norm(mu, sig).pdf(0)
p2 = gaussian(mu, sig)
assert(numpy.isclose(p1, p2))

In [4]:
%%time
bandwidth = 0.8
K = []
Z_ = []
for z in Z:
    z_low, z_high = numpy.percentile(z, [0.1, 99.9], axis=0)
    z_ = numpy.linspace(z_low, z_high, 1000, axis=1, dtype='float32')
    Z_.append(z_)
    diff = numpy.expand_dims(z, axis=-1) - numpy.expand_dims(z_, axis=0)
    K.append(gaussian(diff, bandwidth))

CPU times: user 51.4 s, sys: 8.38 s, total: 59.7 s
Wall time: 59.7 s


In [5]:
M = metadata.reset_index()[['subject_id', 'rec_id']].drop_duplicates()
M = M.set_index('rec_id', verify_integrity=True)
subject_ids = [M.loc[i['rec_id']].values[:, 0] for i in Y_dict]
unique_ids = [numpy.array(sorted(set(i))) for i in subject_ids]

In [6]:
%%time

density = [k.sum(axis=0) for k in K]
sums = [d.sum(axis=1) * (z_[:, 1] - z_[:, 0]) for d, z_ in zip(density, Z_)]
density = [d / numpy.expand_dims(s, axis=-1) for d, s in zip(density, sums)]
K_pos = [k * numpy.expand_dims(y == 1, axis=-1) for y, k in zip(Y, K)]
K_neg = [k * numpy.expand_dims(y == -1, axis=-1) for y, k in zip(Y, K)]
del K
density_pos = [k.sum(axis=0) for k in K_pos]
density_neg = [k.sum(axis=0) for k in K_neg]

CPU times: user 19.2 s, sys: 5.59 s, total: 24.8 s
Wall time: 24.8 s


In [8]:
j = next(i for i in range(len(names)) if 'cirrhosis of liv' in names[i].lower())
print(j, names[j])
pyplot.close(0)
fig, ax = pyplot.subplots(num=0)
# ax.plot(domain[j], density[j], 'w');
ax.plot(Z_[0][j], K_pos[0].sum(axis=0)[j] / K_pos[0].sum(axis=0)[j].mean(), 'r');
ax.plot(Z_[0][j], K_neg[0].sum(axis=0)[j] / K_neg[0].sum(axis=0)[j].mean(), 'g');
pyplot.ylabel('Probability Density')
pyplot.xlabel('Network Output')
pyplot.legend(['Positive', 'Negative'])

62 571.2: Alcoholic Cirrhosis Of Liver


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.legend.Legend at 0x7f50184f9850>

In [9]:
%%time

density_pos_pt = [
    numpy.stack([
        K_pos[i][subject_ids[i] == j].sum(axis=0)
        for j in unique_ids[i]
    ])
    for i in range(len(Z))
]

density_neg_pt = [
    numpy.stack([
        K_neg[i][subject_ids[i] == j].sum(axis=0) 
        for j in unique_ids[i]
    ])
    for i in range(len(Z))
]

CPU times: user 7.21 s, sys: 592 ms, total: 7.8 s
Wall time: 7.81 s


In [10]:
%%time

range_indices = numpy.expand_dims(numpy.arange(Z[0].shape[1]), 1)
diffs = [z_[:, 1:2] - z_[:, :1] for z_ in Z_]
JZ = [(Z[i].T - Z_[i][:, :1]) / diffs[i] for i in range(len(Z))]
JZ = [numpy.round(jz.clip(0, Z_[0].shape[1] - 1)).astype('int') for jz in JZ]

Q = [numpy.zeros(z.shape) * numpy.nan for z in Z]
Pi = [numpy.zeros(z.shape) * numpy.nan for z in Z]

for i in range(len(Z)):
    for j, subject_id in enumerate(unique_ids[i]):
        a = density_pos[i] - density_pos_pt[i][j] + 1
        b = density_neg[i] - density_neg_pt[i][j]
        q_ = a / (a + b)
        J = subject_ids[i] == subject_id
        assert(numpy.isnan(Q[i][J]).all())
        Q[i][J] = q_[range_indices, JZ[i][:, J]].T
        pi = (Y[i][~J] == 1).sum(axis=0)
        pi = pi / numpy.maximum((Y[i][~J] != 0).sum(axis=0), 1)
        Pi[i][J] = numpy.expand_dims(pi, 0)

risk = [numpy.log2(q / pi) for q, pi in zip(Q, Pi)]
# risk = [numpy.log2(p / pi) for p, pi in zip(P, Pi)]
        
error = (Z_[0][range_indices, JZ[0]] - Z[0].T) / diffs[0]
assert(numpy.abs(error).mean() < 1)
assert(not any(numpy.isnan(q).any() for q in Q))

CPU times: user 7.34 s, sys: 36 ms, total: 7.38 s
Wall time: 7.37 s


In [11]:
Y_stack = numpy.vstack(Y)
Z_stack = numpy.vstack(Z)
risk_stack = numpy.vstack(risk)

In [12]:
j = next(i for i in range(len(names)) if 'cirrhosis of liv' in names[i].lower())
print(j, names[j])

# sorted_indices = risk_stack[:, j].argsort()
triage_order = numpy.vstack(P)[:, j].argsort()[::-1]
y_sorted = Y_stack[triage_order, j]

sensitivity = (y_sorted == 1).cumsum() / (y_sorted == 1).sum()

pyplot.close(30)
fig, ax = pyplot.subplots(num=30)

p0 = 0.2
s0 = sensitivity[round(triage_order.size * p0)]

ax.plot([0, 1], [0, 1], '--', color='y')
ax.plot(numpy.linspace(0, 1, triage_order.size), sensitivity, 'b')
ax.plot([p0, p0], [0, 1], ':', color='gray')
ax.plot(p0, p0, 'ow', mfc='none')
ax.plot(p0, s0, 'ow', mfc='none')
ax.annotate(
    '{:d}% detected'.format(int(round(p0 * 100))), 
    (p0, p0), 
    textcoords="offset points",
    xytext=(8, -8),
    ha='left'
)

ax.annotate(
    '{:d}% detected'.format(int(round(s0 * 100))),
    (p0, s0), 
    textcoords="offset points",
    xytext=(8, -8),
    ha='left'
)

ax.set_xlabel('Percentile', fontsize=14)
ax.set_ylabel('Sensitivity', fontsize=14)
# ax.set_title(names[j])

pyplot.legend(['Random', 'Triaged'])

fig.canvas.layout.height = '600px'
ax.set_aspect('equal')

62 571.2: Alcoholic Cirrhosis Of Liver


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [13]:
j = next(i for i in range(len(names)) if 'heart fail' in names[i].lower())
print(j, names[j])
pyplot.close(30)
fig, ax = pyplot.subplots(num=30)

42 428: Heart Failure


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [567]:
# with open('posterior.pkl', 'wb') as f:
#     pickle.dump([Q, Pi], f)

In [14]:
j = next(i for i in range(len(names)) if 'heart fail' in names[i].lower())
print(j, names[j])
pyplot.close(1)
fig, ax = pyplot.subplots(num=1)
z = numpy.hstack([z[:, j] for z in Z])
r = numpy.hstack([r[:, j] for r in risk])
ax.plot(z, r, '.m')

42 428: Heart Failure


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x7f50437e5e10>]

In [15]:
j = next(i for i in range(len(names)) if 'heart fail' in names[i].lower())
print(j, names[j])
pyplot.close(2)
fig, ax = pyplot.subplots(num=2)
is_pos = numpy.vstack(Y) == 1
is_neg = numpy.vstack(Y) == -1
z2 = numpy.hstack([numpy.log(q[:, j] / (1 - q[:, j])) for q in Q])
# ax.hist(z2, bins=100, color='w');
ax.hist(z2[is_neg[:, j]], bins=100, color='g');
ax.hist(z2[is_pos[:, j]], bins=100, color='r');

42 428: Heart Failure


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [16]:
pos = (Y_stack == 1).sum(axis=0)
neg = (Y_stack == -1).sum(axis=0)
prevalence = pos / (pos + neg)

# flagged = numpy.logical_and(risk_stack < -2, risk_stack > -3)

min_pos_pred = numpy.expand_dims(numpy.percentile(Z_stack, 10, axis=0), axis=0)
flagged_already = numpy.zeros(risk_stack.shape, dtype='bool')
flagged_low_risk = []
for i in [-4, -3, -2, -1]:
    flagged = numpy.logical_and(risk_stack < i, ~flagged_already)
    flagged_low_risk.append(flagged)
    flagged_already = numpy.logical_or(flagged_already, flagged)
# flagged_low_risk = [risk_stack < i for i in [-4, -3, -2, -1]]

flagged_high_risk = []
for i in [2, 1]:
    flagged = numpy.logical_and(risk_stack > i, Z_stack > min_pos_pred)
    flagged = numpy.logical_and(flagged, ~flagged_already)
    flagged_high_risk.append(flagged)
    flagged_already = numpy.logical_or(flagged_already, flagged)

F = numpy.stack(flagged_low_risk + [~flagged_already] + flagged_high_risk[::-1])

pos_flagged = (F * (Y_stack == 1)).sum(axis=1)
neg_flagged = (F * (Y_stack == -1)).sum(axis=1)
prevalence_flagged = pos_flagged / numpy.maximum(pos_flagged + neg_flagged, 1)
risk_flagged = numpy.log2((prevalence_flagged.clip(1e-12, 1) / prevalence))

flag_index = 0
probs = 100 * F.mean(axis=1)[flag_index]
# probs = 100 * prob_flagged[flag_index]
I = numpy.argsort(probs)[::-1]
print((probs > 2).sum(), 'diseases')
for i in I:
    if probs[i] > 2:
        print(round(probs[i], 1), round(risk_flagged[flag_index, i], 1), names[i])

6 diseases
8.7 -4.6 785.51: Cardiogenic Shock
4.0 -3.3 348.5: Cerebral Edema
2.7 -5.6 426: Conduction Disorders
2.3 -5.5 430: Subarachnoid Hemorrhage
2.2 -4.5 410: Acute Myocardial Infarction
2.1 -5.8 348.4: Compression Of Brain


In [17]:
probs = F.mean(axis=1)
starts = probs.cumsum(axis=0) - probs
I = numpy.arange(F.shape[0]) != 4
J = numpy.argsort(probs[I].sum(0))
for j in J[::-1][:10]:
    print(names[j])

785.51: Cardiogenic Shock
426: Conduction Disorders
427.41: Ventricular Fibrillation
428.2: Systolic Heart Failure
437.3: Cerebral Aneurysm Nonruptured
425: Cardiomyopathy
410: Acute Myocardial Infarction
410.7: Subendocardial Infarction
198.3: Brain And Spinal Cord
191: Malignant Neoplasm Of Brain


In [37]:
ax.barh?

[0;31mSignature:[0m [0max[0m[0;34m.[0m[0mbarh[0m[0;34m([0m[0my[0m[0;34m,[0m [0mwidth[0m[0;34m,[0m [0mheight[0m[0;34m=[0m[0;36m0.8[0m[0;34m,[0m [0mleft[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0;34m*[0m[0;34m,[0m [0malign[0m[0;34m=[0m[0;34m'center'[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Make a horizontal bar plot.

The bars are positioned at *y* with the given *align*\ment. Their
dimensions are given by *width* and *height*. The horizontal baseline
is *left* (default 0).

Each of *y*, *width*, *height*, and *left* may either be a scalar
applying to all bars, or it may be a sequence of length N providing a
separate value for each bar.

Parameters
----------
y : scalar or array-like
    The y coordinates of the bars. See also *align* for the
    alignment of the bars to the coordinates.

width : scalar or array-like
    The width(s) of the bars.

height : sequence of scalars, optional, 

In [55]:
pyplot.close(4)
fig, ax = pyplot.subplots(num=4)

colors = [
    [0, 1, 0, 1],
    [0, 1, 0, 0.75],
    [0, 1, 0, 0.5],
    [0, 1, 0, 0.25],
    [0, 0, 0, 0],
    [1, 0, 0, 0.5],
    [1, 0, 0, 1]
]

labels = [i[:50] for i in names[J]]
n = 35
for i in range(len(probs)):
    ax.barh(labels[-n:], probs[i][J][-n:], left=starts[i][J][-n:], height=1, edgecolor='k', linewidth=1, color=colors[i])

ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
fig.canvas.layout.height = '900px'
fig.subplots_adjust(right=0.5)
ax.set_ylim(0.5, len(labels[-n:]) - 0.5)
ax.set_xlabel('1 - Percentile', fontsize=14)
ax.legend(['1/16', 
    '1/8', 
    '1/4', 
    '1/2', 
    '1', 
    '2', 
    '4'
], title='Relative Risk', loc='lower center', framealpha=1)
fig.tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [56]:
pyplot.close(5)
fig, ax = pyplot.subplots(num=5)

colors = [
    [0, 1, 0, 1],
    [0, 1, 0, 0.75],
    [0, 1, 0, 0.5],
    [0, 1, 0, 0.25],
    [0, 0, 0, 0],
    [1, 0, 0, 0.5],
    [1, 0, 0, 1]
]

targets = [-4, -3, -2, -1, 1, 2]

ax.plot(targets, targets, 'o', mfc='none', color='black', markersize=10, linewidth=3)

labels = [i[:50] for i in names[J]]
for i, r in enumerate(range(-4, 3)):
    if r != 0:
        J = numpy.logical_and(probs[i] > 0.02, risk_flagged[i] > -10)
        x = (risk_flagged[i][J] * probs[i][J]).sum() / probs[i][J].sum()
        dx = risk_flagged[i][J].std()
        pyplot.plot([r, r], [x - dx, x + dx], ':', color='black')
#         pyplot.plot(r, x, 'o', color=colors[i][:-1])
        pyplot.plot(r, x, 'o', color='black', markersize=10)

ax.yaxis.set_label_position('right')
ax.yaxis.tick_right()
ax.set_xlabel('Estimated Risk', fontsize=20, labelpad=20)
ax.set_ylabel('Actual Risk', fontsize=20, labelpad=10)
labels = ['$\\frac{1}{16}$', '$\\frac{1}{8}$', '$\\frac{1}{4}$', '$\\frac{1}{2}$', '1', '2', '4']
ax.set_xticks(range(-4, 3))
ax.set_yticks(range(-5, 4))
ax.set_xticklabels(labels, fontsize=20)
ax.set_yticklabels(['$\\frac{1}{32}$'] + labels + ['8'], fontsize=20)
# fig.legend(['Target', 'Error', 'Actual'], loc='upper left', 
#            bbox_to_anchor=(0.05, 0.95))
fig.canvas.layout.height = '1000px'
fig.canvas.layout.width = '1000px'
fig.tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [32]:
%%time
bandwidth = 0.5
r_ = numpy.linspace(-6, 4, 1000, dtype='float32')
K_risk = []
for i in range(len(probs)):
    J = numpy.logical_and(probs[i] > 0.02, risk_flagged[i] > -10)
    r = risk_flagged[i][J]
    diff = numpy.expand_dims(r, axis=-1) - numpy.expand_dims(r_, axis=0)
    K_risk.append(gaussian(diff, bandwidth) * probs[i][J][:, numpy.newaxis])

CPU times: user 20 ms, sys: 28 ms, total: 48 ms
Wall time: 43.1 ms


In [33]:
pyplot.close(6)
fig, ax = pyplot.subplots(num=6)
risk_density = numpy.vstack([i.sum(0) for i in K_risk])
risk_density /= risk_density.sum(1, keepdims=True)
for i in range(len(probs)):
    ax.plot(r_, risk_density[i], color=colors[i])
# ax.set_ylim(0, 20)
fig.tight_layout()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [616]:
codes_life_threat = [
    'Died', '038', '198.3', '348.1', '348.31', '348.4', '348.5', '410', 
    '427.41', '427.5', '428', '430', '431', '434', '570', '571.1', '572.2', 
    '572.3', '572.4', '785.51', '785.52'
]
J_life_threat = [next(i for i, name in enumerate(names) if j in name) for j in codes_life_threat]
for i in J_life_threat:
    print(names[i])

Died
038: Septicemia
198.3: Brain And Spinal Cord
348.1: Anoxic Brain Damage
348.31: Metabolic Encephalopathy
348.4: Compression Of Brain
348.5: Cerebral Edema
410: Acute Myocardial Infarction
427.41: Ventricular Fibrillation
427.5: Cardiac Arrest
428: Heart Failure
430: Subarachnoid Hemorrhage
431: Intracerebral Hemorrhage
434: Occlusion Of Cerebral Arteries
570: Acute And Subacute Necrosis Of Liver
571.1: Acute Alcoholic Hepatitis
572.2: Hepatic Coma
572.3: Portal Hypertension
572.4: Hepatorenal Syndrome
785.51: Cardiogenic Shock
785.52: Septic Shock


In [638]:
Q_life_threat = numpy.vstack(Q)[:, J_life_threat]
q_life_threat = 1 - numpy.prod(1 - Q_life_threat, axis=1)
I_life_threat = numpy.argsort(q_life_threat)
y_life_threat = numpy.vstack(Y)[:, J_life_threat].any(axis=1)
y_life_threat.mean()

0.9372316734394013

In [630]:
y_life_threat = numpy.vstack(Y)[:, J_life_threat].any(axis=1)
y_life_threat.mean()

0.9372316734394013

In [639]:
y_life_threat[I_life_threat[-1000:]].mean()

0.963

In [621]:
len(I_life_threat)

43557

In [625]:
q_life_threat[I_life_threat[:10]]

array([0.24118276, 0.24127313, 0.25864473, 0.27165263, 0.27254666,
       0.27551108, 0.27778007, 0.27783205, 0.27897889, 0.28018849])

In [653]:
j = next(i for i in range(len(names)) if 'cardiogenic' in names[i].lower())
names[j]

'785.51: Cardiogenic Shock'

In [662]:
score = risk_stack[:, j]
I_life_threat = numpy.argsort(score)
y_triaged = numpy.vstack(Y)[:, j][I_life_threat]
print(
    (y_triaged == 1)[:1000].sum() / (y_triaged == 1).sum(), 
    (y_triaged == 1)[-1000:].sum() / (y_triaged == 1).sum()
)

0.0018999366687777073 0.13489550348321722


In [663]:
score = (risk_stack[:, J_life_threat] < -3).mean(axis=1)
I_life_threat = numpy.argsort(score)
y_triaged = y_life_threat[I_life_threat]
print(
    (y_triaged == 1)[:1000].sum() / (y_triaged == 1).sum(), 
    (y_triaged == 1)[-1000:].sum() / (y_triaged == 1).sum()
)

0.02226685936849325 0.02354065110354457


In [659]:
(y_life_threat == 1).mean()

0.9372316734394013