In [None]:
import os
import sys
import h5py
import argparse
import numpy as np
from collections import Counter

# 3rd party imports
from sklearn import manifold
from keras.models import Model
from keras import backend as K
from IPython.display import Image

# ML4CVD Imports
from ml4cvd.arguments import parse_args
from ml4cvd.tensor_generators import big_batch_from_minibatch_generator, test_train_valid_tensor_generators
from ml4cvd.recipes import train_shallow_model, train_multimodal_multitask, test_multimodal_multitask
from ml4cvd.models import make_multimodal_to_multilabel_model, train_model_from_generators, make_hidden_layer_model

%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.ticker import NullFormatter

In [None]:
sys.argv = ['train', 
            '--tensors', '/mnt/disks/pix-size-tensors/', 
            '--input_tensors', 'mri_systole_diastole_weighted', 
            '--output_tensors', 'mri_systole_diastole_segmented_weighted', 'lv_mass', 
            'end_systole_volume', 'end_diastole_volume', 'ejection_fraction',
            '--batch_size', '12',
            '--pool_z', '1',
            '--epochs', '2',  
            '--learning_rate', '0.001',
            '--u_connect',
            '--training_steps', '128',
            '--validation_steps', '10',
            '--test_steps', '24',
            '--model_file', '/mnt/ml4cvd/projects/jamesp/data/models/mri_systole_diastole_une_mass.hd5',
            '--id', 'mri_systole_diastole_unet_lv_mass_diseases']
args = parse_args()
generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(args.tensor_maps_in,
                                                                                       args.tensor_maps_out,
                                                                                       args.tensors, args.batch_size,
                                                                                       args.valid_ratio,
                                                                                       args.test_ratio,
                                                                                       args.icd_csv,
                                                                                       args.balance_by_icds, True)
model = make_multimodal_to_multilabel_model(args.model_file, args.model_layers, args.model_freeze,
                                                args.tensor_maps_in, args.tensor_maps_out, args.activation,
                                                args.dense_layers, args.dropout, args.mlp_concat, args.conv_layers,
                                                args.max_pools, args.res_layers, args.dense_blocks, args.block_size,
                                                args.conv_bn, args.conv_x, args.conv_y, args.conv_z, args.conv_dropout,
                                                args.conv_width, args.u_connect, args.pool_x, args.pool_y, args.pool_z, args.padding,
                                                args.learning_rate)
test_data, test_labels, test_paths = big_batch_from_minibatch_generator(args.tensor_maps_in, args.tensor_maps_out,
                                                                        generate_test, args.test_steps)

In [None]:
layer_name = 'embed'
d1 = model.get_layer(layer_name)
w1 = d1.get_weights()
for w in w1:
    print(w.shape)
embed_model = make_hidden_layer_model(model, 'input_mri_systole_diastole', layer_name)
embed_model.summary()
print(list(test_data.keys()))
x_embed = embed_model.predict(test_data, batch_size=args.batch_size)
predictions = model.predict(test_data, batch_size=args.batch_size)

In [None]:
categorical_labels = ['Genetic-sex_Female_0_0', 'hypertension', 'coronary_artery_disease', 'Handedness-chiralitylaterality_Righthanded_0_0']
continuous_labels = ['22200_Year-of-birth_0_0|34_Year-of-birth_0_0', '21001_Body-mass-index-BMI_0_0', '1070_Time-spent-watching-television-TV_0_0', '102_Pulse-rate-automated-reading_0_0', '1488_Tea-intake_0_0', '21002_Weight_0_0']
label_dict = {k: np.zeros((len(test_paths))) for k in categorical_labels + continuous_labels}
for i, tp in enumerate(test_paths):
    hd5 = h5py.File(tp, 'r')
    #print(list(hd5['continuous'].keys()))
    for k in categorical_labels:
        if k in hd5['categorical']:
            label_dict[k][i] = 1
        elif k in hd5 and hd5[k][0] == 1:
            label_dict[k][i] = 1
    for mk in continuous_labels:
        for k in mk.split('|'):
            if k in hd5['continuous']:
                label_dict[mk][i] = hd5['continuous'][k][0]

print(list(label_dict.keys()))
print(label_dict['22200_Year-of-birth_0_0|34_Year-of-birth_0_0'])

In [None]:
n_components = 2
max_rows = 10
(fig, subplots) = plt.subplots(min(max_rows, len(label_dict)), 3, figsize=(16, max_rows*4))
perplexities = [8, 18, 35]

p2y = {}
for i, perplexity in enumerate(perplexities):
    tsne = manifold.TSNE(n_components=n_components, init='random', random_state=0, perplexity=perplexity)
    p2y[perplexity] = tsne.fit_transform(x_embed)

j = -1
for k in label_dict:
    j += 1
    if j == max_rows:
        break
    if k in categorical_labels:
        red = label_dict[k] == 1.0
        green = label_dict[k] != 1.0
    elif k in continuous_labels:
        colors = label_dict[k]       
    print('process key:', k)
    for i, perplexity in enumerate(perplexities):
        ax = subplots[j, i]
        ax.set_title(k+", Perplexity=%d" % perplexity)
        if k in categorical_labels:
            ax.scatter(p2y[perplexity][green, 0], p2y[perplexity][green, 1], c="g")
            ax.scatter(p2y[perplexity][red, 0], p2y[perplexity][red, 1], c="r")
            ax.legend(['no_'+k, k], loc='lower left')
        elif k in continuous_labels:
            points = ax.scatter(p2y[perplexity][:, 0], p2y[perplexity][:, 1], c=colors, cmap='jet') 
            if i == len(perplexities)-1:
                fig.colorbar(points, ax=ax)
                
        ax.xaxis.set_major_formatter(NullFormatter())
        ax.yaxis.set_major_formatter(NullFormatter())
        ax.axis('tight')
plt.show()
