In [1]:
import os
import sys
import math
import argparse
import numpy as np
import pandas as pd
from collections import defaultdict, Counter
from sklearn.decomposition import PCA
from typing import Dict, List, Tuple, Iterable, Union, Optional, Set, Sequence, Callable, DefaultDict, Any

# Keras imports
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import LeakyReLU, PReLU, ELU, ThresholdedReLU, Lambda, Reshape, LayerNormalization
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, Callback
from tensorflow.keras.layers import SpatialDropout1D, SpatialDropout2D, SpatialDropout3D, add, concatenate
from tensorflow.keras.layers import Input, Dense, Dropout, BatchNormalization, Activation, Flatten, LSTM, RepeatVector
from tensorflow.keras.layers import Conv1D, Conv2D, Conv3D, UpSampling1D, UpSampling2D, UpSampling3D, MaxPooling1D
from tensorflow.keras.layers import MaxPooling2D, MaxPooling3D, AveragePooling1D, AveragePooling2D, AveragePooling3D, Layer
from tensorflow.keras.layers import SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Concatenate, Add
from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalAveragePooling3D


# ML4CVD Imports
from ml4cvd.TensorMap import TensorMap
from ml4cvd.arguments import parse_args
from ml4cvd.models import make_multimodal_multitask_model, train_model_from_generators, make_hidden_layer_model, _conv_layer_from_kind_and_dimension
from ml4cvd.tensor_generators import TensorGenerator, big_batch_from_minibatch_generator, test_train_valid_tensor_generators
from ml4cvd.recipes import plot_predictions, infer_hidden_layer_multimodal_multitask

# IPython imports
%matplotlib inline
import matplotlib.pyplot as plt


Tensor = tf.Tensor

ACTIVATION_CLASSES = {
    'leaky': LeakyReLU(),
    'prelu': PReLU(),
    'elu': ELU(),
    'thresh_relu': ThresholdedReLU,
}
ACTIVATION_FUNCTIONS = {
    'swish': tf.nn.swish,
    'gelu': tfa.activations.gelu,
    'lisht': tfa.activations.lisht,
    'mish': tfa.activations.mish,
}
NORMALIZATION_CLASSES = {
    'batch_norm': BatchNormalization,
    'layer_norm': LayerNormalization,
    'instance_norm': tfa.layers.InstanceNormalization,
    'poincare_norm': tfa.layers.PoincareNormalize,
}
CONV_REGULARIZATION_CLASSES = {
    # class name -> (dimension -> class)
    'spatial_dropout': {2: SpatialDropout1D, 3: SpatialDropout2D, 4: SpatialDropout3D},
    'dropout': defaultdict(lambda _: Dropout),
}
DENSE_REGULARIZATION_CLASSES = {
    'dropout': Dropout,  # TODO: add l1, l2
}

In [2]:

def _activation_layer(activation: str) -> Activation:
    return (
        ACTIVATION_CLASSES.get(activation, None)
        or Activation(ACTIVATION_FUNCTIONS.get(activation, None) or activation)
    )


def _normalization_layer(norm: str) -> Layer:
    if not norm:
        return lambda x: x
    return NORMALIZATION_CLASSES[norm]()


def _regularization_layer(dimension: int, regularization_type: str, rate: float):
    if not regularization_type:
        return lambda x: x
    if regularization_type in DENSE_REGULARIZATION_CLASSES:
        return DENSE_REGULARIZATION_CLASSES[regularization_type](rate)
    return CONV_REGULARIZATION_CLASSES[regularization_type][dimension](rate)


def _calc_start_shape(
        num_upsamples: int, output_shape: Tuple[int, ...], upsample_rates: Sequence[int], channels: int,
) -> Tuple[int, ...]:
    """
    Given the number of blocks in the decoder and the upsample rates, return required input shape to get to output shape
    """
    upsample_rates = list(upsample_rates) + [1] * len(output_shape)
    return tuple((shape // rate**num_upsamples for shape, rate in zip(output_shape[:-1], upsample_rates))) + (channels,)




class FlatToStructure:
    """Takes a flat input, applies a dense layer, then restructures to output_shape"""
    def __init__(
            self,
            output_shape: Tuple[int, ...],
            activation: str,
            normalization: str,
    ):
        self.input_shapes = output_shape
        self.dense = Dense(units=int(np.prod(output_shape)))
        self.activation = _activation_layer(activation)
        self.reshape = Reshape(output_shape)
        self.norm = _normalization_layer(normalization)

    def __call__(self, x: Tensor) -> Tensor:
        return self.reshape(self.norm(self.activation(self.dense(x))))


def _conv_layer_from_kind_and_dimension(
        dimension: int, conv_layer_type: str, conv_x: List[int], conv_y: List[int], conv_z: List[int],
) -> Tuple[Layer, List[Tuple[int, ...]]]:
    if dimension == 4 and conv_layer_type == 'conv':
        conv_layer = Conv3D
        kernel = zip(conv_x, conv_y, conv_z)
    elif dimension == 3 and conv_layer_type == 'conv':
        conv_layer = Conv2D
        kernel = zip(conv_x, conv_y)
    elif dimension == 2 and conv_layer_type == 'conv':
        conv_layer = Conv1D
        kernel = zip(conv_x)
    elif dimension == 3 and conv_layer_type == 'separable':
        conv_layer = SeparableConv2D
        kernel = zip(conv_x, conv_y)
    elif dimension == 2 and conv_layer_type == 'separable':
        conv_layer = SeparableConv1D
        kernel = zip(conv_x)
    elif dimension == 3 and conv_layer_type == 'depth':
        conv_layer = DepthwiseConv2D
        kernel = zip(conv_x, conv_y)
    else:
        raise ValueError(f'Unknown convolution type: {conv_layer_type} for dimension: {dimension}')
    return conv_layer, list(kernel)


def _upsampler(dimension, pool_x, pool_y, pool_z):
    if dimension == 4:
        return UpSampling3D(size=(pool_x, pool_y, pool_z))
    elif dimension == 3:
        return UpSampling2D(size=(pool_x, pool_y))
    elif dimension == 2:
        return UpSampling1D(size=pool_x)
    

    
def _one_by_n_kernel(dimension):
    return tuple([1] * (dimension - 1))


class DenseConvolutionalBlock:
    def __init__(
            self,
            *,
            dimension: int,
            block_size: int,
            conv_layer_type: str,
            filters: int,
            conv_x: List[int],
            conv_y: List[int],
            conv_z: List[int],
            activation: str,
            normalization: str,
            regularization: str,
            regularization_rate: float,
    ):
        conv_layer, kernels = _conv_layer_from_kind_and_dimension(dimension, conv_layer_type, conv_x, conv_y, conv_z)
        if isinstance(conv_layer, DepthwiseConv2D):
            self.conv_layers = [conv_layer(kernel_size=kernel, padding='same') for kernel in kernels]
        else:
            self.conv_layers = [conv_layer(filters=filters, kernel_size=kernel, padding='same') for kernel in kernels]
        self.activations = [_activation_layer(activation) for _ in range(block_size)]
        self.normalizations = [_normalization_layer(normalization) for _ in range(block_size)]
        self.regularizations = [_regularization_layer(dimension, regularization, regularization_rate) for _ in range(block_size)]
        print(f'Dense Block Convolutional Layers (num_filters, kernel_size): {list(zip([filters]*len(kernels), kernels))}')

    def __call__(self, x: Tensor) -> Tensor:
        dense_connections = [x]
        for i, (convolve, activate, normalize, regularize) in enumerate(
            zip(
                    self.conv_layers, self.activations, self.normalizations, self.regularizations,
            ),
        ):
            x = normalize(regularize(activate(convolve(x))))
            if i < len(self.conv_layers) - 1:  # output of block does not get concatenated to
                dense_connections.append(x)
                x = Concatenate()(dense_connections[:])  # [:] is necessary because of tf weirdness
        return x

    
class ConvDecoder2:
    def __init__(
            self,
            *,
            tensor_map_out: TensorMap,
            filters_per_dense_block: List[int],
            conv_layer_type: str,
            conv_x: List[int],
            conv_y: List[int],
            conv_z: List[int],
            block_size: int,
            activation: str,
            normalization: str,
            regularization: str,
            regularization_rate: float,
            upsample_x: int,
            upsample_y: int,
            upsample_z: int,
    ):
        dimension = tensor_map_out.axes()
        self.dense_blocks = [
            DenseConvolutionalBlock(
                dimension=tensor_map_out.axes(), conv_layer_type=conv_layer_type, filters=filters, conv_x=[x]*block_size,
                conv_y=[y]*block_size, conv_z=[z]*block_size, block_size=block_size, activation=activation, normalization=normalization,
                regularization=regularization, regularization_rate=regularization_rate,
            )
            for filters, x, y, z in zip(filters_per_dense_block, conv_x, conv_y, conv_z)
        ]
        conv_layer, _ = _conv_layer_from_kind_and_dimension(dimension, 'conv', conv_x, conv_y, conv_z)
        self.conv_label = conv_layer(tensor_map_out.shape[-1], _one_by_n_kernel(dimension), activation=tensor_map_out.activation, name=tensor_map_out.output_name())
        self.upsamples = [_upsampler(dimension, upsample_x, upsample_y, upsample_z) for _ in range(len(filters_per_dense_block) + 1)]
        print(f'Decode has: {list(enumerate(zip(self.dense_blocks, self.upsamples)))}')
    def __call__(self, x: Tensor) -> Tensor:
        for i, (dense_block, upsample) in enumerate(zip(self.dense_blocks, self.upsamples)):
            
            x = upsample(x)
            x = dense_block(x)
        return self.conv_label(x)
    
    

In [3]:
def make_paired_autoencoder_model(
    pairs: List[Tuple[TensorMap, TensorMap]],
    **kwargs
) -> Model:
    inputs = {tm: Input(shape=tm.shape, name=tm.input_name()) for tm in args.tensor_maps_in}
    original_outputs = {tm:1 for tm in args.tensor_maps_out}
    multimodal_activations = []
    desired_distance_tm = []
    my_metrics = {}
    outputs = []
    losses = []
    for left, right in pairs:
        args.tensor_maps_in = [left]
        left_model = make_multimodal_multitask_model(**args.__dict__)
        encode_left = make_hidden_layer_model(left_model, [left], args.hidden_layer)
        h_left = encode_left(inputs[left])
        
        args.tensor_maps_in = [right]
        right_model = make_multimodal_multitask_model(**args.__dict__)     
        encode_right = make_hidden_layer_model(right_model, [right], args.hidden_layer)
        h_right = encode_right(inputs[right])        
        
        tff = lambda tm, hd5, d: np.zeros((1,))
        tm0 = TensorMap(f'paired_{left.name}_{right.name}', shape=(1,), tensor_from_file=tff)
        desired_distance_tm.append(tm0)
        
        # Compute the L2 distance
        l2_layer = Lambda(lambda tensors: K.mean(K.square(tensors[0] - tensors[1]), axis=-1, keepdims=True), name=tm0.output_name())
        l2_distance = l2_layer([h_left, h_right])
        outputs.append(l2_distance)
        losses.append('binary_crossentropy')
        multimodal_activations.extend([h_left, h_right])
        
    multimodal_activation = Concatenate()(multimodal_activations)
    
    pre_decoder_shapes: Dict[TensorMap, Optional[Tuple[int, ...]]] = {}
    for tm in args.tensor_maps_out:
        shape = _calc_start_shape(num_upsamples=len(args.dense_blocks), output_shape=tm.shape, 
                                  upsample_rates=[args.pool_x, args.pool_y, args.pool_z], 
                                  channels=args.dense_blocks[-1])    
        
        restructure = FlatToStructure(output_shape=shape, activation=args.activation, 
                                      normalization=args.dense_normalize)
        
        decode = ConvDecoder2(
            tensor_map_out=tm,
            filters_per_dense_block=args.dense_blocks[::-1],
            conv_layer_type=args.conv_type,
            conv_x=args.conv_x,
            conv_y=args.conv_y,
            conv_z=args.conv_z,
            block_size=args.block_size,
            activation=args.activation,
            normalization=args.conv_normalize,
            regularization=args.conv_regularize,
            regularization_rate=args.conv_regularize_rate,
            upsample_x=args.pool_x,
            upsample_y=args.pool_y,
            upsample_z=args.pool_z,
        )
        
        outputs.append(decode(restructure(multimodal_activation)))
        losses.append(tm.loss)

    args.tensor_maps_out =  list(original_outputs.keys()) + desired_distance_tm
    args.tensor_maps_in = list(inputs.keys())
    
    opt = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
    #outputs.reverse()  # Make paired loss last
    #losses.reverse()
    m = Model(inputs=list(inputs.values()), outputs=outputs)
    m.compile(optimizer=opt, loss=losses)
    m.summary()
    
    if kwargs['model_layers'] is not None:
        m.load_weights(kwargs['model_layers'], by_name=True)
        print(f"Loaded model weights from:{kwargs['model_layers']}")
        
    return m

In [None]:
sys.argv = ['train', 
            '--tensors', '/mnt/disks/segmented-sax-lax/2020-07-07/', 
            '--input_tensors', 'lax_2ch_diastole_slice0_3d', 'lax_3ch_diastole_slice0_3d', 
            '--output_tensors', 'lax_2ch_diastole_slice0_3d', 'lax_3ch_diastole_slice0_3d',
            '--activation', 'swish',
            '--conv_layers', '32',
            '--conv_x', '3', '3', '3', '3',
            '--conv_y', '3', '3', '3', '3',
            '--conv_z', '3', '3', '3', '3', 
            '--dense_blocks', '32', '24',
            '--block_size', '3',
            '--dense_layers', '392',
            '--pool_x', '2',
            '--pool_y', '2',
            '--batch_size', '1',
            '--patience', '32',
            '--epochs', '24',
            '--learning_rate', '0.001',
            '--training_steps', '256',
            '--validation_steps', '30',
            '--test_steps', '2',
            '--num_workers', '4',
            '--inspect_model',
            '--tensormap_prefix', 'ml4cvd.tensormap.ukb.mri',
            '--id', 'lax_2ch_3ch_diastole_paired_autoencoder_relu']
args = parse_args()
pairs = [(args.tensor_maps_in[0], args.tensor_maps_in[1])]
overparameterized_model = make_paired_autoencoder_model(pairs, **args.__dict__)
generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)
train_model_from_generators(
        overparameterized_model, generate_train, generate_valid, args.training_steps, args.validation_steps, args.batch_size,
        args.epochs, args.patience, args.output_folder, args.id, args.inspect_model, args.inspect_show_labels,
)

2020-08-17 14:21:16,656 - logger:25 - INFO - Logging configuration was loaded. Log messages can be found at ./recipes_output/lax_2ch_3ch_diastole_paired_autoencoder_relu/log_2020-08-17_14-21_0.log.
2020-08-17 14:21:16,659 - arguments:414 - INFO - Command Line was: 
./scripts/tf.sh train --tensors /mnt/disks/segmented-sax-lax/2020-07-07/ --input_tensors lax_2ch_diastole_slice0_3d lax_3ch_diastole_slice0_3d --output_tensors lax_2ch_diastole_slice0_3d lax_3ch_diastole_slice0_3d --activation swish --conv_layers 32 --conv_x 3 3 3 3 --conv_y 3 3 3 3 --conv_z 3 3 3 3 --dense_blocks 32 24 --block_size 3 --dense_layers 392 --pool_x 2 --pool_y 2 --batch_size 1 --patience 32 --epochs 24 --learning_rate 0.001 --training_steps 256 --validation_steps 30 --test_steps 2 --num_workers 4 --inspect_model --tensormap_prefix ml4cvd.tensormap.ukb.mri --id lax_2ch_3ch_diastole_paired_autoencoder_relu

2020-08-17 14:21:16,672 - models:379 - INFO - Residual Block Convolutional Layers (num_filters, kernel_size)

2020-08-17 14:21:17,654 - models:379 - INFO - Residual Block Convolutional Layers (num_filters, kernel_size): [(32, (3, 3))]
2020-08-17 14:21:17,660 - models:414 - INFO - Dense Block Convolutional Layers (num_filters, kernel_size): [(32, (3, 3)), (32, (3, 3)), (32, (3, 3))]
2020-08-17 14:21:17,665 - models:414 - INFO - Dense Block Convolutional Layers (num_filters, kernel_size): [(24, (3, 3)), (24, (3, 3)), (24, (3, 3))]
2020-08-17 14:21:17,679 - models:414 - INFO - Dense Block Convolutional Layers (num_filters, kernel_size): [(24, (3, 3)), (24, (3, 3)), (24, (3, 3))]
2020-08-17 14:21:17,683 - models:414 - INFO - Dense Block Convolutional Layers (num_filters, kernel_size): [(32, (3, 3)), (32, (3, 3)), (32, (3, 3))]
2020-08-17 14:21:17,691 - models:414 - INFO - Dense Block Convolutional Layers (num_filters, kernel_size): [(24, (3, 3)), (24, (3, 3)), (24, (3, 3))]
2020-08-17 14:21:17,696 - models:414 - INFO - Dense Block Convolutional Layers (num_filters, kernel_size): [(32, (3, 3)), (32

Dense Block Convolutional Layers (num_filters, kernel_size): [(24, (3, 3)), (24, (3, 3)), (24, (3, 3))]
Dense Block Convolutional Layers (num_filters, kernel_size): [(32, (3, 3)), (32, (3, 3)), (32, (3, 3))]
Decode has: [(0, (<__main__.DenseConvolutionalBlock object at 0x7ff67033a198>, <tensorflow.python.keras.layers.convolutional.UpSampling2D object at 0x7ff670346358>)), (1, (<__main__.DenseConvolutionalBlock object at 0x7ff67033a9e8>, <tensorflow.python.keras.layers.convolutional.UpSampling2D object at 0x7ff670346438>))]
Dense Block Convolutional Layers (num_filters, kernel_size): [(24, (3, 3)), (24, (3, 3)), (24, (3, 3))]
Dense Block Convolutional Layers (num_filters, kernel_size): [(32, (3, 3)), (32, (3, 3)), (32, (3, 3))]
Decode has: [(0, (<__main__.DenseConvolutionalBlock object at 0x7ff67030e588>, <tensorflow.python.keras.layers.convolutional.UpSampling2D object at 0x7ff670296748>)), (1, (<__main__.DenseConvolutionalBlock object at 0x7ff67030edd8>, <tensorflow.python.keras.layer

2020-08-17 14:21:22,578 - tensor_generators:661 - INFO - Found 31871 train, 9242 validation, and 4569 testing tensors at: /mnt/disks/segmented-sax-lax/2020-07-07/
2020-08-17 14:21:22,935 - models:1316 - INFO - Saving architecture diagram to:./recipes_output/lax_2ch_3ch_diastole_paired_autoencoder_relu/architecture_graph_lax_2ch_3ch_diastole_paired_autoencoder_relu.png
2020-08-17 14:21:23,970 - tensor_generators:151 - INFO - Started 3 train workers with cache size 0.875GB.
2020-08-17 14:21:24,367 - tensor_generators:151 - INFO - Started 1 validation workers with cache size 0.875GB.
Train for 256 steps, validate for 1 steps
2020-08-17 14:21:50,301 - models:1254 - INFO - Spent:26.48 seconds training, Samples trained on:256 Per sample training speed:0.103 seconds.
2020-08-17 14:21:52,056 - models:1260 - INFO - Spent:1.75 seconds predicting, Samples inferred:256 Per sample inference speed:0.0548 seconds.
Train for 256 steps, validate for 30 steps
Epoch 1/24
Epoch 00001: val_loss improved fr

In [5]:
sys.argv = ['train', 
            '--tensors', '/mnt/disks/segmented-sax-lax/2020-07-07/', 
            '--input_tensors', 'lax_2ch_diastole_slice0_3d', 'lax_3ch_diastole_slice0_3d', 
            '--output_tensors',  'lax_2ch_diastole_slice0_3d', 'lax_3ch_diastole_slice0_3d', 'LVM',
            '--activation', 'swish',
            '--conv_layers', '24',
            '--conv_x', '3', '3', '3',
            '--conv_y', '3', '3', '3',
            '--conv_z', '3', '3', '3',
            '--dense_blocks', '24',
            '--block_size', '4',
            '--dense_layers', '92',
            '--pool_x', '2',
            '--pool_y', '2',
            '--batch_size', '2',
            '--patience', '32',
            '--epochs', '292',
            '--learning_rate', '0.001',
            '--training_steps', '256',
            '--validation_steps', '30',
            '--test_steps', '2',
            '--num_workers', '4',
            '--inspect_model',
            '--model_file', './recipes_output/lax_2ch_3ch_diastole_paired_autoencoder/lax_2ch_3ch_diastole_paired_autoencoder.h5',
            '--tensormap_prefix', 'ml4cvd.tensormap.ukb.mri',
            '--id', 'lax_2ch_3ch_diastole_paired_autoencoder']
args = parse_args()
plot_predictions(args)

2020-08-17 13:52:58,487 - logger:25 - INFO - Logging configuration was loaded. Log messages can be found at ./recipes_output/lax_2ch_3ch_diastole_paired_autoencoder/log_2020-08-17_13-52_0.log.
2020-08-17 13:52:58,489 - arguments:414 - INFO - Command Line was: 
./scripts/tf.sh train --tensors /mnt/disks/segmented-sax-lax/2020-07-07/ --input_tensors lax_2ch_diastole_slice0_3d lax_3ch_diastole_slice0_3d --output_tensors lax_2ch_diastole_slice0_3d lax_3ch_diastole_slice0_3d LVM --activation swish --conv_layers 24 --conv_x 3 3 3 --conv_y 3 3 3 --conv_z 3 3 3 --dense_blocks 24 --block_size 4 --dense_layers 92 --pool_x 2 --pool_y 2 --batch_size 2 --patience 32 --epochs 292 --learning_rate 0.001 --training_steps 256 --validation_steps 30 --test_steps 2 --num_workers 4 --inspect_model --model_file ./recipes_output/lax_2ch_3ch_diastole_paired_autoencoder/lax_2ch_3ch_diastole_paired_autoencoder.h5 --tensormap_prefix ml4cvd.tensormap.ukb.mri --id lax_2ch_3ch_diastole_paired_autoencoder

2020-08-17

2020-08-17 13:53:03,527 - tensor_generators:151 - INFO - Started 4 test workers with cache size 0.0GB.
2020-08-17 13:53:04,668 - tensor_generators:504 - INFO - Made a big batch of tensors with key:input_lax_2ch_diastole_slice0_3d_continuous and shape:(4, 200, 160, 1).
2020-08-17 13:53:04,669 - tensor_generators:504 - INFO - Made a big batch of tensors with key:input_lax_3ch_diastole_slice0_3d_continuous and shape:(4, 200, 160, 1).
2020-08-17 13:53:04,675 - tensor_generators:504 - INFO - Made a big batch of tensors with key:output_lax_3ch_diastole_slice0_3d_continuous and shape:(4, 200, 160, 1).
2020-08-17 13:53:04,676 - tensor_generators:504 - INFO - Made a big batch of tensors with key:output_lax_2ch_diastole_slice0_3d_continuous and shape:(4, 200, 160, 1).
2020-08-17 13:53:04,677 - tensor_generators:504 - INFO - Made a big batch of tensors with key:output_LVM_continuous and shape:(4, 1).
2020-08-17 13:53:05,827 - explorations:57 - INFO - Write predictions as PNGs TensorMap:lax_3ch_di