In [None]:
import os
import sys
import h5py
import argparse
import numpy as np
from collections import Counter

# 3rd party imports
from sklearn import manifold
from keras.models import Model
from keras import backend as K
from IPython.display import Image

# ML4CVD Imports
from ml4cvd.arguments import parse_args
from ml4cvd.tensor_generators import test_train_valid_tensor_generators
from ml4cvd.recipes import train_shallow_model, train_multimodal_multitask, test_multimodal_multitask
from ml4cvd.models import make_multimodal_multitask_model, train_model_from_generators, make_hidden_layer_model

%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.ticker import NullFormatter

In [None]:
sys.argv = ['train', 
            '--tensors', '/mnt/disks/survey-tensors2/2019-03-25/', 
            '--input_tensors', 'categorical-phenotypes-94', 
            '--output_tensors', 'categorical-phenotypes-94', 
            '--id', 'cat32_mlp',
            '--epochs', '1',
            '--training_steps', '100', 
            '--validation_steps', '2', 
            '--test_steps', '1',             
            '--batch_size', '1024',
            '--model_file', '/mnt/ml4cvd/projects/jamesp/data/models/cat94_autoencoder.hd5'
           ]
args = parse_args()
generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(args.tensor_maps_in, args.tensor_maps_out, args.tensors, args.batch_size, args.valid_ratio, args.test_ratio, args.test_modulo, args.icd_csv, args.balance_by_icds)
model = make_multimodal_multitask_model(**args.__dict__)


In [None]:
layer_name = 'dense_1'
d1 = model.get_layer(layer_name)
w1 = d1.get_weights()
rev_cm =  dict((v,k) for k,v in args.tensor_maps_in[0].channel_map.items())
for w in w1:
    print(w.shape)
embed_model = make_hidden_layer_model(model, args.tensor_maps_in, layer_name)
embed_model.summary()
test_data, test_labels, test_paths = next(generate_test)
print(list(test_data.keys()))
x_embed = embed_model.predict(test_data)
y_pred = model.predict(test_data)

In [None]:
categorical_labels = ['Genetic-sex_Female_0_0', 'hypertension', 'coronary_artery_disease', 
                      'Handedness-chiralitylaterality_Righthanded_0_0', 'atrial_fibrillation_or_flutter', 
            'coronary_artery_disease_soft', 'coronary_artery_disease', 'death', 
            'diabetes_type_2', 'heart_failure', 'hypercholesterolemia',
            'hypertension', 'ischemic_stroke', 'myocardial_infarction',  'stroke']
continuous_labels = ['22200_Year-of-birth_0_0|34_Year-of-birth_0_0', '21001_Body-mass-index-BMI_0_0', 
                     '1070_Time-spent-watching-television-TV_0_0', '102_Pulse-rate-automated-reading_0_0', '1488_Tea-intake_0_0', '21002_Weight_0_0']
label_dict = {k: np.zeros((len(test_paths))) for k in categorical_labels + continuous_labels}
to_delete = []
for i, tp in enumerate(test_paths):
    hd5 = h5py.File(tp, 'r')
    #print(list(hd5['continuous'].keys()))
    for k in categorical_labels:
        if k in hd5['categorical']:
            label_dict[k][i] = 1
        elif k in hd5 and hd5[k][0] == 1:
            label_dict[k][i] = 1
    for mk in continuous_labels:
        for k in mk.split('|'):
            if k in hd5['continuous']:
                if hd5['continuous'][k][0] < 0:
                    to_delete.append(i)
                else:
                    label_dict[mk][i] = hd5['continuous'][k][0]
            elif k == mk:
                to_delete.append(i)
                

print(list(label_dict.keys()))
print(x_embed.shape)
print('Will delete:', len(to_delete), 'because they are missing or invalid continuous values.')
for k in label_dict:
    label_dict[k] = np.delete(label_dict[k], to_delete)
x_embed = np.delete(x_embed, to_delete, axis=0)
print(x_embed.shape)

In [None]:
n_components = 2
max_rows = 60
rows = min(max_rows, len(label_dict))
(fig, subplots) = plt.subplots(rows, 3, figsize=(20, rows*4))
perplexities = [8, 18, 35]

p2y = {}
for i, perplexity in enumerate(perplexities):
    tsne = manifold.TSNE(n_components=n_components, init='random', random_state=0, perplexity=perplexity)
    p2y[perplexity] = tsne.fit_transform(x_embed)

j = -1
for k in label_dict:
    j += 1
    if j == max_rows:
        break
    if k in categorical_labels:
        red = label_dict[k] == 1.0
        green = label_dict[k] != 1.0
    elif k in continuous_labels:
        colors = label_dict[k]       
    print('process key:', k)
    for i, perplexity in enumerate(perplexities):
        ax = subplots[j, i]
        ax.set_title(k+", Perplexity=%d" % perplexity)
        if k in categorical_labels:
            ax.scatter(p2y[perplexity][green, 0], p2y[perplexity][green, 1], c="g", alpha=0.7)
            ax.scatter(p2y[perplexity][red, 0], p2y[perplexity][red, 1], c="r", alpha=0.7)
            ax.legend(['no_'+k, k], loc='lower left')
        elif k in continuous_labels:
            points = ax.scatter(p2y[perplexity][:, 0], p2y[perplexity][:, 1], c=colors, cmap='jet', alpha=0.7) 
            if i == len(perplexities)-1:
                fig.colorbar(points, ax=ax)
                
        ax.xaxis.set_major_formatter(NullFormatter())
        ax.yaxis.set_major_formatter(NullFormatter())
        ax.axis('tight')
plt.show()