In [None]:
from ipywidgets import interact, FloatSlider, interactive_output, BoundedIntText, interactive, GridspecLayout, Dropdown, Layout
import numpy as np
import seaborn as sns

import sys
import argparse

# ml4h Imports
from ml4h.arguments import parse_args
from sklearn.decomposition import PCA

from bokeh.io import push_notebook, show, output_notebook
from bokeh.plotting import figure
from ml4h.models import make_variational_multimodal_multitask_model, train_model_from_generators
from ml4h.tensor_generators import test_train_valid_tensor_generators, big_batch_from_minibatch_generator
from ml4h.recipes import train_multimodal_multitask
import matplotlib.pyplot as plt
from itertools import product
%matplotlib inline
output_notebook()

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))  # WIDER DISPLAY FOR WIDGETS

In [None]:
sys.argv = [
    'train',
    '--tensors', '/mnt/disks/ecg-rest-31k/2019-06-10',
    '--input_tensors', 'ecg_rest_median',
    '--output_tensors', 'ecg_rest_median',
    '--training_steps', '150',
    '--validation_steps', '30',
    '--epochs', '200',
    '--batch_size', '32',
    '--output_folder', '/home/ndiamant/train_runs/',
    '--dense_layers', '32',
    '--id', 'vae_ecg_short_for_explore',
    '--variational',
    '--model_file', '/home/ndiamant/train_runs/vae_ecg_short_for_explore/vae_ecg_short_for_explore.hd5',
]
args = parse_args()

In [None]:
m, enc, dec = make_variational_multimodal_multitask_model(**args.__dict__)

In [None]:
generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)

In [None]:
x_test_in, x_test_out, _ = big_batch_from_minibatch_generator(generate_test, minibatches=64)
input_key = list(x_test_in.keys())[0]
pred = m.predict(x_test_in)
encoded = enc.predict(x_test_in)

In [None]:
for i in range(6):
    plt.figure(figsize=(10, 5))
    plt.title(f'Latent dim {i} distribution.')
    sns.distplot(encoded[:, i], bins=25)
    plt.show()
    plt.figure(figsize=(10, 5))
    
    plt.plot(x_test_in[input_key][i, :, 0], label='og')
    plt.plot(pred[i, :, 0], c='k', linestyle='--', label='reconstructed')
    plt.legend()
    plt.show()

In [None]:
p = figure(title="ECG reconstruction", plot_height=400, plot_width=1200, y_range=(pred.min(), pred.max()),
           background_fill_color='#efefef')
x = np.arange(x_test_in[input_key].shape[1])
y = pred[0, :, 0]
r = p.line(x, y, color="#8888cc", line_width=1.5, alpha=0.8, legend_label='Reconstruction from Slider Values')
og_recon = p.line(x, y, color='black', line_width=1, line_dash='dashed', legend_label='Sample Reconstruction')
og = p.line(x, y, color='red', line_width=1, line_dash='dotted', legend_label='Original Sample')
p.legend.click_policy="hide"
ndim = encoded.shape[-1]

class Updater:    
    def __init__(self):
        self.sliders = [
            FloatSlider(min=-3, max=3, step=.01, value=encoded[0, i], description=f'latent {i}') 
            for i in range(ndim)
        ]

        self.sample_slider = BoundedIntText(min=0, max=len(pred)-1)
        self.channel_slider = Dropdown(options=list(range(pred.shape[-1])))
        for slider in self.sliders:
            slider.observe(self._slider_callback, names='value')
        self.prev_sample = 0
        self.__name__ = 'update_latent_values'
        
    def _slider_callback(self, change):
        self(sample=self.sample_slider.value, channel=self.channel_slider.value)
    
    #self = None is gross, but necessary to avoid error in interact
    def __call__(self=None, sample=0, channel=0):
        if self.prev_sample != sample:
            for i, slider in enumerate(self.sliders):
                slider.value = encoded[sample, i]
            self.prev_sample = sample
        z = np.array([[slider.value for slider in self.sliders]])
        recon = dec.predict(z)
        r.data_source.data['y'] = recon[0, :, channel]
        og_recon.data_source.data['y'] = pred[sample, :, channel]
        og.data_source.data['y'] = x_test_in[input_key][sample, :, channel]
        push_notebook()

In [None]:
show(p, notebook_handle=True)
update = Updater()
interact(update, sample=update.sample_slider, channel=update.channel_slider)
grid = GridspecLayout(ndim // 4, 4)
for k, (i, j) in enumerate(product(range(ndim // 4), range(4))):
    grid[i, j] = update.sliders[k]
grid

In [None]:
pca = PCA()
pca.fit(encoded)
plt.title('Scree plot for latent variables')
plt.plot(pca.explained_variance_ratio_)
plt.show()

In [None]:
plt.figure(figsize=(10, 7))
sum_variance = pca.explained_variance_ratio_.cumsum()
plt.plot(sum_variance)
plt.axhline(.9, c='g', linestyle='--', label='90% variance explained')
pct_90 = np.where(sum_variance > .9)[0][0]
plt.axvline(pct_90, c='r', linestyle='--', label=f'# pcs for 90% variance explained: {pct_90}')
plt.legend()
plt.show()