In [1]:
from data import train_data_generator, stack_gen
from model_bn import unet
from keras.callbacks import ModelCheckpoint, EarlyStopping
import os
from os import makedirs, path
import tensorflow as tf
import pandas as pd
from datetime import datetime
from sklearn.model_selection import train_test_split
from random import randint

In [2]:
from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, Dropout, UpSampling2D, concatenate, BatchNormalization
from keras.optimizers import Adam
from keras.initializers import *
from keras.metrics import SparseTopKCategoricalAccuracy, SparseCategoricalCrossentropy, MeanSquaredError
import keras.backend as K
import tensorflow as tf
from keras.objectives import mean_squared_error

In [3]:
image_side_length = 512
batch_size = 8
validation_split_rate = 0.1
epoch_num = 5
auto_encoder = True
seed = 82
model_load_path = "transfer_learning/autoencoder{seed}.hdf5".format(seed=seed)
initializer = 'he_normal'

In [4]:
# build model
auto_encoder_model = unet(
    input_size=(image_side_length, image_side_length, 2),
    pretrained_weights=model_load_path,
    learning_rate=1e-3,
    classify_level=4,
    auto_encoder=True
    )

In [5]:
auto_encoder_model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 512, 512, 2) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 512, 512, 64) 1216        input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 512, 512, 64) 36928       conv2d[0][0]                     
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 512, 512, 64) 256         conv2d_1[0][0]                   
______________________________________________________________________________________________

In [15]:
[(i, l.name) for i, l in enumerate(auto_encoder_model.layers)]

[(0, 'input_1'),
 (1, 'conv2d'),
 (2, 'conv2d_1'),
 (3, 'batch_normalization'),
 (4, 'max_pooling2d'),
 (5, 'conv2d_2'),
 (6, 'conv2d_3'),
 (7, 'batch_normalization_1'),
 (8, 'max_pooling2d_1'),
 (9, 'conv2d_4'),
 (10, 'conv2d_5'),
 (11, 'batch_normalization_2'),
 (12, 'max_pooling2d_2'),
 (13, 'conv2d_6'),
 (14, 'conv2d_7'),
 (15, 'batch_normalization_3'),
 (16, 'max_pooling2d_3'),
 (17, 'conv2d_8'),
 (18, 'conv2d_9'),
 (19, 'batch_normalization_4'),
 (20, 'up_sampling2d'),
 (21, 'conv2d_10'),
 (22, 'concatenate'),
 (23, 'conv2d_11'),
 (24, 'conv2d_12'),
 (25, 'batch_normalization_5'),
 (26, 'up_sampling2d_1'),
 (27, 'conv2d_13'),
 (28, 'concatenate_1'),
 (29, 'conv2d_14'),
 (30, 'conv2d_15'),
 (31, 'batch_normalization_6'),
 (32, 'up_sampling2d_2'),
 (33, 'conv2d_16'),
 (34, 'concatenate_2'),
 (35, 'conv2d_17'),
 (36, 'conv2d_18'),
 (37, 'batch_normalization_7'),
 (38, 'up_sampling2d_3'),
 (39, 'conv2d_19'),
 (40, 'concatenate_3'),
 (41, 'conv2d_20'),
 (42, 'conv2d_21'),
 (43, 'batch

In [20]:
conv_outputs = {}
for i in range(1, 6):
    l_idx = 3 + 4*(i-1)
    conv_outputs[i] = auto_encoder_model.layers[l_idx].output
conv_outputs

{1: <tf.Tensor 'batch_normalization/Identity:0' shape=(None, 512, 512, 64) dtype=float32>,
 2: <tf.Tensor 'batch_normalization_1/Identity:0' shape=(None, 256, 256, 128) dtype=float32>,
 3: <tf.Tensor 'batch_normalization_2/Identity:0' shape=(None, 128, 128, 256) dtype=float32>,
 4: <tf.Tensor 'batch_normalization_3/Identity:0' shape=(None, 64, 64, 512) dtype=float32>,
 5: <tf.Tensor 'batch_normalization_4/Identity:0' shape=(None, 32, 32, 1024) dtype=float32>}

In [None]:
latent_output = auto_encoder_model.get_layer('batch_normalization_4').output
conv4 = 'max_pooling2d_3'

In [27]:
auto_encoder = False
classify_level = 4

In [28]:
up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = initializer)(UpSampling2D(size = (2,2))(conv_outputs[5]))
merge6 = concatenate([conv_outputs[4] ,up6], axis = 3)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(merge6)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv6)
conv6 = BatchNormalization()(conv6)

up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = initializer)(UpSampling2D(size = (2,2))(conv6))
merge7 = concatenate([conv_outputs[3], up7], axis = 3)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(merge7)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv7)
conv7 = BatchNormalization()(conv7)

up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = initializer)(UpSampling2D(size = (2,2))(conv7))
merge8 = concatenate([conv_outputs[2], up8], axis = 3)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(merge8)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv8)
conv8 = BatchNormalization()(conv8)

up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = initializer)(UpSampling2D(size = (2,2))(conv8))
merge9 = concatenate([conv_outputs[1], up9], axis = 3)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(merge9)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = initializer)(conv9)
conv9 = BatchNormalization()(conv9)
if auto_encoder:
    conv10 = Conv2D(2, 1, activation = 'sigmoid')(conv9)
    loss_func = "mean_squared_error"
    metrics = [
        MeanSquaredError()
        ]
else:
    conv10 = Conv2D(classify_level, 1, activation = 'softmax')(conv9)
    loss_func = 'sparse_categorical_crossentropy'
    metrics = [
        SparseTopKCategoricalAccuracy(k=1),
        SparseCategoricalCrossentropy(axis=-1)
        ]

In [29]:
unet_model = Model(inputs=auto_encoder_model.input, outputs=conv10)

In [30]:
unet_model.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 512, 512, 2) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 512, 512, 64) 1216        input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 512, 512, 64) 36928       conv2d[0][0]                     
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 512, 512, 64) 256         conv2d_1[0][0]                   
____________________________________________________________________________________________

In [32]:
for i in range(0, 20):
    unet_model.layers[i].trainable = False

In [35]:
learning_rate = 1e-3
unet_model.compile(
        optimizer = Adam(lr = learning_rate),
        loss = loss_func,
        metrics = metrics
        )

In [36]:
unet_model.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 512, 512, 2) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 512, 512, 64) 1216        input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 512, 512, 64) 36928       conv2d[0][0]                     
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 512, 512, 64) 256         conv2d_1[0][0]                   
____________________________________________________________________________________________