In [None]:
import os
import sys
import argparse
import numpy as np
from collections import Counter

# Keras imports
from keras.models import Model
from keras import backend as K
from sklearn.manifold import TSNE

# IPython imports
from IPython.display import Image

# ML4CVD Imports
sys.path.append("../ml4cvd")
from arguments import parse_args
from tensor_generators import test_train_valid_tensor_generators
from recipes import train_shallow_model, train_multimodal_multitask, test_multimodal_multitask
from models import make_multimodal_to_multilabel_model, train_model_from_generators, make_hidden_layer_model

In [None]:
sys.argv = ['train_shallow', 
            '--tensors', '/mnt/disks/survey-tensors2/2019-03-25/', 
            '--input_tensors', 'categorical-phenotypes-965', 
            '--output_tensors', 'allergic_rhinitis', 'anxiety', 'asthma', 'atrial_fibrillation_or_flutter', 'back_pain', 
            'breast_cancer', 'cardiac_surgery', 'cervical_cancer', 'colorectal_cancer', 'coronary_artery_disease_hard', 
            'coronary_artery_disease_intermediate', 'coronary_artery_disease_soft', 'death', 'diabetes_all', 
            'diabetes_type_1', 'diabetes_type_2', 'enlarged_prostate', 'heart_failure', 'hypertension', 'lung_cancer', 
            'migraine', 'myocardial_infarction', 'osteoporosis', 'skin_cancer', 'stroke',
            '--model_file', '../trained_models/shallow_cat965/shallow_cat965.hd5',
            '--id', 'shallow_cat965',
            '--epochs', '1',
            '--test_steps', '2',
            '--batch_size', '128',
            '--training_steps', '2',
            '--validation_steps', '1']

args = parse_args()
m = make_shallow_model(args.tensor_maps_in, args.tensor_maps_out, args.learning_rate, args.model_file, args.model_layers)
rev_cm =  dict((v,k) for k,v in args.tensor_maps_in[0].channel_map.items())

In [None]:
for tm in args.tensor_maps_out:
    print('\n\n~~~~~~~~~~~~~~~ Looking at TM: ', tm.output_name(), ' ~~~~~~~~~~~~~~~~~~')
    d1 = m.get_layer(tm.output_name())
    w1 = d1.get_weights()
    for i in np.argsort(w1[0][:,1])[::-1]:
        sign = '+' if w1[0][i,1] > 0 else '-'
        print(sign+rev_cm[i])

In [None]:
sys.argv = ['train', 
            '--tensors', '/mnt/disks/survey-tensors2/2019-03-25/', 
            '--input_tensors', 'categorical-phenotypes-965', 
            '--output_tensors', 'allergic_rhinitis', 'anxiety', 'asthma', 'atrial_fibrillation_or_flutter', 'back_pain', 
            'breast_cancer', 'cardiac_surgery', 'cervical_cancer', 'colorectal_cancer', 'coronary_artery_disease_hard', 
            'coronary_artery_disease_intermediate', 'coronary_artery_disease_soft', 'death', 'diabetes_all', 
            'diabetes_type_1', 'diabetes_type_2', 'enlarged_prostate', 'heart_failure', 'hypertension', 'lung_cancer', 
            'migraine', 'myocardial_infarction', 'osteoporosis', 'skin_cancer', 'stroke',
            '--id', 'mlp_cat965',
            '--epochs', '1',
            '--training_steps', '100', 
            '--validation_steps', '2', 
            '--batch_size', '512',
            '--model_file', '../trained_models/mlp_cat965_deep_new/mlp_cat965_deep_new.hd5'
           ]
args = parse_args()
generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(args.tensor_maps_in,
                                                                                       args.tensor_maps_out,
                                                                                       args.tensors, args.batch_size,
                                                                                       args.valid_ratio,
                                                                                       args.test_ratio,
                                                                                       args.icd_csv,
                                                                                       args.balance_by_icds)
model = make_multimodal_to_multilabel_model(args.model_file, args.model_layers, args.model_freeze,
                                                args.tensor_maps_in, args.tensor_maps_out, args.activation,
                                                args.dense_layers, args.dropout, args.mlp_concat, args.conv_layers,
                                                args.max_pools, args.res_layers, args.dense_blocks, args.block_size,
                                                args.conv_bn, args.conv_x, args.conv_y, args.conv_z, args.conv_dropout,
                                                args.conv_width, args.u_connect, args.pool_x, args.pool_y, args.pool_z, args.padding,
                                                args.learning_rate)

In [None]:
layer_name = 'output_coronary_artery_disease_soft_categorical_index'
d1 = model.get_layer(layer_name)
w1 = d1.get_weights()
rev_cm =  dict((v,k) for k,v in args.tensor_maps_in[0].channel_map.items())
for w in w1:
    print(w.shape)
embed_model = make_hidden_layer_model(model, layer_name)
embed_model.summary()
test_data, test_labels = next(generate_test)
print(list(test_data.keys()))
x_embed = embed_model.predict(test_data)
y_pred = model.predict(test_data)

In [None]:
label_dict = {tm.name: test_labels[tm.output_name()] for y, tm in zip(y_pred, args.tensor_maps_out)}
print(x_embed.shape)
for y, tm in zip(y_pred, args.tensor_maps_out):
    print(tm.name, y.shape)

In [None]:
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt

from matplotlib.ticker import NullFormatter
from sklearn import manifold, datasets
from time import time
n_components = 2
max_rows = 28
(fig, subplots) = plt.subplots(min(max_rows, len(label_dict)), 3, figsize=(16, max_rows*5))
perplexities = [8, 18, 35]

p2y = {}
for i, perplexity in enumerate(perplexities):
    tsne = manifold.TSNE(n_components=n_components, init='random', random_state=0, perplexity=perplexity)
    p2y[perplexity] = tsne.fit_transform(x_embed)

j = -1
for k in label_dict:
    j += 1
    if j == max_rows:
        break
    red = label_dict[k][:,1] == 1.0
    green = label_dict[k][:,0] == 1.0
    print('process key:', k)
    for i, perplexity in enumerate(perplexities):
        ax = subplots[j, i]
        ax.set_title(k+", Perplexity=%d" % perplexity)
        ax.scatter(p2y[perplexity][green, 0], p2y[perplexity][green, 1], c="g")
        ax.scatter(p2y[perplexity][red, 0], p2y[perplexity][red, 1], c="r")
        ax.legend(['no_'+k, k], loc='lower left')
        ax.xaxis.set_major_formatter(NullFormatter())
        ax.yaxis.set_major_formatter(NullFormatter())
        ax.axis('tight')
plt.show()