In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [None]:
from copy import deepcopy
import imageio
import numpy as np
import cv2
import matplotlib.pyplot as plt
import keras
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
from tqdm import tqdm
from tensorflow.keras import Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer, InputSpec, DepthwiseConv2D, BatchNormalization, Activation, Conv2D, concatenate, Conv2DTranspose
from tensorflow.keras import backend as K
from tensorflow.keras.backend import int_shape, permute_dimensions
from collections import OrderedDict
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, SeparableConv2D

In [None]:
Conv2D_config_template = {
    'name': '_node_name',
    'class_name': 'Conv2D',
    'config': {'name': '_node_name',
    'trainable': True,
    'dtype': 'float32',
    'filters': 1,
    'kernel_size': (1, 1),
    'strides': (1, 1),
    'padding': 'same',
    'data_format': 'channels_last',
    'dilation_rate': (1, 1),
    'activation': 'linear',
    'use_bias': True,
    'kernel_initializer': {'class_name': 'GlorotUniform',
     'config': {'seed': None}},
    'bias_initializer': {'class_name': 'Zeros', 'config': {}},
    'kernel_regularizer': None,
    'bias_regularizer': None,
    'activity_regularizer': None,
    'kernel_constraint': None,
    'bias_constraint': None},
   'name': '_node_name',
   'inbound_nodes': [[['_in_node_name', 0, 0, {}]]]}

DepthwiseConv2D_config_template = {'name': '_node_name',
 'class_name': 'DepthwiseConv2D',
 'config': {'name': '_node_name',
  'trainable': True,
  'dtype': 'float32',
  'kernel_size': (3, 3),
  'strides': (1, 1),
  'padding': 'same',
  'data_format': 'channels_last',
  'dilation_rate': (1, 1),
  'activation': 'linear',
  'use_bias': False,
  'bias_initializer': {'class_name': 'Zeros', 'config': {'dtype': 'float32'}},
  'bias_regularizer': None,
  'activity_regularizer': None,
  'bias_constraint': None,
  'depth_multiplier': 1,
  'depthwise_initializer': {'class_name': 'GlorotUniform',
   'config': {'seed': None, 'dtype': 'float32'}},
  'depthwise_regularizer': None,
  'depthwise_constraint': None},
 'inbound_nodes': [[['_in_node_name', 0, 0, {}]]]}

In [None]:
placeholder = Input((32, 32, 3), name='data')
x = SeparableConv2D(4, (5,5), padding='same', name='src_layer')(placeholder)
src_model = Model(placeholder, x, name='src_model')

placeholder = Input((32, 32, 3), name='data')
x = DepthwiseConv2D((5,5), padding='same', name='dst_1', use_bias=False)(placeholder)
x = Conv2D(4, (1, 1), padding='same', name='dst_2')(x)
dst_model = Model(placeholder, x, name='dst_model')

In [None]:
w_1, w_2, b = src_model.get_layer('src_layer').get_weights()
dst_model.get_layer('dst_1').set_weights([w_1])
dst_model.get_layer('dst_2').set_weights([w_2, b])

In [None]:
x_in = np.random.uniform(size=(1, 32, 32, 3))
print(np.abs(dst_model.predict(x_in) - src_model.predict(x_in)).sum())

In [None]:
def transfer_SeparableConv2D_DepthwiseConv2D(src_model, dst_model, transfer_rule):
    _weigths = src_model.get_layer(transfer_rule['src']).get_weights()
    dst_model.get_layer(transfer_rule['dst']).set_weights([_weigths[0]])
    
def transfer_SeparableConv2D_Conv2D(src_model, dst_model, transfer_rule):
    _weigths = src_model.get_layer(transfer_rule['src']).get_weights()
#     if len(_weigths) == 3:
#         w_1, w_2, b = _weigths
#     assert len(_weigths) == 3, 'Check use_bias==False rule'
    dst_model.get_layer(transfer_rule['dst']).set_weights(_weigths[1:])

# detect correction
def detect_transform_SeparableConv2D(keras_config):
    index_list = []
    for i, item in enumerate(keras_config['layers']):
        if item['class_name'] == 'SeparableConv2D':
            index_list.append(i)
    return index_list

def apply_transform_SeparableConv2D(keras_config):
    index_list = detect_transform_SeparableConv2D(keras_config)
    weight_transfer_rule_dict = {}
    while len(index_list) > 0:
        i = index_list[0]
        r_layer_config = keras_config['layers'].pop(i)
        # Transfer DepthWise
        i_layer_config = deepcopy(DepthwiseConv2D_config_template)
        #TODO :: check unique name
        prev_name = i_layer_config['name'] = r_layer_config['name'] + f'_dwc_{i}'
        for key in DepthwiseConv2D_config_template['config'].keys():
            if key in r_layer_config['config']:
                i_layer_config['config'][key] = r_layer_config['config'][key]
        i_layer_config['inbound_nodes'] = r_layer_config['inbound_nodes']
        i_layer_config['config']['name'] = i_layer_config['name']
        i_layer_config['config']['use_bias'] = False
        keras_config['layers'].insert(i, i_layer_config)
        weight_transfer_rule_dict[i_layer_config['name']] = {'transfer_call': transfer_SeparableConv2D_DepthwiseConv2D,
                                                            'src': r_layer_config['name'], 'dst':i_layer_config['name']}
        
        
        # Transfer Conv
        i_layer_config = deepcopy(Conv2D_config_template)
        for key in set(Conv2D_config_template['config'].keys()) - set(['kernel_size', 
                                                                       'strides', 'dilation_rate']):
            if key in r_layer_config['config']:
                i_layer_config['config'][key] = r_layer_config['config'][key]
        i_layer_config['name'] = r_layer_config['name']
        i_layer_config['inbound_nodes'] = [[[prev_name, 0, 0, {}]]]
        
        keras_config['layers'].insert(i+1, i_layer_config)
        weight_transfer_rule_dict[i_layer_config['name']] = {'transfer_call': transfer_SeparableConv2D_Conv2D,
                                                    'src': r_layer_config['name'], 'dst':i_layer_config['name']}
        
        index_list = detect_transform_SeparableConv2D(keras_config)
    return keras_config, weight_transfer_rule_dict

In [None]:
def transfer_weights(src_model, dst_model, weight_transfer_rule_dict):
    for dst_layer in tqdm(dst_model.layers):
        if dst_layer.name in weight_transfer_rule_dict:
            transfer_rule = weight_transfer_rule_dict[dst_layer.name]
            func = transfer_rule['transfer_call']
            func(src_model, dst_model, transfer_rule)
        else:
            src_model.get_layer(dst_layer.name).set_weights(dst_layer.get_weights())

In [None]:
dst_model_config, weight_transfer_rule_dict = apply_transform_SeparableConv2D(src_model.get_config())

In [None]:
dst_model = Model.from_config(dst_model_config)

In [None]:
transfer_weights(src_model, dst_model, weight_transfer_rule_dict)

In [None]:
x_in = np.random.uniform(size=(1,) + _dst_model.input_shape[1:])
print(np.abs(dst_model.predict(x_in) - src_model.predict(x_in)).sum())