In [4]:
from tensorflow.python.keras import backend
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import VersionAwareLayers
from tensorflow.keras import layers
import tensorflow as tf
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.lib.io import file_io
from tensorflow.python.util.tf_export import keras_export
from tensorflow.keras import datasets
from tensorflow.keras import optimizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import os
import numpy as np

In [5]:
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
x_train = tf.keras.applications.resnet50.preprocess_input(x_train.astype(np.float32))
x_test = tf.keras.applications.resnet50.preprocess_input(x_test.astype(np.float32))

In [125]:
def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

In [136]:
def resnet_conv_block(input_shape, filters, conv_shortcut, name, model_name, kernel_size=3, stride=1, DW=False):
    inputs = tf.keras.layers.Input(input_shape)
    bn_axis = 3

    if conv_shortcut:
        shortcut = layers.Conv2D(
            4 * filters, 1, strides=stride, name=name + '_0_conv')(inputs)
        shortcut = layers.BatchNormalization(
            axis=bn_axis, epsilon=1.001e-5, name=name + '_0_bn')(shortcut)
    else:
        shortcut = inputs

    x = layers.Conv2D(filters, 1, strides=stride, name=name + '_1_conv')(inputs)
    x = layers.BatchNormalization(
        axis=bn_axis, epsilon=1.001e-5, name=name + '_1_bn')(x)
    x = layers.Activation('relu', name=name + '_1_relu')(x)

    # Conv2D 3x3
    if not DW:
        #conv
        x = layers.Conv2D(
            filters, kernel_size, padding='SAME', name=name + '_2_conv')(x)
        x = layers.BatchNormalization(
            axis=bn_axis, epsilon=1.001e-5, name=name + '_2_bn')(x)
        x = layers.Activation('relu', name=name + '_2_relu')(x)
    else:
        pw_filters = _make_divisible(filters, 8)
        x = tf.keras.layers.DepthwiseConv2D(kernel_size=3, padding='SAME', name=name + '_2_DWconv')(x)

        x = layers.BatchNormalization(axis=bn_axis,
                                      epsilon=1e-3,
                                      momentum=0.999,
                                      name=name + 'depthwise_BN')(x)

        x = layers.ReLU(6., name=name + 'depthwise_relu')(x)
        x = layers.Conv2D(
            pw_filters, kernel_size=1, padding='SAME', name=name + '_2_conv')(x)
        x = layers.BatchNormalization(
            axis=bn_axis, epsilon=1.001e-5, name=name + '_2_bn')(x)
        pass

    x = layers.Conv2D(4 * filters, 1, name=name + '_3_conv')(x)
    x = layers.BatchNormalization(
        axis=bn_axis, epsilon=1.001e-5, name=name + '_3_bn')(x)

    x = layers.Add(name=name + '_add')([shortcut, x])
    x = layers.Activation('relu', name=name + '_out')(x)

    return tf.keras.Model(inputs=inputs, outputs=x, name=model_name)

In [212]:
def define_modular_resnet50(load_conv_weights=False, load_dw_conv_weights_index=[], folder='weights', dw_indxs=[]):
    global layers
    layers = VersionAwareLayers()

    use_bias = True
    bn_axis = 3

    def before_conv_model(inputs):
        resized = tf.keras.layers.UpSampling2D(size=(7, 7))(inputs)
        x = layers.ZeroPadding2D(
            padding=((3, 3), (3, 3)), name='conv1_pad')(resized)
        x = layers.Conv2D(64, 7, strides=2, use_bias=use_bias, name='conv1_conv')(x)

        x = layers.BatchNormalization(
            axis=bn_axis, epsilon=1.001e-5, name='conv1_bn')(x)
        x = layers.Activation('relu', name='conv1_relu')(x)

        x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name='pool1_pad')(x)
        x = layers.MaxPooling2D(3, strides=2, name='pool1_pool')(x)
        return tf.keras.Model(inputs=inputs, outputs=x, name='input_resize')

    inputs = tf.keras.layers.Input(shape=(32, 32, 3))
    input_resize_model = before_conv_model(inputs)

    x = input_resize_model(inputs)

    # 3 blocks
    model_conv2_1 = resnet_conv_block(input_shape=x.shape[1:], filters=64, conv_shortcut=True, name='conv2',
                                      model_name='block_conv2_1', stride=1,
                                      DW=0 in dw_indxs)
    x = model_conv2_1(x)
    model_conv2_2 = resnet_conv_block(input_shape=x.shape[1:], filters=64, conv_shortcut=False, name='conv2',
                                      model_name='block_conv2_2', stride=1,
                                      DW=1 in dw_indxs)
    x = model_conv2_2(x)
    model_conv2_3 = resnet_conv_block(input_shape=x.shape[1:], filters=64, conv_shortcut=False, name='conv2',
                                      model_name='block_conv2_3', stride=1,
                                      DW=2 in dw_indxs)
    x = model_conv2_3(x)

    # 4 blocks
    model_conv3_1 = resnet_conv_block(input_shape=x.shape[1:], filters=128, conv_shortcut=True, name='conv3',
                                      model_name='block_conv3_1', stride=2,
                                      DW=3 in dw_indxs)
    x = model_conv3_1(x)
    model_conv3_2 = resnet_conv_block(input_shape=x.shape[1:], filters=128, conv_shortcut=False, name='conv3',
                                      model_name='block_conv3_2', stride=1,
                                      DW=4 in dw_indxs)
    x = model_conv3_2(x)
    model_conv3_3 = resnet_conv_block(input_shape=x.shape[1:], filters=128, conv_shortcut=False, name='conv3',
                                      model_name='block_conv3_3', stride=1,
                                      DW=5 in dw_indxs)
    x = model_conv3_3(x)
    model_conv3_4 = resnet_conv_block(input_shape=x.shape[1:], filters=128, conv_shortcut=False, name='conv3',
                                      model_name='block_conv3_4', stride=1,
                                      DW=6 in dw_indxs)
    x = model_conv3_4(x)

    # 6 blocks
    model_conv4_1 = resnet_conv_block(input_shape=x.shape[1:], filters=256, conv_shortcut=True, name='conv4',
                                      model_name='block_conv4_1', stride=1,
                                      DW=7 in dw_indxs)
    x = model_conv4_1(x)
    model_conv4_2 = resnet_conv_block(input_shape=x.shape[1:], filters=256, conv_shortcut=False, name='conv4',
                                      model_name='block_conv4_2', stride=1,
                                      DW=8 in dw_indxs)
    x = model_conv4_2(x)
    model_conv4_3 = resnet_conv_block(input_shape=x.shape[1:], filters=256, conv_shortcut=False, name='conv4',
                                      model_name='block_conv4_3', stride=1,
                                      DW=9 in dw_indxs)
    x = model_conv4_3(x)
    model_conv4_4 = resnet_conv_block(input_shape=x.shape[1:], filters=256, conv_shortcut=False, name='conv4',
                                      model_name='block_conv4_4', stride=1,
                                      DW=10 in dw_indxs)
    x = model_conv4_4(x)
    model_conv4_5 = resnet_conv_block(input_shape=x.shape[1:], filters=256, conv_shortcut=False, name='conv4',
                                      model_name='block_conv4_5', stride=1,
                                      DW=11 in dw_indxs)
    x = model_conv4_5(x)
    model_conv4_6 = resnet_conv_block(input_shape=x.shape[1:], filters=256, conv_shortcut=False, name='conv4',
                                      model_name='block_conv4_6', stride=1,
                                      DW=12 in dw_indxs)
    x = model_conv4_6(x)

    # 3 blocks
    model_conv5_1 = resnet_conv_block(input_shape=x.shape[1:], filters=512, conv_shortcut=True, name='conv5',
                                      model_name='block_conv5_1', stride=1,
                                      DW=13 in dw_indxs)
    x = model_conv5_1(x)
    model_conv5_2 = resnet_conv_block(input_shape=x.shape[1:], filters=512, conv_shortcut=False, name='conv5',
                                      model_name='block_conv5_2', stride=1,
                                      DW=14 in dw_indxs)
    x = model_conv5_2(x)
    model_conv5_3 = resnet_conv_block(input_shape=x.shape[1:], filters=512, conv_shortcut=False, name='conv5',
                                      model_name='block_conv5_3', stride=1,
                                      DW=15 in dw_indxs)
    x = model_conv5_3(x)

    def classifier(input_shape):
        inputs = tf.keras.layers.Input(shape=input_shape)
        x = tf.keras.layers.GlobalAveragePooling2D()(inputs)
        x = tf.keras.layers.Flatten()(x)
        x = tf.keras.layers.Dense(1024, activation="relu")(x)
        x = tf.keras.layers.Dense(512, activation="relu")(x)
        x = tf.keras.layers.Dense(10, activation="softmax", name="classification")(x)
        return tf.keras.Model(inputs=inputs, outputs=x, name='classifier_model')

    classifier_model = classifier(x.shape[1:])
    x = classifier_model(x)

    models_dict = {
        'model_conv2_1': model_conv2_1,
        'model_conv2_2': model_conv2_2,
        'model_conv2_3': model_conv2_3,
        'model_conv3_1': model_conv3_1,
        'model_conv3_2': model_conv3_2,
        'model_conv3_3': model_conv3_3,
        'model_conv3_4': model_conv3_4,
        'model_conv4_1': model_conv4_1,
        'model_conv4_2': model_conv4_2,
        'model_conv4_3': model_conv4_3,
        'model_conv4_4': model_conv4_4,
        'model_conv4_5': model_conv4_5,
        'model_conv4_6': model_conv4_6,
        'model_conv5_1': model_conv5_1,
        'model_conv5_2': model_conv5_2,
        'model_conv5_3': model_conv5_3,
        'classifier_model': classifier_model,
        'input_resize_model': input_resize_model}

    model = tf.keras.Model(inputs=inputs, outputs=x)

    if load_conv_weights:

        for i, name in enumerate(list(models_dict.keys())):
            # print(i)
            if i not in dw_indxs:
                models_dict[name].load_weights(folder + '/' + name + '/conv_weights.h5')

        input_resize_model.load_weights(folder + '/input_resize_model/conv_weights.h5')
        classifier_model.load_weights(folder + '/classifier_model/conv_weights.h5')

    if len(load_dw_conv_weights_index)!=0:
        for index in load_dw_index:
            name=list(models_dict.keys())[index]
            models_dict[name].load_weights(folder + '/' + name + '/dw_conv_weights.h5')

    return model, models_dict


dw_index = [0,1,2,3]
# dw_index = [0,1,]
load_dw_index=[0,1,2]
load_weights = True
model, models_dict = define_modular_resnet50(load_conv_weights=load_weights, load_dw_conv_weights_index=load_dw_index,
                                             dw_indxs=dw_index)
# model.summary()

In [109]:
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath='checkpoints/',
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

In [213]:
# optim = optimizers.RMSprop(centered=True, learning_rate=0.00001)
optim = optimizers.Adam(learning_rate=0.0001, amsgrad=True)
# optim=optimizers.SGD(learning_rate=0.0001,nesterov=True)
# optim = optimizers.Adadelta(learning_rate=0.00001)
model.compile(optimizer=optim, metrics=['accuracy'], loss='sparse_categorical_crossentropy')

In [63]:
# model.load_weights('modular_resnet_acc_train=0.86_test=0.81.h5')
model.load_weights('checkpoints/')

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x2393c87a7f0>

In [None]:
model.fit(x_train, y_train, batch_size=80, epochs=100, validation_data=(x_test, y_test), shuffle=True,
          callbacks=[model_checkpoint_callback])

Epoch 1/100

In [206]:
model.save_weights(f'modular_resnet_{dw_index}.h5')

In [14]:
def eval_model(model, x_train, y_train, x_test, y_test):
    print('train data\n', model.evaluate(x_train, y_train), '\n')
    print('test data\n', model.evaluate(x_test, y_test))

In [166]:
eval_model(model, x_train, y_train, x_test, y_test)

  61/1563 [>.............................] - ETA: 1:59 - loss: 0.7711 - accuracy: 0.7280

KeyboardInterrupt: 

In [210]:
model.evaluate(x_train, y_train)

 164/1563 [==>...........................] - ETA: 1:51 - loss: 0.4546 - accuracy: 0.8441

KeyboardInterrupt: 

In [209]:
model.evaluate(x_test, y_test)



[0.5997694134712219, 0.8079000115394592]

In [151]:
def save_modules_weights(folder, models_dict, dw_index, save_all=False):
    names = list(models_dict.keys())
    for i, name in enumerate(names[:-2]):
        if not os.path.exists(folder + '/' + name):
            os.mkdir(folder + '/' + name)

        if i in dw_index:
            models_dict[name].save_weights(folder + '/' + name + '/dw_conv_weights.h5')
        else:
            if save_all:
                models_dict[name].save_weights(folder + '/' + name + '/conv_weights.h5')

    if save_all:
        models_dict['input_resize_model'].save_weights(folder + '/input_resize_model/conv_weights.h5')
        models_dict['classifier_model'].save_weights(folder + '/classifier_model/conv_weights.h5')

In [211]:
save_modules_weights('weights', models_dict, dw_index, save_all=True)

## Learning pipeline for resnet modules