In [None]:
import os
import sys
import math
import argparse
import numpy as np
from collections import Counter

# Keras imports
from tensorflow.keras.models import Model
from keras import backend as K
from sklearn.manifold import TSNE

# ml4h Imports
from ml4h.arguments import parse_args
from ml4h.tensor_generators import test_train_valid_tensor_generators
from ml4h.recipes import train_shallow_model, train_multimodal_multitask, test_multimodal_multitask
from ml4h.models import make_multimodal_multitask_model, train_model_from_generators, make_shallow_model, make_hidden_layer_model

# IPython imports
from IPython.display import Image
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
sys.argv = ['train_shallow', 
            '--tensors', '/mnt/disks/survey-tensors2/2019-03-25/', 
            '--input_tensors', 'categorical-phenotypes-965', 
            '--output_tensors', 'allergic_rhinitis', 'anxiety', 'asthma', 'atrial_fibrillation_or_flutter', 'back_pain', 
            'breast_cancer', 'cardiac_surgery', 'cervical_cancer', 'colorectal_cancer', 'coronary_artery_disease_hard', 
            'coronary_artery_disease_intermediate', 'coronary_artery_disease_soft', 'death', 'diabetes_all', 
            'diabetes_type_1', 'diabetes_type_2', 'enlarged_prostate', 'heart_failure', 'hypertension', 'lung_cancer', 
            'migraine', 'myocardial_infarction', 'osteoporosis', 'skin_cancer', 'stroke',
            '--model_file', '/mnt/ml4cvd/projects/jamesp/data/models/shallow_cat965.hd5',
            '--id', 'shallow_cat965',
            '--epochs', '1',
            '--test_steps', '2',
            '--batch_size', '128',
            '--training_steps', '2',
            '--validation_steps', '1']

args = parse_args()
m = make_shallow_model(args.tensor_maps_in, args.tensor_maps_out, args.learning_rate, args.model_file, args.model_layers)
rev_cm =  dict((v,k) for k,v in args.tensor_maps_in[0].channel_map.items())

In [None]:
row = 0
col = 0
top_k = 3
text_limit = 64
total_plots = len(args.tensor_maps_out)
rows = max(2, int(math.sqrt(total_plots)))
cols = max(2, total_plots // rows)
fig, axes = plt.subplots(rows, cols, figsize=(48, 48))

for tm in args.tensor_maps_out:
    bar_vals = []
    bar_ticks = []
    bar_colors = []
    #print('\n\n~~~~~~~~~~~~~~~ Looking at TM: ', tm.output_name(), ' ~~~~~~~~~~~~~~~~~~')
    d1 = m.get_layer(tm.output_name())
    w1 = d1.get_weights()
    text_x = 0
    text_y = np.max(w1[0][:,1])
    for idx in np.argsort(w1[0][:,1])[:top_k]:
        axes[row,col].text(text_x, text_y, rev_cm[idx], rotation=90)
        bar_vals.append(w1[0][:,1][idx])
        bar_colors.append('red')
        text_x += 1
    for idx in np.argsort(w1[0][:,1])[-top_k:]:
        axes[row,col].text(text_x, text_y, rev_cm[idx], rotation=90)
        bar_vals.append(w1[0][:,1][idx])
        bar_colors.append('green')
        text_x += 1
    axes[row,col].bar(list(range(len(bar_vals))), bar_vals, color=bar_colors)
    axes[row,col].set_title(tm.name)
    #print('bt:', bar_ticks)
    row += 1
    if row == rows:
        row = 0
        col += 1
        if col >= cols:
            break
fig.show()

In [None]:
for tm in args.tensor_maps_out:
    print('\n\n~~~~~~~~~~~~~~~ Looking at TM: ', tm.output_name(), ' ~~~~~~~~~~~~~~~~~~')
    d1 = m.get_layer(tm.output_name())
    w1 = d1.get_weights()
    for i in np.argsort(w1[0][:,1])[::-1]:
        sign = '+' if w1[0][i,1] > 0 else '-'
        print(sign+rev_cm[i])