In [1]:
# Model set up libraries 
from tensorflow import keras
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Activation, ReLU
from tensorflow.keras.layers import BatchNormalization, Conv2DTranspose, Concatenate
from tensorflow.keras.models import Model, Sequential 
from tensorflow.keras.utils import plot_model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import layers

from tensorflow.keras.applications import VGG16

from sklearn.model_selection import train_test_split

In [8]:
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras import backend


In [16]:
def convolution_block(inputs, num_filters):
        ''' simple UNET convolution block with BatchNormalisation '''

        # convolution layer 1 of the block
        x = Conv2D(num_filters, (3,3), padding='same')(inputs)  # padding='same' to avoid cut-down with conv
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

        # convolution layer 2 of the block
        x = Conv2D(num_filters, (3,3), padding='same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

        # max pooling not used here as just the bridge

        return x

In [17]:
def decoder_block(inputs, skip_tensor, num_filters):
        ''' decoder block for UNET '''
        # adds in the skips with concatenate
        x = Conv2DTranspose(num_filters, (2,2), strides=2, padding='same')(inputs) # stride important here to up-sample
        x = Concatenate()([x, skip_tensor])     # bringing in skip layer
        x = convolution_block(x, num_filters)

        return x

In [18]:
def build_resnet50(input_shape):
    
    inputs = Input(input_shape)
    
    resnet = ResNet50(
        include_top=False,
        weights='imagenet',
        input_tensor=inputs   
    
    )
    
    resnet.trainable = False
    
    ### Get skip layers
    skip1 = resnet.get_layer('input_1').output #  512 x 512, 3 filters in resnet50
    skip2 = resnet.get_layer('conv1_relu').output #  256 x 256, 64 filters in resnet50
    skip3 = resnet.get_layer('conv2_block3_out').output #   128 x 128, 256 filters in resnet50
    skip4 = resnet.get_layer('conv3_block4_out').output #   64 x 64, 512 filters in resnet50

    print(skip1.shape, skip2.shape, skip3.shape, skip4.shape)
    
    ## Bridge 
    bridge = resnet.get_layer('conv4_block6_out').output # 16 x 16, with 512 filters in vgg16
    print(bridge.shape)

    
    ### Decoder
    d1 = decoder_block(bridge, skip4, 512) #  512 filters, as per the bridge
    d2 = decoder_block(d1, skip3, 256) #  256 filters
    d3 = decoder_block(d2, skip2, 128) #  128 filters
    d4 = decoder_block(d3, skip1, 64)  #   64 filters

    ### Output
    outputs = Conv2D(1, (1,1), padding='same', activation='sigmoid')(d4)

    model = Model(inputs, outputs, name='first_RESNET50_UNET')

    return model


In [19]:
backend.clear_session()

res_model = build_resnet50((512,512,3))
# res_model.summary()

(None, 512, 512, 3) (None, 256, 256, 64) (None, 128, 128, 256) (None, 64, 64, 512)
(None, 32, 32, 1024)


In [None]:
layers = ['input_2', 'conv1_pad', 'conv1_conv', 'conv1_bn', 'conv1_relu', 'pool1_pad', 'pool1_pool', 'conv2_block1_1_conv', 'conv2_block1_1_bn', 'conv2_block1_1_relu', 'conv2_block1_2_conv', 'conv2_block1_2_bn', 'conv2_block1_2_relu', 'conv2_block1_0_conv', 'conv2_block1_3_conv', 'conv2_block1_0_bn', 'conv2_block1_3_bn', 'conv2_block1_add', 'conv2_block1_out', 'conv2_block2_1_conv', 'conv2_block2_1_bn', 'conv2_block2_1_relu', 'conv2_block2_2_conv', 'conv2_block2_2_bn', 'conv2_block2_2_relu', 'conv2_block2_3_conv', 'conv2_block2_3_bn', 'conv2_block2_add', 'conv2_block2_out', 'conv2_block3_1_conv', 'conv2_block3_1_bn', 'conv2_block3_1_relu', 'conv2_block3_2_conv', 'conv2_block3_2_bn', 'conv2_block3_2_relu', 'conv2_block3_3_conv', 'conv2_block3_3_bn', 'conv2_block3_add', 'conv2_block3_out', 'conv3_block1_1_conv', 'conv3_block1_1_bn', 'conv3_block1_1_relu', 'conv3_block1_2_conv', 'conv3_block1_2_bn', 'conv3_block1_2_relu', 'conv3_block1_0_conv', 'conv3_block1_3_conv', 'conv3_block1_0_bn', 'conv3_block1_3_bn', 'conv3_block1_add', 'conv3_block1_out', 'conv3_block2_1_conv', 'conv3_block2_1_bn', 'conv3_block2_1_relu', 'conv3_block2_2_conv', 'conv3_block2_2_bn', 'conv3_block2_2_relu', 'conv3_block2_3_conv', 'conv3_block2_3_bn', 'conv3_block2_add', 'conv3_block2_out', 'conv3_block3_1_conv', 'conv3_block3_1_bn', 'conv3_block3_1_relu', 'conv3_block3_2_conv', 'conv3_block3_2_bn', 'conv3_block3_2_relu', 'conv3_block3_3_conv', 'conv3_block3_3_bn', 'conv3_block3_add', 'conv3_block3_out', 'conv3_block4_1_conv', 'conv3_block4_1_bn', 'conv3_block4_1_relu', 'conv3_block4_2_conv', 'conv3_block4_2_bn', 'conv3_block4_2_relu', 'conv3_block4_3_conv', 'conv3_block4_3_bn', 'conv3_block4_add', 'conv3_block4_out', 'conv4_block1_1_conv', 'conv4_block1_1_bn', 'conv4_block1_1_relu', 'conv4_block1_2_conv', 'conv4_block1_2_bn', 'conv4_block1_2_relu', 'conv4_block1_0_conv', 'conv4_block1_3_conv', 'conv4_block1_0_bn', 'conv4_block1_3_bn', 'conv4_block1_add', 'conv4_block1_out', 'conv4_block2_1_conv', 'conv4_block2_1_bn', 'conv4_block2_1_relu', 'conv4_block2_2_conv', 'conv4_block2_2_bn', 'conv4_block2_2_relu', 'conv4_block2_3_conv', 'conv4_block2_3_bn', 'conv4_block2_add', 'conv4_block2_out', 'conv4_block3_1_conv', 'conv4_block3_1_bn', 'conv4_block3_1_relu', 'conv4_block3_2_conv', 'conv4_block3_2_bn', 'conv4_block3_2_relu', 'conv4_block3_3_conv', 'conv4_block3_3_bn', 'conv4_block3_add', 'conv4_block3_out', 'conv4_block4_1_conv', 'conv4_block4_1_bn', 'conv4_block4_1_relu', 'conv4_block4_2_conv', 'conv4_block4_2_bn', 'conv4_block4_2_relu', 'conv4_block4_3_conv', 'conv4_block4_3_bn', 'conv4_block4_add', 'conv4_block4_out', 'conv4_block5_1_conv', 'conv4_block5_1_bn', 'conv4_block5_1_relu', 'conv4_block5_2_conv', 'conv4_block5_2_bn', 'conv4_block5_2_relu', 'conv4_block5_3_conv', 'conv4_block5_3_bn', 'conv4_block5_add', 'conv4_block5_out', 'conv4_block6_1_conv', 'conv4_block6_1_bn', 'conv4_block6_1_relu', 'conv4_block6_2_conv', 'conv4_block6_2_bn', 'conv4_block6_2_relu', 'conv4_block6_3_conv', 'conv4_block6_3_bn', 'conv4_block6_add', 'conv4_block6_out', 'conv5_block1_1_conv', 'conv5_block1_1_bn', 'conv5_block1_1_relu', 'conv5_block1_2_conv', 'conv5_block1_2_bn', 'conv5_block1_2_relu', 'conv5_block1_0_conv', 'conv5_block1_3_conv', 'conv5_block1_0_bn', 'conv5_block1_3_bn', 'conv5_block1_add', 'conv5_block1_out', 'conv5_block2_1_conv', 'conv5_block2_1_bn', 'conv5_block2_1_relu', 'conv5_block2_2_conv', 'conv5_block2_2_bn', 'conv5_block2_2_relu', 'conv5_block2_3_conv', 'conv5_block2_3_bn', 'conv5_block2_add', 'conv5_block2_out', 'conv5_block3_1_conv', 'conv5_block3_1_bn', 'conv5_block3_1_relu', 'conv5_block3_2_conv', 'conv5_block3_2_bn', 'conv5_block3_2_relu', 'conv5_block3_3_conv', 'conv5_block3_3_bn', 'conv5_block3_add', 'conv5_block3_out']

In [None]:
layers.sort()

In [None]:
layers

In [None]:
class BuildResNet50():
    