In [1]:
from tensorflow.keras import models, layers

In [6]:
def conv_3d_block(input_layer, n_filter, ksize, padding='same', activation='relu', name='block'):
    output = layers.Conv3D(n_filter, ksize, padding=padding, name=name+'_conv')(input_layer)
    if activation=='leaky_relu':
        output = layers.LeakyReLU(0.01, name=name+'_act')(output)
    else:
        output = layers.Activation('relu', name=name+'_act')(output)
    return output

def upconv_3d_block(input_layer, n_filter, ksize, strides, padding='same', activation='relu', name='block'):
    output = layers.Conv3DTranspose(n_filter, ksize, strides, padding=padding, name=name+'_conv')(input_layer)
    if activation=='leaky_relu':
        output = layers.LeakyReLU(0.01, name=name+'_act')(output)
    else:
        output = layers.Activation('relu', name=name+'_act')(output)
    return output

def SR3D(input_shape=(None, None, None, 1), layer_activation='relu', last_activation='linear', name='3D_SR'):
    
    '''
    input_shape : (Cor, Sag, Axi)
    '''
    
    input_layer = layers.Input(input_shape, name=name+'_input')
    
    en1 = conv_3d_block(input_layer, 64, 3, activation=layer_activation, name=name+'_en1')
    
    for_concat = layers.UpSampling3D(size=(1,1,6), name=name+'_up_en1')(en1)
    
    en2 = conv_3d_block(en1, 128, 3, activation=layer_activation, name=name+'_en2')
    
    en3 = conv_3d_block(en2, 256, 3, activation=layer_activation, name=name+'_en3')
    
    en4 = conv_3d_block(en3, 512, 3, activation=layer_activation, name=name+'_en4')
    
    up = upconv_3d_block(en4, 64, (1, 1, 6), strides=(1, 1, 6), padding='same', 
                         activation=layer_activation, name=name+'_up')
    concat = layers.Concatenate(axis=-1, name=name+'_concat')([for_concat, up])
    
    output = conv_3d_block(concat, 1, 1, activation=last_activation, name=name+'_output')
    
    return models.Model(inputs=input_layer, outputs=output, name=name)