In [1]:
"""Convert TF v1 MobilenetV1 to TF v2 Keras."""

import os
from typing import Text, List, Dict, Tuple, Callable

import tensorflow as tf

In [2]:
tf.io.decode_csv('1,2,3', record_defaults=[0,0,0])

[<tf.Tensor: shape=(), dtype=int32, numpy=1>,
 <tf.Tensor: shape=(), dtype=int32, numpy=2>,
 <tf.Tensor: shape=(), dtype=int32, numpy=3>]

In [1]:
from research.mobilenet.mobilenet_v1 import mobilenet_v1
from research.mobilenet.configs import archs
from research.mobilenet.mobilenet_trainer import _get_dataset_config, get_dataset

In [2]:
%%bash 

cd "$( dirname "${BASH_SOURCE[0]}" )" || exit
DIR="$( pwd )"
SRC_DIR=${DIR}"/../../../"
export PYTHONPATH=${PYTHONPATH}:${SRC_DIR}

In [5]:
source_checkpoint = '/Users/luoshixin/Downloads/mobilenet_checkpoints/mobilenet_v1_1.0_224/'

In [6]:
ck = tf.train.latest_checkpoint(source_checkpoint)

In [7]:
ck

In [8]:
tf.train.list_variables()

TypeError: list_variables() missing 1 required positional argument: 'ckpt_dir_or_file'

In [4]:
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

from collections import defaultdict
from typing import Text, List, Dict, Tuple, Callable

import tensorflow as tf


def _process_moving_average(ma_terms: List[Text]) -> List[Text]:
  """
    MobilenetV2/Conv/BatchNorm/moving_variance
    MobilenetV2/Conv/BatchNorm/moving_variance/ExponentialMovingAverage
    MobilenetV3/expanded_conv_9/project/BatchNorm/moving_mean/ExponentialMovingAverage
  Args:
    ma_terms: a list of names related to moving average

  Returns:
    a list of names after de-duplicating
  """

  dedup_holder = dict()
  for item in ma_terms:
    base_name = item
    item_split = item.split('/')
    if 'moving_' in item_split[-2]:
      base_name = '/'.join(item_split[0:-1])

    if ((base_name not in dedup_holder)
        or (len(item) > len(dedup_holder[base_name]))):
      dedup_holder[base_name] = item

  return list(dedup_holder.values())


def _load_weights_from_ckpt(checkpoint_path: Text,
                            include_filters: List[Text],
                            exclude_filters: List[Text]
                            ) -> Dict[Text, tf.Tensor]:
  """Load all the weights stored in the checkpoint as {var_name: var_value}

  Args:
    checkpoint_path: path to the checkpoint file xxxxx.ckpt
    include_filters: list of keywords that determine which var_names should be
    kept in the output list
    exclude_filters: list of keywords that determine which var_names should be
    excluded from the output list

  Returns:
    A dictionary of {var_name: tensor values}
  """
  reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
  var_shape_map = reader.get_variable_to_shape_map()
  ma_terms = list()

  var_value_map = {}
  for item in var_shape_map:
    include_check = True
    if include_filters:
      include_check = any([to_check in item
                           for to_check in include_filters])
    exclude_check = True
    if exclude_filters:
      exclude_check = all([to_check not in item
                           for to_check in exclude_filters])

    if exclude_check and 'moving_' in item:
      ma_terms.append(item)
    elif exclude_check and include_check:
      var_value_map[item] = reader.get_tensor(item)

  processed_ma_terms = _process_moving_average(ma_terms)
  for p_item in processed_ma_terms:
    var_value_map[p_item] = reader.get_tensor(p_item)

  return var_value_map


def _decouple_layer_name(var_value_map: Dict[Text, tf.Tensor],
                         use_mv_average: bool = True
                         ) -> List[Tuple[Text, Text, tf.Tensor]]:
  """Sort the names of the weights by the layer they correspond to. The example
  names of the weightes:
    MobilenetV1/Conv2d_0/weights
    MobilenetV1/Conv2d_9_pointwise/weights
    MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta
    MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta/ExponentialMovingAverage
    MobilenetV2/expanded_conv_9/project/weights/ExponentialMovingAverage
    MobilenetV2/expanded_conv_9/project/BatchNorm/beta/ExponentialMovingAverage
    Model_Name/Layer_Name/Component_Name[/Extra]

  Args:
    var_value_map: a dictionary of {var_name: tensor values}
    use_mv_average: whether `ExponentialMovingAverage` should be used. If this
    is true, the `ExponentialMovingAverage` related weightes should be included
    in  `var_value_map`.

  Returns:
    A list of (layer_num, layer_name, layer_component, weight_value)
  """
  layer_list = []
  for weight_name, weight_value in var_value_map.items():
    weight_name_split = weight_name.split('/')
    if use_mv_average and 'ExponentialMovingAverage' in weight_name:
      # MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta/ExponentialMovingAverage
      layer_name = '/'.join(weight_name_split[1:-2])
      layer_component = '/'.join(weight_name_split[-2:-1])
    else:
      # MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta
      layer_name = '/'.join(weight_name_split[1:-1])
      layer_component = '/'.join(weight_name_split[-1:])

    layer_list.append(
      (layer_name, layer_component, weight_value))

  return layer_list


def _layer_weights_list_to_map(
    layer_ordered_list: List[Tuple[Text, Text, tf.Tensor]]
) -> Dict[Text, List[tf.Tensor]]:
  """Organize same layer with multiple components into group.
  For example: BatchNorm has 'gamma', 'beta', 'moving_mean', 'moving_variance'

  Args:
    layer_ordered_list: A list of (layer_num, layer_name,
    layer_component, weight_value)

  Returns:
    A dictionary of {layer_name: layer_weights}
  """

  # define the vars order in Keras layer
  batchnorm_order = ['gamma', 'beta', 'moving_mean', 'moving_variance']
  dense_cnn_order = ['weights', 'biases']
  depthwise_order = ['depthwise_weights', 'biases']

  # Organize same layer with multiple components into group
  keras_weights = defaultdict(list)

  for (layer_name, layer_component, weight) in layer_ordered_list:
    keras_weights[layer_name].append((layer_component, weight))

  # Sort within each group. The ordering should be
  ordered_layer_weights = {}

  for group_name, group in keras_weights.items():
    # format of group: [(layer_component, weight)]
    if len(group) == 1:
      order_weight_group = [group[0][1]]
    else:
      group_len = len(group)
      order_weight_group = [0] * group_len

      if group_len == 2:
        target_order = dense_cnn_order
        if 'depthwise' in group_name:
          target_order = depthwise_order
      elif group_len == 4:
        target_order = batchnorm_order
      else:
        raise ValueError(
          'The number of components {} in a layer is not supported'.format(
            group_len))

      for item_name, item_value in group:
        index = target_order.index(item_name)
        order_weight_group[index] = item_value

    ordered_layer_weights[group_name] = order_weight_group

  return ordered_layer_weights


def generate_layer_weights_map(checkpoint_path: Text,
                               include_filters: List[Text],
                               exclude_filters: List[Text],
                               use_mv_average: bool = True
                               ) -> Dict[Text, List[tf.Tensor]]:
  """Generate a dictionary of {layer_name: layer_weights} from checkpoint.

  Args:
    checkpoint_path: path to the checkpoint file xxxxx.ckpt
    include_filters: list of keywords that determine which var_names should be
    kept in the output list
    exclude_filters: list of keywords that determine which var_names should be
    excluded from the output list
    use_mv_average: whether `ExponentialMovingAverage` should be used. If this
    is true, the `ExponentialMovingAverage` related weightes should be included
    in  `var_value_map`.

  Returns:
    A dictionary of {layer_name: layer_weights}
  """
  var_value_map = _load_weights_from_ckpt(
    checkpoint_path=checkpoint_path,
    include_filters=include_filters,
    exclude_filters=exclude_filters)

  layer_ordered_list = _decouple_layer_name(
    var_value_map=var_value_map,
    use_mv_average=use_mv_average)

  ordered_layer_weights = _layer_weights_list_to_map(
    layer_ordered_list=layer_ordered_list)

  return ordered_layer_weights


def load_tf2_keras_model_weights(keras_model: tf.keras.Model,
                                 weights_map: Dict[Text, List[tf.Tensor]],
                                 name_map_fn: Callable
                                 ):
  """Load a TF2 Keras model with a {layer_name: layer_weights} dictionary
  generated from TF1 checkpoint.

  Args:
    keras_model: TF2 Keras model
    weights_map: a dictionary of {layer_name: layer_weights}
    name_map_fn: a function that convert TF2 layer name to TF1 layer name

  Returns:

  """
  trainable_layer_types = (
    tf.keras.layers.Conv2D,
    tf.keras.layers.BatchNormalization,
    tf.keras.layers.Dense,
    tf.keras.layers.DepthwiseConv2D,
  )

  trainable_layers = [layer for layer in keras_model.layers
                      if isinstance(layer, trainable_layer_types)]

  for layer in trainable_layers:
    name = layer.name
    tf1_name = name_map_fn(name)
    weight = weights_map[tf1_name]
    layer.set_weights(weight)


def save_keras_checkpoint(keras_model: tf.keras.Model,
                          save_path: Text,
                          save_format: Text = 'ckpt'
                          ):
  """Save a TF2 Keras model to a checkpoint.

  Args:
    keras_model: TF2 Keras model
    save_format: save format: ckpt and tf
    save_path: path to save the checkpoint

  Returns:

  """
  if save_format == 'ckpt':
    checkpoint = tf.train.Checkpoint(model=keras_model)
    manager = tf.train.CheckpointManager(checkpoint,
                                         directory=save_path,
                                         max_to_keep=1)
    manager.save()
  else:
    keras_model.save(save_path, save_format=save_format)


In [5]:
var_value_map = _load_weights_from_ckpt(source_checkpoint, 
                                 ['ExponentialMovingAverage'], 
                                 ['RMSProp', 'global_step', 'loss'])

In [6]:
layer_ordered_list = _decouple_layer_name(var_value_map)

In [7]:
order_keras_weights = _layer_weights_list_to_map(layer_ordered_list)

In [8]:
sorted([(item[0], [iitem.shape for iitem in item[1]]) for item in order_keras_weights.items()])

[('Conv2d_0', [(3, 3, 3, 24)]),
 ('Conv2d_0/BatchNorm', [(24,), (24,), (24,), (24,)]),
 ('Conv2d_10_depthwise', [(3, 3, 384, 1)]),
 ('Conv2d_10_depthwise/BatchNorm', [(384,), (384,), (384,), (384,)]),
 ('Conv2d_10_pointwise', [(1, 1, 384, 384)]),
 ('Conv2d_10_pointwise/BatchNorm', [(384,), (384,), (384,), (384,)]),
 ('Conv2d_11_depthwise', [(3, 3, 384, 1)]),
 ('Conv2d_11_depthwise/BatchNorm', [(384,), (384,), (384,), (384,)]),
 ('Conv2d_11_pointwise', [(1, 1, 384, 384)]),
 ('Conv2d_11_pointwise/BatchNorm', [(384,), (384,), (384,), (384,)]),
 ('Conv2d_12_depthwise', [(3, 3, 384, 1)]),
 ('Conv2d_12_depthwise/BatchNorm', [(384,), (384,), (384,), (384,)]),
 ('Conv2d_12_pointwise', [(1, 1, 384, 768)]),
 ('Conv2d_12_pointwise/BatchNorm', [(768,), (768,), (768,), (768,)]),
 ('Conv2d_13_depthwise', [(3, 3, 768, 1)]),
 ('Conv2d_13_depthwise/BatchNorm', [(768,), (768,), (768,), (768,)]),
 ('Conv2d_13_pointwise', [(1, 1, 768, 768)]),
 ('Conv2d_13_pointwise/BatchNorm', [(768,), (768,), (768,), (76

In [9]:
def mobinetv1_tf1_tf2_name_convert(tf2_layer_name: Text) -> Text:
  """Convert TF2 layer name to TF1 layer name. Examples:
  Conv2d_0/batch_norm -> Conv2d_0/BatchNorm
  Conv2d_4/pointwise/batch_norm -> Conv2d_4_pointwise/BatchNorm
  top/Conv2d_1x1_output -> Logits/Conv2d_1c_1x1

  Args:
    tf2_layer_name: name of TF2 layer

  Returns:
    name of TF1 layer
  """
  if 'top/Conv2d_1x1_output' in tf2_layer_name:
    tf1_layer_name = 'Logits/Conv2d_1c_1x1'
  else:
    if 'batch_norm' in tf2_layer_name:
      tf2_layer_name = tf2_layer_name.replace('batch_norm', 'BatchNorm')

    if '/pointwise' in tf2_layer_name:
      tf1_layer_name = tf2_layer_name.replace('/pointwise', '_pointwise')
    elif '/depthwise' in tf2_layer_name:
      tf1_layer_name = tf2_layer_name.replace('/depthwise', '_depthwise')
    else:
      tf1_layer_name = tf2_layer_name

  return tf1_layer_name

In [3]:
config = archs.MobileNetV1Config()
config.width_multiplier = 1.0
model = mobilenet_v1(config=config)
model.summary()

Model: "MobileNetV1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
Input (InputLayer)           [(None, 224, 224, 3)]     0         
_________________________________________________________________
Conv2d_0_0 (Conv2D)          (None, 112, 112, 32)      864       
_________________________________________________________________
Conv2d_0_0/batch_norm (Batch (None, 112, 112, 32)      128       
_________________________________________________________________
Conv2d_0_0/relu6 (Activation (None, 112, 112, 32)      0         
_________________________________________________________________
Conv2d_1/depthwise (Depthwis (None, 112, 112, 32)      288       
_________________________________________________________________
Conv2d_1/depthwise/batch_nor (None, 112, 112, 32)      128       
_________________________________________________________________
Conv2d_1/depthwise/relu6 (Ac (None, 112, 112, 32)      

In [11]:
load_tf2_keras_model_weights(model, order_keras_weights, mobinetv1_tf1_tf2_name_convert)

In [12]:
model.compile(
    optimizer='rmsprop',
    loss=tf.keras.losses.categorical_crossentropy,
    metrics=[tf.keras.metrics.categorical_crossentropy])

save_format = 'ckpt'
save_path = '/Users/luoshixin/Downloads/mobilenet_v1_ck'

if save_format == 'ckpt':
    checkpoint = tf.train.Checkpoint(model=model)
    manager = tf.train.CheckpointManager(checkpoint,
                                         directory=save_path,
                                         max_to_keep=1)
    manager.save()
else:
    model.save(save_path, save_format=save_format)

In [13]:
d_config = _get_dataset_config().get("imagenette")()
d_config.split = 'validation'
eval_dataset = get_dataset(d_config)
for batch in eval_dataset.take(1):
    data, label = batch[0], batch[1]

In [14]:
model.predict(data)

array([[2.4249034e-09, 5.1374682e-09, 1.3301625e-09, ..., 5.1343602e-10,
        5.2039693e-09, 8.2226943e-06],
       [1.3226485e-11, 2.8857193e-11, 2.6858283e-11, ..., 2.2072804e-11,
        6.1748051e-09, 4.1291774e-09],
       [1.6818857e-09, 7.9699576e-01, 5.4981319e-06, ..., 2.1244281e-04,
        1.1076544e-05, 6.3732131e-08],
       ...,
       [3.2478045e-08, 5.9373335e-08, 1.2961380e-07, ..., 2.5657037e-08,
        1.0997113e-06, 6.1789278e-06],
       [1.4707939e-07, 5.6232446e-05, 3.6148110e-06, ..., 1.9385812e-08,
        2.1747068e-04, 8.2438401e-06],
       [9.8925716e-07, 1.5429567e-05, 3.1129814e-05, ..., 8.9486957e-06,
        5.9194866e-05, 2.1502106e-05]], dtype=float32)