In [None]:
# Imports
import os
import sys
import pickle
import random
from typing import List, Dict, Callable
from collections import defaultdict, Counter

import csv
import gzip
import h5py
import shutil
import zipfile
import pydicom
import numpy as np


# Keras imports
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import History
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.utils import model_to_dot
from tensorflow.keras.layers import LeakyReLU, PReLU, ELU, ThresholdedReLU, Lambda, Reshape, LayerNormalization
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, Callback
from tensorflow.keras.layers import SpatialDropout1D, SpatialDropout2D, SpatialDropout3D, add, concatenate
from tensorflow.keras.layers import Input, Dense, Dropout, BatchNormalization, Activation, Flatten, LSTM, RepeatVector
from tensorflow.keras.layers import Conv1D, Conv2D, Conv3D, UpSampling1D, UpSampling2D, UpSampling3D, MaxPooling1D
from tensorflow.keras.layers import MaxPooling2D, MaxPooling3D, AveragePooling1D, AveragePooling2D, AveragePooling3D, Layer
from tensorflow.keras.layers import SeparableConv1D, SeparableConv2D, DepthwiseConv2D


%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import gridspec

from ml4cvd.defines import StorageType
from ml4cvd.arguments import parse_args, TMAPS, _get_tmap
from ml4cvd.TensorMap import TensorMap, Interpretation
from ml4cvd.tensor_generators import test_train_valid_tensor_generators
from ml4cvd.models import train_model_from_generators, make_multimodal_multitask_model, _inspect_model, train_model_from_generators, make_hidden_layer_model
from ml4cvd.recipes import test_multimodal_multitask, train_multimodal_multitask, saliency_maps

# Constants
HD5_FOLDER = '/mnt/disks/brains-all-together/2020-02-11/'
MODEL_FOLDER = './models/'

In [None]:
sys.argv = ['train', 
            '--tensors', HD5_FOLDER, 
            '--input_tensors', 't2_20_slices_1',
            '--output_tensors', 'sex',
            '--training_steps', '96',
            '--validation_steps', '24',
            '--test_steps', '24',
            '--batch_size', '2',
            '--id', 't2_20_slices_1_slice_share',
            '--inspect_model',
           ]

args = parse_args()

In [None]:
slice_axis = -1
volume_tm = args.tensor_maps_in[0]
slices = volume_tm.shape[slice_axis]
slice_shape = list(volume_tm.shape)
del slice_shape[slice_axis]
slice_tm = TensorMap(f'slice_{volume_tm.input_name()}', shape=slice_shape)
print(f'Slice name: {slice_tm.name} slice shape: {slice_shape} slices: {slices} original shape: {volume_tm.shape}')
args.tensor_maps_in = [slice_tm]
slice_model = make_multimodal_multitask_model(**args.__dict__)
embed_slice_model = make_hidden_layer_model(slice_model, args.tensor_maps_in, 'embed')
embed_slice_model.summary()

in_volume = Input(shape=volume_tm.shape, name=volume_tm.input_name())
embeddings = []
for i in range(slices):
    if slice_axis == -3 or volume_tm.axes() - slice_axis == 2:
        embeddings.append(embed_slice_model(in_volume[..., i, :, :]))
    elif slice_axis == -2 or volume_tm.axes() - slice_axis == 1:
        embeddings.append(embed_slice_model(in_volume[..., i, :]))
    elif slice_axis == -1 or volume_tm.axes() - slice_axis == 0:
        embeddings.append(embed_slice_model(in_volume[..., i]))
    else:
        raise ValueError(f'Can not handle slice axis {slice_axis} with original shape {volume_tm.shape}')
multimodal_activation = concatenate(embeddings, axis=-1)
for units in args.dense_layers:
    multimodal_activation = Dense(units=units, activation=args.activation)(multimodal_activation)

# build decoders
losses = []
my_metrics = {}
loss_weights = []
output_predictions = {}
tensor_maps_out = args.tensor_maps_out
output_tensor_maps_to_process = tensor_maps_out.copy()
while len(output_tensor_maps_to_process) > 0:
    tm = output_tensor_maps_to_process.pop(0)
    losses.append(tm.loss)
    loss_weights.append(tm.loss_weight)
    my_metrics[tm.output_name()] = tm.metrics
    if tm.is_categorical():
        output_predictions[tm] = Dense(units=tm.shape[0], activation='softmax', name=tm.output_name())(multimodal_activation)
    elif tm.axes() == 1:
        output_predictions[tm] = Dense(units=tm.shape[0], activation=tm.activation, name=tm.output_name())(multimodal_activation)

m = Model(inputs=[in_volume], outputs=[output_predictions[tm] for tm in tensor_maps_out])
m.summary()    

In [None]:
args = parse_args()
generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)

In [None]:
opt = Adam(lr=0.00001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
m.compile(optimizer=opt, loss=losses, loss_weights=loss_weights, metrics=my_metrics)
train_model_from_generators(m, generate_train, generate_valid, args.training_steps, args.validation_steps, 
                            args.batch_size, args.epochs, args.patience, args.output_folder, args.id, 
                            args.inspect_model, args.inspect_show_labels)