In [1]:
import tensorflow as tf
import numpy
from scipy import spatial, linalg
from sklearn import cluster, manifold
from matplotlib import pyplot
import os
import json, pickle
import pandas
from functools import partial, reduce
import importlib

import sys
sys.path.append('../libs')

import flacdb
import prepare_data
import initialize
import data_pipeline
import loss_metrics
import conv_model
import plot_batch
import load_diagnosis
import icd_util

%matplotlib widget

In [2]:
! nvidia-smi

Sun May 10 16:44:46 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.87.00    Driver Version: 418.87.00    CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce GTX TIT...  On   | 00000000:05:00.0 Off |                  N/A |
| 22%   31C    P8    16W / 250W |      1MiB / 12212MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

In [3]:
# ! cp -r /scr1/mimic/initial_data/ /scr1/mimic/initial_data_1451803/

In [3]:
model_id = 1449529
checkpoint_index = -1
ckpts = os.listdir('/scr1/checkpoints')
ckpts = sorted(i for i in ckpts if 'index' in i and str(model_id) in i)
hypes_path = '../hypes/{}.json'.format(ckpts[0].split('.')[0][:-6])
weights_path = '/scr1/checkpoints/' + ckpts[checkpoint_index]
assert(os.path.isfile(hypes_path) and os.path.isfile(weights_path))
weights_path = weights_path.replace('.index', '')
print('found hypes', hypes_path, '\nfound weights', weights_path)
H0 = json.load(open(hypes_path))
H = initialize.load_hypes()
H = {**H, **H0}
part = 'validation'
load_path = '/scr1/mimic/initial_data_{}/'.format(model_id)
# load_path = '/scr1/mimic/initial_data/'
tensors, metadata, priors = initialize.run(H, parts=[part], load_path=load_path)
model = conv_model.build(H, priors, output_activations=True)
model.load_weights(weights_path)
dataset = data_pipeline.build(H, tensors[part], part)

found hypes ../hypes/1449529_20200428-231359.json 
found weights /scr1/checkpoints/1449529_20200428-231359_01152.ckpt


W0510 16:44:58.546582 139644670654208 training_utils.py:1444] Output dense missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to dense.
W0510 16:44:58.547749 139644670654208 training_utils.py:1444] Output tf_op_layer_add_20 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to tf_op_layer_add_20.
W0510 16:44:58.548813 139644670654208 training_utils.py:1444] Output tf_op_layer_add_19 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to tf_op_layer_add_19.
W0510 16:44:58.549370 139644670654208 training_utils.py:1444] Output tf_op_layer_add_18 missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to tf_op_layer_add_18.
W0510 16:44:58.549954 139644670654208 

In [4]:
x, y = next(iter(dataset))
x = {**x, 'mask': tf.cast(x['mask'], 'float')}
p, *activations = model.predict(x)
Z = activations[0]
Z = numpy.hstack([Z, numpy.ones([Z.shape[0], 1])])
A = numpy.vstack([i.numpy() for i in model.weights[-2:]]).T
assert(numpy.allclose(tf.sigmoid(A.dot(Z.T).T), p))
print(p.shape, Z.shape, A.shape)

(32, 96) (32, 257) (96, 257)


In [5]:
w = [i.numpy().flatten() for i in model.weights]
# [(i.min(), i.max(), i.mean()) for i in w]
any(i.nonzero()[0].shape[0] < i.shape[0] for i in w)

False

In [6]:
%%time

Y, Z = [], []
for x, y in dataset.take(200):
    x_ = {**x, 'mask': tf.cast(x['mask'], 'float')}
    p, z, *other = model.predict(x_)
    Z.append(z)
    Y.append(y)

Z = numpy.concatenate(Z)
Z = numpy.hstack([Z, numpy.ones([Z.shape[0], 1])])
Y = {k: numpy.concatenate([y[k] for y in Y]) for k in Y[0].keys()}
Z.shape

CPU times: user 1min 51s, sys: 16.6 s, total: 2min 8s
Wall time: 1min 25s


(5996, 257)

In [7]:
Y.keys()

dict_keys(['diagnosis', 'height', 'weight', 'age', 'rec_id', 'seg_id', 'is_good'])

In [18]:
pyplot.style.use('dark_background')
for a in activations[::-1]:
    b = a.reshape([a.shape[0], -1])
    print((b.sum(0) > 0).sum(), 'nonzero columns')
pyplot.matshow(b)

15991 nonzero columns
1147 nonzero columns
1116 nonzero columns
1307 nonzero columns
1740 nonzero columns
1999 nonzero columns
68 nonzero columns


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.image.AxesImage at 0x7f00bcf29450>

In [19]:
%matplotlib widget
AZ = A.dot(Z.T)
W = spatial.distance_matrix(AZ, AZ)
W = numpy.exp(-(W/100)**2)
pyplot.matshow(W)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.image.AxesImage at 0x7f00bcf4ff50>

In [10]:
group_names = icd_util.load_group_strings()

def get_name(code):
    if code not in group_names:
        return code.replace('_', ' ').title()
    name = code + ': ' + group_names[code]
    name = name.replace('/', '_')
    for character in "',()[]":
        name = name.replace(character, '')
    return name

names = [get_name(i) for i in priors.index]

for i in range(len(priors)):
    print(i, get_name(priors.index[i]))

0 038: Septicemia
1 070: Viral Hepatitis
2 140-239: Neoplasms 
3 155: Malignant Neoplasm Of Liver And Intrahepatic Bile Ducts
4 157: Malignant Neoplasm Of Pancreas
5 162: Malignant Neoplasm Of Trachea Bronchus And Lung
6 179-189: Malignant Neoplasm Of Genitourinary Organs 
7 191: Malignant Neoplasm Of Brain
8 250: Diabetes Mellitus
9 250.4: Diabetes With Renal Manifestations
10 250.6: Diabetes With Neurological Manifestations
11 276.2: Acidosis
12 303: Alcohol Dependence Syndrome
13 305: Nondependent Abuse Of Drugs
14 305.0: Alcohol Abuse
15 317-319: Mental Retardation 
16 320-326: Inflammatory Diseases Of The Central Nervous System 
17 348.1: Anoxic Brain Damage
18 348.4: Compression Of Brain
19 348.5: Cerebral Edema
20 357: Inflammatory And Toxic Neuropathy
21 362.0: Diabetic Retinopathy
22 365: Glaucoma
23 396: Diseases Of Mitral And Aortic Valves
24 397.0: Diseases Of Tricuspid Valve
25 403: Hypertensive Chronic Kidney Disease
26 410: Acute Myocardial Infarction
27 410-414: Ischemi

In [11]:
organs = {
    'brain': [
        '191', '198.3', '348.1', '348.31', '348.4', '348.5', '349.82', 
        '430-438', '430', '431', '432.1', '434', '437.3', '850-854', '852'
    ],
    'liver': [
        '070', '155', '155.0', '570', '571', '571.1', '571.2', '572', '572.2',
        '572.3', '572.4', '574'
    ],
    'lung': ['162', '480-488', '507', '511', '518.0', '518.81', '997.3'],
    'kidney': ['403', '580-589', '584', '585'],
    'heart': [
        '396', '397.0', '410-414', '410', '410.7', '414.0', '416', '424.0', 
        '424.1', '425', '426', '427', '427.1', '427.31', '427.32', '427.41', 
        '427.5', '428', '428.0', '428.2', '428.3', '785.51', '997.1'
    ],
    'sepsis': ['038', '785.52', '995.9', '995.92'],
#     'diabetes': ['250', '250.4', '250.6', '362.0']
}

colors = {
    'brain': 'green',
    'liver': 'gold',
    'lung': 'blue', 
    'kidney': 'magenta',
    'heart': 'red',
    'sepsis': 'peru',
#     'diabetes': 'hotpink'
}

In [12]:
def lookup_color(code):
    for k in organs:
        if code in organs[k]:
            return colors[k]
    else:
        return 'white'

In [13]:
from scipy.cluster.hierarchy import dendrogram

def plot_dendrogram(clusterer):
    # Create linkage matrix and then plot the dendrogram

    # create the counts of samples under each node
    counts = numpy.zeros(clustering.children_.shape[0], dtype='int')
    n = len(clustering.labels_)

    for i, merge in enumerate(clustering.children_):
        for j in merge:
            counts[i] += 1 if j < n else counts[j - n]
        
    linkage_matrix = numpy.column_stack([
        clusterer.children_,
        clusterer.distances_,
        counts
    ]).astype(float)

    # Plot the corresponding dendrogram
    return dendrogram(
        linkage_matrix, 
        labels=names, 
        orientation='left', 
        leaf_font_size=8,
        show_leaf_counts=True
    );

clustering = cluster.AgglomerativeClustering(
    distance_threshold=0, 
    n_clusters=None
)
clustering.fit(AZ)

result = plot_dendrogram(clustering)
fig = pyplot.figure(1)
fig.canvas.layout.width = '100%'
fig.canvas.layout.height = '1600px'
fig.tight_layout()
spines = fig.axes[0].spines
for i in spines:
    spines[i].set_visible(False)
pyplot.xticks([])



([], <a list of 0 Text xticklabel objects>)

In [14]:
AZ_ = AZ[result['leaves']]
W_ = spatial.distance_matrix(AZ_, AZ_)
W_ = numpy.exp(-(W_/100)**2)
fig = pyplot.figure(2)
fig.clear()
pyplot.matshow(W_, fignum=2)
fig.canvas
# pyplot.matshow(spatial.distance_matrix(AZ_, AZ_))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
AZ = A.dot(Z.T)
W = spatial.distance_matrix(AZ, AZ)
W = numpy.exp(-(W/100)**2)
fig = pyplot.figure(3)
fig.clear()
pyplot.matshow(W, fignum=3)
fig.canvas

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [16]:
[(i, j, names[j]) for i, j in enumerate(result['leaves'])]

[(0, 41, '427.5: Cardiac Arrest'),
 (1, 77, '785.51: Cardiogenic Shock'),
 (2, 17, '348.1: Anoxic Brain Damage'),
 (3, 40, '427.41: Ventricular Fibrillation'),
 (4, 54, '441: Aortic Aneurysm And Dissection'),
 (5, 34, '424.1: Aortic Valve Disorders'),
 (6, 53, '440: Atherosclerosis'),
 (7, 27, '410-414: Ischemic Heart Disease '),
 (8, 29, '414.0: Coronary Atherosclerosis'),
 (9, 26, '410: Acute Myocardial Infarction'),
 (10, 28, '410.7: Subendocardial Infarction'),
 (11, 88, '997.1: Cardiac Complications'),
 (12, 38, '427.1: Paroxysmal Ventricular Tachycardia'),
 (13, 33, '424.0: Mitral Valve Disorders'),
 (14, 31, '415-417: Diseases Of Pulmonary Circulation '),
 (15, 32, '416: Chronic Pulmonary Heart Disease'),
 (16, 45, '428.3: Diastolic Heart Failure'),
 (17, 44, '428.2: Systolic Heart Failure'),
 (18, 42, '428: Heart Failure'),
 (19, 43, '428.0: Congestive Heart Failure Unspecified'),
 (20, 35, '425: Cardiomyopathy'),
 (21, 36, '426: Conduction Disorders'),
 (22, 8, '250: Diabetes 

In [59]:
from adjustText import adjust_text

In [36]:
A.shape

(96, 257)

In [38]:
AZ_ = manifold.TSNE(perplexity=25, n_iter=1000).fit_transform(AZ)
# AZ_ = manifold.MDS(n_components=2).fit_transform(AZ)
fig = pyplot.figure(7)
fig.clear()
fig.canvas.layout.width = '1600px'
fig.canvas.layout.height = '1200px'
c = [lookup_color(i) for i in priors.index]
pyplot.scatter(AZ_[:, 0], AZ_[:, 1], c=c, s=64, linewidths=1, edgecolors='black')
ax = fig.gca()
for i in ax.spines.values():
    i.set_visible(False)
pyplot.xticks([])
pyplot.yticks([])
# ax.set_visible(False)
texts = [
    pyplot.text(AZ_[i, 0], AZ_[i, 1], ' ' * 4 + txt, ha='left', va='center', fontsize=6) 
    for i, txt in enumerate(names)
]
# adjust_text(texts)
# for i, txt in enumerate(names):
# #     ax.annotate(priors.index[i].split('_')[0], (AZ_[i, 0], AZ_[i, 1]), fontsize=8)
#     ax.annotate('    ' + txt, (AZ_[i, 0], AZ_[i, 1]), fontsize=6)

# fig.canvas.layout.width = '800px'
# fig.canvas.layout.height = '600px'
# fig.tight_layout()
fig.canvas

Canvas(layout=Layout(height='1200px', width='1600px'), toolbar=Toolbar(toolitems=[('Home', 'Reset original vie…

In [32]:
fig.savefig('embedding.svg', format='svg', transparent=True)

In [None]:
n = len(priors)
weights = [(i, j, W[i, j]) for i in range(n) for j in range(i+1, n)]
to_name = lambda i: get_name(priors.index[i])
weights = [(to_name(i), to_name(j), w) for i, j, w in weights]
weights = sorted(weights, key=lambda i: i[-1])
weights[:50]

In [None]:
sorted(weights, key=lambda i: -i[-1])

In [None]:
from sklearn import cluster
k = 2
clusterer = cluster.SpectralClustering(n_clusters=k, assign_labels="discretize")
clustering = clusterer.fit(A.dot(Z.T))
I = clustering.labels_
W = A.dot(Z.T)
W_ = numpy.vstack([W[I==i] for i in range(k)])
pyplot.style.use('dark_background')
pyplot.matshow(spatial.distance_matrix(W, W))

In [None]:
D = spatial.distance_matrix(A.dot(Z.T), A.dot(Z.T))

In [None]:
D[5, 26]

In [None]:
W = A.dot(Z.T)
W = numpy.vstack([W[I==i] for i in range(k)])
pyplot.style.use('dark_background')
pyplot.matshow(spatial.distance_matrix(W, W))

In [None]:
numpy.allclose(tf.sigmoid(A.dot(Z.T).T), diagnosis)

In [None]:
((tf.sigmoid(A.dot(Z.T).T) - diagnosis).numpy()).min()

In [None]:
A.dot(Z.T).shape

In [None]:
diagnosis

In [None]:
Z = 

In [None]:
model.

In [None]:
w = model.weights[-2].numpy()
b = model.weights[-1].numpy()
A = numpy.vstack([w, b]).T
W = numpy.exp(-spatial.distance_matrix(A, A)**2/2)
D = numpy.diag(W.sum(0))
pyplot.style.use('dark_background')
pyplot.matshow(spatial.distance_matrix(A, A))

In [None]:
s, v = linalg.eigh(D - W)
s[:10]

In [None]:
from sklearn import cluster

In [None]:
k = 5
clustering = cluster.SpectralClustering(n_clusters=k, assign_labels="discretize").fit(A)
I = clustering.labels_

In [None]:
A_ = numpy.vstack([A[I==i] for i in [4, 0, 2, 3, 1]])
pyplot.style.use('dark_background')
pyplot.matshow(-spatial.distance_matrix(A_, A_))

In [None]:
group_names = icd_util.load_group_strings()

def get_name(code):
    if code not in group_names:
        return code
    name = code + ': ' + group_names[code]
    name = name.replace('/', '_')
    for character in "',()[]":
        name = name.replace(character, '')
    return name

In [None]:
for i in [4, 0, 2, 3, 1]:
    for j in priors.index[I==i]:
        print(get_name(j))
    print('\n')

In [None]:
k = 5
clustering2 = cluster.SpectralClustering(n_clusters=k).fit(A)
I_ = clustering2.labels_
A_ = numpy.vstack([A[I_==i] for i in range(k)])
pyplot.style.use('dark_background')
pyplot.matshow(-spatial.distance_matrix(A_, A_))
for i in range(k):
    for j in priors.index[I_==i]:
        print(get_name(j))
    print('\n')

In [None]:
u.shape

In [None]:
%matplotlib widget

In [None]:
pyplot.style.use('dark_background')
pyplot.matshow(W)

In [None]:
A.T.shape

In [None]:
b.shape

In [None]:
W.shape

In [None]:
numpy.vstack([W, b]).shape

In [None]:
numpy.concate

In [None]:
from scipy.cluster.hierarchy import dendrogram
from sklearn.datasets import load_iris
from sklearn.cluster import AgglomerativeClustering


def plot_dendrogram(clusterer, **kwargs):
    # Create linkage matrix and then plot the dendrogram

    # create the counts of samples under each node
    counts = numpy.zeros(clusterer.children_.shape[0])
    n_samples = len(clusterer.labels_)
    for i, merge in enumerate(clusterer.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count
        
    linkage_matrix = numpy.column_stack([
        clusterer.children_,
        clusterer.distances_,
        counts
    ]).astype(float)

    # Plot the corresponding dendrogram
    dendrogram(linkage_matrix, **kwargs)

# setting distance_threshold=0 ensures we compute the full tree.
clusterer = AgglomerativeClustering(distance_threshold=0, n_clusters=None)

clusterer.fit(A)
pyplot.title('Hierarchical Clustering Dendrogram')
# plot the top three levels of the dendrogram
plot_dendrogram(clusterer, truncate_mode='level', p=3)
pyplot.xlabel("Number of points in node (or index of point if no parenthesis).")
pyplot.show()

In [None]:
clusterer.labels_

In [None]:
clusterer.n_connected_components_

In [None]:
clusterer.distances_.shape

In [None]:
clusterer.children_.shape