# Dependencies

In [1]:
import tensorflow as tf
from tensorflow.keras.applications import efficientnet

In [2]:
def write_ckpt_to_h5(path_h5, path_ckpt, keras_model, use_ema=True):
  """Map the weights in checkpoint file (tf) to h5 file (keras).
  Args:
    path_h5: str, path to output hdf5 file to write weights loaded from ckpt
      files.
    path_ckpt: str, path to the ckpt files (e.g. 'efficientnet-b0/model.ckpt')
      that records efficientnet weights from original repo
      https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
    keras_model: keras model, built from keras.applications efficientnet
      functions (e.g. EfficientNetB0)
    use_ema: Bool, whether to use ExponentialMovingAverage result or not
  """
  model_name_keras = keras_model.name
  model_name_tf = model_name_keras.replace('efficientnet', 'efficientnet-')

  keras_weight_names = [w.name for w in keras_model.weights]
  tf_weight_names = get_variable_names_from_ckpt(path_ckpt)

  keras_blocks = get_keras_blocks(keras_weight_names)
  tf_blocks = get_tf_blocks(tf_weight_names)

  print('check variables match in each block')
  for keras_block, tf_block in zip(keras_blocks, tf_blocks):
    check_match(keras_block, tf_block, keras_weight_names, tf_weight_names,
                model_name_tf)
    print('{} and {} match.'.format(tf_block, keras_block))

  block_mapping = {x[0]: x[1] for x in zip(keras_blocks, tf_blocks)}

  changed_weights = 0
  for w in keras_model.weights:
    if 'block' in w.name:
      # example: 'block1a_dwconv/depthwise_kernel:0' -> 'block1a'
      keras_block = w.name.split('/')[0].split('_')[0]
      tf_block = block_mapping[keras_block]
      tf_name = keras_name_to_tf_name_block(
          w.name,
          keras_block=keras_block,
          tf_block=tf_block,
          use_ema=use_ema,
          model_name_tf=model_name_tf)
    elif any([x in w.name for x in ['stem', 'top', 'predictions', 'probs']]):
      tf_name = keras_name_to_tf_name_stem_top(
          w.name, use_ema=use_ema, model_name_tf=model_name_tf)
    elif 'normalization' in w.name:
      print('skipping variable {}: normalization is a layer'
            'in keras implementation, but preprocessing in '
            'TF implementation.'.format(w.name))
      continue
    else:
      raise ValueError('{} failed to parse.'.format(w.name))

    try:
      w_tf = tf.train.load_variable(path_ckpt, tf_name)
      if (w.value().numpy() != w_tf).any():
        w.assign(w_tf)
        changed_weights += 1
    except ValueError as e:
      if any([x in w.name for x in ['top', 'predictions', 'probs']]):
        warnings.warn('Fail to load top layer variable {}'
                      'from {} because of {}.'.format(w.name, tf_name, e))
      else:
        raise ValueError('Fail to load {} from {}'.format(w.name, tf_name))

  total_weights = len(keras_model.weights)
  print('{}/{} weights updated'.format(changed_weights, total_weights))
  keras_model.save_weights(path_h5)


def get_variable_names_from_ckpt(path_ckpt, use_ema=True):
  """Get list of tensor names from checkpoint.
  Args:
    path_ckpt: str, path to the ckpt files
    use_ema: Bool, whether to use ExponentialMovingAverage result or not.
  Returns:
    List of variable names from checkpoint.
  """
  v_all = tf.train.list_variables(path_ckpt)

  # keep name only
  v_name_all = [x[0] for x in v_all]

  if use_ema:
    v_name_all = [x for x in v_name_all if 'ExponentialMovingAverage' in x]
  else:
    v_name_all = [x for x in v_name_all if 'ExponentialMovingAverage' not in x]

  # remove util variables used for RMSprop
  v_name_all = [x for x in v_name_all if 'RMS' not in x]
  return v_name_all


def get_tf_blocks(tf_weight_names):
  """Extract the block names from list of full weight names."""
  # Example: 'efficientnet-b0/blocks_0/conv2d/kernel' -> 'blocks_0'
  tf_blocks = {x.split('/')[1] for x in tf_weight_names if 'block' in x}
  # sort by number
  tf_blocks = sorted(tf_blocks, key=lambda x: int(x.split('_')[1]))
  return tf_blocks


def get_keras_blocks(keras_weight_names):
  """Extract the block names from list of full weight names."""
  # example: 'block1a_dwconv/depthwise_kernel:0' -> 'block1a'
  keras_blocks = {x.split('_')[0] for x in keras_weight_names if 'block' in x}
  return sorted(keras_blocks)


def keras_name_to_tf_name_stem_top(keras_name,
                                   use_ema=True,
                                   model_name_tf='efficientnet-b0'):
  """Mapping name in h5 to ckpt that is in stem or top (head).
  we map name keras_name that points to a weight in h5 file
  to a name of weight in ckpt file.
  Args:
    keras_name: str, the name of weight in the h5 file of keras implementation
    use_ema: Bool, use the ExponentialMovingAverage resuolt in ckpt or not
    model_name_tf: str, the name of model in ckpt.
  Returns:
    String for the name of weight as in ckpt file.
  Raises:
    KeyError: if we cannot parse the keras_name.
  """
  if use_ema:
    ema = '/ExponentialMovingAverage'
  else:
    ema = ''

  stem_top_dict = {
      'probs/bias:0': '{}/head/dense/bias{}',
      'probs/kernel:0': '{}/head/dense/kernel{}',
      'predictions/bias:0': '{}/head/dense/bias{}',
      'predictions/kernel:0': '{}/head/dense/kernel{}',
      'stem_conv/kernel:0': '{}/stem/conv2d/kernel{}',
      'top_conv/kernel:0': '{}/head/conv2d/kernel{}',
  }
  for x in stem_top_dict:
    stem_top_dict[x] = stem_top_dict[x].format(model_name_tf, ema)

  # stem batch normalization
  for bn_weights in ['beta', 'gamma', 'moving_mean', 'moving_variance']:
    tf_name = '{}/stem/tpu_batch_normalization/{}{}'.format(
        model_name_tf, bn_weights, ema)
    stem_top_dict['stem_bn/{}:0'.format(bn_weights)] = tf_name

  # top / head batch normalization
  for bn_weights in ['beta', 'gamma', 'moving_mean', 'moving_variance']:
    tf_name = '{}/head/tpu_batch_normalization/{}{}'.format(
        model_name_tf, bn_weights, ema)
    stem_top_dict['top_bn/{}:0'.format(bn_weights)] = tf_name

  if keras_name in stem_top_dict:
    return stem_top_dict[keras_name]
  raise KeyError('{} from h5 file cannot be parsed'.format(keras_name))


def keras_name_to_tf_name_block(keras_name,
                                keras_block='block1a',
                                tf_block='blocks_0',
                                use_ema=True,
                                model_name_tf='efficientnet-b0'):
  """Mapping name in h5 to ckpt that belongs to a block.
  we map name keras_name that points to a weight in h5 file
  to a name of weight in ckpt file.
  Args:
    keras_name: str, the name of weight in the h5 file of keras implementation
    keras_block: str, the block name for keras implementation (e.g. 'block1a')
    tf_block: str, the block name for tf implementation (e.g. 'blocks_0')
    use_ema: Bool, use the ExponentialMovingAverage resuolt in ckpt or not
    model_name_tf: str, the name of model in ckpt.
  Returns:
    String for the name of weight as in ckpt file.
  Raises:
    ValueError if keras_block does not show up in keras_name
  """

  if keras_block not in keras_name:
    raise ValueError('block name {} not found in {}'.format(
        keras_block, keras_name))

  # all blocks in the first group will not have expand conv and bn
  is_first_blocks = (keras_block[5] == '1')

  tf_name = [model_name_tf, tf_block]

  # depthwide conv
  if 'dwconv' in keras_name:
    tf_name.append('depthwise_conv2d')
    tf_name.append('depthwise_kernel')

  # conv layers
  if is_first_blocks:
    # first blocks only have one conv2d
    if 'project_conv' in keras_name:
      tf_name.append('conv2d')
      tf_name.append('kernel')
  else:
    if 'project_conv' in keras_name:
      tf_name.append('conv2d_1')
      tf_name.append('kernel')
    elif 'expand_conv' in keras_name:
      tf_name.append('conv2d')
      tf_name.append('kernel')

  # squeeze expansion layers
  if '_se_' in keras_name:
    if 'reduce' in keras_name:
      tf_name.append('se/conv2d')
    elif 'expand' in keras_name:
      tf_name.append('se/conv2d_1')

    if 'kernel' in keras_name:
      tf_name.append('kernel')
    elif 'bias' in keras_name:
      tf_name.append('bias')

  # batch normalization layers
  if 'bn' in keras_name:
    if is_first_blocks:
      if 'project' in keras_name:
        tf_name.append('tpu_batch_normalization_1')
      else:
        tf_name.append('tpu_batch_normalization')
    else:
      if 'project' in keras_name:
        tf_name.append('tpu_batch_normalization_2')
      elif 'expand' in keras_name:
        tf_name.append('tpu_batch_normalization')
      else:
        tf_name.append('tpu_batch_normalization_1')

    for x in ['moving_mean', 'moving_variance', 'beta', 'gamma']:
      if x in keras_name:
        tf_name.append(x)
  if use_ema:
    tf_name.append('ExponentialMovingAverage')
  return '/'.join(tf_name)


def check_match(keras_block, tf_block, keras_weight_names, tf_weight_names,
                model_name_tf):
  """Check if the weights in h5 and ckpt match.
  we match each name from keras_weight_names that is in keras_block
  and check if there is 1-1 correspondence to names from tf_weight_names
  that is in tf_block
  Args:
    keras_block: str, the block name for keras implementation (e.g. 'block1a')
    tf_block: str, the block name for tf implementation (e.g. 'blocks_0')
    keras_weight_names: list of str, weight names in keras implementation
    tf_weight_names: list of str, weight names in tf implementation
    model_name_tf: str, the name of model in ckpt.
  """
  names_from_keras = set()
  for x in keras_weight_names:
    if keras_block in x:
      y = keras_name_to_tf_name_block(
          x,
          keras_block=keras_block,
          tf_block=tf_block,
          model_name_tf=model_name_tf)
      names_from_keras.add(y)

  names_from_tf = set()
  for x in tf_weight_names:
    if tf_block in x and x.split('/')[1].endswith(tf_block):
      names_from_tf.add(x)

  names_missing = names_from_keras - names_from_tf
  if names_missing:
    raise ValueError('{} variables not found in checkpoint file: {}'.format(
        len(names_missing), names_missing))

  names_unused = names_from_tf - names_from_keras
  if names_unused:
    warnings.warn('{} variables from checkpoint file are not used: {}'.format(
        len(names_unused), names_unused))

# Download checkpoints

In [3]:
!wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b0.tar.gz
!wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b1.tar.gz
!wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b2.tar.gz
!wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b3.tar.gz
!wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b4.tar.gz
!wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b5.tar.gz
!wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b6.tar.gz
!wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b7.tar.gz

--2021-02-15 17:13:48--  https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b0.tar.gz
Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.193.128, 74.125.31.128, 172.217.204.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.193.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 156406658 (149M) [application/x-tar]
Saving to: ‘noisy_student_efficientnet-b0.tar.gz’


2021-02-15 17:13:50 (81.2 MB/s) - ‘noisy_student_efficientnet-b0.tar.gz’ saved [156406658/156406658]

--2021-02-15 17:13:51--  https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b1.tar.gz
Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.217.128, 172.217.193.128, 173.194.216.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.217.128|:443... connected.
HTTP request sent, awaiting respons

# Unzip checkpoints

In [4]:
!tar -xf noisy_student_efficientnet-b0.tar.gz
!tar -xf noisy_student_efficientnet-b1.tar.gz
!tar -xf noisy_student_efficientnet-b2.tar.gz
!tar -xf noisy_student_efficientnet-b3.tar.gz
!tar -xf noisy_student_efficientnet-b4.tar.gz
!tar -xf noisy_student_efficientnet-b5.tar.gz
!tar -xf noisy_student_efficientnet-b6.tar.gz
!tar -xf noisy_student_efficientnet-b7.tar.gz

# Convert checkpoints to .h5 weights

In [5]:
write_ckpt_to_h5('efficientnetb0_notop.h5', 
                 'noisy_student_efficientnet-b0/model.ckpt', 
                 keras_model=efficientnet.EfficientNetB0(include_top=None))

Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb0_notop.h5
check variables match in each block
blocks_0 and block1a match.
blocks_1 and block2a match.
blocks_2 and block2b match.
blocks_3 and block3a match.
blocks_4 and block3b match.
blocks_5 and block4a match.
blocks_6 and block4b match.
blocks_7 and block4c match.
blocks_8 and block5a match.
blocks_9 and block5b match.
blocks_10 and block5c match.
blocks_11 and block6a match.
blocks_12 and block6b match.
blocks_13 and block6c match.
blocks_14 and block6d match.
blocks_15 and block7a match.
skipping variable normalization/mean:0: normalization is a layerin keras implementation, but preprocessing in TF implementation.
skipping variable normalization/variance:0: normalization is a layerin keras implementation, but preprocessing in TF implementation.
skipping variable normalization/count:0: normalization is a layerin keras implementation, but preprocessing in TF implementation.
309/312 weights update

In [6]:
write_ckpt_to_h5('efficientnetb1_notop.h5', 
                 'noisy-student-efficientnet-b1/model.ckpt', 
                 keras_model=efficientnet.EfficientNetB1(include_top=None))

Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb1_notop.h5
check variables match in each block
blocks_0 and block1a match.
blocks_1 and block1b match.
blocks_2 and block2a match.
blocks_3 and block2b match.
blocks_4 and block2c match.
blocks_5 and block3a match.
blocks_6 and block3b match.
blocks_7 and block3c match.
blocks_8 and block4a match.
blocks_9 and block4b match.
blocks_10 and block4c match.
blocks_11 and block4d match.
blocks_12 and block5a match.
blocks_13 and block5b match.
blocks_14 and block5c match.
blocks_15 and block5d match.
blocks_16 and block6a match.
blocks_17 and block6b match.
blocks_18 and block6c match.
blocks_19 and block6d match.
blocks_20 and block6e match.
blocks_21 and block7a match.
blocks_22 and block7b match.
skipping variable normalization_1/mean:0: normalization is a layerin keras implementation, but preprocessing in TF implementation.
skipping variable normalization_1/variance:0: normalization is a layerin keras i

In [7]:
write_ckpt_to_h5('efficientnetb2_notop.h5', 
                 'noisy-student-efficientnet-b2/model.ckpt', 
                 keras_model=efficientnet.EfficientNetB2(include_top=None))

Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb2_notop.h5
check variables match in each block
blocks_0 and block1a match.
blocks_1 and block1b match.
blocks_2 and block2a match.
blocks_3 and block2b match.
blocks_4 and block2c match.
blocks_5 and block3a match.
blocks_6 and block3b match.
blocks_7 and block3c match.
blocks_8 and block4a match.
blocks_9 and block4b match.
blocks_10 and block4c match.
blocks_11 and block4d match.
blocks_12 and block5a match.
blocks_13 and block5b match.
blocks_14 and block5c match.
blocks_15 and block5d match.
blocks_16 and block6a match.
blocks_17 and block6b match.
blocks_18 and block6c match.
blocks_19 and block6d match.
blocks_20 and block6e match.
blocks_21 and block7a match.
blocks_22 and block7b match.
skipping variable normalization_2/mean:0: normalization is a layerin keras implementation, but preprocessing in TF implementation.
skipping variable normalization_2/variance:0: normalization is a layerin keras i

In [8]:
write_ckpt_to_h5('efficientnetb3_notop.h5', 
                 'noisy-student-efficientnet-b3/model.ckpt', 
                 keras_model=efficientnet.EfficientNetB3(include_top=None))

Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb3_notop.h5
check variables match in each block
blocks_0 and block1a match.
blocks_1 and block1b match.
blocks_2 and block2a match.
blocks_3 and block2b match.
blocks_4 and block2c match.
blocks_5 and block3a match.
blocks_6 and block3b match.
blocks_7 and block3c match.
blocks_8 and block4a match.
blocks_9 and block4b match.
blocks_10 and block4c match.
blocks_11 and block4d match.
blocks_12 and block4e match.
blocks_13 and block5a match.
blocks_14 and block5b match.
blocks_15 and block5c match.
blocks_16 and block5d match.
blocks_17 and block5e match.
blocks_18 and block6a match.
blocks_19 and block6b match.
blocks_20 and block6c match.
blocks_21 and block6d match.
blocks_22 and block6e match.
blocks_23 and block6f match.
blocks_24 and block7a match.
blocks_25 and block7b match.
skipping variable normalization_3/mean:0: normalization is a layerin keras implementation, but preprocessing in TF implement

In [9]:
write_ckpt_to_h5('efficientnetb4_notop.h5', 
                 'noisy-student-efficientnet-b4/model.ckpt', 
                 keras_model=efficientnet.EfficientNetB4(include_top=None))

Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb4_notop.h5
check variables match in each block
blocks_0 and block1a match.
blocks_1 and block1b match.
blocks_2 and block2a match.
blocks_3 and block2b match.
blocks_4 and block2c match.
blocks_5 and block2d match.
blocks_6 and block3a match.
blocks_7 and block3b match.
blocks_8 and block3c match.
blocks_9 and block3d match.
blocks_10 and block4a match.
blocks_11 and block4b match.
blocks_12 and block4c match.
blocks_13 and block4d match.
blocks_14 and block4e match.
blocks_15 and block4f match.
blocks_16 and block5a match.
blocks_17 and block5b match.
blocks_18 and block5c match.
blocks_19 and block5d match.
blocks_20 and block5e match.
blocks_21 and block5f match.
blocks_22 and block6a match.
blocks_23 and block6b match.
blocks_24 and block6c match.
blocks_25 and block6d match.
blocks_26 and block6e match.
blocks_27 and block6f match.
blocks_28 and block6g match.
blocks_29 and block6h match.
blocks_3

In [10]:
write_ckpt_to_h5('efficientnetb5_notop.h5', 
                 'noisy-student-efficientnet-b5/model.ckpt', 
                 keras_model=efficientnet.EfficientNetB5(include_top=None))

Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb5_notop.h5
check variables match in each block
blocks_0 and block1a match.
blocks_1 and block1b match.
blocks_2 and block1c match.
blocks_3 and block2a match.
blocks_4 and block2b match.
blocks_5 and block2c match.
blocks_6 and block2d match.
blocks_7 and block2e match.
blocks_8 and block3a match.
blocks_9 and block3b match.
blocks_10 and block3c match.
blocks_11 and block3d match.
blocks_12 and block3e match.
blocks_13 and block4a match.
blocks_14 and block4b match.
blocks_15 and block4c match.
blocks_16 and block4d match.
blocks_17 and block4e match.
blocks_18 and block4f match.
blocks_19 and block4g match.
blocks_20 and block5a match.
blocks_21 and block5b match.
blocks_22 and block5c match.
blocks_23 and block5d match.
blocks_24 and block5e match.
blocks_25 and block5f match.
blocks_26 and block5g match.
blocks_27 and block6a match.
blocks_28 and block6b match.
blocks_29 and block6c match.
blocks_3

In [11]:
write_ckpt_to_h5('efficientnetb6_notop.h5', 
                 'noisy-student-efficientnet-b6/model.ckpt', 
                 keras_model=efficientnet.EfficientNetB6(include_top=None))

Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb6_notop.h5
check variables match in each block
blocks_0 and block1a match.
blocks_1 and block1b match.
blocks_2 and block1c match.
blocks_3 and block2a match.
blocks_4 and block2b match.
blocks_5 and block2c match.
blocks_6 and block2d match.
blocks_7 and block2e match.
blocks_8 and block2f match.
blocks_9 and block3a match.
blocks_10 and block3b match.
blocks_11 and block3c match.
blocks_12 and block3d match.
blocks_13 and block3e match.
blocks_14 and block3f match.
blocks_15 and block4a match.
blocks_16 and block4b match.
blocks_17 and block4c match.
blocks_18 and block4d match.
blocks_19 and block4e match.
blocks_20 and block4f match.
blocks_21 and block4g match.
blocks_22 and block4h match.
blocks_23 and block5a match.
blocks_24 and block5b match.
blocks_25 and block5c match.
blocks_26 and block5d match.
blocks_27 and block5e match.
blocks_28 and block5f match.
blocks_29 and block5g match.
blocks_3

In [12]:
write_ckpt_to_h5('efficientnetb7_notop.h5', 
                 'noisy-student-efficientnet-b7/model.ckpt', 
                 keras_model=efficientnet.EfficientNetB7(include_top=None))

Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb7_notop.h5
check variables match in each block
blocks_0 and block1a match.
blocks_1 and block1b match.
blocks_2 and block1c match.
blocks_3 and block1d match.
blocks_4 and block2a match.
blocks_5 and block2b match.
blocks_6 and block2c match.
blocks_7 and block2d match.
blocks_8 and block2e match.
blocks_9 and block2f match.
blocks_10 and block2g match.
blocks_11 and block3a match.
blocks_12 and block3b match.
blocks_13 and block3c match.
blocks_14 and block3d match.
blocks_15 and block3e match.
blocks_16 and block3f match.
blocks_17 and block3g match.
blocks_18 and block4a match.
blocks_19 and block4b match.
blocks_20 and block4c match.
blocks_21 and block4d match.
blocks_22 and block4e match.
blocks_23 and block4f match.
blocks_24 and block4g match.
blocks_25 and block4h match.
blocks_26 and block4i match.
blocks_27 and block4j match.
blocks_28 and block5a match.
blocks_29 and block5b match.
blocks_3

# Clean-up

In [13]:
# Files
!rm noisy_student_efficientnet-b0.tar.gz
!rm noisy_student_efficientnet-b1.tar.gz
!rm noisy_student_efficientnet-b2.tar.gz
!rm noisy_student_efficientnet-b3.tar.gz
!rm noisy_student_efficientnet-b4.tar.gz
!rm noisy_student_efficientnet-b5.tar.gz
!rm noisy_student_efficientnet-b6.tar.gz
!rm noisy_student_efficientnet-b7.tar.gz

# Directories
!rm -r noisy_student_efficientnet-b0
!rm -r noisy-student-efficientnet-b1
!rm -r noisy-student-efficientnet-b2
!rm -r noisy-student-efficientnet-b3
!rm -r noisy-student-efficientnet-b4
!rm -r noisy-student-efficientnet-b5
!rm -r noisy-student-efficientnet-b6
!rm -r noisy-student-efficientnet-b7