In [157]:
from tensorflow.python.keras import backend
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.keras import layers
import tensorflow as tf
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.lib.io import file_io
from tensorflow.python.util.tf_export import keras_export
from tensorflow.keras import datasets
from tensorflow.keras import optimizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import os
import numpy as np

In [3]:
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
x_train = tf.keras.applications.resnet50.preprocess_input(x_train.astype(np.float32))
x_test = tf.keras.applications.resnet50.preprocess_input(x_test.astype(np.float32))

In [None]:
def define_resnet50():
    inputs = tf.keras.layers.Input(shape=(32, 32, 3))
    resized = tf.keras.layers.UpSampling2D(size=(7, 7))(inputs)

    features = tf.keras.applications.ResNet50(input_shape=(224, 224, 3),
                                              include_top=False,
                                              weights='imagenet')(resized)

    x = tf.keras.layers.GlobalAveragePooling2D()(features)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(1024, activation="relu")(x)
    x = tf.keras.layers.Dense(512, activation="relu")(x)
    x = tf.keras.layers.Dense(10, activation="softmax", name="classification")(x)
    model = tf.keras.Model(inputs=inputs, outputs=x)

    return model


model = define_resnet50()
model.summary()

In [71]:
def resnet_conv_block(input_shape, filters, conv_shortcut, name, model_name, kernel_size=3, stride=1, DW=False):
    inputs = tf.keras.layers.Input(input_shape)
    bn_axis = 3

    if conv_shortcut:
        shortcut = layers.Conv2D(
            4 * filters, 1, strides=stride, name=name + '_0_conv')(inputs)
        shortcut = layers.BatchNormalization(
            axis=bn_axis, epsilon=1.001e-5, name=name + '_0_bn')(shortcut)
    else:
        shortcut = inputs

    x = layers.Conv2D(filters, 1, strides=stride, name=name + '_1_conv')(inputs)
    x = layers.BatchNormalization(
        axis=bn_axis, epsilon=1.001e-5, name=name + '_1_bn')(x)
    x = layers.Activation('relu', name=name + '_1_relu')(x)

    # Conv2D 3x3
    if not DW:
        #conv
        x = layers.Conv2D(
            filters, kernel_size, padding='SAME', name=name + '_2_conv')(x)
        x = layers.BatchNormalization(
            axis=bn_axis, epsilon=1.001e-5, name=name + '_2_bn')(x)
        x = layers.Activation('relu', name=name + '_2_relu')(x)
    else:
        x = tf.keras.layers.DepthwiseConv2D(kernel_size=3, padding='SAME', name=name + '_2_DWconv')(x)
        x = layers.Conv2D(
            filters, kernel_size=1, padding='SAME', name=name + '_2_conv')(x)
        x = layers.BatchNormalization(
            axis=bn_axis, epsilon=1.001e-5, name=name + '_2_bn')(x)
        pass

    x = layers.Conv2D(4 * filters, 1, name=name + '_3_conv')(x)
    x = layers.BatchNormalization(
        axis=bn_axis, epsilon=1.001e-5, name=name + '_3_bn')(x)

    x = layers.Add(name=name + '_add')([shortcut, x])
    x = layers.Activation('relu', name=name + '_out')(x)

    return tf.keras.Model(inputs=inputs, outputs=x, name=model_name)

In [218]:
# def define_modular_resnet50(load_conv_weights=False, folder='weights', dw_indx=1):
#     global layers
load_conv_weights = True
folder = 'weights'
dw_indx = 30

layers = VersionAwareLayers()

use_bias = True
bn_axis = 3


def before_conv_model(inputs):
    resized = tf.keras.layers.UpSampling2D(size=(7, 7))(inputs)
    x = layers.ZeroPadding2D(
        padding=((3, 3), (3, 3)), name='conv1_pad')(resized)
    x = layers.Conv2D(64, 7, strides=2, use_bias=use_bias, name='conv1_conv')(x)

    x = layers.BatchNormalization(
        axis=bn_axis, epsilon=1.001e-5, name='conv1_bn')(x)
    x = layers.Activation('relu', name='conv1_relu')(x)

    x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name='pool1_pad')(x)
    x = layers.MaxPooling2D(3, strides=2, name='pool1_pool')(x)
    return tf.keras.Model(inputs=inputs, outputs=x, name='input_resize')


inputs = tf.keras.layers.Input(shape=(32, 32, 3))
input_resize_model = before_conv_model(inputs)

x = input_resize_model(inputs)

# 3 blocks
model_conv2_1 = resnet_conv_block(input_shape=x.shape[1:], filters=64, conv_shortcut=True, name='conv2',
                                  model_name='block_conv2_1', stride=1,
                                  DW=True if dw_indx == 1 or dw_indx == 0 else False)
x = model_conv2_1(x)
model_conv2_2 = resnet_conv_block(input_shape=x.shape[1:], filters=64, conv_shortcut=False, name='conv2',
                                  model_name='block_conv2_2', stride=1,
                                  DW=True if dw_indx == 2 or dw_indx == 0 else False)
x = model_conv2_2(x)
model_conv2_3 = resnet_conv_block(input_shape=x.shape[1:], filters=64, conv_shortcut=False, name='conv2',
                                  model_name='block_conv2_3', stride=1,
                                  DW=True if dw_indx == 3 or dw_indx == 0 else False)
x = model_conv2_3(x)

# 4 blocks
model_conv3_1 = resnet_conv_block(input_shape=x.shape[1:], filters=128, conv_shortcut=True, name='conv3',
                                  model_name='block_conv3_1', stride=2,
                                  DW=True if dw_indx == 4 or dw_indx == 0 else False)
x = model_conv3_1(x)
model_conv3_2 = resnet_conv_block(input_shape=x.shape[1:], filters=128, conv_shortcut=False, name='conv3',
                                  model_name='block_conv3_2', stride=1,
                                  DW=True if dw_indx == 5 or dw_indx == 0 else False)
x = model_conv3_2(x)
model_conv3_3 = resnet_conv_block(input_shape=x.shape[1:], filters=128, conv_shortcut=False, name='conv3',
                                  model_name='block_conv3_3', stride=1,
                                  DW=True if dw_indx == 6 or dw_indx == 0 else False)
x = model_conv3_3(x)
model_conv3_4 = resnet_conv_block(input_shape=x.shape[1:], filters=128, conv_shortcut=False, name='conv3',
                                  model_name='block_conv3_4', stride=1,
                                  DW=True if dw_indx == 7 or dw_indx == 0 else False)
x = model_conv3_4(x)

# 6 blocks
model_conv4_1 = resnet_conv_block(input_shape=x.shape[1:], filters=256, conv_shortcut=True, name='conv4',
                                  model_name='block_conv4_1', stride=1,
                                  DW=True if dw_indx == 8 or dw_indx == 0 else False)
x = model_conv4_1(x)
model_conv4_2 = resnet_conv_block(input_shape=x.shape[1:], filters=256, conv_shortcut=False, name='conv4',
                                  model_name='block_conv4_2', stride=1,
                                  DW=True if dw_indx == 9 or dw_indx == 0 else False)
x = model_conv4_2(x)
model_conv4_3 = resnet_conv_block(input_shape=x.shape[1:], filters=256, conv_shortcut=False, name='conv4',
                                  model_name='block_conv4_3', stride=1,
                                  DW=True if dw_indx == 10 or dw_indx == 0 else False)
x = model_conv4_3(x)
model_conv4_4 = resnet_conv_block(input_shape=x.shape[1:], filters=256, conv_shortcut=False, name='conv4',
                                  model_name='block_conv4_4', stride=1,
                                  DW=True if dw_indx == 11 or dw_indx == 0 else False)
x = model_conv4_4(x)
model_conv4_5 = resnet_conv_block(input_shape=x.shape[1:], filters=256, conv_shortcut=False, name='conv4',
                                  model_name='block_conv4_5', stride=1,
                                  DW=True if dw_indx == 12 or dw_indx == 0 else False)
x = model_conv4_5(x)
model_conv4_6 = resnet_conv_block(input_shape=x.shape[1:], filters=256, conv_shortcut=False, name='conv4',
                                  model_name='block_conv4_6', stride=1,
                                  DW=True if dw_indx == 13 or dw_indx == 0 else False)
x = model_conv4_6(x)

# 3 blocks
model_conv5_1 = resnet_conv_block(input_shape=x.shape[1:], filters=512, conv_shortcut=True, name='conv5',
                                  model_name='block_conv5_1', stride=1,
                                  DW=True if dw_indx == 14 or dw_indx == 0 else False)
x = model_conv5_1(x)
model_conv5_2 = resnet_conv_block(input_shape=x.shape[1:], filters=512, conv_shortcut=False, name='conv5',
                                  model_name='block_conv5_2', stride=1,
                                  DW=True if dw_indx == 15 or dw_indx == 0 else False)
x = model_conv5_2(x)
model_conv5_3 = resnet_conv_block(input_shape=x.shape[1:], filters=512, conv_shortcut=False, name='conv5',
                                  model_name='block_conv5_3', stride=1,
                                  DW=True if dw_indx == 16 or dw_indx == 0 else False)
x = model_conv5_3(x)


def classifier(input_shape):
    inputs = tf.keras.layers.Input(shape=input_shape)
    x = tf.keras.layers.GlobalAveragePooling2D()(inputs)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(1024, activation="relu")(x)
    x = tf.keras.layers.Dense(512, activation="relu")(x)
    x = tf.keras.layers.Dense(10, activation="softmax", name="classification")(x)
    return tf.keras.Model(inputs=inputs, outputs=x, name='classifier_model')


classifier_model = classifier(x.shape[1:])
x = classifier_model(x)

models_dict = {'classifier_model': classifier_model,
               'input_resize_model': input_resize_model,
               'model_conv2_1': model_conv2_1,
               'model_conv2_2': model_conv2_2,
               'model_conv2_3': model_conv2_3,
               'model_conv3_1': model_conv3_1,
               'model_conv3_2': model_conv3_2,
               'model_conv3_3': model_conv3_3,
               'model_conv3_4': model_conv3_4,
               'model_conv4_1': model_conv4_1,
               'model_conv4_2': model_conv4_2,
               'model_conv4_3': model_conv4_3,
               'model_conv4_4': model_conv4_4,
               'model_conv4_5': model_conv4_5,
               'model_conv4_6': model_conv4_6,
               'model_conv5_1': model_conv5_1,
               'model_conv5_2': model_conv5_2,
               'model_conv5_3': model_conv5_3}

model = tf.keras.Model(inputs=inputs, outputs=x)

if load_conv_weights:

    # doesn't working
    # for i, name in enumerate(list(models_dict.keys())):
    #     if i != dw_indx:
    #
    #         models_dict[name].load_weights(folder + '/' + name + '/conv_weights.h5')
    # input_resize_model.load_weights(folder + '/input_resize.h5')
    # classifier_model.load_weights(folder + '/classifier.h5')

    if dw_indx != 1:
        model_conv2_1.load_weights(folder + '/model_conv2_1/conv_weights.h5')
    if dw_indx != 2:
        model_conv2_2.load_weights(folder + '/model_conv2_2/conv_weights.h5')
    if dw_indx != 3:
        model_conv2_3.load_weights(folder + '/model_conv2_3/conv_weights.h5')

    # 4 blocks
    if dw_indx != 4:
        model_conv3_1.load_weights(folder + '/model_conv3_1/conv_weights.h5')
    if dw_indx != 5:
        model_conv3_2.load_weights(folder + '/model_conv3_2/conv_weights.h5')
    if dw_indx != 6:
        model_conv3_3.load_weights(folder + '/model_conv3_3/conv_weights.h5')
    if dw_indx != 7:
        model_conv3_4.load_weights(folder + '/model_conv3_4/conv_weights.h5')

    # 6 blocks
    if dw_indx != 8:
        model_conv4_1.load_weights(folder + '/model_conv4_1/conv_weights.h5')
    if dw_indx != 9:
        model_conv4_2.load_weights(folder + '/model_conv4_2/conv_weights.h5')
    if dw_indx != 10:
        model_conv4_3.load_weights(folder + '/model_conv4_3/conv_weights.h5')
    if dw_indx != 11:
        model_conv4_4.load_weights(folder + '/model_conv4_4/conv_weights.h5')
    if dw_indx != 12:
        model_conv4_5.load_weights(folder + '/model_conv4_5/conv_weights.h5')
    if dw_indx != 13:
        model_conv4_6.load_weights(folder + '/model_conv4_6/conv_weights.h5')

    # 3 blocks
    if dw_indx != 14:
        model_conv5_1.load_weights(folder + '/model_conv5_1/conv_weights.h5')
    if dw_indx != 15:
        model_conv5_2.load_weights(folder + '/model_conv5_2/conv_weights.h5')
    if dw_indx != 16:
        model_conv5_3.load_weights(folder + '/model_conv5_3/conv_weights.h5')
    # return model, models_dict

# model, models_dict = define_modular_resnet50(load_conv_weights=True,dw_indx=30)
# model.summary()

In [91]:
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath='checkpoints/',
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

In [212]:
# optim = optimizers.RMSprop(centered=False, learning_rate=0.001)
optim = optimizers.Adam(learning_rate=0.001, amsgrad=True)
# optim=optimizers.SGD(learning_rate=0.0001,nesterov=True)
# optim=optimizers.Adadelta(learning_rate=0.001)
model.compile(optimizer=optim, metrics=['accuracy'], loss='sparse_categorical_crossentropy')

In [214]:
# model.load_weights('modular_resnet_acc_train=0.84_test=0.79.h5')

In [216]:
model.fit(x_train, y_train, batch_size=90, epochs=100, validation_data=(x_test, y_test), shuffle=True,
          callbacks=[model_checkpoint_callback])

Epoch 1/100
 73/556 [==>...........................] - ETA: 5:06 - loss: 1.6451 - accuracy: 0.3682

KeyboardInterrupt: 

In [104]:
model.save_weights('modular_resnet_acc_train=0.xx_test=0.xx.h5')

In [172]:
def eval_model(model, x_train, y_train, x_test, y_test):
    print('train data\n', model.evaluate(x_train, y_train), '\n')
    print('test data\n', model.evaluate(x_test, y_test))

In [201]:
eval_model(model, x_train, y_train, x_test, y_test)

 201/1563 [==>...........................] - ETA: 1:45 - loss: 2.6873 - accuracy: 0.0859

KeyboardInterrupt: 

In [194]:
def save_modules_weights(folder, models_dict):
    names = list(models_dict.keys())
    for name in names:
        if not os.path.exists(folder + '/' + name):
            os.mkdir(folder + '/' + name)
        models_dict[name].save_weights(folder + '/' + name + '/conv_weights.h5')


In [195]:
save_modules_weights('weights', models_dict)

## Learning pipeline for resnet modules