In [None]:
def squeeze_excite_block(input, ratio=8):

    init = input
    channel_axis = 1 if K.image_data_format() == "channels_first" else -1
    filters = init._keras_shape[channel_axis]
    se_shape = (1, 1, 1, filters)

    se = GlobalAveragePooling3D()(init)
    se = Reshape(se_shape)(se)
    se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
    se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)

    if K.image_data_format() == 'channels_first':
        se = Permute((4, 1, 2, 3))(se)

    x = multiply([init, se])
    return x

def Conv3d_BN(x, nb_filter, kernel_size, strides=1, padding='same', name=None):
    x = Conv3D(nb_filter, kernel_size, padding=padding, data_format='channels_last', strides=strides,
               activation='relu')(x)
    x = BatchNormalization()(x)
    return x

def identity_Block(inpt, nb_filter, kernel_size, strides=1, with_conv_shortcut=False):
    x = Conv3d_BN(inpt, nb_filter=nb_filter, kernel_size=kernel_size, strides=strides, padding='same')
    x = Conv3d_BN(x, nb_filter=nb_filter, kernel_size=kernel_size, padding='same')
    x = squeeze_excite_block(x)
    if with_conv_shortcut:
        shortcut = Conv3d_BN(inpt, nb_filter=nb_filter, strides=strides,
                             kernel_size=kernel_size)
        x = Dropout(0.2)(x)
        x = add([x, shortcut])
        return x
    else:
        x = add([x, inpt])
        return x
    
def SEResNet():
    inputs = Input((64, 64, 64, 1))

    # conv1
    c1 = Conv3d_BN(inputs, nb_filter=8, kernel_size=(6, 6, 6), strides=1, padding='same')
    c1 = MaxPooling3D(pool_size=(2, 2, 2), strides=2, data_format='channels_last')(c1)

    # conv2_x
    c2 = identity_Block(c1, nb_filter=16, kernel_size=(2, 2, 2), strides=2, with_conv_shortcut=True)
    c2 = identity_Block(c2, nb_filter=16, kernel_size=(2, 2, 2))

    # conv3_x
    c3 = identity_Block(c2, nb_filter=32, kernel_size=(2, 2, 2), strides=2, with_conv_shortcut=True)
    c3 = identity_Block(c3, nb_filter=32, kernel_size=(2, 2, 2))

    # conv4_x
    c4 = identity_Block(c3, nb_filter=64, kernel_size=(3, 3, 3), strides=2, with_conv_shortcut=True)
    c4 = identity_Block(c4, nb_filter=64, kernel_size=(3, 3, 3))
    c4 = identity_Block(c4, nb_filter=64, kernel_size=(3, 3, 3))
    
    x = GlobalAveragePooling3D(data_format='channels_last')(c4)
    
    out = Dense(1, activation='sigmoid', name = 'd1_1') (x)
    
    model = Model(inputs=[inputs], outputs=[out])
    
    return model