In [None]:
import os
import sys
import math
import argparse
import numpy as np
from collections import Counter

# Keras imports
from tensorflow.keras.models import Model
from keras import backend as K

# ml4h Imports
from ml4h.arguments import parse_args
from ml4h.models import make_multimodal_multitask_model, train_model_from_generators
from ml4h.tensor_generators import TensorGenerator, big_batch_from_minibatch_generator, test_train_valid_tensor_generators

# IPython imports
from IPython.display import Image
%matplotlib inline
import matplotlib.pyplot as plt


In [None]:
def gradients_from_output(args, model, output_layer, output_index):
    K.set_learning_phase(1)
    input_tensor = model.input
    x = model.get_layer(output_layer).output[:,output_index]
    grads = K.gradients(x, input_tensor)[0]
    grads /= (K.sqrt(K.mean(K.square(grads))) + 1e-6) # normalization trick: we normalize the gradient
    iterate = K.function([input_tensor], [x, grads])
    return iterate

def saliency_map(input_tensor, model, output_layer, output_index):
    get_gradients = gradients_from_output(args, model, output_layer, output_index)
    activation, grads = get_gradients([input_tensor])
    print('Activation is:', activation, 'gradient shape:', grads.shape)
    return grads


In [None]:
sys.argv = ['train', 
            '--tensors', '/mnt/disks/sax-lax-40k-lvm/2019-11-21/', 
            '--input_tensors', 't2_flair_brain_30_slices', 
            '--output_tensors', 'age_2',
            '--batch_size', '4',
            '--test_steps', '2',
            '--model_file', '/home/sam/ml/trained_models/t2_flair_brain_age_converge/t2_flair_brain_age_converge.hd5',
            '--id', 'brain_age']

args = parse_args()
_, _, generate_test = test_train_valid_tensor_generators(**args.__dict__)
model = make_multimodal_multitask_model(**args.__dict__)
test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps)

In [None]:
def plot_brain(brain, cols=3, rows=10):
    _, axes = plt.subplots(rows, cols, figsize=(cols*4, rows*4))
    vmin = np.min(brain)
    vmax = np.max(brain)
    for i in range(test_tensor.shape[-1]):
        axes[i//cols, i%cols].imshow(brain[:, :, i], cmap='gray', vmin=vmin, vmax=vmax)
        axes[i//cols, i%cols].set_yticklabels([])
        axes[i//cols, i%cols].set_xticklabels([])    

In [None]:
test_tensor = test_data['input_t2_flair_brain_30_slices_ukb_brain_mri']
print(test_tensor.shape, test_tensor[:1].shape)
plot_brain(test_tensor[0])

In [None]:
grads = saliency_map(test_tensor[:6], model, 'output_21003_Age-when-attended-assessment-centre_2_continuous', 0)
plot_brain(grads[0])

In [None]:
plot_brain(test_tensor[1])

In [None]:
plot_brain(grads[1])

In [None]:
plot_brain(np.ma.masked_where(grads[0] < -5, test_tensor[0]))

In [None]:
plot_brain(np.ma.masked_where(grads[1]) > 3, test_tensor[1]))

In [None]:
plot_brain(np.ma.masked_where(np.abs(grads[3]) > 5, test_tensor[3]))

In [None]:
plot_brain(np.ma.masked_where(np.abs(grads[2]) > 5, test_tensor[2]))

In [None]:
plot_brain(np.ma.masked_where(np.abs(grads[4]) > 5, test_tensor[4]))

In [None]:
plot_brain(np.ma.masked_where(np.abs(grads[5]) > 3, test_tensor[5]))

- do this with hyperintensity
- ventricular volumes
- volumes
- move the window up less stem