In [None]:
import os
import sys
import math
import argparse
import numpy as np
from collections import Counter

# Keras imports
from keras.models import Model
from keras import backend as K

# ML4CVD Imports
from ml4cvd.arguments import parse_args
from ml4cvd.models import make_multimodal_multitask_model, train_model_from_generators

# IPython imports
from IPython.display import Image
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
def iterate_channel(args, model, layer_dict, layer_name='conv5_1', channel=0):
	K.set_learning_phase(1)
	input_tensor = model.input
	if K.image_data_format()== 'channels_first':
		x = layer_dict[layer_name].output[:,channel,:,:]
	else:
		x = layer_dict[layer_name].output[:,:,:,channel]
	
	w = x.shape[1]
	h = x.shape[2]
	shape = layer_dict[layer_name].output_shape

	objective = K.variable(0.)

	objective += K.sum(K.square(x[:, 2: w-2, 2:h-2])) / np.prod(shape[1:])

	# add continuity loss (gives image local coherence, can result in an artful blur)
	#objective -= args.total_variation * total_variation_norm(input_tensor) / np.prod(x.shape[1:])
	# add image L2 norm to loss (prevents pixels from taking very high values, makes image darker)
	#objective -= args.l2 * K.sum(K.square(input_tensor)) / np.prod(x.shape[1:])
	
	# compute the gradient of the input picture wrt this loss
	grads = K.gradients(objective, input_tensor)[0]

	# normalization trick: we normalize the gradient
	grads /= (K.sqrt(K.mean(K.square(grads))) + 1e-6)

	# this function returns the loss and grads given the input picture
	iterate = K.function([input_tensor], [objective, grads])
	return iterate

def iterate_channel_1d(args, model, layer_dict, layer_name='conv5_1', channel=0):
	K.set_learning_phase(1)
	input_tensor = model.input
	if K.image_data_format()== 'channels_first':
		x = layer_dict[layer_name].output[:,channel,:]
	else:
		x = layer_dict[layer_name].output[:,:,channel]
	
	w = x.shape[1]
	shape = layer_dict[layer_name].output_shape

	objective = K.variable(0.)

	objective += K.sum(K.square(x[:, 2: w-2])) / np.prod(shape[1:])

	# add continuity loss (gives image local coherence, can result in an artful blur)
	#objective -= args.total_variation * total_variation_norm(input_tensor) / np.prod(x.shape[1:])
	# add image L2 norm to loss (prevents pixels from taking very high values, makes image darker)
	#objective -= args.l2 * K.sum(K.square(input_tensor)) / np.prod(x.shape[1:])
	
	# compute the gradient of the input picture wrt this loss
	grads = K.gradients(objective, input_tensor)[0]

	# normalization trick: we normalize the gradient
	grads /= (K.sqrt(K.mean(K.square(grads))) + 1e-6)

	# this function returns the loss and grads given the input picture
	iterate = K.function([input_tensor], [objective, grads])
	return iterate


def write_filters(args, model, input_shape, iterate_fxn):
	jitter = 0.000001
	layer_dict = dict([(layer.name, layer) for layer in model.layers])

	for layer in model.layers:
		for filter_index in range(0, layer.output_shape[-1], 8):
			if not 'conv' in layer.name:
				continue

			iterate = iterate_fxn(args, model, layer_dict, layer.name, filter_index)
			#print("Layer name:", layer.name, "filter index:", filter_index)
			input_img_data = np.random.random(input_shape)
			out_file = os.path.join(args.output_folder, args.id, 'write_filters', '%s_filter_%d.png' % (layer.name, filter_index))

			# run gradient ascent
			for i in range(args.epochs):
				random_jitter = jitter * (np.random.random(input_shape) - 0.5)
				input_img_data += random_jitter
				loss_value, grads_value = iterate([input_img_data])
				input_img_data -= random_jitter

				input_img_data += args.learning_rate*grads_value
				#if i % (args.epochs//164) == 0:
				#	print("  After iteration:", i, "loss is:", loss_value," layer name:", layer.name, "filter index:", filter_index)
			
			if not os.path.exists(os.path.dirname(out_file)):
				os.makedirs(os.path.dirname(out_file))
			print('Saved:', out_file)
			if len(input_shape) == 4:
				plt.imsave(out_file, input_img_data[0,:,:,0])
			if len(input_shape) == 3:
				row = 0
				col = 0
				total_plots = input_shape[-1]
				rows = max(2, int(math.sqrt(total_plots)))
				cols = max(2, total_plots // rows)
				fig, axes = plt.subplots(rows, cols, figsize=(48, 48))
				for i in range(total_plots):
					axes[row, col].plot(input_img_data[0,:,0])
					row += 1
					if row == rows:
						row = 0
						col += 1
						if col >= cols:
							break            
				plt.show()

In [None]:
def gradients_from_output(args, model, output_layer, output_index):
    input_tensor = model.input
    x = model.get_layer(output_layer).output[:,output_index]

    objective = K.variable(0.)
    objective += K.sum(K.square(x))

    # compute the gradient of the input picture wrt this loss
    grads = K.gradients(objective, input_tensor)[0]

    # normalization trick: we normalize the gradient
    grads /= (K.sqrt(K.mean(K.square(grads))) + 1e-6)

    # this function returns the loss and grads given the input picture
    iterate = K.function([input_tensor], [objective, grads])
    return iterate


def saliency_map(input_tensor, model, output_layer, output_index):
    get_gradients = gradients_from_output(args, model, output_layer, output_index)
    activation, grads = get_gradients([input_tensor])
    print('Activation is:', activation, 'gradient shape:', grads.shape)
    if len(input_tensor.shape) == 4:
        plt.imshow(grads)
    elif len(input_tensor.shape) == 3:
        plt.plot(grads)

In [None]:
sys.argv = ['train', 
            '--tensors', '/mnt/disks/data/generated/tensors/test/2019-03-21/',
            '--input_tensors', 'mri_slice_weighted', 
            '--output_tensors','mri_slice_segmented_weighted',
            '--epochs', '260',
            '--learning_rate', '0.1',
            '--u_connect',
            '--model_layers', '/mnt/ml4cvd/projects/jamesp/data/models/mri_slice_seg_unet.hd5',
            '--id', 'mri_slice_segmenter']
args = parse_args()
model = make_multimodal_multitask_model(**args.__dict__)


In [None]:
input_shape = (1, args.x, args.y,1) 
write_filters(args, model, input_shape, iterate_channel)

In [None]:
Image('./recipes_output/mri_slice_segmenter/write_filters/conv2d_1_filter_16.png')

In [None]:
Image('./recipes_output/mri_slice_segmenter/write_filters/conv2d_4_filter_8.png')

In [None]:
Image('./recipes_output/mri_slice_segmenter/write_filters/conv2d_6_filter_0.png')

In [None]:
Image('./recipes_output/mri_slice_segmenter/write_filters/conv2d_8_filter_8.png')

In [None]:
Image('./recipes_output/mri_slice_segmenter/write_filters/conv2d_12_filter_0.png')

In [None]:
Image('./recipes_output/mri_slice_segmenter/write_filters/conv2d_11_filter_0.png')

In [None]:
Image('./recipes_output/mri_slice_segmenter/write_filters/conv2d_7_filter_16.png')

In [None]:
Image('./recipes_output/mri_slice_segmenter/write_filters/conv2d_14_filter_0.png')

In [None]:
sys.argv = ['train', 
            '--tensors', '/mnt/disks/ecg-text3/2019-04-26/',
            '--input_tensors', 'ecg_rest_1lead', 
            '--output_tensors', 'ecg_median_1lead', 'ecg_normal', 'ecg_rhythm', 'p-axis', 'p-duration', 'p-offset', 'p-onset', 'pp-interval', 'pq-interval', 'q-offset', 'q-onset', 'qrs-duration', 'qrs-num', 'qt-interval', 'qtc-interval', 'ventricular-rate',
            '--epochs', '260',
            '--learning_rate', '0.1',
            '--u_connect',
            '--model_layers', '/mnt/ml4cvd/projects/jamesp/data/models/ecg_rest_wave_regress_afib_1lead.hd5',
            '--id', 'ecg_rest_wave_regress_afib_1lead']
args = parse_args()
model = make_multimodal_multitask_model(**args.__dict__)


In [None]:
input_tensor = np.random.random((1,600,8))
[print(layer.name) for layer in model.layers]
saliency_map(input_tensor, model, 'output_ecg_rhythm_categorical', 0)

In [None]:
# d1 = model.get_layer('conv1d_1')
# w1 = d1.get_weights()
# rows = max(2, w1[0].shape[-2])
# cols = max(2, w1[0].shape[-1])
# f, axes = plt.subplots(rows, cols, sharex=True, figsize=(int(rows * 12.5), int(cols * 2.5)))
# for row in range(rows):
#     for col in range(cols):
#         axes[row, col].plot(w1[0][:,row, col])
    
# plt.show()

In [None]:
input_shape = (1, 600, 8) 
write_filters(args, model, input_shape, iterate_channel_1d)

In [None]:
Image('./recipes_output/ecg_rest_wave_regress_afib_1lead/write_filters/conv1d_8_filter_16.png')