In [None]:
# Imports
import os
import re
import sys
import math
import h5py
import glob
import pickle
import random
import logging
import hashlib
import operator
from textwrap import wrap
from functools import reduce
from datetime import datetime
from multiprocessing import Pool
from itertools import islice, product
from collections import Counter, OrderedDict, defaultdict
from typing import Iterable, DefaultDict, Dict, List, Tuple, Optional, Callable

from sklearn import manifold
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
from sklearn.metrics import brier_score_loss, precision_score, recall_score, f1_score
from sklearn.calibration import calibration_curve
import seaborn as sns
from biosppy.signals import ecg
from scipy.ndimage.filters import gaussian_filter
from scipy import stats

from ml4h.TensorMap import TensorMap
from ml4h.metrics import concordance_index, coefficient_of_determination
from ml4h.defines import IMAGE_EXT, JOIN_CHAR, PDF_EXT, TENSOR_EXT, ECG_REST_LEADS, PARTNERS_DATETIME_FORMAT, PARTNERS_DATE_FORMAT

import matplotlib
matplotlib.use('Agg')  # Need this to write images from the GSA servers.  Order matters:
import matplotlib.pyplot as plt  # First import matplotlib, then use Agg, then import plt
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import NullFormatter
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.ticker import AutoMinorLocator, MultipleLocator

RECALL_LABEL = 'Recall | Sensitivity | True Positive Rate | TP/(TP+FN)'
FALLOUT_LABEL = 'Fallout | 1 - Specificity | False Positive Rate | FP/(FP+TN)'
PRECISION_LABEL = 'Precision | Positive Predictive Value | TP/(TP+FP)'

SUBPLOT_SIZE = 8

COLOR_ARRAY = [
    'tan', 'indigo', 'cyan', 'pink', 'purple', 'blue', 'chartreuse', 'deepskyblue', 'green', 'salmon', 'aqua', 'magenta', 'aquamarine', 'red',
    'coral', 'tomato', 'grey', 'black', 'maroon', 'hotpink', 'steelblue', 'orange', 'papayawhip', 'wheat', 'chocolate', 'darkkhaki', 'gold',
    'orange', 'crimson', 'slategray', 'violet', 'cadetblue', 'midnightblue', 'darkorchid', 'paleturquoise', 'plum', 'lime',
    'teal', 'peru', 'silver', 'darkgreen', 'rosybrown', 'firebrick', 'saddlebrown', 'dodgerblue', 'orangered',
]

import csv
import gzip
import h5py
import shutil
import zipfile
import pydicom
import numpy as np


# Keras imports
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import History
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.utils import model_to_dot
from tensorflow.keras.layers import LeakyReLU, PReLU, ELU, ThresholdedReLU, Lambda, Reshape, LayerNormalization
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, Callback
from tensorflow.keras.layers import SpatialDropout1D, SpatialDropout2D, SpatialDropout3D, add, concatenate
from tensorflow.keras.layers import Input, Dense, Dropout, BatchNormalization, Activation, Flatten, LSTM, RepeatVector
from tensorflow.keras.layers import Conv1D, Conv2D, Conv3D, UpSampling1D, UpSampling2D, UpSampling3D, MaxPooling1D
from tensorflow.keras.layers import MaxPooling2D, MaxPooling3D, AveragePooling1D, AveragePooling2D, AveragePooling3D, Layer
from tensorflow.keras.layers import SeparableConv1D, SeparableConv2D, DepthwiseConv2D


from ml4h.defines import StorageType
from ml4h.arguments import parse_args, TMAPS, _get_tmap
from ml4h.TensorMap import TensorMap, Interpretation
from ml4h.tensor_generators import test_train_valid_tensor_generators, big_batch_from_minibatch_generator
from ml4h.models import train_model_from_generators, make_multimodal_multitask_model, _inspect_model, train_model_from_generators, make_hidden_layer_model
from ml4h.recipes import test_multimodal_multitask, train_multimodal_multitask, saliency_maps, _predict_and_evaluate

%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import gridspec

# Constants
HD5_FOLDER = '/mnt/disks/ecg-rest-38k-tensors/2020-03-14/'

In [None]:
sys.argv = ['train', 
            '--tensors', HD5_FOLDER, 
            '--input_tensors', 'ecg_rest', 'genetic_caucasian',
            '--output_tensors', 'poor_data_quality',
            '--protected_tensors', 'sex', 'genetic_caucasian', 'age_0',
            '--training_steps', '96',
            '--validation_steps', '24',
            '--test_steps', '24',
            '--epochs', '6',
            '--batch_size', '24',
            '--id', 'ecg_rest_bias'
           ]

args = parse_args()

In [None]:
generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)

In [None]:
model = make_multimodal_multitask_model(**args.__dict__)
train_model_from_generators(model, generate_train, generate_valid, args.training_steps, args.validation_steps, 
                            args.batch_size, args.epochs, args.patience, args.output_folder, args.id, 
                            args.inspect_model, args.inspect_show_labels)

In [None]:
def evaluate_predictions(
    tm: TensorMap, y_predictions: np.ndarray, y_truth: np.ndarray, protected: Dict[TensorMap, np.ndarray], title: str, folder: str, test_paths: List[str] = None,
    max_melt: int = 150000, rocs: List[Tuple[np.ndarray, np.ndarray, Dict[str, int]]] = [],
    scatters: List[Tuple[np.ndarray, np.ndarray, str, List[str]]] = [],
) -> Dict[str, float]:
    """ Evaluate predictions for a given TensorMap with truth data and plot the appropriate metrics.
    Accumulates data in the rocs and scatters lists to facilitate subplotting.

    :param tm: The TensorMap predictions to evaluate
    :param y_predictions: The predictions
    :param y_truth: The truth
    :param title: A title for the plots
    :param folder: The folder to save the plots at
    :param test_paths: The tensor paths that were predicted
    :param max_melt: For multi-dimensional prediction the maximum number of prediction to allow in the flattened array
    :param protected: TensorMaps and tensors sensitive to bias
    :param rocs: (output) List of Tuples which are inputs for ROC curve plotting to allow subplotting downstream
    :param scatters: (output) List of Tuples which are inputs for scatter plots to allow subplotting downstream
    :return: Dictionary of performance metrics with string keys for labels and float values
    """
    performance_metrics = {}
    if tm.is_categorical() and tm.axes() == 1:
        logging.info(f"For tm:{tm.name} with channel map:{tm.channel_map} examples:{y_predictions.shape[0]}")
        logging.info(f"\nSum Truth:{np.sum(y_truth, axis=0)} \nSum pred :{np.sum(y_predictions, axis=0)}")
        performance_metrics.update(subplot_roc_per_class(y_predictions, y_truth, tm.channel_map, protected, 
                                                      title, folder))
        rocs.append((y_predictions, y_truth, tm.channel_map))

    return performance_metrics


def get_fpr_tpr_roc_pred(y_pred, test_truth, labels):
    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()

    for k in labels:
        cur_idx = labels[k]
        aser = roc_curve(test_truth[:, cur_idx], y_pred[:, cur_idx])
        fpr[labels[k]], tpr[labels[k]], _ = aser
        roc_auc[labels[k]] = auc(fpr[labels[k]], tpr[labels[k]])

    return fpr, tpr, roc_auc
def _hash_string_to_color(string):
    """Hash a string to color (using hashlib and not the built-in hash for consistency between runs)"""
    return COLOR_ARRAY[int(hashlib.sha1(string.encode('utf-8')).hexdigest(), 16) % len(COLOR_ARRAY)]


def _text_on_plot(axes, x, y, text, alpha=0.8, background='white'):
    t = axes.text(x, y, text)
    t.set_bbox({'facecolor': background, 'alpha': alpha, 'edgecolor': background})


    

def new_predict_and_evaluate(model, test_data, test_labels, tensor_maps_in, tensor_maps_out, 
                             tensor_maps_protected, batch_size, hidden_layer, plot_path, 
                             test_paths, embed_visualization, alpha):
    layer_names = [layer.name for layer in model.layers]
    performance_metrics = {}
    scatters = []
    rocs = []
    
    protected_data = {tm: test_labels[tm.output_name()] for tm in tensor_maps_protected}
    print(f'tm prot {len(protected_data)}')
    
    y_predictions = model.predict(test_data, batch_size=batch_size)
    for y, tm in zip(y_predictions, tensor_maps_out):
        if tm.output_name() not in layer_names:
            continue
        if not isinstance(y_predictions, list):  # When models have a single output model.predict returns a ndarray otherwise it returns a list
            y = y_predictions
        y_truth = np.array(test_labels[tm.output_name()])
        performance_metrics.update(evaluate_predictions(tm, y, y_truth, protected_data, tm.name, plot_path, 
                                                        test_paths, rocs=rocs, scatters=scatters))
        if tm.is_language():
            sample_from_language_model(tensor_maps_in, tm, model, test_data, max_samples=16)

    if len(rocs) > 1:
        subplot_rocs(rocs, plot_path)
    if len(scatters) > 1:
        subplot_scatters(scatters, plot_path)

    test_labels_1d = {tm: np.array(test_labels[tm.output_name()]) for tm in tensor_maps_out if tm.output_name() in test_labels}
    if embed_visualization == "tsne":
        _tsne_wrapper(model, hidden_layer, alpha, plot_path, test_paths, test_labels_1d, test_data=test_data, tensor_maps_in=tensor_maps_in, batch_size=batch_size)

    return performance_metrics

In [None]:
def subplot_roc_per_class(prediction, truth, labels, protected, title, prefix='./figures/'):
    lw = 2
    col = 0
    row = 1
    labels_to_areas = {}
    true_sums = np.sum(truth, axis=0)
    total_plots = len(protected) + 1
    cols = max(2, int(math.ceil(math.sqrt(total_plots))))
    rows = max(2, int(math.ceil(total_plots / cols)))
    fig, axes = plt.subplots(rows, cols, figsize=(cols*SUBPLOT_SIZE, rows*SUBPLOT_SIZE))
    fpr, tpr, roc_auc = get_fpr_tpr_roc_pred(prediction, truth, labels)
    
    for p in protected:
        print(f'\n name {p.name} truth shape {truth.shape} IN ROCCCC {p.name} and {p.shape} and {protected[p].shape}')
        
        axes[row, col].plot([0, 1], [0, 1], 'k:', lw=0.5)
        axes[row, col].set_title(f'Protected {p.name}')
        for key in labels:    
            if p.is_categorical():
                idx2key = {v: k for k, v in p.channel_map.items()}
                protected_indexes = protected[p][:, 0] == 1
                print(f'\n\n protected_indexes shape {protected_indexes.shape}')

                pfpr, ptpr, proc_auc = get_fpr_tpr_roc_pred(prediction[protected_indexes], 
                                                            truth[protected_indexes], labels)
                label_text = f'{key} roc={proc_auc[labels[key]]:.3f} n={np.sum(protected_indexes):.0f}'

                color = _hash_string_to_color(p.name+key)
                axes[row, col].plot(pfpr[labels[key]], ptpr[labels[key]], color=color, lw=lw, label=label_text)
            elif p.is_continuous():
                threshold = np.median(protected[p])
                protected_indexes = (protected[p] > threshold)[:, 0]
                pfpr, ptpr, proc_auc = get_fpr_tpr_roc_pred(prediction[protected_indexes], 
                                                                truth[protected_indexes], labels)
                label_text = f'{key} roc={proc_auc[labels[key]]:.3f} Highest  n={np.sum(protected_indexes):.0f}'
                color = _hash_string_to_color(p.name+key)
                axes[row, col].plot(pfpr[labels[key]], ptpr[labels[key]], color=color, lw=lw, label=label_text)
                print(f'\n\n median {threshold} protected_indexes shape {protected[p].shape}')                
                axes[row, col].set_xlim([0.0, 1.0])
        axes[row, col].set_ylim([-0.02, 1.03])
        axes[row, col].set_ylabel(RECALL_LABEL)
        axes[row, col].set_xlabel(FALLOUT_LABEL)
        axes[row, col].legend(loc='lower right')
        row += 1
        if row == rows:
            row = 0
            col += 1
            if col >= cols:
                break
                    
    for key in labels:
        labels_to_areas[key] = roc_auc[labels[key]]
        if 'no_' in key and len(labels) == 2:
            continue
        color = _hash_string_to_color(key)
        label_text = f'{key} area: {roc_auc[labels[key]]:.3f} n={true_sums[labels[key]]:.0f}'
        axes[0, 0].plot(fpr[labels[key]], tpr[labels[key]], color=color, lw=lw, label=label_text)
        logging.info(f'ROC Label {label_text} Truth shape {truth.shape}, true sums {true_sums}')

    axes[0, 0].set_title(f'ROC {title} n={truth.shape[0]:.0f}\n')
    axes[0, 0].legend(loc='lower right')
    figure_path = os.path.join(prefix, 'per_class_roc_' + title + IMAGE_EXT)
    if not os.path.exists(os.path.dirname(figure_path)):
        os.makedirs(os.path.dirname(figure_path))
    plt.savefig(figure_path, bbox_inches='tight')
    plt.clf()
    logging.info(f"Saved ROC curve at: {figure_path} with {len(protected)} protected TensorMaps.")
    return labels_to_areas

In [None]:
out_path = os.path.join(args.output_folder, args.id + '/')
test_data, test_labels, test_paths = big_batch_from_minibatch_generator(generate_test, args.test_steps)
new_predict_and_evaluate(model, test_data, test_labels, args.tensor_maps_in, args.tensor_maps_out, 
                      args.tensor_maps_protected, args.batch_size, args.hidden_layer, out_path, 
                      test_paths, args.embed_visualization, args.alpha)