In [2]:
"""Convert TF v1 MobilenetV2 to TF v2 Keras."""

import os
from typing import Text, List, Dict, Tuple, Callable
from collections import defaultdict

import tensorflow as tf

In [1]:
from research.mobilenet.mobilenet_v2 import mobilenet_v2
from research.mobilenet.mobilenet_trainer import _get_dataset_config, get_dataset

from research.mobilenet.configs import archs
MobileNetV2Config = archs.MobileNetV2Config

In [None]:
source_checkpoint = '/Users/luoshixin/Downloads/mobilenet_v2_0.75_224/mobilenet_v2_0.75_224.ckpt'

In [4]:
reader = tf.compat.v1.train.NewCheckpointReader(source_checkpoint)
var_shape_map = reader.get_variable_to_shape_map()

In [5]:
def _process_moving_average(ma_terms: List[Text]) -> List[Text]:
  """
    MobilenetV2/Conv/BatchNorm/moving_variance
    MobilenetV2/Conv/BatchNorm/moving_variance/ExponentialMovingAverage
  Args:
    ma_terms: a list of names related to moving average

  Returns:
    a list of names after de-duplicating
  """
#   print(ma_terms)
  output_list = list()
  replace_flag = False
  for item in ma_terms:
    base_name = item
    item_split = item.split('/')
    if 'moving_' in item_split[-2]:
      base_name = '/'.join(item_split[0:-1])
      replace_flag = True

    if (base_name in output_list) and replace_flag:
      t_index = output_list.index(base_name)
      output_list[t_index] = item

    if base_name not in output_list:
      output_list.append(item)

    replace_flag = False

#   print(output_list)
  return output_list


def _load_weights_from_ckpt(checkpoint_path: Text,
                            include_filters: List[Text],
                            exclude_filters: List[Text]
                            ) -> Dict[Text, tf.Tensor]:
  """Load all the weights stored in the checkpoint as {var_name: var_value}

  Args:
    checkpoint_path: path to the checkpoint file xxxxx.ckpt
    include_filters: list of keywords that determine which var_names should be
    kept in the output list
    exclude_filters: list of keywords that determine which var_names should be
    excluded from the output list

  Returns:
    A dictionary of {var_name: tensor values}
  """
  reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
  var_shape_map = reader.get_variable_to_shape_map()
  ma_terms = list()

  var_value_map = {}
  for item in var_shape_map:
    include_check = True
    if include_filters:
      include_check = any([to_check in item
                           for to_check in include_filters])
    exclude_check = True
    if exclude_filters:
      exclude_check = all([to_check not in item
                           for to_check in exclude_filters])

    if exclude_check and 'moving_' in item:
      ma_terms.append(item)
    elif include_check and exclude_check:
      var_value_map[item] = reader.get_tensor(item)

  processed_ma_terms = _process_moving_average(ma_terms)
  for p_item in processed_ma_terms:
    var_value_map[p_item] = reader.get_tensor(p_item)
  
  return var_value_map


def _decouple_layer_name(var_value_map: Dict[Text, tf.Tensor],
                         use_mv_average: bool = True
                         ) -> List[Tuple[Text, Text, tf.Tensor]]:
  """Sort the names of the weights by the layer they correspond to. The example
  names of the weightes:
    MobilenetV1/Conv2d_0/weights
    MobilenetV1/Conv2d_9_pointwise/weights
    MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta
    MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta/ExponentialMovingAverage

    Model_Name/Layer_Name/Component_Name[/Extra]

  Args:
    var_value_map: a dictionary of {var_name: tensor values}
    use_mv_average: whether `ExponentialMovingAverage` should be used. If this
    is true, the `ExponentialMovingAverage` related weightes should be included
    in  `var_value_map`.

  Returns:
    A list of (layer_num, layer_name, layer_component, weight_value)
  """
  layer_list = []
  for weight_name, weight_value in var_value_map.items():
    weight_name_split = weight_name.split('/')
    if use_mv_average and 'ExponentialMovingAverage' in weight_name:
      # MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta/ExponentialMovingAverage
      layer_name = '/'.join(weight_name_split[1:-2])
      layer_component = '/'.join(weight_name_split[-2:-1])
    else:
      # MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta
      layer_name = '/'.join(weight_name_split[1:-1])
      layer_component = '/'.join(weight_name_split[-1:])

    layer_list.append(
      (layer_name, layer_component, weight_value))

  return layer_list


def _layer_weights_list_to_map(
    layer_ordered_list: List[Tuple[Text, Text, tf.Tensor]]
) -> Dict[Text, List[tf.Tensor]]:
  """Organize same layer with multiple components into group.
  For example: BatchNorm has 'gamma', 'beta', 'moving_mean', 'moving_variance'

  Args:
    layer_ordered_list: A list of (layer_num, layer_name,
    layer_component, weight_value)

  Returns:
    A dictionary of {layer_name: layer_weights}
  """

  # define the vars order in Keras layer
  batchnorm_order = ['gamma', 'beta', 'moving_mean', 'moving_variance']
  dense_cnn_order = ['weights', 'biases']
  depthwise_order = ['depthwise_weights', 'biases']

  # Organize same layer with multiple components into group
  keras_weights = defaultdict(list)

  for (layer_name, layer_component, weight) in layer_ordered_list:
    keras_weights[layer_name].append((layer_component, weight))

  # Sort within each group. The ordering should be
  ordered_layer_weights = {}

  for group_name, group in keras_weights.items():
    # format of group: [(layer_component, weight)]
    if len(group) == 1:
      order_weight_group = [group[0][1]]
    else:
      group_len = len(group)
      order_weight_group = [0] * group_len

      if group_len == 2:
        target_order = dense_cnn_order
        if 'depthwise' in group_name:
          target_order = depthwise_order
      elif group_len == 4:
        target_order = batchnorm_order
      else:
        raise ValueError(
          'The number of components {} in a layer is not supported'.format(
            group_len))

      for item_name, item_value in group:
        index = target_order.index(item_name)
        order_weight_group[index] = item_value

    ordered_layer_weights[group_name] = order_weight_group

  return ordered_layer_weights


def generate_layer_weights_map(checkpoint_path: Text,
                               include_filters: List[Text],
                               exclude_filters: List[Text],
                               use_mv_average: bool = True
                               ) -> Dict[Text, List[tf.Tensor]]:
  """Generate a dictionary of {layer_name: layer_weights} from checkpoint.

  Args:
    checkpoint_path: path to the checkpoint file xxxxx.ckpt
    include_filters: list of keywords that determine which var_names should be
    kept in the output list
    exclude_filters: list of keywords that determine which var_names should be
    excluded from the output list
    use_mv_average: whether `ExponentialMovingAverage` should be used. If this
    is true, the `ExponentialMovingAverage` related weightes should be included
    in  `var_value_map`.

  Returns:
    A dictionary of {layer_name: layer_weights}
  """
  var_value_map = _load_weights_from_ckpt(
    checkpoint_path=checkpoint_path,
    include_filters=include_filters,
    exclude_filters=exclude_filters)

  layer_ordered_list = _decouple_layer_name(
    var_value_map=var_value_map,
    use_mv_average=use_mv_average)

  ordered_layer_weights = _layer_weights_list_to_map(
    layer_ordered_list=layer_ordered_list)

  return ordered_layer_weights


def load_tf2_keras_model_weights(keras_model: tf.keras.Model,
                                 weights_map: Dict[Text, List[tf.Tensor]],
                                 name_map_fn: Callable
                                 ):
  """Load a TF2 Keras model with a {layer_name: layer_weights} dictionary
  generated from TF1 checkpoint.

  Args:
    keras_model: TF2 Keras model
    weights_map: a dictionary of {layer_name: layer_weights}
    name_map_fn: a function that convert TF2 layer name to TF1 layer name

  Returns:

  """
  trainable_layer_types = (
    tf.keras.layers.Conv2D,
    tf.keras.layers.BatchNormalization,
    tf.keras.layers.Dense,
    tf.keras.layers.DepthwiseConv2D,
  )

  trainable_layers = [layer for layer in keras_model.layers
                      if isinstance(layer, trainable_layer_types)]

  for layer in trainable_layers:
    name = layer.name
    tf1_name = name_map_fn(name)
    weight = weights_map[tf1_name]
    layer.set_weights(weight)


def save_keras_checkpoint(keras_model: tf.keras.Model,
                          save_path: Text,
                          save_format: Text = 'ckpt'
                          ):
  """Save a TF2 Keras model to a checkpoint.

  Args:
    keras_model: TF2 Keras model
    save_format: save format: ckpt and tf
    save_path: path to save the checkpoint

  Returns:

  """
  if save_format == 'ckpt':
    checkpoint = tf.train.Checkpoint(model=keras_model)
    manager = tf.train.CheckpointManager(checkpoint,
                                         directory=save_path,
                                         max_to_keep=1)
    manager.save()
  else:
    keras_model.save(save_path, save_format=save_format)

In [6]:
var_value_map = _load_weights_from_ckpt(source_checkpoint, 
                                 ['ExponentialMovingAverage'], 
                                 ['RMSProp', 'global_step', 'loss'])

In [7]:
layer_ordered_list = _decouple_layer_name(var_value_map)

In [8]:
order_keras_weights = _layer_weights_list_to_map(layer_ordered_list)

In [9]:
sorted([(item[0], [iitem.shape for iitem in item[1]]) for item in order_keras_weights.items()])

[('Conv', [(3, 3, 3, 24)]),
 ('Conv/BatchNorm', [(24,), (24,), (24,), (24,)]),
 ('Conv_1', [(1, 1, 240, 1280)]),
 ('Conv_1/BatchNorm', [(1280,), (1280,), (1280,), (1280,)]),
 ('Logits/Conv2d_1c_1x1', [(1, 1, 1280, 1001), (1001,)]),
 ('expanded_conv/depthwise', [(3, 3, 24, 1)]),
 ('expanded_conv/depthwise/BatchNorm', [(24,), (24,), (24,), (24,)]),
 ('expanded_conv/project', [(1, 1, 24, 16)]),
 ('expanded_conv/project/BatchNorm', [(16,), (16,), (16,), (16,)]),
 ('expanded_conv_1/depthwise', [(3, 3, 96, 1)]),
 ('expanded_conv_1/depthwise/BatchNorm', [(96,), (96,), (96,), (96,)]),
 ('expanded_conv_1/expand', [(1, 1, 16, 96)]),
 ('expanded_conv_1/expand/BatchNorm', [(96,), (96,), (96,), (96,)]),
 ('expanded_conv_1/project', [(1, 1, 96, 24)]),
 ('expanded_conv_1/project/BatchNorm', [(24,), (24,), (24,), (24,)]),
 ('expanded_conv_10/depthwise', [(3, 3, 288, 1)]),
 ('expanded_conv_10/depthwise/BatchNorm', [(288,), (288,), (288,), (288,)]),
 ('expanded_conv_10/expand', [(1, 1, 48, 288)]),
 ('ex

In [10]:
def mobinetv2_tf1_tf2_name_convert(tf2_layer_name: Text) -> Text:
  """Convert TF2 layer name to TF1 layer name. Examples:
  Conv2d_0 -> Conv
  Conv2d_0/batch_norm -> Conv/BatchNorm
  Conv2d_18 -> Conv_1
  Conv2d_18/batch_norm -> Conv_1/BatchNorm

  expanded_conv_1/project -> expanded_conv/project
  expanded_conv_1/depthwise -> expanded_conv/depthwise
  expanded_conv_2/expand -> expanded_conv_1/expand
  expanded_conv_2/expand/batch_norm -> expanded_conv_1/expand/BatchNorm
  expanded_conv_2/project -> expanded_conv_1/project
  expanded_conv_2/depthwise -> expanded_conv_1/depthwise

  top/Conv2d_1x1_output -> Logits/Conv2d_1c_1x1

  Args:
    tf2_layer_name: name of TF2 layer

  Returns:
    name of TF1 layer
  """

  if 'top/Conv2d_1x1_output' in tf2_layer_name:
    tf1_layer_name = 'Logits/Conv2d_1c_1x1'
  else:
    if 'batch_norm' in tf2_layer_name:
      tf2_layer_name = tf2_layer_name.replace('batch_norm', 'BatchNorm')
    # process layer number
    tf2_layer_name_split = tf2_layer_name.split('/')
    layer_num_re, reminder = tf2_layer_name_split[0], tf2_layer_name_split[1:]
    layer_num_re_split = layer_num_re.split('_')
    layer_type = '_'.join(layer_num_re_split[0:-1])
    layer_num = int(layer_num_re_split[-1])

    if layer_type == 'Conv2d':
      layer_type = 'Conv'
      if layer_num == 0:
        target_num = ''
      else:
        target_num = '1'
    elif layer_type == 'expanded_conv':
      if layer_num == 1:
        target_num = ''
      else:
        target_num = str(layer_num - 1)
    else:
      raise ValueError('The layer number and type combination is not '
                       'supported: {}, {}'.format(layer_type, str(layer_num)))

    if target_num:
      tf1_layer_name = '/'.join(['_'.join([layer_type, target_num])] + reminder)
    else:
      tf1_layer_name = '/'.join([layer_type] + reminder)

  return tf1_layer_name

In [3]:
config = archs.MobileNetV2Config()
config.width_multiplier = 1.0
model = mobilenet_v2(config=config)
model.summary()

Model: "MobileNetV2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
Input (InputLayer)              [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
Conv2d_0_0 (Conv2D)             (None, 112, 112, 32) 864         Input[0][0]                      
__________________________________________________________________________________________________
Conv2d_0_0/batch_norm (BatchNor (None, 112, 112, 32) 128         Conv2d_0_0[0][0]                 
__________________________________________________________________________________________________
Conv2d_0_0/relu6 (Activation)   (None, 112, 112, 32) 0           Conv2d_0_0/batch_norm[0][0]      
________________________________________________________________________________________

In [12]:
load_tf2_keras_model_weights(model, order_keras_weights, mobinetv2_tf1_tf2_name_convert)

In [13]:
model.compile(
    optimizer='rmsprop',
    loss=tf.keras.losses.categorical_crossentropy,
    metrics=[tf.keras.metrics.categorical_crossentropy])

save_format = 'ckpt'
save_path = '/Users/luoshixin/Downloads/mobilenet_v2_ck'

if save_format == 'ckpt':
    checkpoint = tf.train.Checkpoint(model=model)
    manager = tf.train.CheckpointManager(checkpoint,
                                         directory=save_path,
                                         max_to_keep=1)
    manager.save()
else:
    model.save(save_path, save_format=save_format)

In [14]:
d_config = _get_dataset_config().get("imagenette")()
d_config.split = 'validation'
eval_dataset = get_dataset(d_config)
for batch in eval_dataset.take(1):
    data = batch[0]

In [15]:
model.predict(data)

array([[1.2755347e-04, 7.2183131e-05, 4.0077844e-05, ..., 1.2479289e-04,
        3.7609778e-05, 4.2228425e-05],
       [6.1000499e-04, 9.2725881e-04, 7.7758351e-04, ..., 1.1315856e-04,
        5.4444000e-04, 8.3419589e-05],
       [3.0527397e-05, 5.8289793e-06, 3.0618980e-05, ..., 3.2101882e-05,
        2.0800960e-04, 7.0051624e-06],
       ...,
       [2.8618096e-04, 9.4132386e-02, 1.8524384e-03, ..., 7.9685684e-05,
        2.7463472e-04, 1.6357186e-05],
       [1.9896599e-04, 3.7686979e-05, 8.5134081e-05, ..., 3.5276949e-05,
        2.8962622e-05, 2.6553948e-04],
       [4.3279899e-04, 1.8749197e-04, 6.2476168e-04, ..., 1.9237812e-04,
        1.7259331e-04, 5.6337107e-05]], dtype=float32)