In [None]:
import os
import sys
import h5py
import argparse
import numpy as np
from collections import Counter
import xml.etree.ElementTree as et

# Keras imports
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils.vis_utils import model_to_dot
from tensorflow.keras.layers import Input, Dense, Dropout, AlphaDropout, BatchNormalization, Activation

# ML4CVD Imports
from ml4cvd.plots import plot_ecg
from ml4cvd.arguments import parse_args
from ml4cvd.tensor_writer_ukbb import write_tensors
from ml4cvd.recipes import train_multimodal_multitask
from ml4cvd.models import make_multimodal_multitask_model
from ml4cvd.tensor_generators import big_batch_from_minibatch_generator, test_train_valid_tensor_generators, TensorGenerator

# IPython imports
from IPython.display import Image
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
sys.argv = ['train', 
            '--tensors', '/mnt/disks/tensors-lv-mass/', 
            '--input_tensors', 'mri_systole_diastole_weighted', 
            '--output_tensors', 'mri_systole_diastole_segmented_weighted', 'lv_mass', 'end_systole_volume', 'end_diastole_volume', 'ejection_fraction',
            '--batch_size', '4',
            '--pool_z', '1',
            '--epochs', '2',  
            '--learning_rate', '0.001',
            '--u_connect',
            '--training_steps', '128',
            '--validation_steps', '10',
            '--test_steps', '12',
            '--b_slice_force', '4',
            '--model_file', '/mnt/ml4cvd/projects/jamesp/data/models/mri_systole_diastole_une_mass.hd5',
            '--id', 'mri_systole_diastole_unet_mass']
args = parse_args()
#train_multimodal_multitask(args)

In [None]:
generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(args.tensor_maps_in,  args.tensor_maps_out,  args.tensors, args.batch_size,   args.valid_ratio, args.test_ratio, args.test_modulo, args.balance_csvs)

model = make_multimodal_multitask_model(**args.__dict__)

test_data, test_labels, test_paths = big_batch_from_minibatch_generator(args.tensor_maps_in, args.tensor_maps_out,
                                                                        generate_test, args.test_steps, True)

predictions = model.predict(test_data, batch_size=args.batch_size)

In [None]:
IMAGE_EXT = '.png'
def plot_scatter(prediction, truth, title, paths=None, prefix='./figures/', top_k=3):
    margin = (np.max(truth)-np.min(truth))/100
    plt.figure(figsize=(16, 8))
    plt.plot([np.min(truth),np.max(truth)],[np.min(truth),np.max(truth)], linewidth=2)
    plt.plot([np.min(prediction),np.max(prediction)],[np.min(prediction),np.max(prediction)], linewidth=4)
    plt.scatter(prediction, truth)
    if paths is not None:
        diff = np.abs(prediction-truth)
        argsorted = diff.argsort(axis=0)[:, 0]
        for idx in argsorted[:top_k]:
            plt.text(prediction[idx]+margin, truth[idx], os.path.basename(paths[idx]))
        for idx in argsorted[-top_k:]:
            plt.text(prediction[idx]+margin, truth[idx], os.path.basename(paths[idx]))
    plt.xlabel('Predictions')
    plt.ylabel('Actual')
    plt.title(title + '\n')
    pearson = np.corrcoef(prediction.flatten(), truth.flatten())[1, 0]  # corrcoef returns full covaraviance matrix
    print("Pearson coefficient is: {}".format(pearson))
    plt.text(np.min(prediction), np.max(truth), 'Pearson:%0.3f R^2:%0.3f' % (pearson, (pearson * pearson)))
    figure_path = os.path.join(prefix, 'scatter_' + title + IMAGE_EXT)
    if not os.path.exists(os.path.dirname(figure_path)):
        os.makedirs(os.path.dirname(figure_path))
    plt.savefig(figure_path)
    print("Saved scatter plot at: {}".format(figure_path))
    return {title + '_pearson': pearson}

In [None]:
plot_path = './recipes_output/'
print(len(test_paths))
for y, tm in zip(predictions, args.tensor_maps_out):
    if tm.is_categorical_any() or len(tm.shape) != 1:
        continue
    plot_scatter(tm.rescale(y.copy()), tm.rescale(test_labels[tm.output_name()].copy()), tm.name, test_paths, plot_path)

In [None]:
xdir = '/mnt/disks/tensors-lv-mass/2019-04-10/'
paths = [xdir+'5688245.hd5']
np.random.seed(1234)
generator = TensorGenerator(1, args.tensor_maps_in, args.tensor_maps_out, paths, None, True)
data, labels, _, hd5s = next(generator)
predictions = model.predict(data, batch_size=1)
y = {tm.output_name():p for p, tm in zip(predictions, args.tensor_maps_out)}
print(list(labels.keys()))
systole_idx = 1
diastole_idx = 0
hd5 = h5py.File(hd5s[0], 'r')
print('true lv mass:', labels['output_lv_mass_continuous'], 'predicted lv mass:', y['output_lv_mass_continuous'])
truth = np.argmax(labels['output_mri_systole_diastole_segmented_categorical'][0, :, :, systole_idx, :], axis=-1)
sys_prediction = np.argmax(y['output_mri_systole_diastole_segmented_categorical'][0, :, :, systole_idx, :], axis=-1)
true_donut = np.ma.masked_where(truth == 2, data['input_mri_systole_diastole'][0, :, :, systole_idx, 0])
predict_donut = np.ma.masked_where(sys_prediction == 2, data['input_mri_systole_diastole'][0, :, :, systole_idx, 0])
plt.imshow(truth)

In [None]:
plt.imshow(sys_prediction)

In [None]:
plt.imshow(data['input_mri_systole_diastole'][0, :, :, systole_idx, 0])

In [None]:
plt.imshow(true_donut)

In [None]:
plt.imshow(predict_donut)

In [None]:
diastole_truth = np.argmax(labels['output_mri_systole_diastole_segmented_categorical'][0, :, :, diastole_idx, :], axis=-1)
diastole_prediction = np.argmax(y['output_mri_systole_diastole_segmented_categorical'][0, :, :, diastole_idx, :], axis=-1)
diastole_donut = np.ma.masked_where(diastole_truth == 2, data['input_mri_systole_diastole'][0, :, :, diastole_idx, 0])
diastole_predict_donut = np.ma.masked_where(diastole_prediction == 2, data['input_mri_systole_diastole'][0, :, :, diastole_idx, 0])
plt.imshow(diastole_truth)

In [None]:
plt.imshow(diastole_prediction)

In [None]:
plt.imshow(data['input_mri_systole_diastole'][0, :, :, diastole_idx, 0])

In [None]:
plt.imshow(diastole_donut)

In [None]:
plt.imshow(diastole_predict_donut)