In [None]:
# Relevant libraries and functions
from __future__ import print_function
import random
import matplotlib.pyplot as plt
import scipy.ndimage
import numpy as np, h5py
import os, time, sys
import tensorflow as tf
import keras
from keras.models import Model
from keras.layers import BatchNormalization, Convolution2D, Input, SpatialDropout2D, UpSampling2D, MaxPooling2D, concatenate
from keras.layers.core import Activation, Layer
from keras.layers import Dense, Dropout, Conv1D, Input, Conv2D, add, Conv3D, Reshape
from keras.callbacks import History, EarlyStopping, ModelCheckpoint, CSVLogger
from itertools import cycle
from sklearn import metrics
from keras.optimizers import RMSprop
from keras.utils import np_utils
from keras.backend.tensorflow_backend import set_session
from keras.layers.convolutional import Convolution2D, MaxPooling2D, SeparableConv2D, Conv2DTranspose

In [None]:
f_data = '.../trainData' # Directory with trainging data
stacks = os.listdir(f_data)
numS = int(len(stacks))

nTG = 160 # Number of time-points
xX = 28
yY = 28

tpsfD = np.ndarray(
        (numS, int(nTG), int(xX), int(yY), int(1)), dtype=np.float32
        )
t1 = np.ndarray(
        (numS, int(xX), int(yY), int(1)), dtype=np.float32
        )
t2 = np.ndarray(
        (numS, int(xX), int(yY), int(1)), dtype=np.float32
        )
tR = np.ndarray(
        (numS, int(xX), int(yY), int(1)), dtype=np.float32
        )

i = 0;
for d in stacks:
    # Save values to respective mapping
    f = h5py.File(os.path.join(f_data,d),'r') 
    tpsfD[i,:,:,:,0] = f.get('sigD')
    f = h5py.File(os.path.join(f_data,d),'r') 
    t1[i,:,:,0] = f.get('t1')
    f = h5py.File(os.path.join(f_data,d),'r') 
    t2[i,:,:,0] = f.get('t2')
    f = h5py.File(os.path.join(f_data,d),'r') 
    tR[i,:,:,0] = f.get('rT')
    i = i + 1
    
tpsfD =  np.moveaxis(tpsfD, 1, -2)

In [None]:
# Ensure TPSF voxel shape is correct dimensionality (# samples, x, y, time-points, 1)
tpsfD.shape

In [None]:
# Relevant resblock functions (Keras API)
def resblock_2D(num_filters, size_filter, x):
    Fx = Conv2D(num_filters, size_filter, padding='same', activation=None)(x)
    Fx = Activation('relu')(Fx)
    Fx = Conv2D(num_filters, size_filter, padding='same', activation=None)(Fx)
    output = add([Fx, x])
    output = Activation('relu')(output)
    return output

def resblock_2D_BN(num_filters, size_filter, x):
    Fx = Conv2D(num_filters, size_filter, padding='same', activation=None)(x)
    Fx = BatchNormalization()(Fx)
    Fx = Activation('relu')(Fx)
    Fx = Conv2D(num_filters, size_filter, padding='same', activation=None)(Fx)
    Fx = BatchNormalization()(Fx)
    output = add([Fx, x])
    #output = BatchNormalization()(output)
    output = Activation('relu')(output)
    return output

def resblock_3D_BN(num_filters, size_filter, x):
    Fx = Conv3D(num_filters, size_filter, padding='same', activation=None)(x)
    Fx = BatchNormalization()(Fx)
    Fx = Activation('relu')(Fx)
    Fx = Conv3D(num_filters, size_filter, padding='same', activation=None)(Fx)
    Fx = BatchNormalization()(Fx)
    output = add([Fx, x])
    #output = BatchNormalization()(output)
    output = Activation('relu')(output)
    return output

def xCeptionblock_2D_BN(num_filters, size_filter, x):
    Fx = SeparableConv2D(num_filters, size_filter, padding='same', activation=None)(x)
    Fx = BatchNormalization()(Fx)
    Fx = Activation('relu')(Fx)
    Fx = SeparableConv2D(num_filters, size_filter, padding='same', activation=None)(Fx)
    Fx = BatchNormalization()(Fx)
    output = add([Fx, x])
    output = Activation('relu')(output)
    return output

In [None]:
modelD = None
xX = 28;
yY = 28;

t_data = Input(shape=(xX, yY, 160,1))
tpsf = t_data

# # # # # # # # 3D-Model # # # # # # # #

tpsf = Conv3D(50,kernel_size=(1,1,10),strides=(1,1,5), padding='same', activation=None, data_format="channels_last")(tpsf)
tpsf = BatchNormalization()(tpsf)
tpsf = Activation('relu')(tpsf)
tpsf = resblock_3D_BN(50, (1,1,5), tpsf)
tpsf = Reshape((xX,yY,1600))(tpsf)
tpsf = Conv2D(256, 1, padding='same', activation=None, data_format="channels_last")(tpsf)
tpsf = BatchNormalization()(tpsf)
tpsf = Activation('relu')(tpsf)
tpsf = Conv2D(256, 1, padding='same', activation=None, data_format="channels_last")(tpsf)
tpsf = BatchNormalization()(tpsf)
tpsf = Activation('relu')(tpsf)
tpsf = resblock_2D_BN(256, 1, tpsf)
tpsf = resblock_2D_BN(256, 1, tpsf)

# Short-lifetime branch
imgT1 = Conv2D(64, 1, padding='same', activation=None)(tpsf)
imgT1 = BatchNormalization()(imgT1)
imgT1 = Activation('relu')(imgT1)
imgT1 = Conv2D(32, 1, padding='same', activation=None)(imgT1)
imgT1 = BatchNormalization()(imgT1)
imgT1 = Activation('relu')(imgT1)
imgT1 = Conv2D(1, 1, padding='same', activation=None)(imgT1)
imgT1 = Activation('relu')(imgT1)

# Long-lifetime branch
imgT2 = Conv2D(64, 1, padding='same', activation=None)(tpsf)
imgT2 = BatchNormalization()(imgT2)
imgT2 = Activation('relu')(imgT2)
imgT2 = Conv2D(32, 1, padding='same', activation=None)(imgT2)
imgT2 = BatchNormalization()(imgT2)
imgT2 = Activation('relu')(imgT2)
imgT2 = Conv2D(1, 1, padding='same', activation=None)(imgT2)
imgT2 = Activation('relu')(imgT2)

# Amplitude-Ratio branch
imgTR = Conv2D(64, 1, padding='same', activation=None)(tpsf)
imgTR = BatchNormalization()(imgTR)
imgTR = Activation('relu')(imgTR)
imgTR = Conv2D(32, 1, padding='same', activation=None)(imgTR)
imgTR = BatchNormalization()(imgTR)
imgTR = Activation('relu')(imgTR)
imgTR = Conv2D(1, 1, padding='same', activation=None)(imgTR)
imgTR = Activation('relu')(imgTR)

modelD = Model(inputs=[t_data], outputs=[imgT1,imgT2, imgTR])
rmsprop = RMSprop(lr=1e-5)

modelD.compile(loss='mse',
              optimizer=rmsprop,
              metrics=['mae'])

In [None]:
# Setting patience (patience = 15 recommended)
earlyStopping = EarlyStopping(monitor='val_loss', 
                              patience = 15, 
                              verbose = 0,
                              mode = 'auto')

fN = 'testName' # Assign some name for weights and training/validation loss curves here

# Save loss curve (mse) and MAE information over all trained epochs. (monitor = '' can be changed to focus on other tau parameters)
modelCheckPoint = ModelCheckpoint(filepath=fN+'.h5', 
                                  monitor='val_loss', 
                                  save_best_only=True, 
                                  verbose=0)
# Train network (80/20 train/validation split, batch_size=20 recommended, nb_epoch may vary based on application)
history = History()
csv_logger = CSVLogger(fN+'.log')
history = modelD.fit([tpsfD], [t1,t2,tR],
          validation_split=0.2,
          batch_size=20, nb_epoch=500, verbose=1, shuffle=True, callbacks=[earlyStopping,csv_logger,modelCheckPoint])

In [None]:
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# Post-training: load "best" trained weights (obtained through patience - lowest value of loss)
# THIS CAN BE ANY WEIGHT FILE, AS LONG AS THE NETWORK ARCHITECTURE MATCHES THE ONE USED!
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

modelD.load_weights(fN+'.h5')

In [None]:
# Upload test data and use 3D-CNN for inference

t_data = '.../testData' # directory with test data
stacksT = os.listdir(f_data)
numT = int(len(stacks))

nTG = 160
xX = 28
yY = 28

tpsfT = np.ndarray(
        (numT, int(nTG), int(xX), int(yY), int(1)), dtype=np.float32
        )
t1T = np.ndarray(
        (numT, int(xX), int(yY), int(1)), dtype=np.float32
        )
t2T = np.ndarray(
        (numT, int(xX), int(yY), int(1)), dtype=np.float32
        )
tRT = np.ndarray(
        (numT, int(xX), int(yY), int(1)), dtype=np.float32
        )

i = 0;
for d in stacksT:
    # Save values to respective mapping
    f = h5py.File(os.path.join(f_data,d),'r') 
    tpsfT[i,:,:,:,0] = f.get('sigD')
    f = h5py.File(os.path.join(f_data,d),'r') 
    t1T[i,:,:,0] = f.get('t1')
    f = h5py.File(os.path.join(f_data,d),'r') 
    t2T[i,:,:,0] = f.get('t2')
    f = h5py.File(os.path.join(f_data,d),'r') 
    tRT[i,:,:,0] = f.get('rT')
    i = i + 1
    
tpsfT =  np.moveaxis(tpsfT, 1, -2)
# tpsfT = np.moveaxis(tpsfT, 1, 2)

In [None]:
# Perform inference on test data with trained model
testV = modelD.predict(tpsfT)
t1P = testV[0] # Predicted t1 values
t2P = testV[1] # Predicted t2 values
tRP = testV[2] # Predicted AR values

# Visualize example
n = 2 # Number to illustrate w/ matplotlib
fig = plt.figure(figsize=(10,10))
ax1 = fig.add_subplot(3,2,1)
# Predicted tau1
ax1.imshow(t1P[n,:,:,0], interpolation='nearest', vmin=.2, vmax=.6)
ax2 = fig.add_subplot(3,2,2)
# G.T. tau1
ax2.imshow(t1T[n,:,:,0], interpolation='nearest', vmin=.2, vmax=.6)
ax3 = fig.add_subplot(3,2,3)
# Predicted tau2
ax3.imshow(t2P[n,:,:,0], interpolation='nearest', vmin=0.8, vmax=1.5)
# G.T. tau2
ax4 = fig.add_subplot(3,2,4)
ax4.imshow(t2T[n,:,:,0], interpolation='nearest', vmin=.8, vmax=1.5)
# Predicted amplitude ratio
ax5 = fig.add_subplot(3,2,5)
ax5.imshow(tRP[n,:,:,0], interpolation='nearest', vmin=0, vmax=1)
# G.T. amplitude ratio
ax6 = fig.add_subplot(3,2,6)
ax6.imshow(tRT[n,:,:,0], interpolation='nearest', vmin=0, vmax=1)