In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [None]:
from copy import deepcopy
import imageio
import numpy as np
import cv2
import matplotlib.pyplot as plt
import keras
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
from tqdm import tqdm
from tensorflow.keras import Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer, InputSpec, DepthwiseConv2D, BatchNormalization, Activation, Conv2D, Add, Conv2DTranspose
from tensorflow.keras import backend as K
from tensorflow.keras.backend import int_shape, permute_dimensions
from collections import OrderedDict
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, SeparableConv2D
from tensorflow.keras.initializers import RandomNormal, RandomUniform

In [None]:
def rename_layer(model_config, src_name, dst_name):
    for layer_item in model_config['input_layers']:
        input_condition = layer_item[0] == src_name
        if input_condition:
            # print('Rename Input:: ', layer_item[0])
            layer_item[0] = dst_name

    for layer_item in model_config['output_layers']:
        output_condition = layer_item[0] == src_name
        if output_condition:
            # print('Rename Output:: ', layer_item[0])
            layer_item[0] = dst_name
    #     print(rename_condition,  layer_list[0])

    for layer_item in model_config['layers']:
        name_condition = layer_item['name'] == src_name
        if name_condition:
            # print('Rename Layers:: name', layer_item['name'])
            layer_item['name'] = dst_name

        name_condition = layer_item['config']['name'] == src_name
        if name_condition:
            # print('Rename Layers Config:: config', layer_item['config']['name'])
            layer_item['config']['name'] = dst_name
            #             print('               Rename Layers Config:: config', layer_item['config']['name'])

        if len(layer_item['inbound_nodes']) > 0:
            for item in layer_item['inbound_nodes'][0]:
                inbound_condition = item[0] == src_name
                if inbound_condition:
                    # print('Rename Layers Inbound::', item[0])
                    item[0] = dst_name
    return model_config

In [None]:
Conv2D_config_template = {
    'name': '_node_name',
    'class_name': 'Conv2D',
    'config': {'name': '_node_name',
    'trainable': True,
    'dtype': 'float32',
    'filters': 1,
    'kernel_size': (1, 1),
    'strides': (1, 1),
    'padding': 'same',
    'data_format': 'channels_last',
    'dilation_rate': (1, 1),
    'activation': 'linear',
    'use_bias': True,
    'kernel_initializer': {'class_name': 'GlorotUniform',
     'config': {'seed': None}},
    'bias_initializer': {'class_name': 'Zeros', 'config': {}},
    'kernel_regularizer': None,
    'bias_regularizer': None,
    'activity_regularizer': None,
    'kernel_constraint': None,
    'bias_constraint': None},
   'name': '_node_name',
   'inbound_nodes': [[['_in_node_name', 0, 0, {}]]]}

In [None]:
placeholder = Input((32, 32, 3), name='data')
x = a = Conv2D(4, (7,7), padding='same', name='src_1', bias_initializer=RandomNormal())(placeholder)
x = BatchNormalization(name='src_2', 
                       beta_initializer=RandomNormal(),
                       gamma_initializer=RandomNormal(),
                       moving_mean_initializer=RandomNormal(),
                       moving_variance_initializer=RandomUniform(1,2))(x)
# x = Add()([a, x])
src_model = Model(placeholder, x, name='src_model')

placeholder = Input((32, 32, 3), name='data')
x = Conv2D(4, (7, 7), padding='same', name='dst_1')(placeholder)
dst_model = Model(placeholder, x, name='dst_model')

In [None]:
weight, bias = src_model.get_layer('src_1').get_weights()
gamma, beta, mean, var = src_model.get_layer('src_2').get_weights()
eps = src_model.get_layer('src_2').get_config()['epsilon']
a = gamma / np.sqrt(var + eps)
weight = weight*a.reshape((1,1,-1,1))
bias = a*(bias - mean) + beta
dst_model.get_layer('dst_1').set_weights([weight, bias])

In [None]:
x_in = np.random.uniform(size=(1,) + dst_model.input_shape[1:])
print(np.abs(dst_model.predict(x_in) - src_model.predict(x_in)).sum())

In [None]:
def transfer_Conv2DBatchNormalization_Conv2D(src_model, dst_model, transfer_rule):
    layer_c = src_model.get_layer(transfer_rule['src_c'])
    weigths_c = layer_c.get_weights()
    layer_b = src_model.get_layer(transfer_rule['src_b'])
    weigths_b = layer_b.get_weights()
    
    eps = layer_b.get_config()['epsilon']
    weight, bias = weigths_c
    gamma, beta, mean, var = weigths_b
    
    a = gamma / np.sqrt(var + eps)
    weight = weight*a
    bias = a*(bias - mean) + beta
    
    dst_model.get_layer(transfer_rule['dst']).set_weights([weight, bias])

def get_outbound_nodes(keras_config):
    outbound_dict = {}
    index_dict = {}
    for _i, _layer in enumerate(keras_config['layers']):
        out_node_name = _layer['name']
        index_dict[out_node_name] = _i
        if len(_layer['inbound_nodes']) == 0:
            continue
        in_node_name = _layer['inbound_nodes'][0][0][0]
        if in_node_name in outbound_dict:
            outbound_dict[in_node_name] += [out_node_name]
        else:
            outbound_dict[in_node_name] = [out_node_name]
        
    return outbound_dict, index_dict
    
def detect_transform_Conv2DBatchNormalization(keras_config):
    index_list = []
    outbound_dict, index_dict = get_outbound_nodes(keras_config)
    for i, item in enumerate(keras_config['layers']):
        if item['class_name'] == 'BatchNormalization':
            in_node_name = item['inbound_nodes'][0][0][0]
            in_node_class_name = keras_config['layers'][index_dict[in_node_name]]['class_name']
            if in_node_class_name == 'Conv2D':
                if len(outbound_dict[in_node_name]) == 1:
                    index_list.append(i)
    return index_list

def apply_transform_Conv2DBatchNormalization(keras_config):
    index_list = detect_transform_Conv2DBatchNormalization(keras_config)
    weight_transfer_rule_dict = {}
    while len(index_list) > 0:
        i = index_list[0]
        r_layer_config = keras_config['layers'].pop(i)
        src_name = r_layer_config['name']
        dst_name = r_layer_config['inbound_nodes'][0][0][0]
        keras_config = rename_layer(keras_config, src_name, dst_name)
        merged_dst = dst_name + '_M'
        keras_config = rename_layer(keras_config, dst_name, merged_dst)
        weight_transfer_rule_dict[merged_dst] = {'transfer_call': transfer_Conv2DBatchNormalization_Conv2D,
                                                             'src_c': dst_name, 
                                                             'src_b': src_name, 
                                                             'dst':merged_dst}        
        index_list = detect_transform_Conv2DBatchNormalization(keras_config)
    return keras_config, weight_transfer_rule_dict

In [None]:
def transfer_weights(src_model, dst_model, weight_transfer_rule_dict):
    for dst_layer in tqdm(dst_model.layers):
        if dst_layer.name in weight_transfer_rule_dict:
            transfer_rule = weight_transfer_rule_dict[dst_layer.name]
            func = transfer_rule['transfer_call']
            func(src_model, dst_model, transfer_rule)
        else:
            src_model.get_layer(dst_layer.name).set_weights(dst_layer.get_weights())

In [None]:
dst_model_config, weight_transfer_rule_dict = apply_transform_Conv2DBatchNormalization(src_model.get_config())

In [None]:
dst_model = Model.from_config(dst_model_config)

In [None]:
transfer_weights(src_model, dst_model, weight_transfer_rule_dict)

In [None]:
x_in = np.random.uniform(size=(1,) + dst_model.input_shape[1:])
print(np.abs(dst_model.predict(x_in) - src_model.predict(x_in)).sum())