# import statements

In [27]:
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Activation, ReLU
from tensorflow.keras.layers import BatchNormalization, Conv2DTranspose, Concatenate
from tensorflow.keras.models import Model, Sequential 
from tensorflow.keras.utils import plot_model

from tensorflow.keras.applications import VGG16

# define blocks for the model

In [55]:
# convolution block

def convolution_block(inputs, num_filters):
    # convolution layer 1 of the block
    x = Conv2D(num_filters, (3,3), padding='same')(inputs)  # padding='same' to avoid cut-down with conv
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    # convolution layer 2 of the block
    x = Conv2D(num_filters, (3,3), padding='same')(x) 
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    # max pooling not used here as just the bridge
    
    return x

In [59]:
# decoder block

def decoder_block(inputs, skip_tensor, num_filters):
    # adds in the skips with concatenate
    x = Conv2DTranspose(num_filters, (2,2), strides=2, padding='same')(inputs) # stride important here to up-sample
    x = Concatenate()([x, skip_tensor])     # bringing in skip layer
    x = convolution_block(x, num_filters)
    
    return x

# set out model

In [69]:
# build vgg-16 

def build_vgg16_unet(input_shape):
    inputs = Input(input_shape)
    
    # see actual VGG-16 here: https://github.com/keras-team/keras/blob/v2.9.0/keras/applications/vgg16.py#L43-L227
    vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=inputs)
    # vgg16.summary()
    vgg16.trainable = False
    
    ''' Encoder - skip layers '''
    skip1 = vgg16.get_layer('block1_conv2').output #  256 x 256, 64 filters in vgg16
    skip2 = vgg16.get_layer('block2_conv2').output #  128 x 128, 128 filters in vgg16
    skip3 = vgg16.get_layer('block3_conv3').output #   64 x 64, 256 filters in vgg16
    skip4 = vgg16.get_layer('block4_conv3').output #   32 x 32, 512 filters in vgg16
    # display('skip4: ' + str(skip4.shape))
    
    # only need to specify the skip layers, as VGG16 is an Encoder
    # Therefore, VGG16 comes built with MaxPool2d, so we don't specify
    
    ''' Bridge '''
    bridge = vgg16.get_layer('block5_conv3').output # 16 x 16, with 512 filters in vgg16
    # display('bridge: ' + str(bridge.shape))
    
    
    ''' Decoder '''
    d1 = decoder_block(bridge, skip4, 512) #  512 filters, as per the bridge
    d2 = decoder_block(d1, skip3, 256) #  256 filters
    d3 = decoder_block(d2, skip2, 128) #  128 filters
    d4 = decoder_block(d3, skip1, 64)  #   64 filters
    
    ''' Output '''
    outputs = Conv2D(1, (1,1), padding='same', activation='sigmoid')(d4)
    
    model = Model(inputs, outputs, name='first_VGG16_UNET') 
    
    return model
    

In [91]:
our_input = (224,224,3)
model = build_vgg16_unet(our_input)

In [92]:
def compile_model(m):
    m.compile(
        loss='binary_crossentropy',
        optimizer='adam' 
    )
    return m

In [93]:
model = compile_model(model)

## Could do with a timestamp method for this

In [97]:
model_path_and_filename = '../models/first_UNET_input_shape_224x224x3.h5'
model.save(model_path_and_filename)

In [74]:
# Calling `save('my_model')` creates a SavedModel folder `my_model`.
# model.save("my_model")
# don't this way, harder work to reload and run a model (will require both the model and the code that created it)

In [95]:
# It can be used to reconstruct the model identically.
reconstructed_model = keras.models.load_model(model_path_and_filename)

In [None]:
# getting data in using tf dataset

# tf.dataset tensorslices

# using decode.tif 

In [68]:
model.summary()
# Total params: 18,549,761
# Trainable params: 3,833,153
# Non-trainable params: 14,716,608

Model: "first_VGG16_UNET"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_21 (InputLayer)          [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 block1_conv1 (Conv2D)          (None, 224, 224, 64  1792        ['input_21[0][0]']               
                                )                                                                 
                                                                                                  
 block1_conv2 (Conv2D)          (None, 224, 224, 64  36928       ['block1_conv1[0][0]']           
                                )                                                  