In [1]:
from keras.models import Model
from keras.layers import Conv2D, Conv2DTranspose, add, concatenate, multiply, Activation, Input, SpatialDropout2D, BatchNormalization

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
import keras
print(keras.__version__)

2.3.1


In [3]:
import tensorflow
print(tensorflow.__version__)

1.14.0


In [4]:
def ConvBnElu(inp, filters, kernel_size=3, strides=1):
    
    x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding="same", kernel_initializer="he_uniform", use_bias=False)(inp)
    x = BatchNormalization()(x)
    x = Activation("elu")(x)
    return x

In [5]:
# #test
# t = tf.Variable(tf.random.normal((1,16,16,3)))
# assert ConvBnElu(t, filters=16).get_shape().as_list() == [1,16,16,16]
# assert ConvBnElu(t, filters=8, strides=2).get_shape().as_list() == [1,8,8,8]

In [6]:
def deconv(inp):
    '''Deconv upsampling of x. Doubles x and y dimension and maintains z.
    '''
    num_filters = inp.get_shape().as_list()[-1]
    
    x = Conv2DTranspose(filters=num_filters, kernel_size=4, strides=2, padding="same", 
                        use_bias=False, kernel_initializer="he_uniform")(inp)
    x = BatchNormalization()(x)
    x = Activation("elu")(x)
    
    return x

In [7]:
# #test
# t = tf.Variable(tf.random.normal((1,16,16,3)))
# assert deconv(t).get_shape().as_list() == [1,32,32,3]

In [8]:
def attention_gate(inp_1, inp_2, n_intermediate_filters):
    '''Attention gate. Compresses both inputs to n_intermediate_filters filters before processing.
    '''    
    inp_1_conv = Conv2D(n_intermediate_filters, kernel_size=1, strides=1, padding="same", kernel_initializer="he_uniform")(inp_1)
    inp_2_conv = Conv2D(n_intermediate_filters, kernel_size=1, strides=1, padding="same", kernel_initializer="he_uniform")(inp_2)
    
    f = Activation("relu")(add([inp_1_conv, inp_2_conv]))
    g = Conv2D(filters=1, kernel_size=1, strides=1, padding="same", kernel_initializer="he_uniform")(f)
    h = Activation("sigmoid")(g)
    return multiply([inp_1, h])

In [9]:
# #test
# t1 = tf.Variable(tf.random.normal((1,16,16,64)))
# t2 = tf.Variable(tf.random.normal((1,16,16,64)))


# assert attention_gate(t1, t2, 16).get_shape().as_list() == [1,16,16,64]

In [10]:
def attention_concat_upsample(across, below):
    '''Upsamples below and concatenates with an attention gated version of across. Below needs to be 1/2 the size of across.
    '''
    below_filters = below.get_shape().as_list()[-1]
    below_upsampled = deconv(below)
    attention_across = attention_gate(across, below_upsampled, below_filters)
    return concatenate([below_upsampled, attention_across])

In [11]:
# #test
# acr = tf.Variable(tf.random.normal((1,16,16,64)))
# bel = tf.Variable(tf.random.normal((1,8,8,32)))

# assert attention_concat_upsample(acr, bel).get_shape().as_list() == [1,16,16,96]

In [12]:
    def RR_block(inp, out_filters, dropout=0.2):

        initial = skip = ConvBnElu(inp, out_filters, kernel_size=1)

        c1 = ConvBnElu(initial, out_filters)
        c1 = SpatialDropout2D(dropout)(c1)
        c2 = ConvBnElu(add([initial, c1]), out_filters)
        c2 = SpatialDropout2D(dropout)(c2)
        c3 = ConvBnElu(c2, out_filters)
        c3 = SpatialDropout2D(dropout)(c3)
        c4 = ConvBnElu(add([c2,c3]), out_filters)

        return add([skip, c4])

In [13]:
def AttentionR2Unet(input_shape=(256,256,1), output_classes=1, depth=4, n_filters_init=64, dropout_enc=0.2, dropout_dec=0.2):
    
    inputs = x = Input(input_shape)
    skips = []
    features = n_filters_init
    
    #encoder
    for i in range(depth):
        x = RR_block(x, features, dropout=dropout_enc)
        skips.append(x)
        x = ConvBnElu(x, features, kernel_size=4, strides=2)
        features *=2
    
    #bottleneck
    x = RR_block(x, features)
    
    #decoder
    for i in reversed(range(depth)):
        features //= 2
        x = attention_concat_upsample(across=skips[i], below=x)
        x = RR_block(x, features, dropout=dropout_dec)

    #head
    final_conv = Conv2D(output_classes, kernel_size=1, strides=1, padding="same", 
                        kernel_initializer="he_uniform", use_bias=False)(x)
    final_bn = BatchNormalization()(final_conv)
    sigm = Activation("sigmoid")(final_bn)
    return Model(inputs, sigm)

In [14]:
u = AttentionR2Unet()

In [19]:
u.summary(line_length=150)

Model: "model_1"
______________________________________________________________________________________________________________________________________________________
Layer (type)                                     Output Shape                     Param #           Connected to                                      
input_1 (InputLayer)                             (None, 256, 256, 1)              0                                                                   
______________________________________________________________________________________________________________________________________________________
conv2d_1 (Conv2D)                                (None, 256, 256, 64)             64                input_1[0][0]                                     
______________________________________________________________________________________________________________________________________________________
batch_normalization_1 (BatchNormalization)       (None, 256, 256, 64)        