In [1]:
import os
from typing import Text, List, Dict, Tuple, Callable
from collections import defaultdict

import numpy as np
import tensorflow as tf

In [2]:
from official.vision.beta.modeling import backbones
from official.vision.beta.modeling import classification_model

In [45]:
from research.mobilenet.mobilenet_v3 import mobilenet_v3_large
from research.mobilenet.mobilenet_trainer import _get_dataset_config, get_dataset

from research.mobilenet.configs import archs
MobileNetV3LargeConfig = archs.MobileNetV3LargeConfig

In [49]:
def _process_moving_average(ma_terms: List[Text]) -> List[Text]:
  """
    MobilenetV2/Conv/BatchNorm/moving_variance
    MobilenetV2/Conv/BatchNorm/moving_variance/ExponentialMovingAverage
    MobilenetV3/expanded_conv_9/project/BatchNorm/moving_mean/ExponentialMovingAverage
  Args:
    ma_terms: a list of names related to moving average

  Returns:
    a list of names after de-duplicating
  """
  output_list = list()
  base_name_set = set()
  replace_flag = False
  for item in ma_terms:
    base_name = item
    item_split = item.split('/')
    if 'moving_' in item_split[-2]:
      base_name = '/'.join(item_split[0:-1])
      replace_flag = True

    if base_name in base_name_set:
      if replace_flag:
        t_index = output_list.index(base_name)
        output_list[t_index] = item
    else:
      output_list.append(item)
      base_name_set.add(base_name)

    replace_flag = False

  return output_list


def _load_weights_from_ckpt(checkpoint_path: Text,
                            include_filters: List[Text],
                            exclude_filters: List[Text]
                            ) -> Dict[Text, tf.Tensor]:
  """Load all the weights stored in the checkpoint as {var_name: var_value}

  Args:
    checkpoint_path: path to the checkpoint file xxxxx.ckpt
    include_filters: list of keywords that determine which var_names should be
    kept in the output list
    exclude_filters: list of keywords that determine which var_names should be
    excluded from the output list

  Returns:
    A dictionary of {var_name: tensor values}
  """
  reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
  var_shape_map = reader.get_variable_to_shape_map()
  ma_terms = list()

  var_value_map = {}
  for item in var_shape_map:
    include_check = True
    if include_filters:
      include_check = any([to_check in item
                           for to_check in include_filters])
    exclude_check = True
    if exclude_filters:
      exclude_check = all([to_check not in item
                           for to_check in exclude_filters])

    if exclude_check and 'moving_' in item:
      ma_terms.append(item)
    elif include_check and exclude_check:
      var_value_map[item] = reader.get_tensor(item)

  processed_ma_terms = _process_moving_average(ma_terms)
  for p_item in processed_ma_terms:
    var_value_map[p_item] = reader.get_tensor(p_item)
  
  return var_value_map


def _decouple_layer_name(var_value_map: Dict[Text, tf.Tensor],
                         use_mv_average: bool = True
                         ) -> List[Tuple[Text, Text, tf.Tensor]]:
  """Sort the names of the weights by the layer they correspond to. The example
  names of the weightes:
    MobilenetV1/Conv2d_0/weights
    MobilenetV1/Conv2d_9_pointwise/weights
    MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta
    MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta/ExponentialMovingAverage

    Model_Name/Layer_Name/Component_Name[/Extra]

  Args:
    var_value_map: a dictionary of {var_name: tensor values}
    use_mv_average: whether `ExponentialMovingAverage` should be used. If this
    is true, the `ExponentialMovingAverage` related weightes should be included
    in  `var_value_map`.

  Returns:
    A list of (layer_num, layer_name, layer_component, weight_value)
  """
  layer_list = []
  for weight_name, weight_value in var_value_map.items():
    weight_name_split = weight_name.split('/')
    if use_mv_average and 'ExponentialMovingAverage' in weight_name:
      # MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta/ExponentialMovingAverage
      layer_name = '/'.join(weight_name_split[1:-2])
      layer_component = '/'.join(weight_name_split[-2:-1])
    else:
      # MobilenetV1/Conv2d_9_pointwise/BatchNorm/beta
      layer_name = '/'.join(weight_name_split[1:-1])
      layer_component = '/'.join(weight_name_split[-1:])

    layer_list.append(
      (layer_name, layer_component, weight_value))

  return layer_list


def _layer_weights_list_to_map(
    layer_ordered_list: List[Tuple[Text, Text, tf.Tensor]]
) -> Dict[Text, List[tf.Tensor]]:
  """Organize same layer with multiple components into group.
  For example: BatchNorm has 'gamma', 'beta', 'moving_mean', 'moving_variance'

  Args:
    layer_ordered_list: A list of (layer_num, layer_name,
    layer_component, weight_value)

  Returns:
    A dictionary of {layer_name: layer_weights}
  """

  # define the vars order in Keras layer
  batchnorm_order = ['gamma', 'beta', 'moving_mean', 'moving_variance']
  dense_cnn_order = ['weights', 'biases']
  depthwise_order = ['depthwise_weights', 'biases']

  # Organize same layer with multiple components into group
  keras_weights = defaultdict(list)

  for (layer_name, layer_component, weight) in layer_ordered_list:
    keras_weights[layer_name].append((layer_component, weight))

  # Sort within each group. The ordering should be
  ordered_layer_weights = {}

  for group_name, group in keras_weights.items():
    # format of group: [(layer_component, weight)]
    if len(group) == 1:
      order_weight_group = [group[0][1]]
    else:
      group_len = len(group)
      order_weight_group = [0] * group_len

      if group_len == 2:
        target_order = dense_cnn_order
        if 'depthwise' in group_name:
          target_order = depthwise_order
      elif group_len == 4:
        target_order = batchnorm_order
      else:
        raise ValueError(
          'The number of components {} in a layer is not supported for {}'.format(
            group_len, group_name))

      for item_name, item_value in group:
        index = target_order.index(item_name)
        order_weight_group[index] = item_value

    ordered_layer_weights[group_name] = order_weight_group

  return ordered_layer_weights


def generate_layer_weights_map(checkpoint_path: Text,
                               include_filters: List[Text],
                               exclude_filters: List[Text],
                               use_mv_average: bool = True
                               ) -> Dict[Text, List[tf.Tensor]]:
  """Generate a dictionary of {layer_name: layer_weights} from checkpoint.

  Args:
    checkpoint_path: path to the checkpoint file xxxxx.ckpt
    include_filters: list of keywords that determine which var_names should be
    kept in the output list
    exclude_filters: list of keywords that determine which var_names should be
    excluded from the output list
    use_mv_average: whether `ExponentialMovingAverage` should be used. If this
    is true, the `ExponentialMovingAverage` related weightes should be included
    in  `var_value_map`.

  Returns:
    A dictionary of {layer_name: layer_weights}
  """
  var_value_map = _load_weights_from_ckpt(
    checkpoint_path=checkpoint_path,
    include_filters=include_filters,
    exclude_filters=exclude_filters)

  layer_ordered_list = _decouple_layer_name(
    var_value_map=var_value_map,
    use_mv_average=use_mv_average)

  ordered_layer_weights = _layer_weights_list_to_map(
    layer_ordered_list=layer_ordered_list)

  return ordered_layer_weights

## Load checkpoint

In [146]:
source_checkpoint = '/home/jupyter/v3-large_224_1.0_float/pristine/model.ckpt-540000'
reader = tf.compat.v1.train.NewCheckpointReader(source_checkpoint)
var_shape_map = reader.get_variable_to_shape_map()

In [147]:
var_value_map = _load_weights_from_ckpt(source_checkpoint, 
                                 ['ExponentialMovingAverage'], 
                                 ['RMSProp', 'global_step', 'loss', 'Momentum'])

In [148]:
layer_ordered_list = _decouple_layer_name(var_value_map)

In [149]:
order_keras_weights = _layer_weights_list_to_map(layer_ordered_list)

In [150]:
sorted([(item[0], [iitem.shape for iitem in item[1]]) for item in order_keras_weights.items()])

[('Conv', [(3, 3, 3, 16)]),
 ('Conv/BatchNorm', [(16,), (16,), (16,), (16,)]),
 ('Conv_1', [(1, 1, 160, 960)]),
 ('Conv_1/BatchNorm', [(960,), (960,), (960,), (960,)]),
 ('Conv_2', [(1, 1, 960, 1280), (1280,)]),
 ('Logits/Conv2d_1c_1x1', [(1, 1, 1280, 1001), (1001,)]),
 ('expanded_conv/depthwise', [(3, 3, 16, 1)]),
 ('expanded_conv/depthwise/BatchNorm', [(16,), (16,), (16,), (16,)]),
 ('expanded_conv/project', [(1, 1, 16, 16)]),
 ('expanded_conv/project/BatchNorm', [(16,), (16,), (16,), (16,)]),
 ('expanded_conv_1/depthwise', [(3, 3, 64, 1)]),
 ('expanded_conv_1/depthwise/BatchNorm', [(64,), (64,), (64,), (64,)]),
 ('expanded_conv_1/expand', [(1, 1, 16, 64)]),
 ('expanded_conv_1/expand/BatchNorm', [(64,), (64,), (64,), (64,)]),
 ('expanded_conv_1/project', [(1, 1, 64, 24)]),
 ('expanded_conv_1/project/BatchNorm', [(24,), (24,), (24,), (24,)]),
 ('expanded_conv_10/depthwise', [(3, 3, 480, 1)]),
 ('expanded_conv_10/depthwise/BatchNorm', [(480,), (480,), (480,), (480,)]),
 ('expanded_conv

## Create MNV3 model

In [152]:
tf.keras.backend.clear_session()
tf.keras.backend.set_image_data_format('channels_last')

backbone = backbones.MobileNet(
    model_id='MobileNetV3Large', filter_size_scale=1.0)

num_classes = 1001
model = classification_model.ClassificationModel(
    backbone=backbone,
    num_classes=num_classes,
    dropout_rate=0.2,
)

In [153]:
model.summary()

Model: "classification_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, None, None, 3)]   0         
_________________________________________________________________
mobile_net (MobileNet)       {'2': (None, None, None,  4226432   
_________________________________________________________________
global_average_pooling2d_1 ( (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1001)              1282281   
Total params: 5,508,713
Trainable params: 5,484,313
Non-trainable params: 24,400
_________________________________________________________________


In [154]:
[(item.name, item.shape) for item in model.layers[4].weights]

[('dense/kernel:0', TensorShape([1280, 1001])),
 ('dense/bias:0', TensorShape([1001]))]

In [155]:
[layer.name for layer in model.layers[1].layers if layer.weights]

['conv2dbn_block',
 'inverted_bottleneck_block',
 'inverted_bottleneck_block_1',
 'inverted_bottleneck_block_2',
 'inverted_bottleneck_block_3',
 'inverted_bottleneck_block_4',
 'inverted_bottleneck_block_5',
 'inverted_bottleneck_block_6',
 'inverted_bottleneck_block_7',
 'inverted_bottleneck_block_8',
 'inverted_bottleneck_block_9',
 'inverted_bottleneck_block_10',
 'inverted_bottleneck_block_11',
 'inverted_bottleneck_block_12',
 'inverted_bottleneck_block_13',
 'inverted_bottleneck_block_14',
 'conv2dbn_block_1',
 'conv2dbn_block_2']

### Conv Layer

In [156]:
[item.name for item in model.layers[1].layers[1].weights]

['conv2dbn_block/conv2d/kernel:0',
 'conv2dbn_block/batch_normalization/gamma:0',
 'conv2dbn_block/batch_normalization/beta:0',
 'conv2dbn_block/batch_normalization/moving_mean:0',
 'conv2dbn_block/batch_normalization/moving_variance:0']

In [157]:
[item.name for item in model.layers[1].layers[-2].weights]

['conv2dbn_block_2/conv2d/kernel:0', 'conv2dbn_block_2/conv2d/bias:0']

### Inverted Bottleneck Layer without Projection

In [158]:
[item.name for item in model.layers[1].layers[3].weights]

['inverted_bottleneck_block/depthwise_conv2d/depthwise_kernel:0',
 'inverted_bottleneck_block/batch_normalization/gamma:0',
 'inverted_bottleneck_block/batch_normalization/beta:0',
 'inverted_bottleneck_block/conv2d/kernel:0',
 'inverted_bottleneck_block/batch_normalization_1/gamma:0',
 'inverted_bottleneck_block/batch_normalization_1/beta:0',
 'inverted_bottleneck_block/batch_normalization/moving_mean:0',
 'inverted_bottleneck_block/batch_normalization/moving_variance:0',
 'inverted_bottleneck_block/batch_normalization_1/moving_mean:0',
 'inverted_bottleneck_block/batch_normalization_1/moving_variance:0']

### Inverted Bottleneck Layer

In [159]:
[item.name for item in model.layers[1].layers[5].weights]

['inverted_bottleneck_block_1/conv2d/kernel:0',
 'inverted_bottleneck_block_1/batch_normalization/gamma:0',
 'inverted_bottleneck_block_1/batch_normalization/beta:0',
 'inverted_bottleneck_block_1/depthwise_conv2d/depthwise_kernel:0',
 'inverted_bottleneck_block_1/batch_normalization_1/gamma:0',
 'inverted_bottleneck_block_1/batch_normalization_1/beta:0',
 'inverted_bottleneck_block_1/conv2d_1/kernel:0',
 'inverted_bottleneck_block_1/batch_normalization_2/gamma:0',
 'inverted_bottleneck_block_1/batch_normalization_2/beta:0',
 'inverted_bottleneck_block_1/batch_normalization/moving_mean:0',
 'inverted_bottleneck_block_1/batch_normalization/moving_variance:0',
 'inverted_bottleneck_block_1/batch_normalization_1/moving_mean:0',
 'inverted_bottleneck_block_1/batch_normalization_1/moving_variance:0',
 'inverted_bottleneck_block_1/batch_normalization_2/moving_mean:0',
 'inverted_bottleneck_block_1/batch_normalization_2/moving_variance:0']

### Inverted Bottleneck Layer with SE

In [160]:
[(item.name, item.shape) for item in model.layers[1].layers[9].weights]

[('inverted_bottleneck_block_3/conv2d/kernel:0', TensorShape([1, 1, 24, 72])),
 ('inverted_bottleneck_block_3/batch_normalization/gamma:0',
  TensorShape([72])),
 ('inverted_bottleneck_block_3/batch_normalization/beta:0', TensorShape([72])),
 ('inverted_bottleneck_block_3/depthwise_conv2d/depthwise_kernel:0',
  TensorShape([5, 5, 72, 1])),
 ('inverted_bottleneck_block_3/batch_normalization_1/gamma:0',
  TensorShape([72])),
 ('inverted_bottleneck_block_3/batch_normalization_1/beta:0',
  TensorShape([72])),
 ('inverted_bottleneck_block_3/squeeze_excitation/conv2d_2/kernel:0',
  TensorShape([1, 1, 72, 24])),
 ('inverted_bottleneck_block_3/squeeze_excitation/conv2d_2/bias:0',
  TensorShape([24])),
 ('inverted_bottleneck_block_3/squeeze_excitation/conv2d_3/kernel:0',
  TensorShape([1, 1, 24, 72])),
 ('inverted_bottleneck_block_3/squeeze_excitation/conv2d_3/bias:0',
  TensorShape([72])),
 ('inverted_bottleneck_block_3/conv2d_1/kernel:0',
  TensorShape([1, 1, 72, 40])),
 ('inverted_bottleneck

## Convert checkpoint 

In [161]:
def tf2_to_tf1_name(tf2_layer_name):
    '''
    ['conv2dbn_block',
     'inverted_bottleneck_block',
     'inverted_bottleneck_block_1',
     'inverted_bottleneck_block_2',
     ......
     'conv2dbn_block_1',
     'conv2dbn_block_2']
     
     ['Conv',
      'expanded_conv',
      'expanded_conv_1',
      'expanded_conv_2',
      ......
      'Conv_1',
      'Conv_2'
     ]
    '''
    if 'conv2dbn_block' in tf2_layer_name:
        tf1_layer_name = tf2_layer_name.replace('conv2dbn_block', 'Conv')
    elif 'inverted_bottleneck_block' in tf2_layer_name:
        tf1_layer_name = tf2_layer_name.replace('inverted_bottleneck_block', 'expanded_conv')
    else:
        raise ValueError

    return tf1_layer_name
        
def get_tf1_layers_by_tf2_name(tf2_layer_name, tf1_order_keras_weights):
    tf1_layer_name = tf2_to_tf1_name(tf2_layer_name)
    
    tf1_layer_dict = {layer_name: tf1_order_keras_weights[layer_name] 
                      for layer_name in tf1_order_keras_weights 
                      if layer_name.split('/')[0] == tf1_layer_name}
    
    return tf1_layer_name, tf1_layer_dict


def order_tf1_layer_weights(tf1_layer_dict, tf1_layer_name):
    conv_order = ['expand', 'depthwise', 'squeeze_excite', 'project']
    batch_norm_order = ['expand', 'depthwise', 'project']
    
    result_holder = []
    
    conv_holder = defaultdict(list)
    batch_norm_holder = dict()
    if tf1_layer_name.startswith('expanded_conv'):
        for (layer_name, layer_weights) in sorted(tf1_layer_dict.items(), key=lambda x: x[0]):
            layer_name_splits = layer_name.split('/')
            component = layer_name_splits[1]
                
            if len(layer_name_splits) == 3 and layer_name_splits[2] == 'BatchNorm':
                batch_norm_holder[component] = layer_weights
            else:
                conv_holder[component].extend(layer_weights)
    
        print([(name, len(conv_holder[name])) for name in conv_holder])
        print([(name, len(conv_holder[name])) for name in batch_norm_holder])
        for comp in conv_order:
            if comp in conv_holder:
                result_holder.extend(conv_holder[comp])
            if comp in batch_norm_holder:
                result_holder.extend(batch_norm_holder[comp][0:2])
        for comp in batch_norm_order:
            if comp in batch_norm_holder:
                result_holder.extend(batch_norm_holder[comp][2:])
    
    elif tf1_layer_name.startswith('Conv'):
        for (layer_name, layer_weights) in sorted(tf1_layer_dict.items(), key=lambda x: x[0]):
            result_holder.extend(layer_weights)
    
    else:
        raise ValueError

    return result_holder

In [162]:
tf1_layer_name, tf1_layer_dict = get_tf1_layers_by_tf2_name('conv2dbn_block_2', order_keras_weights)
sorted(list(tf1_layer_dict.keys()))


['Conv_2']

In [163]:
len(order_tf1_layer_weights(tf1_layer_dict, tf1_layer_name))

2

## Set weights

In [164]:
for layer in model.layers[1].layers:
    if layer.weights:
        layer_name = layer.name
        print(layer_name)
        tf1_layer_name, tf1_layer_dict = get_tf1_layers_by_tf2_name(layer_name, order_keras_weights)
        tf1_weights = order_tf1_layer_weights(tf1_layer_dict, tf1_layer_name)
        layer.set_weights(tf1_weights)

conv2dbn_block
inverted_bottleneck_block
[('depthwise', 1), ('project', 1)]
[('depthwise', 1), ('project', 1)]
inverted_bottleneck_block_1
[('depthwise', 1), ('expand', 1), ('project', 1)]
[('depthwise', 1), ('expand', 1), ('project', 1)]
inverted_bottleneck_block_2
[('depthwise', 1), ('expand', 1), ('project', 1)]
[('depthwise', 1), ('expand', 1), ('project', 1)]
inverted_bottleneck_block_3
[('depthwise', 1), ('expand', 1), ('project', 1), ('squeeze_excite', 4)]
[('depthwise', 1), ('expand', 1), ('project', 1)]
inverted_bottleneck_block_4
[('depthwise', 1), ('expand', 1), ('project', 1), ('squeeze_excite', 4)]
[('depthwise', 1), ('expand', 1), ('project', 1)]
inverted_bottleneck_block_5
[('depthwise', 1), ('expand', 1), ('project', 1), ('squeeze_excite', 4)]
[('depthwise', 1), ('expand', 1), ('project', 1)]
inverted_bottleneck_block_6
[('depthwise', 1), ('expand', 1), ('project', 1)]
[('depthwise', 1), ('expand', 1), ('project', 1)]
inverted_bottleneck_block_7
[('depthwise', 1), ('exp

In [165]:
head_kernel, head_bias = order_keras_weights['Logits/Conv2d_1c_1x1']
head_kernel = np.squeeze(head_kernel)
(head_kernel.shape, head_bias.shape)

model.layers[4].set_weights([head_kernel, head_bias])

In [166]:
model.summary()

Model: "classification_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, None, None, 3)]   0         
_________________________________________________________________
mobile_net (MobileNet)       {'2': (None, None, None,  4226432   
_________________________________________________________________
global_average_pooling2d_1 ( (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1001)              1282281   
Total params: 5,508,713
Trainable params: 5,484,313
Non-trainable params: 24,400
_________________________________________________________________


## Launch evaluation

In [167]:
from research.mobilenet import mobilenet_trainer

d_config = mobilenet_trainer._get_dataset_config().get('imagenet2012')()
# build evaluation dataset
d_config.split = 'validation'
d_config.batch_size = 128
d_config.one_hot = False
d_config.data_dir = 'gs://tf_mobilenet/imagenet/imagenet-2012-tfrecord'

d_config

ImageNetConfig(name='imagenet2012', data_dir='gs://tf_mobilenet/imagenet/imagenet-2012-tfrecord', filenames=None, builder='records', split='validation', image_size=224, num_classes=1000, num_channels=3, num_examples=1281167, batch_size=128, use_per_replica_batch_size=True, num_devices=1, dtype='float32', one_hot=False, augmenter=AugmentConfig(name='autoaugment', params=None), download=False, shuffle_buffer_size=10000, file_shuffle_buffer_size=100, skip_decoding=True, cache=False, tf_data_service=None, mean_subtract=True, standardize=True, num_eval_examples=50000)

In [168]:
# the checkpoint is trained using slim
eval_dataset = mobilenet_trainer.get_dataset(d_config, slim_preprocess=True)

# compile model
if d_config.one_hot:
    loss_obj = tf.keras.losses.CategoricalCrossentropy()
else:
    loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()

model.compile(
    optimizer='rmsprop',
    loss=loss_obj,
    metrics=[mobilenet_trainer._get_metrics(one_hot=d_config.one_hot)['acc']])

# run evaluation
eval_result = model.evaluate(eval_dataset)



## Save checkpoint

In [169]:
save_format = 'ckpt'
save_path = '/home/jupyter/mobilenet_v3large_1.0'

checkpoint_items = model.checkpoint_items

if save_format == 'ckpt':
    checkpoint = tf.train.Checkpoint(model=model, **checkpoint_items)
    manager = tf.train.CheckpointManager(checkpoint,
                                         directory=save_path,
                                         max_to_keep=3)
    manager.save()
else:
    model.save(save_path, save_format=save_format)