In [None]:
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# CFAR-10

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

batch_size = 4

trainset     = torchvision.datasets.CIFAR10(root='./data' ,train=True,download=True,transform=transform)

trainloader   = torch.utils.data.DataLoader(trainset , batch_size=batch_size,shuffle=True , num_workers=2)

testset     = torchvision.datasets.CIFAR10(root='./data' ,train=True,download=True , transform=transform)

testloader =torch.utils.data.DataLoader(testset , batch_size=batch_size ,shuffle=False ,num_workers=2)

classes = ('plane' , 'car' ,'bird' , 'cat' , 'deer' ,'dog' ,'frog' ,'horse' ,'ship' , 'truck')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:

import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    img= img/2 +0.5
    npimg=img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)))
    plt.show()
    
dataiter= iter(trainloader)
images , labels =dataiter.next()

imshow(torchvision.utils.make_grid(images))

print(''.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))


# lars_optimizer

In [None]:
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layer-wise Adaptive Rate Scaling optimizer for large-batch training."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v1 as tf


class LARSOptimizer(tf.train.Optimizer):
  """Layer-wise Adaptive Rate Scaling for large batch training.
  Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
  I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
  Implements the LARS learning rate scheme presented in the paper above. This
  optimizer is useful when scaling the batch size to up to 32K without
  significant performance degradation. It is recommended to use the optimizer
  in conjunction with:
      - Gradual learning rate warm-up
      - Linear learning rate scaling
      - Poly rule learning rate decay
  Note, LARS scaling is currently only enabled for dense tensors. Sparse tensors
  use the default momentum optimizer.
  """

  def __init__(
      self,
      learning_rate,
      momentum=0.9,
      weight_decay=0.0001,
      # The LARS coefficient is a hyperparameter
      eeta=0.001,
      epsilon=0.0,
      name="LARSOptimizer",
      # Enable skipping variables from LARS scaling.
      # TODO(sameerkm): Enable a direct mechanism to pass a
      # subset of variables to the optimizer.
      skip_list=None,
      use_nesterov=False):
    """Construct a new LARS Optimizer.
    Args:
      learning_rate: A `Tensor` or floating point value. The base learning rate.
      momentum: A floating point value. Momentum hyperparameter.
      weight_decay: A floating point value. Weight decay hyperparameter.
      eeta: LARS coefficient as used in the paper. Dfault set to LARS
        coefficient from the paper. (eeta / weight_decay) determines the highest
        scaling factor in LARS.
      epsilon: Optional epsilon parameter to be set in models that have very
        small gradients. Default set to 0.0.
      name: Optional name prefix for variables and ops created by LARSOptimizer.
      skip_list: List of strings to enable skipping variables from LARS scaling.
        If any of the strings in skip_list is a subset of var.name, variable
        'var' is skipped from LARS scaling. For a typical classification model
        with batch normalization, the skip_list is ['batch_normalization',
        'bias']
      use_nesterov: when set to True, nesterov momentum will be enabled
    Raises:
      ValueError: If a hyperparameter is set to a non-sensical value.
    """
    if momentum < 0.0:
      raise ValueError("momentum should be positive: %s" % momentum)
    if weight_decay < 0.0:
      raise ValueError("weight_decay should be positive: %s" % weight_decay)
    super(LARSOptimizer, self).__init__(use_locking=False, name=name)

    self._learning_rate = learning_rate
    self._momentum = momentum
    self._weight_decay = weight_decay
    self._eeta = eeta
    self._epsilon = epsilon
    self._name = name
    self._skip_list = skip_list
    self._use_nesterov = use_nesterov

  def _create_slots(self, var_list):
    for v in var_list:
      self._zeros_slot(v, "momentum", self._name)

  def compute_lr(self, grad, var):
    scaled_lr = self._learning_rate
    if self._skip_list is None or not any(v in var.name
                                          for v in self._skip_list):
      w_norm = tf.norm(var, ord=2)
      g_norm = tf.norm(grad, ord=2)
      trust_ratio = tf.where(
          tf.math.greater(w_norm, 0),
          tf.where(
              tf.math.greater(g_norm, 0),
              (self._eeta * w_norm /
               (g_norm + self._weight_decay * w_norm + self._epsilon)), 1.0),
          1.0)
      scaled_lr = self._learning_rate * trust_ratio
      # Add the weight regularization gradient
      grad = grad + self._weight_decay * var
    return scaled_lr, grad

  def _apply_dense(self, grad, var):
    scaled_lr, grad = self.compute_lr(grad, var)
    mom = self.get_slot(var, "momentum")
    return tf.raw_ops.ApplyMomentum(
        var,
        mom,
        tf.cast(1.0, var.dtype.base_dtype),
        grad * scaled_lr,
        self._momentum,
        use_locking=False,
        use_nesterov=self._use_nesterov)

  def _resource_apply_dense(self, grad, var):
    scaled_lr, grad = self.compute_lr(grad, var)
    mom = self.get_slot(var, "momentum")
    return tf.raw_ops.ResourceApplyMomentum(
        var=var.handle,
        accum=mom.handle,
        lr=tf.cast(1.0, var.dtype.base_dtype),
        grad=grad * scaled_lr,
        momentum=self._momentum,
        use_locking=False,
        use_nesterov=self._use_nesterov)

  # Fallback to momentum optimizer for sparse tensors
  def _apply_sparse(self, grad, var):
    mom = self.get_slot(var, "momentum")
    return tf.raw_ops.SparseApplyMomentum(
        var,
        mom,
        tf.cast(self._learning_rate_tensor, var.dtype.base_dtype),
        grad.values,
        grad.indices,
        tf.cast(self._momentum_tensor, var.dtype.base_dtype),
        use_locking=self._use_locking,
        use_nesterov=self._use_nesterov).op

  def _resource_apply_sparse(self, grad, var, indices):
    mom = self.get_slot(var, "momentum")
    return tf.raw_ops.ResourceSparseApplyMomentum(
        var.handle,
        mom.handle,
        tf.cast(self._learning_rate_tensor, grad.dtype),
        grad,
        indices,
        tf.cast(self._momentum_tensor, grad.dtype),
        use_locking=self._use_locking,
        use_nesterov=self._use_nesterov)

  def _prepare(self):
    learning_rate = self._learning_rate
    if callable(learning_rate):
      learning_rate = learning_rate()
    self._learning_rate_tensor = tf.convert_to_tensor(
        learning_rate, name="learning_rate")
    momentum = self._momentum
    if callable(momentum):
      momentum = momentum()
    self._momentum_tensor = tf.convert_to_tensor(momentum, name="momentum")

# utils.py 

In [None]:
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model utilities."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import os
import sys

from absl import flags
from absl import logging
import numpy as np
import tensorflow.compat.v1 as tf

#import lars_optimizer
from tensorflow.python.tpu import tpu_function  # pylint:disable=g-direct-tensorflow-import

FLAGS = flags.FLAGS


def build_learning_rate(initial_lr,
                        global_step,
                        steps_per_epoch=None,
                        lr_decay_type='exponential',
                        decay_factor=0.97,
                        decay_epochs=2.4,
                        total_steps=None,
                        warmup_epochs=5):
  """Build learning rate."""
  if lr_decay_type == 'exponential':
    assert steps_per_epoch is not None
    decay_steps = steps_per_epoch * decay_epochs
    lr = tf.train.exponential_decay(
        initial_lr, global_step, decay_steps, decay_factor, staircase=True)
  elif lr_decay_type == 'cosine':
    assert total_steps is not None
    lr = 0.5 * initial_lr * (
        1 + tf.cos(np.pi * tf.cast(global_step, tf.float32) / total_steps))
  elif lr_decay_type == 'constant':
    lr = initial_lr
  elif lr_decay_type == 'poly':
    tf.logging.info('Using poly LR schedule')
    assert steps_per_epoch is not None
    assert total_steps is not None
    warmup_steps = int(steps_per_epoch * warmup_epochs)
    min_step = tf.constant(1, dtype=tf.int64)
    decay_steps = tf.maximum(min_step, tf.subtract(global_step, warmup_steps))
    lr = tf.train.polynomial_decay(
        initial_lr,
        decay_steps,
        total_steps - warmup_steps + 1,
        end_learning_rate=0.1,
        power=2.0)
  else:
    assert False, 'Unknown lr_decay_type : %s' % lr_decay_type

  if warmup_epochs:
    logging.info('Learning rate warmup_epochs: %d', warmup_epochs)
    warmup_steps = int(warmup_epochs * steps_per_epoch)
    warmup_lr = (
        initial_lr * tf.cast(global_step, tf.float32) / tf.cast(
            warmup_steps, tf.float32))
    lr = tf.cond(global_step < warmup_steps, lambda: warmup_lr, lambda: lr)

  return lr


def build_optimizer(learning_rate,
                    optimizer_name='rmsprop',
                    decay=0.9,
                    epsilon=0.001,
                    momentum=0.9,
                    lars_weight_decay=None,
                    lars_epsilon=None):
  """Build optimizer."""
  if optimizer_name == 'sgd':
    logging.info('Using SGD optimizer')
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
  elif optimizer_name == 'momentum':
    logging.info('Using Momentum optimizer')
    optimizer = tf.train.MomentumOptimizer(
        learning_rate=learning_rate, momentum=momentum)
  elif optimizer_name == 'rmsprop':
    logging.info('Using RMSProp optimizer')
    optimizer = tf.train.RMSPropOptimizer(learning_rate, decay, momentum,
                                          epsilon)
  elif optimizer_name == 'lars':
    logging.info('Using LARS optimizer')
    assert lars_weight_decay is not None, 'LARS weight decay is None.'
    assert lars_epsilon is not None, 'LARS epsilon is None.'
    optimizer = lars_optimizer.LARSOptimizer(
        learning_rate,
        momentum=momentum,
        weight_decay=lars_weight_decay,
        skip_list=['batch_normalization', 'bias', 'beta', 'gamma'],
        epsilon=lars_epsilon)
  else:
    logging.fatal('Unknown optimizer: %s', optimizer_name)

  return optimizer


class TpuBatchNormalization(tf.layers.BatchNormalization):
  # class TpuBatchNormalization(tf.layers.BatchNormalization):
  """Cross replica batch normalization."""

  def __init__(self, fused=False, **kwargs):
    if fused in (True, None):
      raise ValueError('TpuBatchNormalization does not support fused=True.')
    super(TpuBatchNormalization, self).__init__(fused=fused, **kwargs)

  def _cross_replica_average(self, t, num_shards_per_group):
    """Calculates the average value of input tensor across TPU replicas."""
    num_shards = tpu_function.get_tpu_context().number_of_shards
    group_assignment = None
    if num_shards_per_group > 1:
      if num_shards % num_shards_per_group != 0:
        raise ValueError('num_shards: %d mod shards_per_group: %d, should be 0'
                         % (num_shards, num_shards_per_group))
      num_groups = num_shards // num_shards_per_group
      group_assignment = [[
          x for x in range(num_shards) if x // num_shards_per_group == y
      ] for y in range(num_groups)]
    return tf.tpu.cross_replica_sum(t, group_assignment) / tf.cast(
        num_shards_per_group, t.dtype)

  def _moments(self, inputs, reduction_axes, keep_dims):
    """Compute the mean and variance: it overrides the original _moments."""
    shard_mean, shard_variance = super(TpuBatchNormalization, self)._moments(
        inputs, reduction_axes, keep_dims=keep_dims)

    num_shards = tpu_function.get_tpu_context().number_of_shards or 1
    if num_shards <= 8:  # Skip cross_replica for 2x2 or smaller slices.
      num_shards_per_group = 1
    else:
      num_shards_per_group = max(8, num_shards // 8)
    logging.info('TpuBatchNormalization with num_shards_per_group %s',
                 num_shards_per_group)
    if num_shards_per_group > 1:
      # Compute variance using: Var[X]= E[X^2] - E[X]^2.
      shard_square_of_mean = tf.math.square(shard_mean)
      shard_mean_of_square = shard_variance + shard_square_of_mean
      group_mean = self._cross_replica_average(
          shard_mean, num_shards_per_group)
      group_mean_of_square = self._cross_replica_average(
          shard_mean_of_square, num_shards_per_group)
      group_variance = group_mean_of_square - tf.math.square(group_mean)
      return (group_mean, group_variance)
    else:
      return (shard_mean, shard_variance)


class BatchNormalization(tf.layers.BatchNormalization):
  """Fixed default name of BatchNormalization to match TpuBatchNormalization."""

  def __init__(self, name='tpu_batch_normalization', **kwargs):
    super(BatchNormalization, self).__init__(name=name, **kwargs)


def train_batch_norm(**kwargs):
  if 'optimizer' in FLAGS and FLAGS.optimizer == 'lars':
    return DistributedBatchNormalization(**kwargs)
  return TpuBatchNormalization(**kwargs)


def eval_batch_norm(**kwargs):
  if 'optimizer' in FLAGS and FLAGS.optimizer == 'lars':
    return DistributedBatchNormalization(**kwargs)
  return BatchNormalization(**kwargs)


class DistributedBatchNormalization:
  """Distributed batch normalization used in https://arxiv.org/abs/2011.00071."""

  def __init__(self, axis, momentum, epsilon):
    self.axis = axis
    self.momentum = momentum
    self.epsilon = epsilon

  def __call__(self, x, training, distname='batch_normalization'):
    shape = [x.shape[-1]]
    with tf.variable_scope('batch_normalization'):
      ones = tf.initializers.ones()
      zeros = tf.initializers.zeros()
      gamma = tf.get_variable(
          'gamma', shape, initializer=ones, trainable=True, use_resource=True)
      beta = tf.get_variable(
          'beta', shape, initializer=zeros, trainable=True, use_resource=True)
      moving_mean = tf.get_variable(
          'moving_mean',
          shape,
          initializer=zeros,
          trainable=False,
          use_resource=True)
      moving_variance = tf.get_variable(
          'moving_variance',
          shape,
          initializer=ones,
          trainable=False,
          use_resource=True)
    num_replicas = FLAGS.num_replicas

    x = tf.cast(x, tf.float32)
    if training:
      if num_replicas <= 8:
        group_assign = None
        group_shards = tf.cast(num_replicas, tf.float32)
      else:

        group_shards = max(
            1,
            int(FLAGS.batch_norm_batch_size /
                (FLAGS.train_batch_size / num_replicas)))
        group_assign = np.arange(num_replicas, dtype=np.int32)
        group_assign = group_assign.reshape([-1, group_shards])
        group_assign = group_assign.tolist()
        group_shards = tf.cast(group_shards, tf.float32)

      mean = tf.reduce_mean(x, [0, 1, 2])
      mean = tf.tpu.cross_replica_sum(mean, group_assign) / group_shards

      # Var[x] = E[x^2] - E[x]^2
      mean_sq = tf.reduce_mean(tf.math.square(x), [0, 1, 2])
      mean_sq = tf.tpu.cross_replica_sum(mean_sq, group_assign) / group_shards
      variance = mean_sq - tf.math.square(mean)

      decay = tf.cast(1. - self.momentum, tf.float32)

      def u(moving, normal, name):
        num_replicas_fp = tf.cast(num_replicas, tf.float32)
        normal = tf.tpu.cross_replica_sum(normal) / num_replicas_fp
        diff = decay * (moving - normal)
        return tf.assign_sub(moving, diff, use_locking=True, name=name)

      tf.add_to_collection(tf.GraphKeys.UPDATE_OPS,
                           u(moving_mean, mean, name='moving_mean'))
      tf.add_to_collection(tf.GraphKeys.UPDATE_OPS,
                           u(moving_variance, variance, name='moving_variance'))

      x = tf.nn.batch_normalization(
          x,
          mean=mean,
          variance=variance,
          offset=beta,
          scale=gamma,
          variance_epsilon=self.epsilon)
    else:

      x, _, _ = tf.nn.fused_batch_norm(
          x,
          scale=gamma,
          offset=beta,
          mean=moving_mean,
          variance=moving_variance,
          epsilon=self.epsilon,
          is_training=False)

    return x


def drop_connect(inputs, is_training, survival_prob):
  """Drop the entire conv with given survival probability."""
  # "Deep Networks with Stochastic Depth", https://arxiv.org/pdf/1603.09382.pdf
  if not is_training:
    return inputs

  # Compute tensor.
  batch_size = tf.shape(inputs)[0]
  random_tensor = survival_prob
  random_tensor += tf.random_uniform([batch_size, 1, 1, 1], dtype=inputs.dtype)
  binary_tensor = tf.floor(random_tensor)
  # Unlike conventional way that multiply survival_prob at test time, here we
  # divide survival_prob at training time, such that no addition compute is
  # needed at test time.
  output = tf.div(inputs, survival_prob) * binary_tensor
  return output


def archive_ckpt(ckpt_eval, ckpt_objective, ckpt_path):
  """Archive a checkpoint if the metric is better."""
  ckpt_dir, ckpt_name = os.path.split(ckpt_path)

  saved_objective_path = os.path.join(ckpt_dir, 'best_objective.txt')
  saved_objective = float('-inf')
  if tf.gfile.Exists(saved_objective_path):
    with tf.gfile.GFile(saved_objective_path, 'r') as f:
      saved_objective = float(f.read())
  if saved_objective > ckpt_objective:
    logging.info('Ckpt %s is worse than %s', ckpt_objective, saved_objective)
    return False

  filenames = tf.gfile.Glob(ckpt_path + '.*')
  if filenames is None:
    logging.info('No files to copy for checkpoint %s', ckpt_path)
    return False

  # Clear the old folder.
  dst_dir = os.path.join(ckpt_dir, 'archive')
  if tf.gfile.Exists(dst_dir):
    tf.gfile.DeleteRecursively(dst_dir)
  tf.gfile.MakeDirs(dst_dir)

  # Write checkpoints.
  for f in filenames:
    dest = os.path.join(dst_dir, os.path.basename(f))
    tf.gfile.Copy(f, dest, overwrite=True)
  ckpt_state = tf.train.generate_checkpoint_state_proto(
      dst_dir,
      model_checkpoint_path=ckpt_name,
      all_model_checkpoint_paths=[ckpt_name])
  with tf.gfile.GFile(os.path.join(dst_dir, 'checkpoint'), 'w') as f:
    f.write(str(ckpt_state))
  with tf.gfile.GFile(os.path.join(dst_dir, 'best_eval.txt'), 'w') as f:
    f.write('%s' % ckpt_eval)

  # Update the best objective.
  with tf.gfile.GFile(saved_objective_path, 'w') as f:
    f.write('%f' % ckpt_objective)

  logging.info('Copying checkpoint %s to %s', ckpt_path, dst_dir)
  return True


def get_ema_vars():
  """Get all exponential moving average (ema) variables."""
  ema_vars = tf.trainable_variables() + tf.get_collection('moving_vars')
  for v in tf.global_variables():
    # We maintain mva for batch norm moving mean and variance as well.
    if 'moving_mean' in v.name or 'moving_variance' in v.name:
      ema_vars.append(v)
  return list(set(ema_vars))


class DepthwiseConv2D(tf.keras.layers.DepthwiseConv2D, tf.layers.Layer):
  """Wrap keras DepthwiseConv2D to tf.layers."""

  pass


class Conv2D(tf.layers.Conv2D):
  """Wrapper for Conv2D with specialization for fast inference."""

  def _bias_activation(self, outputs):
    if self.use_bias:
      outputs = tf.nn.bias_add(outputs, self.bias, data_format='NCHW')
    if self.activation is not None:
      return self.activation(outputs)
    return outputs

  def _can_run_fast_1x1(self, inputs):
    batch_size = inputs.shape.as_list()[0]
    return (self.data_format == 'channels_first' and
            batch_size == 1 and
            self.kernel_size == (1, 1))

  def _call_fast_1x1(self, inputs):
    # Compute the 1x1 convolution as a matmul.
    inputs_shape = tf.shape(inputs)
    flat_inputs = tf.reshape(inputs, [inputs_shape[1], -1])
    flat_outputs = tf.matmul(
        tf.squeeze(self.kernel),
        flat_inputs,
        transpose_a=True)
    outputs_shape = tf.concat([[1, self.filters], inputs_shape[2:]], axis=0)
    outputs = tf.reshape(flat_outputs, outputs_shape)

    # Handle the bias and activation function.
    return self._bias_activation(outputs)

  def call(self, inputs):
    if self._can_run_fast_1x1(inputs):
      return self._call_fast_1x1(inputs)
    return super(Conv2D, self).call(inputs)


class EvalCkptDriver(object):
  """A driver for running eval inference.
  Attributes:
    model_name: str. Model name to eval.
    batch_size: int. Eval batch size.
    image_size: int. Input image size, determined by model name.
    num_classes: int. Number of classes, default to 1000 for ImageNet.
    include_background_label: whether to include extra background label.
    advprop_preprocessing: whether to use advprop preprocessing.
  """

  def __init__(self,
               model_name,
               batch_size=1,
               image_size=224,
               num_classes=1000,
               include_background_label=False,
               advprop_preprocessing=False):
    """Initialize internal variables."""
    self.model_name = model_name
    self.batch_size = batch_size
    self.num_classes = num_classes
    self.include_background_label = include_background_label
    self.image_size = image_size
    self.advprop_preprocessing = advprop_preprocessing

  def restore_model(self, sess, ckpt_dir, enable_ema=True, export_ckpt=None):
    """Restore variables from checkpoint dir."""
    sess.run(tf.global_variables_initializer())
    checkpoint = tf.train.latest_checkpoint(ckpt_dir)
    if enable_ema:
      ema = tf.train.ExponentialMovingAverage(decay=0.0)
      ema_vars = get_ema_vars()
      var_dict = ema.variables_to_restore(ema_vars)
      ema_assign_op = ema.apply(ema_vars)
    else:
      var_dict = get_ema_vars()
      ema_assign_op = None

    tf.train.get_or_create_global_step()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(var_dict, max_to_keep=1)
    saver.restore(sess, checkpoint)

    if export_ckpt:
      if ema_assign_op is not None:
        sess.run(ema_assign_op)
      saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True)
      saver.save(sess, export_ckpt)

  def build_model(self, features, is_training):
    """Build model with input features."""
    del features, is_training
    raise ValueError('Must be implemented by subclasses.')

  def get_preprocess_fn(self):
    raise ValueError('Must be implemented by subclsses.')

  def build_dataset(self, filenames, labels, is_training):
    """Build input dataset."""
    batch_drop_remainder = False
    if 'condconv' in self.model_name and not is_training:
      # CondConv layers can only be called with known batch dimension. Thus, we
      # must drop all remaining examples that do not make up one full batch.
      # To ensure all examples are evaluated, use a batch size that evenly
      # divides the number of files.
      batch_drop_remainder = True
      num_files = len(filenames)
      if num_files % self.batch_size != 0:
        tf.logging.warn('Remaining examples in last batch are not being '
                        'evaluated.')
    filenames = tf.constant(filenames)
    labels = tf.constant(labels)
    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))

    def _parse_function(filename, label):
      image_string = tf.read_file(filename)
      preprocess_fn = self.get_preprocess_fn()
      image_decoded = preprocess_fn(
          image_string, is_training, image_size=self.image_size)
      image = tf.cast(image_decoded, tf.float32)
      return image, label

    dataset = dataset.map(_parse_function)
    dataset = dataset.batch(self.batch_size,
                            drop_remainder=batch_drop_remainder)

    iterator = dataset.make_one_shot_iterator()
    images, labels = iterator.get_next()
    return images, labels

  def run_inference(self,
                    ckpt_dir,
                    image_files,
                    labels,
                    enable_ema=True,
                    export_ckpt=None):
    """Build and run inference on the target images and labels."""
    label_offset = 1 if self.include_background_label else 0
    with tf.Graph().as_default(), tf.Session() as sess:
      images, labels = self.build_dataset(image_files, labels, False)
      probs = self.build_model(images, is_training=False)
      if isinstance(probs, tuple):
        probs = probs[0]

      self.restore_model(sess, ckpt_dir, enable_ema, export_ckpt)

      prediction_idx = []
      prediction_prob = []
      for _ in range(len(image_files) // self.batch_size):
        out_probs = sess.run(probs)
        idx = np.argsort(out_probs)[::-1]
        prediction_idx.append(idx[:5] - label_offset)
        prediction_prob.append([out_probs[pid] for pid in idx[:5]])

      # Return the top 5 predictions (idx and prob) for each image.
      return prediction_idx, prediction_prob

  def eval_example_images(self,
                          ckpt_dir,
                          image_files,
                          labels_map_file,
                          enable_ema=True,
                          export_ckpt=None):
    """Eval a list of example images.
    Args:
      ckpt_dir: str. Checkpoint directory path.
      image_files: List[str]. A list of image file paths.
      labels_map_file: str. The labels map file path.
      enable_ema: enable expotential moving average.
      export_ckpt: export ckpt folder.
    Returns:
      A tuple (pred_idx, and pred_prob), where pred_idx is the top 5 prediction
      index and pred_prob is the top 5 prediction probability.
    """
    classes = json.loads(tf.gfile.Open(labels_map_file).read())
    pred_idx, pred_prob = self.run_inference(
        ckpt_dir, image_files, [0] * len(image_files), enable_ema, export_ckpt)
    for i in range(len(image_files)):
      print('predicted class for image {}: '.format(image_files[i]))
      for j, idx in enumerate(pred_idx[i]):
        print('  -> top_{} ({:4.2f}%): {}  '.format(j, pred_prob[i][j] * 100,
                                                    classes[str(idx)]))
    return pred_idx, pred_prob

  def eval_imagenet(self, ckpt_dir, imagenet_eval_glob,
                    imagenet_eval_label, num_images, enable_ema, export_ckpt):
    """Eval ImageNet images and report top1/top5 accuracy.
    Args:
      ckpt_dir: str. Checkpoint directory path.
      imagenet_eval_glob: str. File path glob for all eval images.
      imagenet_eval_label: str. File path for eval label.
      num_images: int. Number of images to eval: -1 means eval the whole
        dataset.
      enable_ema: enable expotential moving average.
      export_ckpt: export checkpoint folder.
    Returns:
      A tuple (top1, top5) for top1 and top5 accuracy.
    """
    imagenet_val_labels = [int(i) for i in tf.gfile.GFile(imagenet_eval_label)]
    imagenet_filenames = sorted(tf.gfile.Glob(imagenet_eval_glob))
    if num_images < 0:
      num_images = len(imagenet_filenames)
    image_files = imagenet_filenames[:num_images]
    labels = imagenet_val_labels[:num_images]

    pred_idx, _ = self.run_inference(
        ckpt_dir, image_files, labels, enable_ema, export_ckpt)
    top1_cnt, top5_cnt = 0.0, 0.0
    for i, label in enumerate(labels):
      top1_cnt += label in pred_idx[i][:1]
      top5_cnt += label in pred_idx[i][:5]
      if i % 100 == 0:
        print('Step {}: top1_acc = {:4.2f}%  top5_acc = {:4.2f}%'.format(
            i, 100 * top1_cnt / (i + 1), 100 * top5_cnt / (i + 1)))
        sys.stdout.flush()
    top1, top5 = 100 * top1_cnt / num_images, 100 * top5_cnt / num_images
    print('Final: top1_acc = {:4.2f}%  top5_acc = {:4.2f}%'.format(top1, top5))
    return top1, top5

# *condconv_layers*.py

In [None]:
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""CondConv implementations in Tensorflow Layers.
[1] Brandon Yang, Gabriel Bender, Quoc V. Le, Jiquan Ngiam
  CondConv: Conditionally Parameterized Convolutions for Efficient Inference.
  NeurIPS'19, https://arxiv.org/abs/1904.04971
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow.compat.v1 as tf


def get_condconv_initializer(initializer, num_experts, expert_shape):
  """Wraps the initializer to correctly initialize CondConv variables.
  CondConv initializes biases and kernels in a num_experts x num_params
  matrix for efficient computation. This wrapper ensures that each expert
  is correctly initialized with the given initializer before being flattened
  into the correctly shaped CondConv variable.
  Arguments:
    initializer: The initializer to apply for each individual expert.
    num_experts: The number of experts to be initialized.
    expert_shape: The original shape of each individual expert.
  Returns:
    The initializer for the num_experts x num_params CondConv variable.
  """
  def condconv_initializer(expected_shape, dtype=None, partition=None):
    """CondConv initializer function."""
    num_params = np.prod(expert_shape)
    if (len(expected_shape) != 2 or expected_shape[0] != num_experts or
        expected_shape[1] != num_params):
      raise (ValueError(
          'CondConv variables must have shape [num_experts, num_params]'))
    flattened_kernels = []
    for _ in range(num_experts):
      kernel = initializer(expert_shape, dtype, partition)
      flattened_kernels.append(tf.reshape(kernel, [-1]))
    return tf.stack(flattened_kernels)

  return condconv_initializer


class CondConv2D(tf.keras.layers.Conv2D):
  """2D conditional convolution layer (e.g. spatial convolution over images).
  Attributes:
    filters: Integer, the dimensionality of the output space (i.e. the number of
      output filters in the convolution).
    kernel_size: An integer or tuple/list of 2 integers, specifying the height
      and width of the 2D convolution window. Can be a single integer to specify
      the same value for all spatial dimensions.
    num_experts: The number of expert kernels and biases in the CondConv layer.
    strides: An integer or tuple/list of 2 integers, specifying the strides of
      the convolution along the height and width. Can be a single integer to
      specify the same value for all spatial dimensions. Specifying any stride
      value != 1 is incompatible with specifying any `dilation_rate` value != 1.
    padding: one of `"valid"` or `"same"` (case-insensitive).
    data_format: A string, one of `channels_last` (default) or `channels_first`.
      The ordering of the dimensions in the inputs. `channels_last` corresponds
      to inputs with shape `(batch, height, width, channels)` while
      `channels_first` corresponds to inputs with shape `(batch, channels,
      height, width)`. It defaults to the `image_data_format` value found in
      your Keras config file at `~/.keras/keras.json`. If you never set it, then
      it will be "channels_last".
    dilation_rate: an integer or tuple/list of 2 integers, specifying the
      dilation rate to use for dilated convolution. Can be a single integer to
      specify the same value for all spatial dimensions. Currently, specifying
      any `dilation_rate` value != 1 is incompatible with specifying any stride
      value != 1.
    activation: Activation function to use. If you don't specify anything, no
      activation is applied
      (ie. "linear" activation: `a(x) = x`).
    use_bias: Boolean, whether the layer uses a bias vector.
    kernel_initializer: Initializer for the `kernel` weights matrix.
    bias_initializer: Initializer for the bias vector.
    kernel_regularizer: Regularizer function applied to the `kernel` weights
      matrix.
    bias_regularizer: Regularizer function applied to the bias vector.
    activity_regularizer: Regularizer function applied to the output of the
      layer (its "activation")..
    kernel_constraint: Constraint function applied to the kernel matrix.
    bias_constraint: Constraint function applied to the bias vector.
  Input shape:
    4D tensor with shape: `(samples, channels, rows, cols)` if
      data_format='channels_first'
    or 4D tensor with shape: `(samples, rows, cols, channels)` if
      data_format='channels_last'.
  Output shape:
    4D tensor with shape: `(samples, filters, new_rows, new_cols)` if
      data_format='channels_first'
    or 4D tensor with shape: `(samples, new_rows, new_cols, filters)` if
      data_format='channels_last'. `rows` and `cols` values might have changed
      due to padding.
  """

  def __init__(self,
               filters,
               kernel_size,
               num_experts,
               strides=(1, 1),
               padding='valid',
               data_format=None,
               dilation_rate=(1, 1),
               activation=None,
               use_bias=True,
               kernel_initializer='glorot_uniform',
               bias_initializer='zeros',
               kernel_regularizer=None,
               bias_regularizer=None,
               activity_regularizer=None,
               kernel_constraint=None,
               bias_constraint=None,
               **kwargs):
    super(CondConv2D, self).__init__(
        filters=filters,
        kernel_size=kernel_size,
        strides=strides,
        padding=padding,
        data_format=data_format,
        dilation_rate=dilation_rate,
        activation=activation,
        use_bias=use_bias,
        kernel_initializer=kernel_initializer,
        bias_initializer=bias_initializer,
        kernel_regularizer=kernel_regularizer,
        bias_regularizer=bias_regularizer,
        activity_regularizer=activity_regularizer,
        kernel_constraint=kernel_constraint,
        bias_constraint=bias_constraint,
        **kwargs)
    if num_experts < 1:
      raise ValueError('A CondConv layer must have at least one expert.')
    self.num_experts = num_experts
    if self.data_format == 'channels_first':
      self.converted_data_format = 'NCHW'
    else:
      self.converted_data_format = 'NHWC'

  def build(self, input_shape):
    if len(input_shape) != 4:
      raise ValueError(
          'Inputs to `CondConv2D` should have rank 4. '
          'Received input shape:', str(input_shape))
    input_shape = tf.TensorShape(input_shape)
    channel_axis = self._get_channel_axis()
    if input_shape.dims[channel_axis].value is None:
      raise ValueError('The channel dimension of the inputs '
                       'should be defined. Found `None`.')
    input_dim = int(input_shape[channel_axis])

    self.kernel_shape = self.kernel_size + (input_dim, self.filters)
    kernel_num_params = 1
    for kernel_dim in self.kernel_shape:
      kernel_num_params *= kernel_dim
    condconv_kernel_shape = (self.num_experts, kernel_num_params)
    self.condconv_kernel = self.add_weight(
        name='condconv_kernel',
        shape=condconv_kernel_shape,
        initializer=get_condconv_initializer(self.kernel_initializer,
                                             self.num_experts,
                                             self.kernel_shape),
        regularizer=self.kernel_regularizer,
        constraint=self.kernel_constraint,
        trainable=True,
        dtype=self.dtype)

    if self.use_bias:
      self.bias_shape = (self.filters,)
      condconv_bias_shape = (self.num_experts, self.filters)
      self.condconv_bias = self.add_weight(
          name='condconv_bias',
          shape=condconv_bias_shape,
          initializer=get_condconv_initializer(self.bias_initializer,
                                               self.num_experts,
                                               self.bias_shape),
          regularizer=self.bias_regularizer,
          constraint=self.bias_constraint,
          trainable=True,
          dtype=self.dtype)
    else:
      self.bias = None

    self.input_spec = tf.layers.InputSpec(
        ndim=self.rank + 2, axes={channel_axis: input_dim})

    self.built = True

  def call(self, inputs, routing_weights):
    # Compute example dependent kernels
    kernels = tf.matmul(routing_weights, self.condconv_kernel)
    batch_size = inputs.shape[0].value
    inputs = tf.split(inputs, batch_size, 0)
    kernels = tf.split(kernels, batch_size, 0)
    # Apply example-dependent convolution to each example in the batch
    outputs_list = []
    for input_tensor, kernel in zip(inputs, kernels):
      kernel = tf.reshape(kernel, self.kernel_shape)
      outputs_list.append(
          tf.nn.convolution(
              input_tensor,
              kernel,
              strides=self.strides,
              padding=self._get_padding_op(),
              dilations=self.dilation_rate,
              data_format=self.converted_data_format))
    outputs = tf.concat(outputs_list, 0)

    if self.use_bias:
      # Compute example-dependent biases
      biases = tf.matmul(routing_weights, self.condconv_bias)
      outputs = tf.split(outputs, batch_size, 0)
      biases = tf.split(biases, batch_size, 0)
      # Add example-dependent bias to each example in the batch
      bias_outputs_list = []
      for output, bias in zip(outputs, biases):
        bias = tf.squeeze(bias, axis=0)
        bias_outputs_list.append(
            tf.nn.bias_add(output, bias,
                           data_format=self.converted_data_format))
      outputs = tf.concat(bias_outputs_list, 0)

    if self.activation is not None:
      return self.activation(outputs)
    return outputs

  def get_config(self):
    config = {'num_experts': self.num_experts}
    base_config = super(CondConv2D, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def _get_channel_axis(self):
    if self.data_format == 'channels_first':
      return 1
    else:
      return -1

  def _get_padding_op(self):
    if self.padding == 'causal':
      op_padding = 'valid'
    else:
      op_padding = self.padding
    if not isinstance(op_padding, (list, tuple)):
      op_padding = op_padding.upper()
    return op_padding


class DepthwiseCondConv2D(tf.keras.layers.DepthwiseConv2D):
  """Depthwise separable 2D conditional convolution layer.
  This layer extends the base depthwise 2D convolution layer to compute
  example-dependent parameters. A DepthwiseCondConv2D layer has 'num_experts`
  kernels and biases. It computes a kernel and bias for each example as a
  weighted sum of experts using the input example-dependent routing weights,
  then applies the depthwise convolution to each example.
  Attributes:
    kernel_size: An integer or tuple/list of 2 integers, specifying the height
      and width of the 2D convolution window. Can be a single integer to specify
      the same value for all spatial dimensions.
    num_experts: The number of expert kernels and biases in the
      DepthwiseCondConv2D layer.
    strides: An integer or tuple/list of 2 integers, specifying the strides of
      the convolution along the height and width. Can be a single integer to
      specify the same value for all spatial dimensions. Specifying any stride
      value != 1 is incompatible with specifying any `dilation_rate` value != 1.
    padding: one of `'valid'` or `'same'` (case-insensitive).
    depth_multiplier: The number of depthwise convolution output channels for
      each input channel. The total number of depthwise convolution output
      channels will be equal to `filters_in * depth_multiplier`.
    data_format: A string, one of `channels_last` (default) or `channels_first`.
      The ordering of the dimensions in the inputs. `channels_last` corresponds
      to inputs with shape `(batch, height, width, channels)` while
      `channels_first` corresponds to inputs with shape `(batch, channels,
      height, width)`. It defaults to the `image_data_format` value found in
      your Keras config file at `~/.keras/keras.json`. If you never set it, then
      it will be 'channels_last'.
    activation: Activation function to use. If you don't specify anything, no
      activation is applied
      (ie. 'linear' activation: `a(x) = x`).
    use_bias: Boolean, whether the layer uses a bias vector.
    depthwise_initializer: Initializer for the depthwise kernel matrix.
    bias_initializer: Initializer for the bias vector.
    depthwise_regularizer: Regularizer function applied to the depthwise kernel
      matrix.
    bias_regularizer: Regularizer function applied to the bias vector.
    activity_regularizer: Regularizer function applied to the output of the
      layer (its 'activation').
    depthwise_constraint: Constraint function applied to the depthwise kernel
      matrix.
    bias_constraint: Constraint function applied to the bias vector.
  Input shape:
    4D tensor with shape: `[batch, channels, rows, cols]` if
      data_format='channels_first'
    or 4D tensor with shape: `[batch, rows, cols, channels]` if
      data_format='channels_last'.
  Output shape:
    4D tensor with shape: `[batch, filters, new_rows, new_cols]` if
      data_format='channels_first'
    or 4D tensor with shape: `[batch, new_rows, new_cols, filters]` if
      data_format='channels_last'. `rows` and `cols` values might have changed
      due to padding.
  """

  def __init__(self,
               kernel_size,
               num_experts,
               strides=(1, 1),
               padding='valid',
               depth_multiplier=1,
               data_format=None,
               activation=None,
               use_bias=True,
               depthwise_initializer='glorot_uniform',
               bias_initializer='zeros',
               depthwise_regularizer=None,
               bias_regularizer=None,
               activity_regularizer=None,
               depthwise_constraint=None,
               bias_constraint=None,
               **kwargs):
    super(DepthwiseCondConv2D, self).__init__(
        kernel_size=kernel_size,
        strides=strides,
        padding=padding,
        depth_multiplier=depth_multiplier,
        data_format=data_format,
        activation=activation,
        use_bias=use_bias,
        depthwise_initializer=depthwise_initializer,
        bias_initializer=bias_initializer,
        depthwise_regularizer=depthwise_regularizer,
        bias_regularizer=bias_regularizer,
        activity_regularizer=activity_regularizer,
        depthwise_constraint=depthwise_constraint,
        bias_constraint=bias_constraint,
        **kwargs)
    if num_experts < 1:
      raise ValueError('A CondConv layer must have at least one expert.')
    self.num_experts = num_experts
    if self.data_format == 'channels_first':
      self.converted_data_format = 'NCHW'
    else:
      self.converted_data_format = 'NHWC'

  def build(self, input_shape):
    if len(input_shape) < 4:
      raise ValueError(
          'Inputs to `DepthwiseCondConv2D` should have rank 4. '
          'Received input shape:', str(input_shape))
    input_shape = tf.TensorShape(input_shape)
    if self.data_format == 'channels_first':
      channel_axis = 1
    else:
      channel_axis = 3
    if input_shape.dims[channel_axis].value is None:
      raise ValueError('The channel dimension of the inputs to '
                       '`DepthwiseConv2D` '
                       'should be defined. Found `None`.')
    input_dim = int(input_shape[channel_axis])
    self.depthwise_kernel_shape = (self.kernel_size[0], self.kernel_size[1],
                                   input_dim, self.depth_multiplier)

    depthwise_kernel_num_params = 1
    for dim in self.depthwise_kernel_shape:
      depthwise_kernel_num_params *= dim
    depthwise_condconv_kernel_shape = (self.num_experts,
                                       depthwise_kernel_num_params)

    self.depthwise_condconv_kernel = self.add_weight(
        shape=depthwise_condconv_kernel_shape,
        initializer=get_condconv_initializer(self.depthwise_initializer,
                                             self.num_experts,
                                             self.depthwise_kernel_shape),
        name='depthwise_condconv_kernel',
        regularizer=self.depthwise_regularizer,
        constraint=self.depthwise_constraint,
        trainable=True)

    if self.use_bias:
      bias_dim = input_dim * self.depth_multiplier
      self.bias_shape = (bias_dim,)
      condconv_bias_shape = (self.num_experts, bias_dim)
      self.condconv_bias = self.add_weight(
          name='condconv_bias',
          shape=condconv_bias_shape,
          initializer=get_condconv_initializer(self.bias_initializer,
                                               self.num_experts,
                                               self.bias_shape),
          regularizer=self.bias_regularizer,
          constraint=self.bias_constraint,
          trainable=True,
          dtype=self.dtype)
    else:
      self.bias = None
    # Set input spec.
    self.input_spec = tf.layers.InputSpec(
        ndim=4, axes={channel_axis: input_dim})
    self.built = True

  def call(self, inputs, routing_weights):
    # Compute example dependent depthwise kernels
    depthwise_kernels = tf.matmul(routing_weights,
                                  self.depthwise_condconv_kernel)
    batch_size = inputs.shape[0].value
    inputs = tf.split(inputs, batch_size, 0)
    depthwise_kernels = tf.split(depthwise_kernels, batch_size, 0)
    # Apply example-dependent depthwise convolution to each example in the batch
    outputs_list = []
    for input_tensor, depthwise_kernel in zip(inputs, depthwise_kernels):
      depthwise_kernel = tf.reshape(depthwise_kernel,
                                    self.depthwise_kernel_shape)
      if self.data_format == 'channels_first':
        converted_strides = (1, 1) + self.strides
      else:
        converted_strides = (1,) + self.strides + (1,)
      outputs_list.append(
          tf.nn.depthwise_conv2d(
              input_tensor,
              depthwise_kernel,
              strides=converted_strides,
              padding=self.padding.upper(),
              dilations=self.dilation_rate,
              data_format=self.converted_data_format))
    outputs = tf.concat(outputs_list, 0)

    if self.use_bias:
      # Compute example-dependent biases
      biases = tf.matmul(routing_weights, self.condconv_bias)
      outputs = tf.split(outputs, batch_size, 0)
      biases = tf.split(biases, batch_size, 0)
      # Add example-dependent bias to each example in the batch
      bias_outputs_list = []
      for output, bias in zip(outputs, biases):
        bias = tf.squeeze(bias, axis=0)
        bias_outputs_list.append(
            tf.nn.bias_add(output, bias,
                           data_format=self.converted_data_format))
      outputs = tf.concat(bias_outputs_list, 0)

    if self.activation is not None:
      return self.activation(outputs)

    return outputs

  def get_config(self):
    config = {'num_experts': self.num_experts}
    base_config = super(DepthwiseCondConv2D, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))


# efficientnet_builder.py

In [None]:
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model Builder for EfficientNet."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import os
import re
from absl import logging
import numpy as np
import six
import tensorflow.compat.v1 as tf

#import efficientnet_model
#import utils
MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]


def efficientnet_params(model_name):
  """Get efficientnet params based on model name."""
  params_dict = {
      # (width_coefficient, depth_coefficient, resolution, dropout_rate)
      'efficientnet-b0': (1.0, 1.0, 224, 0.2),
      'efficientnet-b1': (1.0, 1.1, 240, 0.2),
      'efficientnet-b2': (1.1, 1.2, 260, 0.3),
      'efficientnet-b3': (1.2, 1.4, 300, 0.3),
      'efficientnet-b4': (1.4, 1.8, 380, 0.4),
      'efficientnet-b5': (1.6, 2.2, 456, 0.4),
      'efficientnet-b6': (1.8, 2.6, 528, 0.5),
      'efficientnet-b7': (2.0, 3.1, 600, 0.5),
      'efficientnet-b8': (2.2, 3.6, 672, 0.5),
      'efficientnet-l2': (4.3, 5.3, 800, 0.5),
  }
  return params_dict[model_name]


class BlockDecoder(object):
  """Block Decoder for readability."""

  def _decode_block_string(self, block_string):
    """Gets a block through a string notation of arguments."""
    if six.PY2:
      assert isinstance(block_string, (str, unicode))
    else:
      assert isinstance(block_string, str)
    ops = block_string.split('_')
    options = {}
    for op in ops:
      splits = re.split(r'(\d.*)', op)
      if len(splits) >= 2:
        key, value = splits[:2]
        options[key] = value

    if 's' not in options or len(options['s']) != 2:
      raise ValueError('Strides options should be a pair of integers.')

    return efficientnet_model.BlockArgs(
        kernel_size=int(options['k']),
        num_repeat=int(options['r']),
        input_filters=int(options['i']),
        output_filters=int(options['o']),
        expand_ratio=int(options['e']),
        id_skip=('noskip' not in block_string),
        se_ratio=float(options['se']) if 'se' in options else None,
        strides=[int(options['s'][0]),
                 int(options['s'][1])],
        conv_type=int(options['c']) if 'c' in options else 0,
        fused_conv=int(options['f']) if 'f' in options else 0,
        space2depth=int(options['d']) if 'd' in options else 0,
        condconv=('cc' in block_string),
        activation_fn=(tf.nn.relu if int(options['a']) == 0
                       else tf.nn.swish) if 'a' in options else None)

  def _encode_block_string(self, block):
    """Encodes a block to a string."""
    args = [
        'r%d' % block.num_repeat,
        'k%d' % block.kernel_size,
        's%d%d' % (block.strides[0], block.strides[1]),
        'e%s' % block.expand_ratio,
        'i%d' % block.input_filters,
        'o%d' % block.output_filters,
        'c%d' % block.conv_type,
        'f%d' % block.fused_conv,
        'd%d' % block.space2depth,
    ]
    if block.se_ratio > 0 and block.se_ratio <= 1:
      args.append('se%s' % block.se_ratio)
    if block.id_skip is False:  # pylint: disable=g-bool-id-comparison
      args.append('noskip')
    if block.condconv:
      args.append('cc')
    return '_'.join(args)

  def decode(self, string_list):
    """Decodes a list of string notations to specify blocks inside the network.
    Args:
      string_list: a list of strings, each string is a notation of block.
    Returns:
      A list of namedtuples to represent blocks arguments.
    """
    assert isinstance(string_list, list)
    blocks_args = []
    for block_string in string_list:
      blocks_args.append(self._decode_block_string(block_string))
    return blocks_args

  def encode(self, blocks_args):
    """Encodes a list of Blocks to a list of strings.
    Args:
      blocks_args: A list of namedtuples to represent blocks arguments.
    Returns:
      a list of strings, each string is a notation of block.
    """
    block_strings = []
    for block in blocks_args:
      block_strings.append(self._encode_block_string(block))
    return block_strings


def swish(features, use_native=True, use_hard=False):
  """Computes the Swish activation function.
  We provide three alternnatives:
    - Native tf.nn.swish, use less memory during training than composable swish.
    - Quantization friendly hard swish.
    - A composable swish, equivalant to tf.nn.swish, but more general for
      finetuning and TF-Hub.
  Args:
    features: A `Tensor` representing preactivation values.
    use_native: Whether to use the native swish from tf.nn that uses a custom
      gradient to reduce memory usage, or to use customized swish that uses
      default TensorFlow gradient computation.
    use_hard: Whether to use quantization-friendly hard swish.
  Returns:
    The activation value.
  """
  if use_native and use_hard:
    raise ValueError('Cannot specify both use_native and use_hard.')

  if use_native:
    return tf.nn.swish(features)

  if use_hard:
    return features * tf.nn.relu6(features + np.float32(3)) * (1. / 6.)

  features = tf.convert_to_tensor(features, name='features')
  return features * tf.nn.sigmoid(features)


_DEFAULT_BLOCKS_ARGS = [
    'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
    'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
    'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
    'r1_k3_s11_e6_i192_o320_se0.25',
]


def efficientnet(width_coefficient=None,
                 depth_coefficient=None,
                 dropout_rate=0.2,
                 survival_prob=0.8):
  """Creates a efficientnet model."""
  global_params = efficientnet_model.GlobalParams(
      blocks_args=_DEFAULT_BLOCKS_ARGS,
      batch_norm_momentum=0.99,
      batch_norm_epsilon=1e-3,
      dropout_rate=dropout_rate,
      survival_prob=survival_prob,
      data_format='channels_last',
      num_classes=1000,
      width_coefficient=width_coefficient,
      depth_coefficient=depth_coefficient,
      depth_divisor=8,
      min_depth=None,
      relu_fn=tf.nn.swish,
      # The default is TPU-specific batch norm.
      # The alternative is tf.layers.BatchNormalization.
      batch_norm=utils.train_batch_norm,  # TPU-specific requirement.
      use_se=True,
      clip_projection_output=False)
  return global_params


def get_model_params(model_name, override_params):
  """Get the block args and global params for a given model."""
  if model_name.startswith('efficientnet'):
    width_coefficient, depth_coefficient, _, dropout_rate = (
        efficientnet_params(model_name))
    global_params = efficientnet(
        width_coefficient, depth_coefficient, dropout_rate)
  else:
    raise NotImplementedError('model name is not pre-defined: %s' % model_name)

  if override_params:
    # ValueError will be raised here if override_params has fields not included
    # in global_params.
    global_params = global_params._replace(**override_params)

  decoder = BlockDecoder()
  blocks_args = decoder.decode(global_params.blocks_args)

  logging.info('global_params= %s', global_params)
  return blocks_args, global_params


def build_model(images,
                model_name,
                training,
                override_params=None,
                model_dir=None,
                fine_tuning=False,
                features_only=False,
                pooled_features_only=False):
  """A helper function to create a model and return predicted logits.
  Args:
    images: input images tensor.
    model_name: string, the predefined model name.
    training: boolean, whether the model is constructed for training.
    override_params: A dictionary of params for overriding. Fields must exist in
      efficientnet_model.GlobalParams.
    model_dir: string, optional model dir for saving configs.
    fine_tuning: boolean, whether the model is used for finetuning.
    features_only: build the base feature network only (excluding final
      1x1 conv layer, global pooling, dropout and fc head).
    pooled_features_only: build the base network for features extraction (after
      1x1 conv layer and global pooling, but before dropout and fc head).
  Returns:
    logits: the logits tensor of classes.
    endpoints: the endpoints for each layer.
  Raises:
    When model_name specified an undefined model, raises NotImplementedError.
    When override_params has invalid fields, raises ValueError.
  """
  assert isinstance(images, tf.Tensor)
  assert not (features_only and pooled_features_only)

  # For backward compatibility.
  if override_params and override_params.get('drop_connect_rate', None):
    override_params['survival_prob'] = 1 - override_params['drop_connect_rate']

  if not training or fine_tuning:
    if not override_params:
      override_params = {}
    override_params['batch_norm'] = utils.eval_batch_norm
    if fine_tuning:
      override_params['relu_fn'] = functools.partial(swish, use_native=False)
  blocks_args, global_params = get_model_params(model_name, override_params)

  if model_dir:
    param_file = os.path.join(model_dir, 'model_params.txt')
    if not tf.gfile.Exists(param_file):
      if not tf.gfile.Exists(model_dir):
        tf.gfile.MakeDirs(model_dir)
      with tf.gfile.GFile(param_file, 'w') as f:
        logging.info('writing to %s', param_file)
        f.write('model_name= %s\n\n' % model_name)
        f.write('global_params= %s\n\n' % str(global_params))
        f.write('blocks_args= %s\n\n' % str(blocks_args))

  with tf.variable_scope(model_name):
    model = efficientnet_model.Model(blocks_args, global_params)
    outputs = model(
        images,
        training=training,
        features_only=features_only,
        pooled_features_only=pooled_features_only)
  if features_only:
    outputs = tf.identity(outputs, 'features')
  elif pooled_features_only:
    outputs = tf.identity(outputs, 'pooled_features')
  else:
    outputs = tf.identity(outputs, 'logits')
  return outputs, model.endpoints


def build_model_base(images, model_name, training, override_params=None):
  """Create a base feature network and return the features before pooling.
  Args:
    images: input images tensor.
    model_name: string, the predefined model name.
    training: boolean, whether the model is constructed for training.
    override_params: A dictionary of params for overriding. Fields must exist in
      efficientnet_model.GlobalParams.
  Returns:
    features: base features before pooling.
    endpoints: the endpoints for each layer.
  Raises:
    When model_name specified an undefined model, raises NotImplementedError.
    When override_params has invalid fields, raises ValueError.
  """
  assert isinstance(images, tf.Tensor)
  # For backward compatibility.
  if override_params and override_params.get('drop_connect_rate', None):
    override_params['survival_prob'] = 1 - override_params['drop_connect_rate']

  blocks_args, global_params = get_model_params(model_name, override_params)

  with tf.variable_scope(model_name):
    model = efficientnet_model.Model(blocks_args, global_params)
    features = model(images, training=training, features_only=True)

  features = tf.identity(features, 'features')
  return features, model.endpoints

# efficientnet_model.py


In [None]:
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains definitions for EfficientNet model.
[1] Mingxing Tan, Quoc V. Le
  EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks.
  ICML'19, https://arxiv.org/abs/1905.11946
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import functools
import math

from absl import logging
import numpy as np
import six
from six.moves import xrange
import tensorflow.compat.v1 as tf

#import utils
#from condconv import condconv_layers

GlobalParams = collections.namedtuple('GlobalParams', [
    'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'data_format',
    'num_classes', 'width_coefficient', 'depth_coefficient', 'depth_divisor',
    'min_depth', 'survival_prob', 'relu_fn', 'batch_norm', 'use_se',
    'se_coefficient', 'local_pooling', 'condconv_num_experts',
    'clip_projection_output', 'blocks_args', 'fix_head_stem', 'use_bfloat16'
])
# Note: the default value of None is not necessarily valid. It is valid to leave
# width_coefficient, depth_coefficient at None, which is treated as 1.0 (and
# which also allows depth_divisor and min_depth to be left at None).
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)

BlockArgs = collections.namedtuple('BlockArgs', [
    'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
    'expand_ratio', 'id_skip', 'strides', 'se_ratio', 'conv_type', 'fused_conv',
    'space2depth', 'condconv', 'activation_fn'
])
# defaults will be a public argument for namedtuple in Python 3.7
# https://docs.python.org/3/library/collections.html#collections.namedtuple
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)


def conv_kernel_initializer(shape, dtype=None, partition_info=None):
  """Initialization for convolutional kernels.
  The main difference with tf.variance_scaling_initializer is that
  tf.variance_scaling_initializer uses a truncated normal with an uncorrected
  standard deviation, whereas here we use a normal distribution. Similarly,
  tf.initializers.variance_scaling uses a truncated normal with
  a corrected standard deviation.
  Args:
    shape: shape of variable
    dtype: dtype of variable
    partition_info: unused
  Returns:
    an initialization for the variable
  """
  del partition_info
  kernel_height, kernel_width, _, out_filters = shape
  fan_out = int(kernel_height * kernel_width * out_filters)
  return tf.random_normal(
      shape, mean=0.0, stddev=np.sqrt(2.0 / fan_out), dtype=dtype)


def dense_kernel_initializer(shape, dtype=None, partition_info=None):
  """Initialization for dense kernels.
  This initialization is equal to
    tf.variance_scaling_initializer(scale=1.0/3.0, mode='fan_out',
                                    distribution='uniform').
  It is written out explicitly here for clarity.
  Args:
    shape: shape of variable
    dtype: dtype of variable
    partition_info: unused
  Returns:
    an initialization for the variable
  """
  del partition_info
  init_range = 1.0 / np.sqrt(shape[1])
  return tf.random_uniform(shape, -init_range, init_range, dtype=dtype)


def round_filters(filters, global_params, skip=False):
  """Round number of filters based on depth multiplier."""
  orig_f = filters
  multiplier = global_params.width_coefficient
  divisor = global_params.depth_divisor
  min_depth = global_params.min_depth
  if skip or not multiplier:
    return filters

  filters *= multiplier
  min_depth = min_depth or divisor
  new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
  # Make sure that round down does not go down by more than 10%.
  if new_filters < 0.9 * filters:
    new_filters += divisor
  logging.info('round_filter input=%s output=%s', orig_f, new_filters)
  return int(new_filters)


def round_repeats(repeats, global_params, skip=False):
  """Round number of filters based on depth multiplier."""
  multiplier = global_params.depth_coefficient
  if skip or not multiplier:
    return repeats
  return int(math.ceil(multiplier * repeats))


class MBConvBlock(tf.keras.layers.Layer):
  """A class of MBConv: Mobile Inverted Residual Bottleneck.
  Attributes:
    endpoints: dict. A list of internal tensors.
  """

  def __init__(self, block_args, global_params):
    """Initializes a MBConv block.
    Args:
      block_args: BlockArgs, arguments to create a Block.
      global_params: GlobalParams, a set of global parameters.
    """
    super(MBConvBlock, self).__init__()
    self._block_args = block_args
    self._local_pooling = global_params.local_pooling
    self._batch_norm_momentum = global_params.batch_norm_momentum
    self._batch_norm_epsilon = global_params.batch_norm_epsilon
    self._batch_norm = global_params.batch_norm
    self._condconv_num_experts = global_params.condconv_num_experts
    self._data_format = global_params.data_format
    self._se_coefficient = global_params.se_coefficient
    if self._data_format == 'channels_first':
      self._channel_axis = 1
      self._spatial_dims = [2, 3]
    else:
      self._channel_axis = -1
      self._spatial_dims = [1, 2]

    self._relu_fn = (self._block_args.activation_fn
                     or global_params.relu_fn or tf.nn.swish)
    self._has_se = (
        global_params.use_se and self._block_args.se_ratio is not None and
        0 < self._block_args.se_ratio <= 1)

    self._clip_projection_output = global_params.clip_projection_output

    self.endpoints = None

    self.conv_cls = utils.Conv2D
    self.depthwise_conv_cls = utils.DepthwiseConv2D
    if self._block_args.condconv:
      self.conv_cls = functools.partial(
          condconv_layers.CondConv2D, num_experts=self._condconv_num_experts)
      self.depthwise_conv_cls = functools.partial(
          condconv_layers.DepthwiseCondConv2D,
          num_experts=self._condconv_num_experts)

    # Builds the block accordings to arguments.
    self._build()

  def block_args(self):
    return self._block_args

  def _build(self):
    """Builds block according to the arguments."""
    if self._block_args.space2depth == 1:
      self._space2depth = tf.layers.Conv2D(
          self._block_args.input_filters,
          kernel_size=[2, 2],
          strides=[2, 2],
          kernel_initializer=conv_kernel_initializer,
          padding='same',
          data_format=self._data_format,
          use_bias=False)
      self._bnsp = self._batch_norm(
          axis=self._channel_axis,
          momentum=self._batch_norm_momentum,
          epsilon=self._batch_norm_epsilon)

    if self._block_args.condconv:
      # Add the example-dependent routing function
      self._avg_pooling = tf.keras.layers.GlobalAveragePooling2D(
          data_format=self._data_format)
      self._routing_fn = tf.layers.Dense(
          self._condconv_num_experts, activation=tf.nn.sigmoid)

    filters = self._block_args.input_filters * self._block_args.expand_ratio
    kernel_size = self._block_args.kernel_size

    # Fused expansion phase. Called if using fused convolutions.
    self._fused_conv = self.conv_cls(
        filters=filters,
        kernel_size=[kernel_size, kernel_size],
        strides=self._block_args.strides,
        kernel_initializer=conv_kernel_initializer,
        padding='same',
        data_format=self._data_format,
        use_bias=False)

    # Expansion phase. Called if not using fused convolutions and expansion
    # phase is necessary.
    self._expand_conv = self.conv_cls(
        filters=filters,
        kernel_size=[1, 1],
        strides=[1, 1],
        kernel_initializer=conv_kernel_initializer,
        padding='same',
        data_format=self._data_format,
        use_bias=False)
    self._bn0 = self._batch_norm(
        axis=self._channel_axis,
        momentum=self._batch_norm_momentum,
        epsilon=self._batch_norm_epsilon)

    # Depth-wise convolution phase. Called if not using fused convolutions.
    self._depthwise_conv = self.depthwise_conv_cls(
        kernel_size=[kernel_size, kernel_size],
        strides=self._block_args.strides,
        depthwise_initializer=conv_kernel_initializer,
        padding='same',
        data_format=self._data_format,
        use_bias=False)

    self._bn1 = self._batch_norm(
        axis=self._channel_axis,
        momentum=self._batch_norm_momentum,
        epsilon=self._batch_norm_epsilon)

    if self._has_se:
      num_reduced_filters = int(self._block_args.input_filters * (
          self._block_args.se_ratio * (self._se_coefficient
                                       if self._se_coefficient else 1)))
      # account for space2depth transformation in SE filter depth, since
      # the SE compression ratio is w.r.t. the original filter depth before
      # space2depth is applied.
      num_reduced_filters = (num_reduced_filters // 4
                             if self._block_args.space2depth == 1
                             else num_reduced_filters)
      num_reduced_filters = max(1, num_reduced_filters)
      # Squeeze and Excitation layer.
      self._se_reduce = utils.Conv2D(
          num_reduced_filters,
          kernel_size=[1, 1],
          strides=[1, 1],
          kernel_initializer=conv_kernel_initializer,
          padding='same',
          data_format=self._data_format,
          use_bias=True)
      self._se_expand = utils.Conv2D(
          filters,
          kernel_size=[1, 1],
          strides=[1, 1],
          kernel_initializer=conv_kernel_initializer,
          padding='same',
          data_format=self._data_format,
          use_bias=True)

    # Output phase.
    filters = self._block_args.output_filters
    self._project_conv = self.conv_cls(
        filters=filters,
        kernel_size=[1, 1],
        strides=[1, 1],
        kernel_initializer=conv_kernel_initializer,
        padding='same',
        data_format=self._data_format,
        use_bias=False)
    self._bn2 = self._batch_norm(
        axis=self._channel_axis,
        momentum=self._batch_norm_momentum,
        epsilon=self._batch_norm_epsilon)

  def _call_se(self, input_tensor):
    """Call Squeeze and Excitation layer.
    Args:
      input_tensor: Tensor, a single input tensor for Squeeze/Excitation layer.
    Returns:
      A output tensor, which should have the same shape as input.
    """
    if self._local_pooling:
      shape = input_tensor.get_shape().as_list()
      kernel_size = [
          1, shape[self._spatial_dims[0]], shape[self._spatial_dims[1]], 1]
      se_tensor = tf.nn.avg_pool(
          input_tensor,
          ksize=kernel_size,
          strides=[1, 1, 1, 1],
          padding='VALID',
          data_format=self._data_format)
    else:
      se_tensor = tf.reduce_mean(
          input_tensor, self._spatial_dims, keepdims=True)
    se_tensor = self._se_expand(self._relu_fn(self._se_reduce(se_tensor)))
    logging.info('Built SE %s : %s', self.name, se_tensor.shape)
    return tf.sigmoid(se_tensor) * input_tensor

  def call(self, inputs, training=True, survival_prob=None):
    """Implementation of call().
    Args:
      inputs: the inputs tensor.
      training: boolean, whether the model is constructed for training.
      survival_prob: float, between 0 to 1, drop connect rate.
    Returns:
      A output tensor.
    """
    logging.info('Block %s input shape: %s', self.name, inputs.shape)
    x = inputs

    fused_conv_fn = self._fused_conv
    expand_conv_fn = self._expand_conv
    depthwise_conv_fn = self._depthwise_conv
    project_conv_fn = self._project_conv

    if self._block_args.condconv:
      pooled_inputs = self._avg_pooling(inputs)
      routing_weights = self._routing_fn(pooled_inputs)
      # Capture routing weights as additional input to CondConv layers
      fused_conv_fn = functools.partial(
          self._fused_conv, routing_weights=routing_weights)
      expand_conv_fn = functools.partial(
          self._expand_conv, routing_weights=routing_weights)
      depthwise_conv_fn = functools.partial(
          self._depthwise_conv, routing_weights=routing_weights)
      project_conv_fn = functools.partial(
          self._project_conv, routing_weights=routing_weights)

    # creates conv 2x2 kernel
    if self._block_args.space2depth == 1:
      with tf.variable_scope('space2depth'):
        x = self._relu_fn(
            self._bnsp(self._space2depth(x), training=training))
      logging.info('Block start with space2depth shape: %s', x.shape)

    if self._block_args.fused_conv:
      # If use fused mbconv, skip expansion and use regular conv.
      x = self._relu_fn(self._bn1(fused_conv_fn(x), training=training))
      logging.info('Conv2D shape: %s', x.shape)
    else:
      # Otherwise, first apply expansion and then apply depthwise conv.
      if self._block_args.expand_ratio != 1:
        x = self._relu_fn(self._bn0(expand_conv_fn(x), training=training))
        logging.info('Expand shape: %s', x.shape)

      x = self._relu_fn(self._bn1(depthwise_conv_fn(x), training=training))
      logging.info('DWConv shape: %s', x.shape)

    if self._has_se:
      with tf.variable_scope('se'):
        x = self._call_se(x)

    self.endpoints = {'expansion_output': x}

    x = self._bn2(project_conv_fn(x), training=training)
    # Add identity so that quantization-aware training can insert quantization
    # ops correctly.
    x = tf.identity(x)
    if self._clip_projection_output:
      x = tf.clip_by_value(x, -6, 6)
    if self._block_args.id_skip:
      if all(
          s == 1 for s in self._block_args.strides
      ) and inputs.get_shape().as_list()[-1] == x.get_shape().as_list()[-1]:
        # Apply only if skip connection presents.
        if survival_prob:
          x = utils.drop_connect(x, training, survival_prob)
        x = tf.add(x, inputs)
    logging.info('Project shape: %s', x.shape)
    return x


class MBConvBlockWithoutDepthwise(MBConvBlock):
  """MBConv-like block without depthwise convolution and squeeze-and-excite."""

  def _build(self):
    """Builds block according to the arguments."""
    filters = self._block_args.input_filters * self._block_args.expand_ratio
    if self._block_args.expand_ratio != 1:
      # Expansion phase:
      self._expand_conv = tf.layers.Conv2D(
          filters,
          kernel_size=[3, 3],
          strides=self._block_args.strides,
          kernel_initializer=conv_kernel_initializer,
          padding='same',
          use_bias=False)
      self._bn0 = self._batch_norm(
          axis=self._channel_axis,
          momentum=self._batch_norm_momentum,
          epsilon=self._batch_norm_epsilon)

    # Output phase:
    filters = self._block_args.output_filters
    self._project_conv = tf.layers.Conv2D(
        filters,
        kernel_size=[1, 1],
        strides=[1, 1],
        kernel_initializer=conv_kernel_initializer,
        padding='same',
        use_bias=False)
    self._bn1 = self._batch_norm(
        axis=self._channel_axis,
        momentum=self._batch_norm_momentum,
        epsilon=self._batch_norm_epsilon)

  def call(self, inputs, training=True, survival_prob=None):
    """Implementation of call().
    Args:
      inputs: the inputs tensor.
      training: boolean, whether the model is constructed for training.
      survival_prob: float, between 0 to 1, drop connect rate.
    Returns:
      A output tensor.
    """
    logging.info('Block input shape: %s', inputs.shape)
    if self._block_args.expand_ratio != 1:
      x = self._relu_fn(self._bn0(self._expand_conv(inputs), training=training))
    else:
      x = inputs
    logging.info('Expand shape: %s', x.shape)

    self.endpoints = {'expansion_output': x}

    x = self._bn1(self._project_conv(x), training=training)
    # Add identity so that quantization-aware training can insert quantization
    # ops correctly.
    x = tf.identity(x)
    if self._clip_projection_output:
      x = tf.clip_by_value(x, -6, 6)

    if self._block_args.id_skip:
      if all(
          s == 1 for s in self._block_args.strides
      ) and self._block_args.input_filters == self._block_args.output_filters:
        # Apply only if skip connection presents.
        if survival_prob:
          x = utils.drop_connect(x, training, survival_prob)
        x = tf.add(x, inputs)
    logging.info('Project shape: %s', x.shape)
    return x


class Model(tf.keras.Model):
  """A class implements tf.keras.Model for MNAS-like model.
    Reference: https://arxiv.org/abs/1807.11626
  """

  def __init__(self, blocks_args=None, global_params=None):
    """Initializes an `Model` instance.
    Args:
      blocks_args: A list of BlockArgs to construct block modules.
      global_params: GlobalParams, a set of global parameters.
    Raises:
      ValueError: when blocks_args is not specified as a list.
    """
    super(Model, self).__init__()
    if not isinstance(blocks_args, list):
      raise ValueError('blocks_args should be a list.')
    self._global_params = global_params
    self._blocks_args = blocks_args
    self._relu_fn = global_params.relu_fn or tf.nn.swish
    self._batch_norm = global_params.batch_norm
    self._fix_head_stem = global_params.fix_head_stem

    self.endpoints = None

    self._build()

  def _get_conv_block(self, conv_type):
    conv_block_map = {0: MBConvBlock, 1: MBConvBlockWithoutDepthwise}
    return conv_block_map[conv_type]

  def _build(self):
    """Builds a model."""
    self._blocks = []
    batch_norm_momentum = self._global_params.batch_norm_momentum
    batch_norm_epsilon = self._global_params.batch_norm_epsilon
    if self._global_params.data_format == 'channels_first':
      channel_axis = 1
      self._spatial_dims = [2, 3]
    else:
      channel_axis = -1
      self._spatial_dims = [1, 2]

    # Stem part.
    self._conv_stem = utils.Conv2D(
        filters=round_filters(32, self._global_params, self._fix_head_stem),
        kernel_size=[3, 3],
        strides=[2, 2],
        kernel_initializer=conv_kernel_initializer,
        padding='same',
        data_format=self._global_params.data_format,
        use_bias=False)
    self._bn0 = self._batch_norm(
        axis=channel_axis,
        momentum=batch_norm_momentum,
        epsilon=batch_norm_epsilon)

    # Builds blocks.
    for i, block_args in enumerate(self._blocks_args):
      assert block_args.num_repeat > 0
      assert block_args.space2depth in [0, 1, 2]
      # Update block input and output filters based on depth multiplier.
      input_filters = round_filters(block_args.input_filters,
                                    self._global_params)

      output_filters = round_filters(block_args.output_filters,
                                     self._global_params)
      kernel_size = block_args.kernel_size
      if self._fix_head_stem and (i == 0 or i == len(self._blocks_args) - 1):
        repeats = block_args.num_repeat
      else:
        repeats = round_repeats(block_args.num_repeat, self._global_params)
      block_args = block_args._replace(
          input_filters=input_filters,
          output_filters=output_filters,
          num_repeat=repeats)

      # The first block needs to take care of stride and filter size increase.
      conv_block = self._get_conv_block(block_args.conv_type)
      if not block_args.space2depth:  #  no space2depth at all
        self._blocks.append(conv_block(block_args, self._global_params))
      else:
        # if space2depth, adjust filters, kernels, and strides.
        depth_factor = int(4 / block_args.strides[0] / block_args.strides[1])
        block_args = block_args._replace(
            input_filters=block_args.input_filters * depth_factor,
            output_filters=block_args.output_filters * depth_factor,
            kernel_size=((block_args.kernel_size + 1) // 2 if depth_factor > 1
                         else block_args.kernel_size))
        # if the first block has stride-2 and space2depth transformation
        if (block_args.strides[0] == 2 and block_args.strides[1] == 2):
          block_args = block_args._replace(strides=[1, 1])
          self._blocks.append(conv_block(block_args, self._global_params))
          block_args = block_args._replace(  # sp stops at stride-2
              space2depth=0,
              input_filters=input_filters,
              output_filters=output_filters,
              kernel_size=kernel_size)
        elif block_args.space2depth == 1:
          self._blocks.append(conv_block(block_args, self._global_params))
          block_args = block_args._replace(space2depth=2)
        else:
          self._blocks.append(conv_block(block_args, self._global_params))
      if block_args.num_repeat > 1:  # rest of blocks with the same block_arg
        # pylint: disable=protected-access
        block_args = block_args._replace(
            input_filters=block_args.output_filters, strides=[1, 1])
        # pylint: enable=protected-access
      for _ in xrange(block_args.num_repeat - 1):
        self._blocks.append(conv_block(block_args, self._global_params))

    # Head part.
    self._conv_head = utils.Conv2D(
        filters=round_filters(1280, self._global_params, self._fix_head_stem),
        kernel_size=[1, 1],
        strides=[1, 1],
        kernel_initializer=conv_kernel_initializer,
        padding='same',
        data_format=self._global_params.data_format,
        use_bias=False)
    self._bn1 = self._batch_norm(
        axis=channel_axis,
        momentum=batch_norm_momentum,
        epsilon=batch_norm_epsilon)

    self._avg_pooling = tf.keras.layers.GlobalAveragePooling2D(
        data_format=self._global_params.data_format)
    if self._global_params.num_classes:
      self._fc = tf.layers.Dense(
          self._global_params.num_classes,
          kernel_initializer=dense_kernel_initializer)
    else:
      self._fc = None

    if self._global_params.dropout_rate > 0:
      self._dropout = tf.keras.layers.Dropout(self._global_params.dropout_rate)
    else:
      self._dropout = None

  def call(self,
           inputs,
           training=True,
           features_only=None,
           pooled_features_only=False):
    """Implementation of call().
    Args:
      inputs: input tensors.
      training: boolean, whether the model is constructed for training.
      features_only: build the base feature network only.
      pooled_features_only: build the base network for features extraction
        (after 1x1 conv layer and global pooling, but before dropout and fc
        head).
    Returns:
      output tensors.
    """
    outputs = None
    self.endpoints = {}
    reduction_idx = 0
    # Calls Stem layers
    with tf.variable_scope('stem'):
      outputs = self._relu_fn(
          self._bn0(self._conv_stem(inputs), training=training))
    logging.info('Built stem layers with output shape: %s', outputs.shape)
    self.endpoints['stem'] = outputs

    # Calls blocks.
    for idx, block in enumerate(self._blocks):
      is_reduction = False  # reduction flag for blocks after the stem layer
      # If the first block has space-to-depth layer, then stem is
      # the first reduction point.
      if (block.block_args().space2depth == 1 and idx == 0):
        reduction_idx += 1
        self.endpoints['reduction_%s' % reduction_idx] = outputs

      elif ((idx == len(self._blocks) - 1) or
            self._blocks[idx + 1].block_args().strides[0] > 1):
        is_reduction = True
        reduction_idx += 1

      with tf.variable_scope('blocks_%s' % idx):
        survival_prob = self._global_params.survival_prob
        if survival_prob:
          drop_rate = 1.0 - survival_prob
          survival_prob = 1.0 - drop_rate * float(idx) / len(self._blocks)
          logging.info('block_%s survival_prob: %s', idx, survival_prob)
        outputs = block.call(
            outputs, training=training, survival_prob=survival_prob)
        self.endpoints['block_%s' % idx] = outputs
        if is_reduction:
          self.endpoints['reduction_%s' % reduction_idx] = outputs
        if block.endpoints:
          for k, v in six.iteritems(block.endpoints):
            self.endpoints['block_%s/%s' % (idx, k)] = v
            if is_reduction:
              self.endpoints['reduction_%s/%s' % (reduction_idx, k)] = v
    self.endpoints['features'] = outputs

    if not features_only:
      # Calls final layers and returns logits.
      with tf.variable_scope('head'):
        outputs = self._relu_fn(
            self._bn1(self._conv_head(outputs), training=training))
        self.endpoints['head_1x1'] = outputs

        if self._global_params.local_pooling:
          shape = outputs.get_shape().as_list()
          kernel_size = [
              1, shape[self._spatial_dims[0]], shape[self._spatial_dims[1]], 1]
          outputs = tf.nn.avg_pool(
              outputs, ksize=kernel_size, strides=[1, 1, 1, 1], padding='VALID')
          self.endpoints['pooled_features'] = outputs
          if not pooled_features_only:
            if self._dropout:
              outputs = self._dropout(outputs, training=training)
            self.endpoints['global_pool'] = outputs
            if self._fc:
              outputs = tf.squeeze(outputs, self._spatial_dims)
              outputs = self._fc(outputs)
            self.endpoints['head'] = outputs
        else:
          outputs = self._avg_pooling(outputs)
          self.endpoints['pooled_features'] = outputs
          if not pooled_features_only:
            if self._dropout:
              outputs = self._dropout(outputs, training=training)
            self.endpoints['global_pool'] = outputs
            if self._fc:
              outputs = self._fc(outputs)
            self.endpoints['head'] = outputs
    return outputs

# efficientnet_lite_model_qat.py


In [None]:
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements EfficientNet Lite model for Quantization Aware Training.
[1] Mingxing Tan, Quoc V. Le
  EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks.
  ICML'19, https://arxiv.org/abs/1905.11946
"""
import functools

import tensorflow.compat.v1 as tf

#import efficientnet_model


class FunctionalModelBuilder:
  """A class that builds functional api keras models."""

  def __init__(self, name='FunctionalModel'):
    self.name = name
    self.built = False

  def build(self, input_shape: tf.TensorShape):
    del input_shape  # Only used by subclasses.
    self.built = True

  def call(self, inputs, training):
    raise NotImplementedError('This function is implemented in subclasses.')

  def get_functional_model(self, input_shape, training):
    functional_inputs = tf.keras.Input(
        shape=input_shape[1:], batch_size=input_shape[0])
    functional_outputs = self(functional_inputs, training)
    return tf.keras.Model(inputs=functional_inputs, outputs=functional_outputs)

  def __call__(self, inputs, training):
    if not self.built:
      if tf.nest.is_nested(inputs):
        input_shapes = [
            input_tensor.shape for input_tensor in tf.nest.flatten(inputs)
        ]
      else:
        input_shapes = inputs.shape
      self.build(input_shapes[1:])
    return self.call(inputs, training)


class FunctionalMBConvBlock(FunctionalModelBuilder):
  """A class of MBConv: Mobile Inverted Residual Bottleneck.
  Attributes:
    endpoints: dict. A list of internal tensors.
  """

  def __init__(self, block_args, global_params, dtype, name, **kwargs):
    """Initializes a MBConv block.
    Args:
      block_args: BlockArgs, arguments to create a Block.
      global_params: GlobalParams, a set of global parameters.
      dtype: Layer type.
      name: Layer name.
      **kwargs: Keyword arguments.
    """
    super().__init__(**kwargs)
    self._block_args = block_args
    self._dtype = dtype
    self._name = name
    self._batch_norm_momentum = global_params.batch_norm_momentum
    self._batch_norm_epsilon = global_params.batch_norm_epsilon
    self._batch_norm = global_params.batch_norm
    self._data_format = global_params.data_format
    self._conv_kernel_initializer = tf.compat.v2.keras.initializers.VarianceScaling(
        scale=2.0, mode='fan_out', distribution='untruncated_normal')
    if self._data_format == 'channels_first':
      self._channel_axis = 1
      self._spatial_dims = [2, 3]
    else:
      self._channel_axis = -1
      self._spatial_dims = [1, 2]

    self._relu_fn = functools.partial(tf.keras.layers.ReLU, 6.0)
    self._survival_prob = global_params.survival_prob

    self.endpoints = None

  def block_args(self):
    return self._block_args

  def build(self, input_shape):
    """Builds block according to the arguments."""
    conv2d_id = 0
    batch_norm_id = 0
    if self._block_args.expand_ratio != 1:
      self._expand_conv = tf.keras.layers.Conv2D(
          filters=(self._block_args.input_filters *
                   self._block_args.expand_ratio),
          kernel_size=[1, 1],
          strides=[1, 1],
          kernel_initializer=self._conv_kernel_initializer,
          padding='same',
          data_format=self._data_format,
          use_bias=False,
          dtype=self._dtype,
          name=f'{self._name}/conv2d')
      conv2d_id += 1
      self._bn0 = self._batch_norm(
          axis=self._channel_axis,
          momentum=self._batch_norm_momentum,
          epsilon=self._batch_norm_epsilon,
          dtype=self._dtype,
          name=f'{self._name}/tpu_batch_normalization')
      batch_norm_id += 1

    self._depthwise_conv = tf.keras.layers.DepthwiseConv2D(
        kernel_size=[
            self._block_args.kernel_size, self._block_args.kernel_size
        ],
        strides=self._block_args.strides,
        depthwise_initializer=self._conv_kernel_initializer,
        padding='same',
        data_format=self._data_format,
        use_bias=False,
        dtype=self._dtype,
        name=f'{self._name}/depthwise_conv2d')

    batch_norm_name_suffix = f'_{batch_norm_id}' if batch_norm_id else ''
    self._bn1 = self._batch_norm(
        axis=self._channel_axis,
        momentum=self._batch_norm_momentum,
        epsilon=self._batch_norm_epsilon,
        dtype=self._dtype,
        name=f'{self._name}/tpu_batch_normalization{batch_norm_name_suffix}')
    batch_norm_id += 1

    # Output phase.
    conv2d_name_suffix = f'_{conv2d_id}' if conv2d_id else ''
    self._project_conv = tf.keras.layers.Conv2D(
        filters=self._block_args.output_filters,
        kernel_size=[1, 1],
        strides=[1, 1],
        kernel_initializer=self._conv_kernel_initializer,
        padding='same',
        data_format=self._data_format,
        use_bias=False,
        dtype=self._dtype,
        name=f'{self._name}/conv2d{conv2d_name_suffix}')
    batch_norm_name_suffix = f'_{batch_norm_id}' if batch_norm_id else ''
    self._bn2 = self._batch_norm(
        axis=self._channel_axis,
        momentum=self._batch_norm_momentum,
        epsilon=self._batch_norm_epsilon,
        dtype=self._dtype,
        name=f'{self._name}/tpu_batch_normalization{batch_norm_name_suffix}')
    self._spartial_dropout_2d = tf.keras.layers.SpatialDropout2D(
        rate=1 - self._survival_prob, dtype=self._dtype)

  def call(self, inputs, training):
    """Implementation of call().
    Args:
      inputs: the inputs tensor.
      training: boolean, whether the model is constructed for training.
    Returns:
      A output tensor.
    """
    x = inputs

    if self._block_args.expand_ratio != 1:
      x = self._relu_fn()(self._bn0(self._expand_conv(x), training=training))

    x = self._relu_fn()(self._bn1(self._depthwise_conv(x), training=training))
    self.endpoints = {'expansion_output': x}

    x = self._bn2(self._project_conv(x), training=training)

    if (all(s == 1 for s in self._block_args.strides) and
        inputs.get_shape().as_list()[-1] == x.get_shape().as_list()[-1]):
      # Apply only if skip connection presents.
      if self._survival_prob:
        x = self._spartial_dropout_2d(x)
      x = tf.keras.layers.Add(dtype=self._dtype)([x, inputs])

    return x


class FunctionalModel(FunctionalModelBuilder):
  """A class implements tf.keras.Model for MNAS-like model.
    Reference: https://arxiv.org/abs/1807.11626
  """

  def __init__(self,
               model_name,
               blocks_args=None,
               global_params=None,
               features_only=None,
               pooled_features_only=False,
               **kwargs):
    """Initializes an `Model` instance.
    Args:
      model_name: Name of the model.
      blocks_args: A list of BlockArgs to construct block modules.
      global_params: GlobalParams, a set of global parameters.
      features_only: build the base feature network only.
      pooled_features_only: build the base network for features extraction
        (after 1x1 conv layer and global pooling, but before dropout and fc
        head).
      **kwargs: Keyword arguments.
    Raises:
      ValueError: when blocks_args is not specified as a list.
    """
    super().__init__(**kwargs)
    if not isinstance(blocks_args, list):
      raise ValueError('blocks_args should be a list.')
    self._model_name = model_name
    self._global_params = global_params
    self._blocks_args = blocks_args
    self._dtype = 'float32'
    if self._global_params.use_bfloat16:
      self._dtype = 'mixed_bfloat16'
    self._features_only = features_only
    self._pooled_features_only = pooled_features_only
    self._relu_fn = functools.partial(tf.keras.layers.ReLU, 6.0)
    self._batch_norm = global_params.batch_norm
    self._fix_head_stem = global_params.fix_head_stem
    self._conv_kernel_initializer = tf.compat.v2.keras.initializers.VarianceScaling(
        scale=2.0, mode='fan_out', distribution='untruncated_normal')
    self._dense_kernel_initializer = tf.keras.initializers.VarianceScaling(
        scale=1.0 / 3.0, mode='fan_out', distribution='uniform')
    self.endpoints = None

  def build(self, input_shape):
    """Builds a model."""
    del input_shape  # Unused.
    self._blocks = []
    batch_norm_momentum = self._global_params.batch_norm_momentum
    batch_norm_epsilon = self._global_params.batch_norm_epsilon
    if self._global_params.data_format == 'channels_first':
      channel_axis = 1
      self._spatial_dims = [2, 3]
    else:
      channel_axis = -1
      self._spatial_dims = [1, 2]

    # Stem part.
    self._conv_stem = tf.keras.layers.Conv2D(
        filters=efficientnet_model.round_filters(32, self._global_params,
                                                 self._fix_head_stem),
        kernel_size=[3, 3],
        strides=[2, 2],
        kernel_initializer=self._conv_kernel_initializer,
        padding='same',
        data_format=self._global_params.data_format,
        use_bias=False,
        dtype=self._dtype,
        name=f'{self._model_name}/stem/conv2d')
    self._bn0 = self._batch_norm(
        axis=channel_axis,
        momentum=batch_norm_momentum,
        epsilon=batch_norm_epsilon,
        name=f'{self._model_name}/stem/tpu_batch_normalization')

    # Builds blocks.
    for i, block_args in enumerate(self._blocks_args):
      assert block_args.num_repeat > 0
      assert block_args.space2depth in [0, 1, 2]
      # Update block input and output filters based on depth multiplier.
      input_filters = efficientnet_model.round_filters(block_args.input_filters,
                                                       self._global_params)

      output_filters = efficientnet_model.round_filters(
          block_args.output_filters, self._global_params)
      if self._fix_head_stem and (i == 0 or i == len(self._blocks_args) - 1):
        repeats = block_args.num_repeat
      else:
        repeats = efficientnet_model.round_repeats(block_args.num_repeat,
                                                   self._global_params)
      block_args = block_args._replace(
          input_filters=input_filters,
          output_filters=output_filters,
          num_repeat=repeats)

      # The first block needs to take care of stride and filter size increase.
      self._blocks.append(
          FunctionalMBConvBlock(
              block_args=block_args,
              global_params=self._global_params,
              dtype=self._dtype,
              name=f'{self._model_name}/blocks_{len(self._blocks)}'))

      if block_args.num_repeat > 1:  # rest of blocks with the same block_arg
        # pylint: disable=protected-access
        block_args = block_args._replace(
            input_filters=block_args.output_filters, strides=[1, 1])
        # pylint: enable=protected-access
      for _ in range(block_args.num_repeat - 1):
        self._blocks.append(
            FunctionalMBConvBlock(
                block_args,
                self._global_params,
                dtype=self._dtype,
                name=f'{self._model_name}/blocks_{len(self._blocks)}'))

    # Head part.
    self._conv_head = tf.keras.layers.Conv2D(
        filters=efficientnet_model.round_filters(1280, self._global_params,
                                                 self._fix_head_stem),
        kernel_size=[1, 1],
        strides=[1, 1],
        kernel_initializer=self._conv_kernel_initializer,
        padding='same',
        data_format=self._global_params.data_format,
        use_bias=False,
        dtype=self._dtype,
        name=f'{self._model_name}/head/conv2d')
    self._bn1 = self._batch_norm(
        axis=channel_axis,
        momentum=batch_norm_momentum,
        epsilon=batch_norm_epsilon,
        dtype=self._dtype,
        name=f'{self._model_name}/head/tpu_batch_normalization')

    if self._global_params.num_classes:
      self._fc = tf.keras.layers.Dense(
          self._global_params.num_classes,
          kernel_initializer=self._dense_kernel_initializer,
          dtype=self._dtype,
          name=f'{self._model_name}/head/dense')
    else:
      self._fc = None

    if self._global_params.dropout_rate > 0:
      self._dropout = tf.keras.layers.Dropout(
          self._global_params.dropout_rate, dtype=self._dtype)
    else:
      self._dropout = None

  def call(self, inputs, training):
    """Implementation of call().
    Args:
      inputs: input tensors.
      training: boolean, whether the model is constructed for training.
    Returns:
      output tensors.
    """
    outputs = None
    self.endpoints = {}
    reduction_idx = 0

    # Calls Stem layers
    outputs = self._relu_fn()(
        self._bn0(self._conv_stem(inputs), training=training))
    self.endpoints['stem'] = outputs

    # Calls blocks.
    for idx, block in enumerate(self._blocks):
      is_reduction = False  # reduction flag for blocks after the stem layer
      if ((idx == len(self._blocks) - 1) or
          self._blocks[idx + 1].block_args().strides[0] > 1):
        is_reduction = True
        reduction_idx += 1

      survival_prob = self._global_params.survival_prob
      if survival_prob:
        drop_rate = 1.0 - survival_prob
        survival_prob = 1.0 - drop_rate * float(idx) / len(self._blocks)
      outputs = block(outputs, training)
      self.endpoints['block_%s' % idx] = outputs

      if is_reduction:
        self.endpoints['reduction_%s' % reduction_idx] = outputs
      if block.endpoints:
        for k, v in block.endpoints.items():
          self.endpoints['block_%s/%s' % (idx, k)] = v
          if is_reduction:
            self.endpoints['reduction_%s/%s' % (reduction_idx, k)] = v
    self.endpoints['features'] = outputs

    if not self._features_only:
      outputs = self._relu_fn()(
          self._bn1(self._conv_head(outputs), training=training))
      self.endpoints['head_1x1'] = outputs

      shape = outputs.get_shape().as_list()
      outputs = tf.keras.layers.AveragePooling2D(
          pool_size=(shape[self._spatial_dims[0]],
                     shape[self._spatial_dims[1]]),
          strides=[1, 1],
          padding='valid',
          dtype=self._dtype)(
              outputs)
      self.endpoints['pooled_features'] = outputs
      if not self._pooled_features_only:
        if self._dropout:
          outputs = self._dropout(outputs)
        self.endpoints['global_pool'] = outputs
        if self._fc:
          outputs = tf.keras.layers.Flatten(dtype=self._dtype)(outputs)
          outputs = self._fc(outputs)
        self.endpoints['head'] = outputs

    return [outputs] + list(
        filter(lambda endpoint: endpoint is not None, [
            self.endpoints.get('reduction_1'),
            self.endpoints.get('reduction_2'),
            self.endpoints.get('reduction_3'),
            self.endpoints.get('reduction_4'),
            self.endpoints.get('reduction_5'),
        ]))

# efficientnet_lite_builder

In [None]:
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model Builder for EfficientNet Edge Models."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from absl import logging
import tensorflow.compat.v1 as tf

#import efficientnet_builder
#import efficientnet_model
#import utils
#from lite import efficientnet_lite_model_qat
# Edge models use inception-style MEAN and STDDEV for better post-quantization.
MEAN_RGB = [127.0, 127.0, 127.0]
STDDEV_RGB = [128.0, 128.0, 128.0]


def efficientnet_lite_params(model_name):
  """Get efficientnet params based on model name."""
  if '-qat' in model_name:
    model_name = model_name[:model_name.find('-qat')]
  params_dict = {
      # (width_coefficient, depth_coefficient, resolution, dropout_rate)
      'efficientnet-lite0': (1.0, 1.0, 224, 0.2),
      'efficientnet-lite1': (1.0, 1.1, 240, 0.2),
      'efficientnet-lite2': (1.1, 1.2, 260, 0.3),
      'efficientnet-lite3': (1.2, 1.4, 280, 0.3),
      'efficientnet-lite4': (1.4, 1.8, 300, 0.3),
  }
  return params_dict[model_name]


_DEFAULT_BLOCKS_ARGS = [
    'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
    'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
    'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
    'r1_k3_s11_e6_i192_o320_se0.25',
]


def efficientnet_lite(width_coefficient=None,
                      depth_coefficient=None,
                      dropout_rate=0.2,
                      survival_prob=0.8):
  """Creates a efficientnet model."""
  global_params = efficientnet_model.GlobalParams(
      blocks_args=_DEFAULT_BLOCKS_ARGS,
      batch_norm_momentum=0.99,
      batch_norm_epsilon=1e-3,
      dropout_rate=dropout_rate,
      survival_prob=survival_prob,
      data_format='channels_last',
      num_classes=1000,
      width_coefficient=width_coefficient,
      depth_coefficient=depth_coefficient,
      depth_divisor=8,
      min_depth=None,
      relu_fn=tf.nn.relu6,  # Relu6 is for easier quantization.
      # The default is TPU-specific batch norm.
      # The alternative is tf.layers.BatchNormalization.
      batch_norm=utils.TpuBatchNormalization,  # TPU-specific requirement.
      clip_projection_output=False,
      fix_head_stem=True,  # Don't scale stem and head.
      local_pooling=True,  # special cases for tflite issues.
      use_se=False,  # SE is not well supported on many lite devices.
      use_bfloat16=False)  # This flag is only read by QAT version of the model.
  return global_params


def get_model_params(model_name, override_params):
  """Get the block args and global params for a given model."""
  if model_name.startswith('efficientnet-lite'):
    width_coefficient, depth_coefficient, _, dropout_rate = (
        efficientnet_lite_params(model_name))
    global_params = efficientnet_lite(
        width_coefficient, depth_coefficient, dropout_rate)
  else:
    raise NotImplementedError('model name is not pre-defined: %s' % model_name)

  if override_params:
    # ValueError will be raised here if override_params has fields not included
    # in global_params.
    global_params = global_params._replace(**override_params)

  decoder = efficientnet_builder.BlockDecoder()
  blocks_args = decoder.decode(global_params.blocks_args)

  logging.info('global_params= %s', global_params)
  return blocks_args, global_params


def build_model(images,
                model_name,
                training,
                override_params=None,
                model_dir=None,
                fine_tuning=False,
                features_only=False,
                pooled_features_only=False):
  """A helper function to create a model and return predicted logits.
  Args:
    images: input images tensor.
    model_name: string, the predefined model name.
    training: boolean, whether the model is constructed for training.
    override_params: A dictionary of params for overriding. Fields must exist in
      efficientnet_model.GlobalParams.
    model_dir: string, optional model dir for saving configs.
    fine_tuning: boolean, whether the model is used for finetuning.
    features_only: build the base feature network only (excluding final
      1x1 conv layer, global pooling, dropout and fc head).
    pooled_features_only: build the base network for features extraction (after
      1x1 conv layer and global pooling, but before dropout and fc head).
  Returns:
    logits: the logits tensor of classes.
    endpoints: the endpoints for each layer.
  Raises:
    When model_name specified an undefined model, raises NotImplementedError.
    When override_params has invalid fields, raises ValueError.
  """
  assert isinstance(images, tf.Tensor)
  assert not (features_only and pooled_features_only)

  # For backward compatibility.
  if override_params and override_params.get('drop_connect_rate', None):
    override_params['survival_prob'] = 1 - override_params['drop_connect_rate']

  if '-qat' in model_name:
    model_name = model_name[:model_name.find('-qat')]
    with_quantization_aware_training = True
  else:
    with_quantization_aware_training = False

  if not training or fine_tuning:
    if not override_params:
      override_params = {}
    override_params['batch_norm'] = utils.BatchNormalization
  blocks_args, global_params = get_model_params(model_name, override_params)

  if model_dir:
    param_file = os.path.join(model_dir, 'model_params.txt')
    if not tf.gfile.Exists(param_file):
      if not tf.gfile.Exists(model_dir):
        tf.gfile.MakeDirs(model_dir)
      with tf.gfile.GFile(param_file, 'w') as f:
        logging.info('writing to %s', param_file)
        f.write('model_name= %s\n\n' % model_name)
        f.write('global_params= %s\n\n' % str(global_params))
        f.write('blocks_args= %s\n\n' % str(blocks_args))

  with tf.variable_scope(model_name):
    if with_quantization_aware_training:
      model = efficientnet_lite_model_qat.FunctionalModel(
          model_name, blocks_args, global_params, features_only,
          pooled_features_only)
      outputs = model(images, training=training)[0]
    else:
      model = efficientnet_model.Model(blocks_args, global_params)
      outputs = model(
          images,
          training=training,
          features_only=features_only,
          pooled_features_only=pooled_features_only)
  if features_only:
    outputs = tf.identity(outputs, 'features')
  elif pooled_features_only:
    outputs = tf.identity(outputs, 'pooled_features')
  else:
    outputs = tf.identity(outputs, 'logits')
  return outputs, model.endpoints


def build_model_base(images, model_name, training, override_params=None):
  """Create a base feature network and return the features before pooling.
  Args:
    images: input images tensor.
    model_name: string, the predefined model name.
    training: boolean, whether the model is constructed for training.
    override_params: A dictionary of params for overriding. Fields must exist in
      efficientnet_model.GlobalParams.
  Returns:
    features: base features before pooling.
    endpoints: the endpoints for each layer.
  Raises:
    When model_name specified an undefined model, raises NotImplementedError.
    When override_params has invalid fields, raises ValueError.
  """
  assert isinstance(images, tf.Tensor)
  # For backward compatibility.
  if override_params and override_params.get('drop_connect_rate', None):
    override_params['survival_prob'] = 1 - override_params['drop_connect_rate']

  blocks_args, global_params = get_model_params(model_name, override_params)

  with tf.variable_scope(model_name):
    model = efficientnet_model.Model(blocks_args, global_params)
    features = model(images, training=training, features_only=True)

  features = tf.identity(features, 'features')
  return features, model.endpoints

# *efficientnet_lite_builder_test*.py


In [None]:
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for efficientnet_lite_builder."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow.compat.v1 as tf

#from lite import efficientnet_lite_builder


class EfficientnetBuilderTest(tf.test.TestCase):

  def _test_model_params(self,
                         model_name,
                         input_size,
                         expected_params,
                         override_params=None,
                         features_only=False,
                         pooled_features_only=False):
    images = tf.zeros((1, input_size, input_size, 3), dtype=tf.float32)
    efficientnet_lite_builder.build_model(
        images,
        model_name=model_name,
        override_params=override_params,
        training=True,
        features_only=features_only,
        pooled_features_only=pooled_features_only)
    num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])

    self.assertEqual(num_params, expected_params)

  def test_efficientnet_b0(self):
    self._test_model_params(
        'efficientnet-lite0', 224, expected_params=4652008)

  def test_efficientnet_b1(self):
    self._test_model_params(
        'efficientnet-lite1', 240, expected_params=5416680)

  def test_efficientnet_b2(self):
    self._test_model_params(
        'efficientnet-lite2', 260, expected_params=6092072)

  def test_efficientnet_b3(self):
    self._test_model_params(
        'efficientnet-lite3', 280, expected_params=8197096)

  def test_efficientnet_b4(self):
    self._test_model_params(
        'efficientnet-lite4', 300, expected_params=13006568)


if __name__ == '__main__':
  tf.test.main()

# efficientnet_lite_model_qat_test.py 


In [None]:
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for efficientnet_lite_model_qat."""

from absl.testing import parameterized
import tensorflow as tf
#import tensorflow_model_optimization as tfmot

#from lite import efficientnet_lite_builder
#from lite import efficientnet_lite_model_qat


class EfficientnetLiteModelQatTest(parameterized.TestCase, tf.test.TestCase):

  @parameterized.parameters(('efficientnet-lite0',), ('efficientnet-lite1'),
                            ('efficientnet-lite2',), ('efficientnet-lite3',),
                            ('efficientnet-lite4',))
  def test_values_match(self, model_name):
    images = tf.random.stateless_uniform((1, 224, 224, 3), seed=(2, 3))

    tf.random.set_seed(0)
    outputs, _ = efficientnet_lite_builder.build_model(
        images,
        model_name=model_name,
        override_params=None,
        training=False,
        features_only=False,
        pooled_features_only=False)

    tf.random.set_seed(0)
    outputs_qat, _ = efficientnet_lite_builder.build_model(
        images,
        model_name=model_name + '-qat',
        override_params=None,
        training=False,
        features_only=False,
        pooled_features_only=False)

    self.assertAllClose(tf.reduce_sum(outputs), tf.reduce_sum(outputs_qat))

  @parameterized.parameters(('efficientnet-lite0',), ('efficientnet-lite1'),
                            ('efficientnet-lite2',), ('efficientnet-lite3',),
                            ('efficientnet-lite4',))
  def test_model_quantizable(self, model_name):
    images = tf.random.stateless_uniform((1, 224, 224, 3), seed=(2, 3))
    override_params = {}
    override_params['batch_norm'] = tf.keras.layers.BatchNormalization
    blocks_args, global_params = efficientnet_lite_builder.get_model_params(
        model_name, override_params=override_params)
    model_qat = efficientnet_lite_model_qat.FunctionalModel(
        model_name=model_name,
        blocks_args=blocks_args,
        global_params=global_params,
        features_only=False,
        pooled_features_only=False).get_functional_model(
            training=True, input_shape=images.shape)
    try:
      tfmot.quantization.keras.quantize_model(model_qat)
    except Exception as e:  # pylint: disable=broad-except
      self.fail('Exception raised: %s' % str(e))


if __name__ == '__main__':
  tf.test.main()

# 剛開始設定的初始值

In [None]:
import torch
import torch.nn as nn
from math import ceil


base_model= [

    #expand_ratio , channels ,repeats , stride , kernel_size     
    [1, 16 ,1 ,1 ,3],
    [6, 24 ,2 ,2 ,3],
    [6, 40 ,2 ,2 ,5],
    [6, 80 ,3 ,2 ,3],
    [6, 112 ,3 ,1, 5],         
    [6, 192 ,4 ,2, 5],
    [6, 320 ,1 ,1 ,3],
             
]


phi_values = {
    
    "b0":(0,224,0.2),
    "b1":(0.5,240,0.2),
    "b2":(1,260,0.3),
    "b3":(2,300,0.3),
    "b4":(3,380,0.4),
    "b5":(4,456,0.4),
    "b6":(5,528,0.5),
    "b7":(6,600,0.5),

}

# CNNBlock

In [None]:

import torch
import torch.nn as nn
from math import ceil






class CNNBlock(nn.Module):
  def __init__(
      self,in_channels, out_channels ,kernel_size,stride ,padding,groups=1 ):
    
    super(CNNBlock, self).__init__()
    self.cnn = nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        groups=groups,
    )
    self.bn = nn.BatchNorm2d(out_channels)
    self.silu= nn.SiLU() #SiLU <-> Swish

  def forward(self,x):
    return self.silu(self.bn(self.cnn(x)))






    

  

# SqueezeExcitation

In [None]:
class SqueezeExcitation(nn.Module):
  def __init__(self,in_channels, reduced_dim):
    super(SqueezeExcitation, self).__init__()
    self.se =nn.Sequential(
     nn.AdaptiveAvgPool2d(1), 
     nn.Conv2d(in_channels, reduced_dim, 1),
     nn.SiLU(),
     nn.Conv2d(reduced_dim,in_channels, 1),
     nn.Sigmoid(),
    )   
  def forward(self, x):
     return x*self.se(x)





  

# InvertedResidualBlock

In [None]:
class InvertedResidualBlock(nn.Module):
  def __init__(
      self,
      in_channels,
      out_channels,
      kernel_size,
      stride,
      padding,
      expand_ratio,
      reduction=4,
      survival_prob=0.8,

  ):
    super(InvertedResidualBlock, self).__init__()
    self.survival_prob = 0.8
    self.use_residual = in_channels == out_channels and stride == 1 
    hidden_dim  = in_channels * expand_ratio
    self.expand = in_channels != hidden_dim
    reduced_dim = int(in_channels / reduction)

    if self.expand:
      self.expand_conv = CNNBlock(in_channels,hidden_dim,kernel_size=3,stride=1,padding=1,
      )
    
    self.conv = nn.Sequential(
      CNNBlock(
          hidden_dim, hidden_dim, kernel_size,stride,padding, groups=hidden_dim,
      ),
      SqueezeExcitation(hidden_dim, reduced_dim),
      nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
      nn.BatchNorm2d(out_channels),
    )

  def stochastic_depth(self, x):
     if not self.training:
       return x
     binary_tensor = torch.rand(x.shape[0], 1, 1, 1, device=x.device) < self.survival_prob
     return torch.div(x, self.survival_prob)*binary_tensor

    


  def forward(self, inputs):
     x= self.expand_conv(inputs) if self.expand else inputs 

     if self.use_residual:
       return self.stochastic_depth(self.conv(x)) + inputs
     else:
       return self.conv(x)

# Efficient Net

In [None]:
class EfficientNet(nn.Module):
  def __init__(self, version, num_classes):
    super(EfficientNet,self).__init__()
    width_factor, depth_factor, dropout_rate=self.calculate_factors(version)
    last_channels = ceil(1280*width_factor)
    self.pool = nn.AdaptiveAvgPool2d(1)
    self.features = self.create_features(width_factor, depth_factor, last_channels)
    self.classifier= nn.Sequential(
        nn.Dropout(dropout_rate),
        nn.Linear(last_channels, num_classes),

    )


  def calculate_factors(self,version,alpha=1.2, beta=1.1):
    phi , res ,drop_rate = phi_values[version]
    depth_factor = alpha**phi
    width_factor = beta**phi
    return width_factor,depth_factor, drop_rate 

  def create_features(self,width_factor,depth_factor,last_channels):
    channels =int(32*width_factor)
    features =[CNNBlock(3,channels, 3, stride=2,padding=1)]
    in_channels = channels 

    for expand_ratio, channels, repeats , stride , kernel_size in base_model:
      out_channels = 4*ceil(int(channels*width_factor)/4)
      layers_repeats=  ceil(repeats * depth_factor)

      for layer in range(layers_repeats):
        features.append(
            InvertedResidualBlock(
                in_channels,
                out_channels,
                expand_ratio=expand_ratio,
                stride = stride if layer == 0 else 1,
                kernel_size=kernel_size,
                padding=kernel_size//2,
            )
            
        )
        in_channels = out_channels

      features.append(
        CNNBlock(in_channels, last_channels, kernel_size=1 , stride=1, padding=0)
      )

      return nn.Sequential(*features)

  def forward(self,x):
     x = self.pool(self.features(x))
     return self.classifier(x.view(x.shape[0], -1))
  
        


        



# Evaulation &  comparision


In [None]:

version ="b0"
num_examples, num_classes = 4, 10

net=EfficientNet(version=version,num_classes=num_classes)
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net.to(device)

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)

In [None]:


for epoch in range(10):
    
    running_loss= 0.0
    for i ,data in enumerate(trainloader, 0):
        
        inputs, labels =data
        device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        inputs, labels =inputs.to(device), labels.to(device)


        optimizer.zero_grad()
        outputs=net(inputs)
        loss =criterion(outputs,labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 2000 ==1999:
            print(f'[{epoch+1},{i+1:5d}] loss: {running_loss / 2000 :.3f}')
            running_loss =0.0
            
print('Finished Training')

In [None]:

PATH ='./clfar_net.pth'
torch.save(net.state_dict(),PATH)

In [None]:
net =Net()
net.load_state_dict(torch.load(PATH))

outputs = net(images)



dataiter =iter(testloader)
images , labels =dataiter.next()


imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ' ,''.join(f'{classes[labels[j]]:5s}' for j in range(4)))

In [None]:
correct_pred = {classname: 0 for classname in classes}
total_pred =  {classname: 0 for classname in classes}

In [None]:
with torch.no_grad():
  for data in testloader:
    images, labels = data
    outputs = net(images)
    _, predictions =torch.max(outputs, 1)

    for label, prediction in zip(labels, predictions):
      if label ==prediction:
        correct_pred[classes[label]] += 1
      total_pred[classes[label]] += 1



for classname, correct_count in correct_pred.items():
  accuracy =100*float(correct_count) / total_pred[classname]
  print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')