In [None]:
#@markdown # Installation

%%bash
pip3 install -qU tensorflow-datasets
pip3 install -qU tensorflow-addons --no-deps
pip3 install -qU jupyter-tensorboard
pip3 install -qU --pre efficientnet

In [None]:
#@markdown # Prepare a package for custom dataset.

!pip install -Uq tensorflow_datasets

def _snake_to_camel(word):
  return ''.join(x.capitalize() or '_' for x in word.split('_'))

NAME_DATASET = 'aigc1110v1' #@param ['aigc1109v2', 'aigc1110v1'] {type: 'string'}
NAME_DATASET_CAMEL = _snake_to_camel(NAME_DATASET)

import os
if not os.path.exists(NAME_DATASET):
  !tfds new $NAME_DATASET

_code = f'''

import os
import io
import json

from absl import logging
from xml.etree import ElementTree

import tensorflow_datasets as tfds
from tensorflow.io import decode_jpeg, encode_jpeg, extract_jpeg_shape

_DESCRIPTION = """
Trash datasets made by IRIS Lab at 2020/10/19.
"""

_OBJECT_LABELS = [
    'paper',
    'paperpack',
    'can',      
    'glass',    
    'pet',      
    'plastic',  
    'vinyl',    
]

class {NAME_DATASET_CAMEL}(tfds.core.GeneratorBasedBuilder):
  """DatasetBuilder for aigc_201019 dataset."""

  VERSION = tfds.core.Version('1.0.0')
  RELEASE_NOTES = {{
      '1.0.0': 'Initial release.',
  }}
  MANUAL_DOWNLOAD_INSTRUCTIONS = """\
  manual_dir should contain two files:  {NAME_DATASET}_train.tar and
  {NAME_DATASET}_valid.tar'
  """

  def _info(self) -> tfds.core.DatasetInfo:
    """Returns the dataset metadata."""
    annotations = {{
        'label' : tfds.features.ClassLabel(names=_OBJECT_LABELS),
        'bbox'  : tfds.features.BBoxFeature()
    }}
    return tfds.core.DatasetInfo(
        builder=self,
        description=_DESCRIPTION,
        features=tfds.features.FeaturesDict({{
            'image' : tfds.features.Image(encoding_format='jpeg'),
            'objects': tfds.features.Sequence(annotations),
            'image_id': tfds.features.Text(),
        }}),
        # If there's a common (input, target) tuple from the
        # features, specify them here. They'll be used if
        # `as_supervised=True` in `builder.as_dataset`.
        supervised_keys=None,  # e.g. ('image', 'label')
    )

  def _split_generators(self, dl_manager: tfds.download.DownloadManager):
    """Returns SplitGenerators."""
    train_path = os.path.join(dl_manager.manual_dir, '{NAME_DATASET}_train.tar')
    valid_path = os.path.join(dl_manager.manual_dir, '{NAME_DATASET}_validation.tar')

    return [
        tfds.core.SplitGenerator(
            name=tfds.Split.TRAIN,
            gen_kwargs={{
                'archive': dl_manager.iter_archive(train_path),
                're_archive': dl_manager.iter_archive(train_path)
            }},
        ),
        tfds.core.SplitGenerator(
            name=tfds.Split.VALIDATION,
            gen_kwargs={{
                'archive': dl_manager.iter_archive(valid_path),
                're_archive': dl_manager.iter_archive(valid_path)
            }},
        ),
    ]

  def _parse_xml(self, byte_string):
    root_elem = ElementTree.parse(byte_string)
    obj_elems = root_elem.iter(tag='object')

    annotation = []
    for obj_elem in obj_elems:
      obj_dict = dict()
      label = obj_elem.find('name').text
      bbox_elem = obj_elem.find('bndbox')
      bbox = [int(bbox_elem.find('ymin').text),
              int(bbox_elem.find('xmin').text),
              int(bbox_elem.find('ymax').text),
              int(bbox_elem.find('xmax').text)]
      obj_dict['label'] = label
      obj_dict['bbox_raw'] = bbox
      annotation.append(obj_dict)
    return annotation

  def _parse_json(self, byte_string):
    byte_string = byte_string.getvalue()
    _label_conversion_tbl = {{
        'c_1': 'paper',
        'c_2': 'paperpack',
        'c_3': 'can',
        'c_4': 'glass',
        'c_5': 'pet',
        'c_6': 'plastic',
        'c_7': 'vinyl',
    }}
    annot_dict = json.loads(byte_string)
    objects = annot_dict['object']
    annotation = []
    for obj in objects:
      obj_dict = dict()
      obj_dict['label'] = _label_conversion_tbl[obj['label']]
      box = obj['box']
      # (xmin, ymin, xmax, ymax) -> (ymin, xmin, ymax, xmax)
      obj_dict['bbox_raw'] = [box[1], box[0], box[3], box[2]]
      annotation.append(obj_dict)
    return annotation

  def _build_relative_bbox(self, absolute_bbox, shape):
    height  = shape[0]
    width = shape[1]
    relative_bbox = []
    relative_bbox.append(absolute_bbox[0] / height)
    relative_bbox.append(absolute_bbox[1] / width)
    relative_bbox.append(absolute_bbox[2] / height)
    relative_bbox.append(absolute_bbox[3] / width)
    return relative_bbox


  def _convert_png_to_jpeg(self, png_image_bytes):
    return io.BytesIO(tfds.core.utils.png_to_jpeg(png_image_bytes))

  def _generate_examples(self, archive, re_archive):
    """Yields examples."""
    error_list = []
    all_annotations = dict()

    # Search every annotation files.
    for fpath, fobj in archive:
      prefix, ext = os.path.splitext(fpath)
      image_id = prefix.split(os.path.sep)[-1]
      if ext == '.xml' or ext == '.json':
        try:
          fobj_mem = io.BytesIO(fobj.read())
          annotations = self._parse_xml(fobj_mem) if ext =='.xml' else self._parse_json(fobj_mem)
          all_annotations[image_id] = annotations
        except:
          error_list.append(image_id)
          logging.warning(
              f" Error occurs during parsing '{{fpath}}'. "
               "Skip this file."
          )

    for fpath, fobj in re_archive:
      prefix, ext = os.path.splitext(fpath)
      if ext == '.xml' or ext == '.json': continue
      image_id = prefix.split(os.path.sep)[-1]
      fobj_mem = io.BytesIO(fobj.read())
      if ext in ['.jpeg', '.jpg', '.JPEG', '.JPG']:
        if image_id in error_list: continue
        image = fobj_mem
      elif ext in ['png', 'PNG']:
        if image_id in error_list: continue
        logging.info(
            f" Convert '{{fpath}}' to PNG file."
        )
        image = self._convert_png_to_jpeg(fobj_mem)
      else:
        logging.warning((
            f' {{ext}} format is not expected.'
            f' {{fpath}} made a error.'))
        
      _, fname = os.path.split(fpath)
      try:
        annotations = all_annotations[image_id]
      except:
        logging.warning(
            f" Error occurs. There is no annotation file for '{{fpath}}'."
             "Skip this file."
        )
        continue
      for obj in annotations:
        try:  # There are images which raise error with 'extract_jpeg_shape'.
          shape = extract_jpeg_shape(image.getvalue()).numpy()
        except:
          image = decode_jpeg(image.getvalue())
          shape = image.shape
          image = io.BytesIO(encode_jpeg(image).numpy())

        bbox = self._build_relative_bbox(obj['bbox_raw'], shape)
        obj['bbox'] = tfds.features.BBox(*bbox)
        del obj['bbox_raw']
      record = {{
          'image': image,
          'objects': annotations,
          'image_id': image_id,
      }} 
      yield fname, record

'''

with open(f'{NAME_DATASET}/{NAME_DATASET}.py', 'w') as f:
  f.write(_code)
  
import importlib
_ = importlib.import_module(NAME_DATASET)


2020-11-16 06:06:15.014801: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
Dataset generated at /content/aigc1110v1
You can start searching `TODO(aigc1110v1)` to complete the implementation.
Please check https://www.tensorflow.org/datasets/add_dataset for additional details.


In [None]:
#@markdown # **Hyperparameters & Presets**
#@markdown * BATCH_SIZE_PER_REPLICA is recommend to being multiples of 64.
#@markdown * RAND_AUGMENT - [RandAugment: Practical automated data augmentation with a reduced search space](https://arxiv.org/pdf/1909.13719)
#@markdown * USE_BFLOAT16 - Wheter to use [mixed precision](https://www.tensorflow.org/guide/mixed_precision).
 
_input_sizes = {
  'B0': 224,
  'B1': 240,
  'B2': 260,
  'B3': 300,
  'B4': 380,
  'B5': 456,
  'B6': 528,
  'B7': 600,
}

#@markdown ## **Tensorboard & Checkpoints**
SHOW_TENSORBOARD = False #@param {type: 'boolean'}
WRITE_TB = True #@param {type: 'boolean'}
STORE_CKPT = True #@param {type: 'boolean'}
TOKEN = 'f1_loss_OR0.4_t3_v1110' #@param {type: 'string'}

#@markdown ## **Training Setting**
MODEL_CODE = 'B7' #@param ['B0', 'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7']
INPUT_SIZE = _input_sizes[MODEL_CODE]
RESIZE_SIZE = int(INPUT_SIZE * (1/0.875))
EPOCHS =  150#@param {type: 'integer'}
BATCH_SIZE_PER_REPLICA = 64 #@param {type: "slider", min:64, max:256, step:64}
FIX_INPUT_SIZE = 660 #@param {type: 'integer'}
FIX_RESIZE_SIZE = int(FIX_INPUT_SIZE * (1/0.875))
#@markdown ## **Loss & Optimizer**
LOSS_FUNC = 'f1_loss' #@param ["binary_crossentropy", "f1_loss", "categorical_crossentropy", "hybrid"]
OPTIMIZER = 'SGD' #@param ["SGD", "RMSprop", "Adam"]
LR_BASE_VALUE = 0.1 #@param {type: 'number'}
LR_DECAY_TYPE = 'cosine' #@param ["exponential", "cosine", "constant", "poly"]
LR_DECAY_EPOCHS = 10 #@param {type: 'number'}
LR_WARMUP_EPOCHS = 5 #@param {type: 'number'}
LABEL_SMOOTHING = 0.1 #@param {type: 'number'}
#@markdown ## **Augmentation**
RAND_AUGMENT = True #@param {type: 'boolean'}
RA_NUM_LAYERS = 2 #@param {type: "slider", min:1, max:3, step:1}
RA_MAGNITUDE = 13 #@param {type: "slider", min:5, max:30, step:1}
USE_BFLOAT16 = True #@param {type: 'boolean'}
#@markdown ## **VM & GCS Setting**
BUCKET = 'gs://iris-us' #@param ["gs://iris-us"]
GCP_VM = False #@param {type: 'boolean'}
GCP_TPU_NAME = '' #@param {type: 'string'}
#@markdown ## **Finetuning Options**
OPEN_RATIO =  0.4 #@param {type: 'slider', min: 0, max: 1.0, step: 0.1}

SEED = 8047
EPOCH = 1

import os
import json
import requests

# Check current region.
_regions = {
    'us': 'America',
    'eu': 'Europe',
}
res = requests.get('http://ipinfo.io')
info = json.loads(res.text)
region = info['timezone'].split('/')[0]
if region != _regions[BUCKET.split('-')[-1]]:
  raise Exception('[WARNING] Region setting is wrong.')

# Set GCS TFDS path.
DATA_DIR = os.path.join(BUCKET, 'tfds_datasets')

# Checkpoint path.
CKPT_PATH = os.path.join(BUCKET, 'ai', 'checkpoints', TOKEN)

# Set TPU name. (None means auto-detection)
TPU_NAME = None
if GCP_VM:
  TPU_NAME = GCP_TPU_NAME

OBJECT_LABELS = [
    'paper',
    'paperpack',
    'can',      
    'glass',    
    'pet',      
    'plastic',  
    'vinyl',    
]


In [None]:
#@markdown  # Import packages

import sys
import time
import inspect
import functools

import pickle
import json

from IPython.display import display
from PIL import Image
import matplotlib.pyplot as plt

import math
import numpy as np

from tqdm import tqdm

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_addons as tfa
from tensorflow.keras.applications.imagenet_utils import preprocess_input
from tensorflow.keras.mixed_precision import experimental as mixed_precision
import efficientnet.tfkeras as efn

from sklearn.metrics import f1_score

if not GCP_VM:
  import tensorflow_gcs_config as tgc

if USE_BFLOAT16:
  policy = mixed_precision.Policy('mixed_bfloat16')
  mixed_precision.set_policy(policy)



## TPU & GCS

In [None]:
def initialize():
  if not GCP_VM:
    key_path = 'my_key.json'

    key_data =  {
      # GCS Auth key data.
    }
    
    with open(key_path, mode='w') as f:
      json.dump(key_data, f)
  
    os.environ['GOOGLE_APPLICATION_CREDENTIALS']=key_path

  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(TPU_NAME)
  tf.config.experimental_connect_to_cluster(resolver)
  tf.tpu.experimental.initialize_tpu_system(resolver)
  strategy = tf.distribute.TPUStrategy(resolver)
  if not GCP_VM:
    tgc.configure_gcs_from_colab_auth()

  return strategy

## Initialize systems.

In [None]:
strategy = initialize()
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
INITIAL_LR = LR_BASE_VALUE * (BATCH_SIZE / 256)

INFO:tensorflow:Initializing the TPU system: grpc://10.77.58.122:8470


INFO:tensorflow:Initializing the TPU system: grpc://10.77.58.122:8470


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Found TPU system:


INFO:tensorflow:Found TPU system:


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


# **Functions**

## Data Processing

### RandAugment

In [None]:
def autocontrast(image):
  """Implements Autocontrast function from PIL using TF ops.
  Args:
    image: A 3D uint8 tensor.
  Returns:
    The image after it has had autocontrast applied to it and will be of type
    uint8.
  """

  def scale_channel(image):
    """Scale the 2D image using the autocontrast rule."""
    # A possibly cheaper version can be done using cumsum/unique_with_counts
    # over the histogram values, rather than iterating over the entire image.
    # to compute mins and maxes.
    lo = tf.cast(tf.reduce_min(image), 'float32')
    hi = tf.cast(tf.reduce_max(image), 'float32')

    # Scale the image, making the lowest value 0 and the highest value 255.
    def scale_values(im):
      scale = 255.0 / (hi - lo)
      offset = -lo * scale
      im = tf.cast(im, 'float32') * scale + offset
      im = tf.clip_by_value(im, 0.0, 255.0)
      return tf.cast(im, tf.uint8)

    result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image)
    return result

  # Assumes RGB for now.  Scales each channel independently
  # and then stacks the result.
  s1 = scale_channel(image[:, :, 0])
  s2 = scale_channel(image[:, :, 1])
  s3 = scale_channel(image[:, :, 2])
  image = tf.stack([s1, s2, s3], 2)
  return image

def equalize(image):
  return tfa.image.equalize(image)

def invert(image):
  return 255 - image

def wrap(image):
  return tfa.image.utils.wrap(image)

def unwrap(image, replace):
  return tfa.image.utils.unwrap(image, replace)

def rotate(image, degrees, replace):
  """Rotates the image by degrees either clockwise or counterclockwise.
  Args:
    image: An image Tensor of type uint8.
    degrees: Float, a scalar angle in degrees to rotate all images by. If
      degrees is positive the image will be rotated clockwise otherwise it will
      be rotated counterclockwise.
    replace: A one or three value 1D tensor to fill empty pixels caused by
      the rotate operation.
  Returns:
    The rotated version of image.
  """
  # Convert from degrees to radians.
  degrees_to_radians = math.pi / 180.0
  radians = degrees * degrees_to_radians

  # In practice, we should randomize the rotation degrees by flipping
  # it negatively half the time, but that's done on 'degrees' outside
  # of the function.
  image = tfa.image.rotate(wrap(image), radians)
  return unwrap(image, replace)

def blend(image1, image2, factor):
  return tf.cast(tfa.image.blend(image1, image2, factor), 'uint8')

def posterize(image, bits):
  """Equivalent of PIL Posterize."""
  shift = 8 - bits
  return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift)

def solarize(image, threshold=128):
  """ For each pixel in the image, select the pixel
      if the value is less than the threshold.
      Otherwise, subtract 255 from the pixel.
  """
  return tf.where(image < threshold, image, 255 - image)

def color(image, factor):
  """Equivalent of PIL Color."""
  degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image))
  return blend(degenerate, image, factor)

def contrast(image, factor):
  """Equivalent of PIL Contrast."""
  degenerate = tf.image.rgb_to_grayscale(image)
  # Cast before calling tf.histogram.
  degenerate = tf.cast(degenerate, tf.int32)

  # Compute the grayscale histogram, then compute the mean pixel value,
  # and create a constant image size of that value.  Use that as the
  # blending degenerate target of the original image.
  hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256)
  mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0
  degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean
  degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
  degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8))
  return blend(degenerate, image, factor)

def brightness(image, factor):
  """Equivalent of PIL Brightness."""
  degenerate = tf.zeros_like(image)
  return blend(degenerate, image, factor)

def sharpness(image, factor):
  return tfa.image.sharpness(image, factor)

def shear_x(image, level, replace):
  return tfa.image.shear_x(image, level, replace)

def shear_y(image, level, replace):
  return tfa.image.shear_y(image, level, replace)

def translate_x(image, pixels, replace):
  """Equivalent of PIL Translate in X dimension."""
  image = tfa.image.translate(wrap(image), [-pixels, 0])
  return unwrap(image, replace)


def translate_y(image, pixels, replace):
  """Equivalent of PIL Translate in Y dimension."""
  image = tfa.image.translate(wrap(image), [0, -pixels])
  return unwrap(image, replace)

def cutout(image, pad_size, replace=0):
  """Apply cutout (https://arxiv.org/abs/1708.04552) to image.
  This operation applies a (2*pad_size x 2*pad_size) mask of zeros to
  a random location within `img`. The pixel values filled in will be of the
  value `replace`. The located where the mask will be applied is randomly
  chosen uniformly over the whole image.
  Args:
    image: An image Tensor of type uint8.
    pad_size: Specifies how big the zero mask that will be generated is that
      is applied to the image. The mask will be of size
      (2*pad_size x 2*pad_size).
    replace: What pixel value to fill in the image in the area that has
      the cutout mask applied to it.
  Returns:
    An image Tensor that is of type uint8.
  """
  image_height = tf.shape(image)[0]
  image_width = tf.shape(image)[1]

  # Sample the center location in the image where the zero mask will be applied.
  cutout_center_height = tf.random.uniform(
      shape=[], minval=0, maxval=image_height,
      dtype=tf.int32)

  cutout_center_width = tf.random.uniform(
      shape=[], minval=0, maxval=image_width,
      dtype=tf.int32)

  lower_pad = tf.maximum(0, cutout_center_height - pad_size)
  upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size)
  left_pad = tf.maximum(0, cutout_center_width - pad_size)
  right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size)

  cutout_shape = [image_height - (lower_pad + upper_pad),
                  image_width - (left_pad + right_pad)]
  padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
  mask = tf.pad(
      tf.zeros(cutout_shape, dtype=image.dtype),
      padding_dims, constant_values=1)
  mask = tf.expand_dims(mask, -1)
  mask = tf.tile(mask, [1, 1, 3])
  image = tf.where(
      tf.equal(mask, 0),
      tf.ones_like(image, dtype=image.dtype) * replace,
      image)
  return image

def solarize_add(image, addition=0, threshold=128):
  # For each pixel in the image less than threshold
  # we add 'addition' amount to it and then clip the
  # pixel value to be between 0 and 255. The value
  # of 'addition' is between -128 and 128.
  added_image = tf.cast(image, tf.int64) + addition
  added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8)
  return tf.where(image < threshold, added_image, image)

In [None]:
_MAX_LEVEL = 10.

NAME_TO_FUNC = {
    'AutoContrast': autocontrast,
    'Equalize': equalize,
    'Invert': invert,
    'Rotate': rotate,
    'Posterize': posterize,
    'Solarize': solarize,
    'SolarizeAdd': solarize_add,
    'Color': color,
    'Contrast': contrast,
    'Brightness': brightness,
    'Sharpness': sharpness,
    'ShearX': shear_x,
    'ShearY': shear_y,
    'TranslateX': translate_x,
    'TranslateY': translate_y,
    'Cutout': cutout,
}

def _randomly_negate_tensor(tensor):
  """With 50% prob turn the tensor negative."""
  should_flip = tf.cast(tf.floor(tf.random.uniform([]) + 0.5), tf.bool)
  final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor)
  return final_tensor

def _translate_level_to_arg(level, translate_const):
  level = (level/_MAX_LEVEL) * float(translate_const)
  # Flip level to negative with 50% chance.
  level = _randomly_negate_tensor(level)
  return (level,)

def _rotate_level_to_arg(level):
  level = (level/_MAX_LEVEL) * 30.
  level = _randomly_negate_tensor(level)
  return (level,)

def _enhance_level_to_arg(level):
  return ((level/_MAX_LEVEL) * 1.8 + 0.1,)

def _shear_level_to_arg(level):
  level = (level/_MAX_LEVEL) * 0.3
  # Flip level to negative with 50% chance.
  level = _randomly_negate_tensor(level)
  return (level,)

def level_to_arg(hparams):
  return {
      'AutoContrast': lambda level: (),
      'Equalize': lambda level: (),
      'Invert': lambda level: (),
      'Rotate': _rotate_level_to_arg,
      'Posterize': lambda level: (int((level/_MAX_LEVEL) * 4),),
      'Solarize': lambda level: (int((level/_MAX_LEVEL) * 256),),
      'SolarizeAdd': lambda level: (int((level/_MAX_LEVEL) * 110),),
      'Color': _enhance_level_to_arg,
      'Contrast': _enhance_level_to_arg,
      'Brightness': _enhance_level_to_arg,
      'Sharpness': _enhance_level_to_arg,
      'ShearX': _shear_level_to_arg,
      'ShearY': _shear_level_to_arg,
      'Cutout': lambda level: (int((level/_MAX_LEVEL) * hparams['cutout_const'],),),
      'TranslateX': lambda level: _translate_level_to_arg(
          level, hparams['translate_const']),
      'TranslateY': lambda level: _translate_level_to_arg(
          level, hparams['translate_const']),
  }


def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams):
  """Return the function that corresponds to `name` and update `level` param."""
  func = NAME_TO_FUNC[name]
  args = level_to_arg(augmentation_hparams)[name](level)

  # Check to see if prob is passed into function. This is used for operations
  # where we alter bboxes independently.
  if 'prob' in inspect.getfullargspec(func)[0]:
    args = tuple([prob] + list(args))

  # Add in replace arg if it is required for the function that is being called.
  # pytype:disable=wrong-arg-types
  if 'replace' in inspect.getfullargspec(func)[0]:
    # Make sure replace is the final argument
    assert 'replace' == inspect.getfullargspec(func)[0][-1]
    args = tuple(list(args) + [replace_value])
  return (func, prob, args)


def distort_image_with_randaugment(image, num_layers, magnitude):
  """Applies the RandAugment policy to `image`.
  RandAugment is from the paper https://arxiv.org/abs/1909.13719,
  Args:
    image: `Tensor` of shape [height, width, 3] representing an image.
    num_layers: Integer, the number of augmentation transformations to apply
      sequentially to an image. Represented as (N) in the paper. Usually best
      values will be in the range [1, 3].
    magnitude: Integer, shared magnitude across all augmentation operations.
      Represented as (M) in the paper. Usually best values are in the range
      [5, 30].
  Returns:
    The augmented version of `image`.
  """
  replace_value = [128] * 3

  augmentation_hparams = dict()
  augmentation_hparams['cutout_const'] = 40
  augmentation_hparams['translate_const'] = 100
  
  available_ops = [
      'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize',
      'Solarize', 'Color', 'Contrast', 'Brightness', 'Sharpness',
      'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd']

  for layer_num in range(num_layers):
    op_to_select = tf.random.uniform(
        [], maxval=len(available_ops), dtype=tf.int32)
    random_magnitude = float(magnitude)
    with tf.name_scope('randaug_layer_{}'.format(layer_num)):
      for (i, op_name) in enumerate(available_ops):
        prob = tf.random.uniform([], minval=0.2, maxval=0.8, dtype=tf.float32)
        func, _, args = _parse_policy_info(op_name, prob, random_magnitude,
                                           replace_value, augmentation_hparams)
        image = tf.cond(
            tf.equal(i, op_to_select),
            lambda selected_func=func, selected_args=args: selected_func(
                image, *selected_args),
            lambda: image)
  return image

### Preprocessing functions for JPEG: Image Process

In [None]:
CROP_PADDING = RESIZE_SIZE - INPUT_SIZE

def distorted_bounding_box_crop(image_bytes,
                                bbox,
                                min_object_covered=0.1,
                                aspect_ratio_range=(0.75, 1.33),
                                area_range=(0.05, 1.0),
                                max_attempts=100,
                                scope=None):
  """Generates cropped_image using one of the bboxes randomly distorted.
  See `tf.image.sample_distorted_bounding_box` for more documentation.
  Args:
    image_bytes: `Tensor` of binary image data.
    bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
        where each coordinate is [0, 1) and the coordinates are arranged
        as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
        image.
    min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
        area of the image must contain at least this fraction of any bounding
        box supplied.
    aspect_ratio_range: An optional list of `float`s. The cropped area of the
        image must have an aspect ratio = width / height within this range.
    area_range: An optional list of `float`s. The cropped area of the image
        must contain a fraction of the supplied image within in this range.
    max_attempts: An optional `int`. Number of attempts at generating a cropped
        region of the image of the specified constraints. After `max_attempts`
        failures, return the entire image.
    scope: Optional `str` for name scope.
  Returns:
    cropped image `Tensor`
  """
  with tf.name_scope('distorted_bounding_box_crop'):
    shape = tf.image.extract_jpeg_shape(image_bytes)
    sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
        shape,
        bounding_boxes=bbox,
        min_object_covered=min_object_covered,
        aspect_ratio_range=aspect_ratio_range,
        area_range=area_range,
        max_attempts=max_attempts,
        use_image_if_no_bounding_boxes=True)
    bbox_begin, bbox_size, _ = sample_distorted_bounding_box

    # Crop the image to the specified bounding box.
    offset_y, offset_x, _ = tf.unstack(bbox_begin)
    target_height, target_width, _ = tf.unstack(bbox_size)
    crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
    image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)

    return image

def _at_least_x_are_equal(a, b, x):
  """At least `x` of `a` and `b` `Tensors` are equal."""
  match = tf.equal(a, b)
  match = tf.cast(match, tf.int32)
  return tf.greater_equal(tf.reduce_sum(match), x)


def _resize_image(image, image_size, method=None):
  if method is None:
    return tf.image.resize(image, [image_size, image_size], 'bicubic')
  return tf.image.resize(image, [image_size, image_size], method)


def _decode_and_center_crop(image_bytes, image_size, resize_method=None):
  """Crops to center of image with padding then scales image_size."""
  shape = tf.image.extract_jpeg_shape(image_bytes)
  image_height = shape[0]
  image_width = shape[1]

  padded_center_crop_size = tf.cast(
      ((image_size / (image_size + CROP_PADDING)) *
       tf.cast(tf.minimum(image_height, image_width), tf.float32)),
      tf.int32)

  offset_height = ((image_height - padded_center_crop_size) + 1) // 2
  offset_width = ((image_width - padded_center_crop_size) + 1) // 2
  crop_window = tf.stack([offset_height, offset_width,
                          padded_center_crop_size, padded_center_crop_size])
  image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
  image = _resize_image(image, image_size, resize_method)
  return image


def _decode_and_random_crop(image_bytes, image_size, resize_method=None):
  """Make a random crop of image_size."""
  bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
  image = distorted_bounding_box_crop(
      image_bytes,
      bbox,
      min_object_covered=0.1,
      aspect_ratio_range=(3. / 4, 4. / 3.),
      area_range=(0.08, 1.0),
      max_attempts=10,
      scope=None)
  original_shape = tf.io.extract_jpeg_shape(image_bytes)
  bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)

  image = tf.cond(
      bad,
      lambda: _decode_and_center_crop(image_bytes, image_size),
      lambda: _resize_image(image, image_size, resize_method))

  return image


def _flip(image):
  """Random horizontal image flip."""
  image = tf.image.random_flip_left_right(image)
  return image

### Preprocessing function for JPEG: Utils

In [None]:
def preprocess_for_train(image_bytes,
                         use_bfloat16=USE_BFLOAT16,
                         image_size=INPUT_SIZE,
                         randaug_num_layers=None,
                         randaug_magnitude=None,
                         resize_method=None):
  image = _decode_and_random_crop(image_bytes, image_size, resize_method)
  image = _flip(image)
  image = tf.reshape(image, [image_size, image_size, 3])

  input_image_type = image.dtype
  image = tf.clip_by_value(image, 0.0, 255.0)
  image = tf.cast(image, dtype=tf.uint8)
  image = distort_image_with_randaugment(image, randaug_num_layers,
                                         randaug_magnitude)
  image = tf.cast(image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
  return image


def preprocess_for_eval(image_bytes,
                        use_bfloat16=USE_BFLOAT16,
                        image_size=INPUT_SIZE,
                        resize_method=None):
  image = _decode_and_center_crop(image_bytes, image_size, resize_method)
  image = tf.reshape(image, [image_size, image_size, 3])
  image = tf.cast(image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
  return image


def preprocess_image(image_bytes,
                     is_training=False,
                     use_bfloat16=USE_BFLOAT16,
                     image_size=INPUT_SIZE,
                     randaug_num_layers=None,
                     randaug_magnitude=None,
                     resize_method=None):
  """Preprocesses the given image.
  Args:
    image_bytes: `Tensor` representing an image binary of arbitrary size.
    is_training: `bool` for whether the preprocessing is for training.
    use_bfloat16: `bool` for whether to use bfloat16.
    image_size: image size.
    randaug_num_layers: 'int', if RandAug is used, what should the number of
      layers be. See autoaugment.py for detailed description.
    randaug_magnitude: 'int', if RandAug is used, what should the magnitude
      be. See autoaugment.py for detailed description.
    resize_method: 'string' or None. Use resize_bicubic in default.
  Returns:
    A preprocessed image `Tensor` with value range of [0, 255].
  """
  if is_training:
    return preprocess_for_train(
        image_bytes, use_bfloat16, image_size,
        randaug_num_layers, randaug_magnitude, resize_method)
  else:
    return preprocess_for_eval(image_bytes, use_bfloat16, image_size,
                               resize_method)

### Normal Augmentation.

In [None]:
def preprocess_normal(x, is_train, use_bfloat16, image_size):
  def normal_augmentation(x):
    x = tf.image.random_flip_left_right(x, seed=SEED)
    x = tf.image.random_flip_up_down(x, seed=SEED)
    x = tf.image.random_brightness(x, 0.4, seed=SEED)
    x = tf.image.random_contrast(x, 0.6, 1.4, seed=SEED)
    x = tf.image.random_saturation(x, 0.6, 1.4, seed=SEED)
    return x

  dtype = 'bfloat16' if use_bfloat16 else 'float32'
  shape = tf.shape(x)
  h = shape[0]
  w = shape[1]
  resized_size = int(image_size * (1/0.875))
  target = [(h//w)*resized_size, resized_size] if h > w else [resized_size, (w//h)*resized_size]

  x = tf.image.resize(x, target, antialias=True)
  x = tf.image.random_crop(x, [image_size, image_size, 3]) if is_train else tf.image.resize_with_crop_or_pad(x, image_size, image_size)
  x = normal_augmentation(x) if is_train else x
  x = tf.cast(x, dtype)
  return x


### Loader function using TFDS

In [None]:
from tensorflow.keras.applications.imagenet_utils import preprocess_input

def labels_to_multihot(labels):
  return tf.minimum(tf.reduce_sum(tf.one_hot(labels, depth=7), axis=0), 1.0)

def apply_as_supervised(combined_data):
  image = combined_data['image']
  labels = combined_data['objects']['label']
  multi_hot = labels_to_multihot(labels)
  return image, multi_hot

def decode_multihot(multihot):
  if isinstance(multihot, tf.Tensor):
    multihot = multihot.numpy()
  labels = np.where(multihot != 0.0)[0]
  text_labels = [OBJECT_LABELS[label] for label in labels]
  return ' '.join(text_labels)

def load_data(data_dir,
              split,
              is_fix=False,
              crop_ratio=None,
              is_normalize=True,
              strategy=None):
  
  global CROP_PADDING
  is_train = True if split == 'train' else False

  if is_fix:
    input_size = FIX_INPUT_SIZE
    CROP_PADDING = FIX_RESIZE_SIZE - FIX_INPUT_SIZE
  else:
    input_size = INPUT_SIZE
    CROP_PADDING = RESIZE_SIZE - INPUT_SIZE

  if RAND_AUGMENT:
    preprocess = lambda x, y: (preprocess_image(x, is_training=is_train,
                                                use_bfloat16=USE_BFLOAT16,
                                                image_size=input_size,
                                                randaug_num_layers=RA_NUM_LAYERS,
                                                randaug_magnitude=RA_MAGNITUDE), y)
  else:
    preprocess = lambda x, y: (preprocess_normal(x, is_train=is_train,
                                                 use_bfloat16=USE_BFLOAT16,
                                                 image_size=input_size), y)
    
  normalize = lambda x, y: (preprocess_input(x, mode='torch'), y)
  central_crop = lambda x, y : (tf.image.central_crop(x, crop_ratio), y)

  data = (tfds.load(NAME_DATASET, data_dir=data_dir,
                    decoders={'image': tfds.decode.SkipDecoding()},
                    split=split) if RAND_AUGMENT else
          tfds.load(NAME_DATASET, data_dir=data_dir, split=split))
  data = data.shuffle(10000) if is_train else data
  data = data.map(apply_as_supervised, -1)
  data = data.map(preprocess, -1)
  data = data.batch(BATCH_SIZE, drop_remainder=True) if is_train else data.batch(BATCH_SIZE)
  data = data.map(normalize, -1) if is_normalize else data
  data = data.map(central_crop, -1) if crop_ratio else data
  data = data.prefetch(-1)
  return data

## Visualization

In [None]:
def show_image(value, name=None):
  if value.dtype == tf.bfloat16:
    value = tf.cast(value, 'float32')
    if tf.reduce_max(value) > 2:
      value /= 255.

  if isinstance(value, tf.Tensor):
    value = value.numpy()
  
  if len(value.shape) == 4:
    value = value[0]

  plt.figure(figsize=(13,13))
  plt.title(name)
  plt.imshow(value)

## Tensorboard

In [None]:
def write_hist_weight(writer, model, epoch):
  with writer.as_default():
    for weight in model.weights:
      tf.summary.histogram(weight.name, weight, step=epoch)
  
def write_scalar(writer, value, epoch, name):
  with writer.as_default():
    tf.summary.scalar(name, value, epoch)

# **Main**

## Create file writers for tensorboard.

In [None]:
# Create a file writer for tensorboard.
if WRITE_TB:
  train_writer = tf.summary.create_file_writer(
    os.path.join(BUCKET, 'ai', 'tensorboard', TOKEN, 'train')
  )
  valid_writer = tf.summary.create_file_writer(
    os.path.join(BUCKET, 'ai', 'tensorboard', TOKEN, 'validation')
  )
  test_writer = tf.summary.create_file_writer(
      os.path.join(BUCKET, 'ai', 'tensorboard', TOKEN, 'test')
  )

## Load a dataset.

In [None]:
train_data = load_data(DATA_DIR,
                       split='train',
                       strategy=strategy)
valid_data = load_data(DATA_DIR,
                       split='validation',
                       strategy=strategy)
test_data = load_data(DATA_DIR,
                      split='test',
                      strategy=strategy)


Instructions for updating:
`seed2` arg is deprecated.Use sample_distorted_bounding_box_v2 instead.


Instructions for updating:
`seed2` arg is deprecated.Use sample_distorted_bounding_box_v2 instead.


## Build a model and define functions for training.

### Functions for Distributed Training

In [None]:
def _distribute_tensors(ctx, tensor):
  idx = ctx.replica_id_in_sync_group
  batch_size = tensor.shape[0]
  k = int(math.ceil(batch_size / 8.0))
  return tensor[idx*k:(idx+1)*k]

def make_tensors_per_replica(tensors, strategy):
  func = strategy.experimental_distribute_values_from_function
  return func(lambda ctx: _distribute_tensors(ctx, tensors))

### Loss functions


In [None]:
def f1_loss_with_logits(labels, predictions, from_logits=False):
  if from_logits:
    predictions = tf.nn.sigmoid(predictions)
  epsilon = 1e-7

  tp = tf.reduce_sum(labels*predictions, axis=-1)
  tn = tf.reduce_sum((1-labels)*(1-predictions), axis=-1)
  fp = tf.reduce_sum((1-labels)*predictions, axis=-1)
  fn = tf.reduce_sum(labels*(1-predictions), axis=-1)
  precision = tp / (tp + fp + epsilon)
  recall = tp / (tp + fn + epsilon)

  f1 = 2*precision*recall / (precision + recall + epsilon)
  # f1 = tf.where(tf.math.is_nan(f1), tf.zeros_like(f1), f1)
  return 1 - f1

def hybrid_loss(labels, predictions, from_logits=False):
  loss1 = f1_loss_with_logits(labels, predictions, from_logits=from_logits)
  loss2 = tf.keras.losses.binary_crossentropy(labels, predictions, from_logits=from_logits,
                                              label_smoothing=LABEL_SMOOTHING)
  return loss1*2.0 + loss2

_loss_functions = {
    'f1_loss': functools.partial(f1_loss_with_logits, from_logits=True),
    'categorical_crossentropy': tf.keras.losses.CategoricalCrossentropy(True, LABEL_SMOOTHING, tf.keras.losses.Reduction.NONE),
    'binary_crossentropy': tf.keras.losses.BinaryCrossentropy(True, LABEL_SMOOTHING, tf.keras.losses.Reduction.NONE),
    'hybrid': functools.partial(hybrid_loss, from_logits=True)
}

with strategy.scope():
  loss_function = _loss_functions[LOSS_FUNC]

  def compute_loss(labels, predictions):
    per_example_loss = loss_function(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss,
                                      global_batch_size=BATCH_SIZE)

### Metrics

In [None]:
with strategy.scope():
  train_accuracy = tfa.metrics.F1Score(7, average='macro', threshold=0.5, name='train_f1_score')

### Model and Optimizer

#### Define a Model

In [None]:
def MyModel(name):
  efficientnet_builder = getattr(efn, f'EfficientNet{MODEL_CODE}')
  pretrained_model = efficientnet_builder(weights='noisy-student', include_top=False)
  num_layers = len(pretrained_model.layers)
  num_trainable_layers = int(num_layers * OPEN_RATIO)
  for layer in pretrained_model.layers:
      layer.trainable = False
  for layer in pretrained_model.layers[-num_trainable_layers:]:
    layer.trainable = True
  
  avg_pool = tf.keras.layers.GlobalAveragePooling2D(name='avg_pool')
  top_dropout = tf.keras.layers.Dropout(0.5, name='top_dropout')
  logits = tf.keras.layers.Dense(7, activation='linear', dtype='float32', name='logits')

  # Connect layers.
  inputs = tf.keras.Input(shape=(INPUT_SIZE, INPUT_SIZE, 3))
  x = pretrained_model(inputs)
  x = avg_pool(x)
  x = top_dropout(x)
  outputs = logits(x)
  model = tf.keras.Model(inputs=inputs, outputs=outputs, name=name)

  return model

#### Learning Rate Schedule

In [None]:
def exponential_decay(initial_lr, global_step, decay_steps, decay_factor):
  p = tf.cast(global_step, 'float32') / tf.cast(decay_steps, 'float32')
  p = tf.floor(p)
  lr = tf.multiply(initial_lr, tf.pow(decay_factor, p))
  return lr

def cosine_decay(initial_lr, global_step, total_steps):
  lr = 0.5 * initial_lr * (1.0 + tf.cos(math.pi * tf.cast(global_step, 'float32') / total_steps))
  return lr

def poly_decay(initial_lr, global_step, decay_steps, warmup_steps, end_lr=0.1, power=2.0):
  p = tf.divide(tf.cast(global_step, 'float32'), decay_steps)
  lr = tf.add(tf.multiply(initial_lr - end_lr, tf.pow(1 - p, power)), end_lr)
  return lr

class LRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(
      self,
      initial_lr,
      steps_per_epoch,
      total_steps=None,
      lr_decay_type='exponential',
      decay_factor=0.97,
      decay_epochs=2.4,
      warmup_epochs=5):
    super(LRSchedule, self).__init__()
    self.initial_lr = initial_lr
    self.total_steps = total_steps
    self.lr_decay_type = lr_decay_type
    self.decay_factor = decay_factor
    self.decay_steps = decay_epochs * steps_per_epoch
    self.warmup_steps = warmup_epochs * steps_per_epoch
    self.prev_lr = tf.cast(initial_lr, 'float32')

  def __call__(self, global_step):
    with tf.name_scope(self.lr_decay_type.capitalize() + 'Decay') as name:
      initial_lr = tf.cast(self.initial_lr, 'float32')
      if self.lr_decay_type == 'exponential':
        lr = exponential_decay(initial_lr, global_step, self.decay_steps, self.decay_factor)
      elif self.lr_decay_type == 'cosine':
        lr = cosine_decay(initial_lr, global_step, self.total_steps)
      elif self.lr_decay_type == 'constant':
        lr = initial_lr
      elif self.lr_decay_type == 'poly':
        min_step = tf.constant(1, dtype='int64')
        decay_steps = tf.maximum(min_step, tf.subtract(global_step, self.warmup_steps))
        lr = poly_decay(initial_lr, global_step, decay_steps, self.warmup_steps)
      else:
        assert False, f'Unknown lr_decay_type: {self.lr_decay_type}'

      if self.warmup_steps:
        warmup_lr = initial_lr * tf.cast(global_step, 'float32') / tf.cast(self.warmup_steps, 'float32')
        lr = tf.cond(global_step < self.warmup_steps, lambda: warmup_lr, lambda: lr)
      return lr

#### Build a model in scope of strategy.

In [None]:
with strategy.scope():
  model = MyModel(name=f'EfficientNet{MODEL_CODE}')
  lr_schedule = LRSchedule(INITIAL_LR, len(train_data), EPOCHS*len(train_data),
                           LR_DECAY_TYPE, decay_epochs=LR_DECAY_EPOCHS,
                           warmup_epochs=LR_WARMUP_EPOCHS)
  if OPTIMIZER == 'SGD':
    optimizer = tf.keras.optimizers.SGD(lr_schedule, momentum=0.9, nesterov=True)
  elif OPTIMIZER == 'RMSprop':
    optimizer = tf.keras.optimizers.RMSprop(lr_schedule, rho=0.9, momentum=0.9, centered=True)
  elif OPTIMIZER == 'Adam':
    optimizer = tf.keras.optimizers.Adam(lr_schedule)
  else:
    assert False, f'Unimplemented optimizer: {OPTIMIZER}'

Downloading data from https://github.com/qubvel/efficientnet/releases/download/v0.0.1/efficientnet-b7_noisy-student_notop.h5


### Step functions

In [None]:
def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, tf.nn.sigmoid(predictions))
  return loss


def test_step(images):
  predictions = model(images, training=False)
  return predictions

In [None]:
@tf.function
def distributed_train_step(inputs_per):
  per_replica_losses = strategy.run(train_step, args=(inputs_per,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

@tf.function
def distributed_test_step(images_per):
  return strategy.run(test_step, args=(images_per,))

### Train function & Evaluation function

In [None]:
BEST_TEST_SCORE = 0
BEST_TEST_CLASSWISE_SCORE = 0

def train(train_data,
          valid_data,
          test_data,
          epochs,
          strategy=strategy,
          store_ckpt=False,
          verbose=0):
  global EPOCH
  global BEST_TEST_SCORE
  global BEST_TEST_CLASSWISE_SCORE
  
  def train_epoch(train_data):
    total_loss = 0.0
    num_batches = 0

    for x, y in train_data:
      num_batches += 1
      x_per = make_tensors_per_replica(x, strategy)
      y_per = make_tensors_per_replica(y, strategy)
      total_loss += distributed_train_step((x_per, y_per))
      train_loss = total_loss / num_batches
      if verbose:
        train_data.set_description(f'Epoch {EPOCH:3}/{epochs:3}')
        train_data.set_postfix(
            loss=train_loss.numpy(),
            f1_score=train_accuracy.result().numpy(),
        )
    CURRENT_LR = optimizer._decayed_lr('float32').numpy()
    return train_loss

  while EPOCH <= epochs:
    if WRITE_TB:
      write_scalar(train_writer, optimizer._decayed_lr('float32'), EPOCH-1, 'learning rate')

    # Train
    begin_time = time.perf_counter()
    if verbose:
      with tqdm(train_data, total=len(train_data), file=sys.stdout) as tqdm_train_data:
        train_loss = train_epoch(tqdm_train_data)
    else:
      print(f'Epoch {EPOCH:3}/{epochs:3}', end='')
      train_loss = train_epoch(train_data)

    # Validation ===============================================================
    # Explore given thresholds to find a best f1-score and threshold.
    nums = 200
    thresholds = np.arange(1, 1+nums) / nums * 0.5

    y_trues, y_logits = get_prediction(valid_data)
    y_probs = tf.nn.sigmoid(y_logits)
    y_trues = y_trues.numpy()
    y_probs = y_probs.numpy()

    valid_threshold = 0
    valid_score = 0
    for threshold in thresholds:
      y_preds = np.where(y_probs >= threshold, 1.0, 0.0)
      score = 0
      for i in range(7):
        score += f1_score(y_trues[:,i], y_preds[:,i])
      score /= 7.
      if valid_score < score:
        valid_threshold = threshold
        valid_score = score

    ## classwise
    valid_classwise_thresholds = []
    valid_classwise_scores = []
    for i in range(7):
      y_prob = y_probs[:, i]
      y_true = y_trues[:, i]

      best_classwise_threshold = 0
      best_classwise_score = 0
      for threshold in thresholds:
        y_pred = np.where(y_prob >= threshold, 1.0, 0.0)
        classwise_score = f1_score(y_true, y_pred)
        if best_classwise_score < classwise_score:
          best_classwise_threshold = threshold
          best_classwise_score = classwise_score
      valid_classwise_thresholds.append(best_classwise_threshold)
      valid_classwise_scores.append(best_classwise_score)
    valid_classwise_score = np.mean(valid_classwise_scores)
    # ==========================================================================
    

    # Test =====================================================================
    
    y_trues, y_logits = get_prediction(test_data)
    y_probs = tf.nn.sigmoid(y_logits)
    y_trues = y_trues.numpy()
    y_probs = y_probs.numpy()
    y_preds = np.where(y_probs >= valid_threshold, 1.0, 0.0)
    test_score = 0
    for i in range(7):
      test_score += f1_score(y_trues[:, i], y_preds[:, i])
    test_score /= 7.

    ## classwise
    test_classwise_scores = []
    for i in range(7):
      y_prob = y_probs[:, i]
      y_true = y_trues[:, i]
      y_pred = np.where(y_prob >= valid_classwise_thresholds[i], 1.0, 0.0)
      test_classwise_scores.append(f1_score(y_true, y_pred))
    test_classwise_score = np.mean(test_classwise_scores)

    # 0: no save
    # 1: saved for universal
    # 2: saved for classwise
    # 3: saved for both
    save_flag = 0
    ckpt_paths = []
    if BEST_TEST_SCORE < test_score and store_ckpt:
      BEST_TEST_SCORE = test_score
      save_flag += 1
      ckpt_paths.append(os.path.join(CKPT_PATH, 'universal'))
    if BEST_TEST_CLASSWISE_SCORE < test_classwise_score and store_ckpt:
      BEST_TEST_CLASSWISE_SCORE = test_classwise_score
      save_flag += 2
      ckpt_paths.append(os.path.join(CKPT_PATH, 'classwise'))

    # Save =====================================================================
    if save_flag:
      threshold = {
          'universal': valid_threshold,
          'classwise': valid_classwise_thresholds,
      }

      # Save the model architecture.
      if EPOCH == 1:
        model_path = os.path.join(CKPT_PATH, 'model.json')
        with tf.io.gfile.GFile(model_path, 'w') as f:
          json.dump(model.to_json(), f)

      # Save parameters of the model.
      for ckpt_path in ckpt_paths:
        model.save_weights(ckpt_path, overwrite=True, save_format='tf')
        # Save the threhold values.
        with tf.io.gfile.GFile(f'{ckpt_path}.threshold', 'wb') as f:
          pickle.dump(threshold, f)
    
    # ==========================================================================


    # Tensorboard
    if WRITE_TB:
      with train_writer.as_default():
        tf.summary.scalar('loss', train_loss, step=EPOCH)
        tf.summary.scalar('f1_score', train_accuracy.result(), step=EPOCH)
      with valid_writer.as_default():
        tf.summary.scalar('universal_threshold', valid_threshold, step=EPOCH)
        for i in range(7):
          tf.summary.scalar(f'{OBJECT_LABELS[i]}_threshold', valid_classwise_thresholds[i], step=EPOCH)
          tf.summary.scalar(f'{OBJECT_LABELS[i]}_f1_score', valid_classwise_scores[i], step=EPOCH)
        tf.summary.scalar('universal_f1_score', valid_score, step=EPOCH)
        tf.summary.scalar('classwise_f1_score', valid_classwise_score, step=EPOCH)
      with test_writer.as_default():
        tf.summary.scalar('universal_f1_score', test_score, step=EPOCH)
        tf.summary.scalar('classwise_f1_score', test_classwise_score, step=EPOCH)
        for i in range(7):
          tf.summary.scalar(f'{OBJECT_LABELS[i]}_f1_score', test_classwise_scores[i], step=EPOCH)

    end_time = time.perf_counter()

    print((f' | Elapsed Time {end_time - begin_time:6.2f} secs'
           f' | Loss {train_loss.numpy():.4f}'
           f'  F1 {train_accuracy.result().numpy():.4f}'
           f' | Val th: {valid_threshold:.2f}'
           f'  Val F1: {valid_score:.4f}'
           f'  Val cw F1: {valid_classwise_score:.4f}'
           f' | Test F1: {test_score:.4f}'
           f'  Test cw F1: {test_classwise_score:.4f} | '), end='')
    if save_flag == 0:
      print()
    elif save_flag == 1:
      print('Saved for an universal f1 score.')
    elif save_flag == 2:
      print('Saved for a classwise f1 score.')
    elif save_flag == 3:
      print('Saved for both f1 scores.')
    
    train_accuracy.reset_states()
    EPOCH += 1
  return BEST_TEST_SCORE, BEST_TEST_CLASSWISE_SCORE

def get_prediction(valid_data):
  ta_logits = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
  ta_labels = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
  for i, (x, y) in enumerate(valid_data):
    x_per = make_tensors_per_replica(x, strategy)
    ta_logits = ta_logits.write(i, tf.concat(distributed_test_step(x_per,).values, axis=0))
    ta_labels = ta_labels.write(i, y)

  labels = ta_labels.concat()
  logits = ta_logits.concat()

  return labels, logits

def evaluate(valid_data, threshold=0.5, verbose=1):
  y_trues, y_logits = get_prediction(valid_data)
  y_probs = tf.nn.sigmoid(y_logits)
  y_preds = tf.where(y_probs >= threshold, 1.0, 0.0)
  
  np_y_preds = y_preds.numpy()
  np_y_trues = y_trues.numpy()

  score = 0
  for i in range(7):
    score += f1_score(np_y_trues[:,i], np_y_preds[:,i])
  score /= 7.

  if verbose:
    print(f'F1 Score: {score}')

  return score


## Run

In [None]:
#@markdown # **TensorBoard**
if SHOW_TENSORBOARD:
  TB_PATH = os.path.join(BUCKET, 'ai', 'tensorboard')
  %reload_ext tensorboard
  %tensorboard --logdir $TB_PATH

In [None]:
train(train_data, valid_data, test_data, epochs=EPOCHS, verbose=0, store_ckpt=STORE_CKPT)

Epoch   1/150 | Elapsed Time 286.91 secs | Loss 0.6769  F1 0.3445 | Val th: 0.49  Val F1: 0.3924  Val cw F1: 0.4233 | Test F1: 0.3367  Test cw F1: 0.3779 | Saved for both f1 scores.
Epoch   2/150 | Elapsed Time  75.15 secs | Loss 0.6645  F1 0.3862 | Val th: 0.50  Val F1: 0.4443  Val cw F1: 0.4551 | Test F1: 0.3771  Test cw F1: 0.3945 | Saved for both f1 scores.
Epoch   3/150 | Elapsed Time  74.45 secs | Loss 0.6347  F1 0.4143 | Val th: 0.50  Val F1: 0.4449  Val cw F1: 0.4921 | Test F1: 0.3904  Test cw F1: 0.4239 | Saved for both f1 scores.
Epoch   4/150 | Elapsed Time  76.28 secs | Loss 0.5935  F1 0.3943 | Val th: 0.44  Val F1: 0.4682  Val cw F1: 0.5561 | Test F1: 0.4037  Test cw F1: 0.4266 | Saved for both f1 scores.
Epoch   5/150 | Elapsed Time  67.54 secs | Loss 0.5631  F1 0.3414 | Val th: 0.16  Val F1: 0.4804  Val cw F1: 0.5498 | Test F1: 0.4169  Test cw F1: 0.4143 | Saved for an universal f1 score.
Epoch   6/150 | Elapsed Time  74.63 secs | Loss 0.5528  F1 0.3116 | Val th: 0.10  V

(0.6517441047661832, 0.6562085534607675)

# Post optimization process

## Find best crop ratio.

In [None]:
# crop_ratios = np.linspace(0.5, 0.9, num=20)
# for cr in crop_ratios:
#   cropped_valid_data = load_data(DATA_DIR,
#                                 is_train=False,
#                                 crop_ratio=cr,
#                                 strategy=strategy)
#   print(f'CR: {cr}', evaluate(cropped_valid_data, 0.1733467))

## Find best classwise thresholds.

In [None]:
# y_trues, y_logits = get_prediction(valid_data)
# y_probs = tf.nn.sigmoid(y_logits)
# thresholds = np.linspace(0., 0.5, num=500, dtype='float32')
# ths = []
# ss = []
# for i in range(7):
#   y_prob = y_probs[:,i]
#   y_true = y_trues[:,i]

#   best_threshold = 0
#   best_score = 0  
#   for threshold in thresholds:
#     y_pred = tf.where(y_prob >= threshold, 1.0, 0.0)
#     np_y_true = y_true.numpy()
#     np_y_pred = y_pred.numpy()
#     score = f1_score(np_y_true, np_y_pred)
#     if best_score < score:
#       best_threshold = threshold
#       best_score = score
#   ths.append(best_threshold)
#   ss.append(best_score)

# print(np.mean(ss))