In [None]:
import os
import sys
import math
import argparse
import numpy as np
from collections import Counter

# Keras imports
from keras.models import Model
from keras import backend as K

# ML4CVD Imports
from ml4cvd.arguments import parse_args
from ml4cvd.models import make_multimodal_multitask_model, train_model_from_generators
from ml4cvd.tensor_generators import TensorGenerator, big_batch_from_minibatch_generator, test_train_valid_tensor_generators

# IPython imports
from IPython.display import Image
%matplotlib inline
import matplotlib.pyplot as plt


In [None]:
def gradients_from_output(args, model, output_layer, output_index):
    K.set_learning_phase(1)
    input_tensor = model.input
    x = model.get_layer(output_layer).output[:,output_index]
    grads = K.gradients(x, input_tensor)[0]
    grads /= (K.sqrt(K.mean(K.square(grads))) + 1e-6) # normalization trick: we normalize the gradient
    iterate = K.function([input_tensor], [x, grads])
    return iterate

def saliency_map(input_tensor, model, output_layer, output_index):
    get_gradients = gradients_from_output(args, model, output_layer, output_index)
    activation, grads = get_gradients([input_tensor])
    print('Activation is:', activation, 'gradient shape:', grads.shape)
    return grads


In [None]:
def plot_ecgs(ecgs, rows=3, cols=4, time_interval=2.5, raw_scale=0.005, hertz=500):
    _, axes = plt.subplots(rows, cols, figsize=(18, 16))
    for i in range(rows):
        for j in range(cols):
            start = int(i*time_interval*hertz)
            stop = int((i+1)*time_interval*hertz)
            axes[i, j].set_xlim(start, stop)
            for label in ecgs:
                axes[i, j].plot(range(start, stop), ecgs[label][start:stop, j + i*cols] * raw_scale, label=label)  
            axes[i, j].legend(loc='lower right')
            axes[i, j].set_xlabel('milliseconds')

In [None]:
sys.argv = ['train', 
            '--tensors', '/mnt/disks/sax-lax-40k-teacher/2019-11-21/', 
            '--input_tensors', 'ecg_rest', 
            '--output_tensors', 'qt-interval',
            '--batch_size', '6',
            '--epochs', '2',  
            '--learning_rate', '0.001',
            '--training_steps', '128',
            '--validation_steps', '10',
            '--test_steps', '1',
            '--model_file', '/mnt/ml4cvd/projects/models/ecg_rest_qt_only/ecg_rest_qt_only.hd5',
            '--id', 'ecg_rest_qt_only']
args = parse_args()
model = make_multimodal_multitask_model(**args.__dict__)
generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)
test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps)
test_tensor = test_data['input_strip_ecg_rest']
grads_qt = saliency_map(test_tensor, model, 'output_QTInterval_continuous', 0)

In [None]:
sys.argv = ['train', 
            '--tensors', '/mnt/disks/sax-lax-40k-teacher/2019-11-21/', 
            '--input_tensors', 'ecg_rest', 
            '--output_tensors', 'ventricular-rate',
            '--batch_size', '6',
            '--epochs', '2',  
            '--learning_rate', '0.001',
            '--training_steps', '128',
            '--validation_steps', '10',
            '--test_steps', '1',
            '--model_file', '/mnt/ml4cvd/projects/models/ecg_rest_ventricular_rate/ecg_rest_ventricular_rate.hd5',
            '--id', 'ecg_rest_ventricular_rate']
args = parse_args()
model = make_multimodal_multitask_model(**args.__dict__)
grads_vr = saliency_map(test_tensor, model, 'output_VentricularRate_continuous', 0)

In [None]:
ecg_dict = {'raw': test_tensor[1], 'QTInterval': grads_qt[1], 'VentricularRate': grads_vr[1]} 
plot_ecgs(ecg_dict)