In [None]:
import os
import sys
import h5py
import argparse
import numpy as np
from collections import Counter
import xml.etree.ElementTree as et 

# Keras imports
from keras.models import Model
from keras.optimizers import Adam
from keras.utils.vis_utils import model_to_dot
from keras.layers import Input, Dense, Dropout, AlphaDropout, BatchNormalization, Activation

import matplotlib
matplotlib.use('Agg') # Need this to write images from the GSA servers.  Order matters:
import matplotlib.pyplot as plt # First import matplotlib, then use Agg, then import plt

# IPython imports
from IPython.display import Image

# ML4CVD Imports
from ml4cvd.arguments import parse_args
from ml4cvd.tensor_writer_ukbb import write_tensors
from ml4cvd.tensor_generators import test_train_valid_tensor_generators
from ml4cvd.defines import TENSOR_EXT, IMAGE_EXT, MRI_TO_SEGMENT, MRI_SEGMENTED
from ml4cvd.recipes import train_multimodal_multitask, test_multimodal_multitask
from ml4cvd.models import make_multimodal_multitask_model, train_model_from_generators

In [None]:
slice_idx = '56'
tensor_path = '/mnt/disks/data/generated/tensors/test/2019-03-21/'
for t in os.listdir(tensor_path):
    if os.path.splitext(t)[-1] != TENSOR_EXT:
        continue
    with h5py.File(tensor_path+t , 'r') as mri_tensors:
        if MRI_TO_SEGMENT not in mri_tensors:
            continue
        if slice_idx in mri_tensors[MRI_TO_SEGMENT]:
            print('got 1 at t:', t, len(mri_tensors[MRI_TO_SEGMENT]))
            mri_slice = np.array(mri_tensors[MRI_TO_SEGMENT][slice_idx])
            slice_labels = np.array(mri_tensors[MRI_SEGMENTED][slice_idx])
            plt.imsave('./mri_label_example'+slice_idx+IMAGE_EXT, slice_labels)
            plt.imsave('./mri_slice_example'+slice_idx+IMAGE_EXT, mri_slice)
            break

In [None]:
Image('./mri_slice_example'+slice_idx+IMAGE_EXT)

In [None]:
Image('./mri_label_example'+slice_idx+IMAGE_EXT)

In [None]:
slice_idx = '296'
tensor_path = '/mnt/disks/data/generated/tensors/test/2019-03-21/'
for t in os.listdir(tensor_path):
    if os.path.splitext(t)[-1] != TENSOR_EXT:
        continue
    with h5py.File(tensor_path+t , 'r') as mri_tensors:
        if MRI_TO_SEGMENT not in mri_tensors:
            continue
        if slice_idx in mri_tensors[MRI_TO_SEGMENT]:
            print('got 1 at t:', t, len(mri_tensors[MRI_TO_SEGMENT]))
            mri_slice = np.array(mri_tensors[MRI_TO_SEGMENT][slice_idx])
            slice_labels = np.array(mri_tensors[MRI_SEGMENTED][slice_idx])
            plt.imsave('./mri_label_example'+slice_idx+IMAGE_EXT, slice_labels)
            plt.imsave('./mri_slice_example'+slice_idx+IMAGE_EXT, mri_slice)
            break
Image('./mri_slice_example'+slice_idx+IMAGE_EXT)

In [None]:
Image('./mri_label_example'+slice_idx+IMAGE_EXT)

In [None]:
slice_idx = '496'
tensor_path = '/mnt/disks/data/generated/tensors/test/2019-03-21/'
for t in os.listdir(tensor_path):
    if os.path.splitext(t)[-1] != TENSOR_EXT:
        continue
    with h5py.File(tensor_path+t , 'r') as mri_tensors:
        if MRI_TO_SEGMENT not in mri_tensors:
            continue
        if slice_idx in mri_tensors[MRI_TO_SEGMENT]:
            print('got 1 at t:', t, len(mri_tensors[MRI_TO_SEGMENT]))
            mri_slice = np.array(mri_tensors[MRI_TO_SEGMENT][slice_idx])
            slice_labels = np.array(mri_tensors[MRI_SEGMENTED][slice_idx])
            plt.imsave('./mri_label_example'+slice_idx+IMAGE_EXT, slice_labels)
            plt.imsave('./mri_slice_example'+slice_idx+IMAGE_EXT, mri_slice)
            break
Image('./mri_slice_example'+slice_idx+IMAGE_EXT)

In [None]:
Image('./mri_label_example'+slice_idx+IMAGE_EXT)

In [None]:
sys.argv = ['tensorize', 
            '--tensors', './my_mri_tensors/',
            '--max_sample_id', '1050000',
            '--xml_field_id']
args = parse_args()
write_tensors(args.id, args.db, args.xml_folder, args.zip_folder, args.phenos_folder, args.output_folder,
                          args.tensors, args.dicoms, args.volume_csv, args.lv_mass_csv, args.icd_csv, args.categorical_field_ids,
                          args.continuous_field_ids, args.mri_field_ids, args.xml_field_ids, args.x, args.y, args.z,
                          args.include_heart_zoom, args.zoom_x, args.zoom_y, args.zoom_width,  args.zoom_height,
                          args.write_pngs, args.min_sample_id, args.max_sample_id, args.min_values, args.ukbb7089_sample_id_to_hail_pkl_path,
                          args.filtered_genotypes_array_path)

In [None]:
sys.argv = ['train', 
            '--tensors','/mnt/disks/data/generated/tensors/test/2019-03-21/', 
            '--input_tensors', 'mri_slice_weighted', 
            '--output_tensors','mri_slice_segmented_weighted',
            '--batch_size', '32', 
            '--epochs', '1',
            '--training_steps', '30',
            '--inspect_model',
            '--u_connect',
            '--id', 'mri_slice_labeler']
args = parse_args()
generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(args.tensor_maps_in,  args.tensor_maps_out,  args.tensors, args.batch_size,   args.valid_ratio, args.test_ratio, args.test_modulo, args.balance_csvs)

In [None]:
model = make_multimodal_multitask_model(**args.__dict__)


In [None]:
prediction = model.predict_generator(generate_test, steps=1)
plt.imsave('./figures/mri_prediction_example'+slice_idx+IMAGE_EXT, np.argmax(prediction[0], axis=-1))
Image('./figures/mri_prediction_example'+slice_idx+IMAGE_EXT)

In [None]:
model = train_model_from_generators(model, generate_train, generate_valid, args.training_steps,
                                        args.validation_steps, args.batch_size, args.epochs, args.patience,
                                        args.output_folder, args.id, args.inspect_model, args.inspect_show_labels)

In [None]:
Image('./recipes_output/mri_slice_labeler/architecture_graph_mri_slice_labeler.png')

In [None]:
prediction = model.predict(np.expand_dims(np.expand_dims(mri_slice, axis=-1), axis=0))
plt.imsave('./figures/mri_prediction_example'+slice_idx+IMAGE_EXT, np.argmax(prediction[0], axis=-1))
Image('./figures/mri_prediction_example'+slice_idx+IMAGE_EXT)

In [None]:
args.batch_size = 32
args.training_steps = 1500
args.inspect_model = False
args.model_file = './recipes_output/mri_slice_labeler/mri_slice_labeler.hd5'
model = make_multimodal_multitask_model(**args.__dict__)
model = train_model_from_generators(model, generate_train, generate_valid, args.training_steps,
                                        args.validation_steps, args.batch_size, args.epochs, args.patience,
                                        args.output_folder, args.id, args.inspect_model, args.inspect_show_labels)

In [None]:
args.model_file = './recipes_output/mri_slice_labeler/mri_slice_labeler.hd5'
model = make_multimodal_multitask_model(**args.__dict__)
prediction = model.predict_generator(generate_test, steps=1)
plt.imsave('./figures/mri_prediction_example'+slice_idx+IMAGE_EXT, np.argmax(prediction, axis=-1)[0])
Image('./figures/mri_prediction_example'+slice_idx+IMAGE_EXT)

In [None]:
args.test_steps=1
test_multimodal_multitask(args)

In [None]:
Image('./recipes_output/mri_slice_labeler/precision_recall_mri_slice_segmented.png')

In [None]:
Image('./recipes_output/mri_slice_labeler/per_class_roc_mri_slice_segmented.png')