In [1]:
from absl import flags, logging, app
from absl.flags import FLAGS
import os
import sys
import yaml
import math
from itertools import product as product
import tensorflow as tf
import numpy as np

from modules.anchor import prior_box
from modules.utils import set_memory_growth, load_yaml, ProgressBar
from modules.models import RetinaFaceModel

In [2]:
for name in list(FLAGS):
      delattr(FLAGS, name)
flags.DEFINE_string('cfg_path', './configs/retinaface_mbv2.yaml', 'path to config file')

"""flags.DEFINE_string('gpu', '0', 'which gpu to use'): This line defines a command-line flag named 'gpu'. 
    The flags.DEFINE_string function is used to create a string-typed flag with a default value of '0' and 
    a description 'which gpu to use'. 
    This flag will allow you to specify the GPU to use when running the script.
    For example, if you run the script with the command: python script.py --gpu=1, 
    the value '1' will be assigned to CUDA_VISIBLE_DEVICES, and the script will use the GPU with the ID 1 for 
    TensorFlow operations. If you don't provide a value for the 'gpu' flag, it will default to '0', and the 
    script will use the GPU with ID 0."""

"flags.DEFINE_string('gpu', '0', 'which gpu to use'): This line defines a command-line flag named 'gpu'. \n    The flags.DEFINE_string function is used to create a string-typed flag with a default value of '0' and \n    a description 'which gpu to use'. \n    This flag will allow you to specify the GPU to use when running the script.\n    For example, if you run the script with the command: python script.py --gpu=1, \n    the value '1' will be assigned to CUDA_VISIBLE_DEVICES, and the script will use the GPU with the ID 1 for \n    TensorFlow operations. If you don't provide a value for the 'gpu' flag, it will default to '0', and the \n    script will use the GPU with ID 0."

In [3]:
flags.DEFINE_string('f', '', 'kernel')
FLAGS(sys.argv)

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

"""os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' sets the environment variable TF_CPP_MIN_LOG_LEVEL to the value '3'.

This environment variable controls the logging verbosity of TensorFlow. 
Here's what each log level value represents:

    0: Shows all logs (default behavior).
    1: Shows only INFO-level logs.
    2: Shows only WARNING-level logs.
    3: Shows only ERROR-level logs.
    4: Disables all logs.

By setting TF_CPP_MIN_LOG_LEVEL to '3', you are instructing TensorFlow to display only ERROR-level logs and 
suppress INFO and WARNING-level logs. 
This is useful if you want to reduce the amount of console output and focus on the most critical log messages.

This environment variable is specific to TensorFlow and controls the verbosity of TensorFlow's own log messages.
"""



In [4]:
logger = tf.get_logger()
"""The tf.get_logger() function is a utility provided by TensorFlow to 
retrieve a logger object that can be used for logging messages within your TensorFlow code.
"""

logger.disabled = True
logger.setLevel(logging.FATAL)

In [5]:
set_memory_growth()

1


In [6]:
cfg = load_yaml(FLAGS.cfg_path)
print(cfg['init_lr'])

0.001


## define network

In [7]:
model = RetinaFaceModel(cfg, training=True)
model.summary(line_length=110)

Model: "RetinaFaceModel"
______________________________________________________________________________________________________________
 Layer (type)                       Output Shape            Param #      Connected to                         
 input_image (InputLayer)           [(None, 300, 400, 3)]   0            []                                   
                                                                                                              
 tf.math.truediv (TFOpLambda)       (None, 300, 400, 3)     0            ['input_image[0][0]']                
                                                                                                              
 tf.math.subtract (TFOpLambda)      (None, 300, 400, 3)     0            ['tf.math.truediv[0][0]']            
                                                                                                              
 MobileNetV2_extrator (Functional)  ((None, 38, 50, 192),   1518464      ['tf.math.subt

## define prior box

In [8]:
priors = prior_box((cfg['input_height'], cfg['input_width']),
                       cfg['min_sizes'],  cfg['steps'], cfg['clip'])

In [9]:
len(priors)

5010

## load dataset

In [10]:
def load_dataset(cfg, priors, shuffle=True, buffer_size=10240):
    """load dataset"""
    logging.info("load dataset from {}".format(cfg['dataset_path']))
    dataset = load_tfrecord_dataset(
        tfrecord_name=cfg['dataset_path'],
        batch_size=cfg['batch_size'],
        
        using_bin=cfg['using_bin'],
        using_flip=cfg['using_flip'],
        using_distort=cfg['using_distort'],
        using_encoding=True,
        priors=priors,
        match_thresh=cfg['match_thresh'],
        ignore_thresh=cfg['ignore_thresh'],
        variances=cfg['variances'],
        shuffle=shuffle,
        buffer_size=buffer_size)
    return dataset

In [11]:
from modules.anchor import encode_tf
def _parse_tfrecord(using_bin, using_flip, using_distort,
                    using_encoding, priors, match_thresh, ignore_thresh,
                    variances):
    def parse_tfrecord(tfrecord):
        features = {
            'image/img_name': tf.io.FixedLenFeature([], tf.string),
            'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
            'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
            'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
            'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
            'image/object/landmark0/x': tf.io.VarLenFeature(tf.float32),
            'image/object/landmark0/y': tf.io.VarLenFeature(tf.float32),
            'image/object/landmark1/x': tf.io.VarLenFeature(tf.float32),
            'image/object/landmark1/y': tf.io.VarLenFeature(tf.float32),
           
            'image/object/landmark/valid': tf.io.VarLenFeature(tf.float32)}
        if using_bin:
            features['image/encoded'] = tf.io.FixedLenFeature([], tf.string)
            x = tf.io.parse_single_example(tfrecord, features)
            img = tf.image.decode_image(x['image/encoded'], channels=3)
        else:
            features['image/img_path'] = tf.io.FixedLenFeature([], tf.string)
            x = tf.io.parse_single_example(tfrecord, features)
            image_encoded = tf.io.read_file(x['image/img_path'])
            img = tf.image.decode_image(image_encoded, channels=3)

        labels = tf.stack(
            [tf.sparse.to_dense(x['image/object/bbox/xmin']),
             tf.sparse.to_dense(x['image/object/bbox/ymin']),
             tf.sparse.to_dense(x['image/object/bbox/xmax']),
             tf.sparse.to_dense(x['image/object/bbox/ymax']),
             tf.sparse.to_dense(x['image/object/landmark0/x']),
             tf.sparse.to_dense(x['image/object/landmark0/y']),
             tf.sparse.to_dense(x['image/object/landmark1/x']),
             tf.sparse.to_dense(x['image/object/landmark1/y']),
             
             tf.sparse.to_dense(x['image/object/landmark/valid'])], axis=1)

        img, labels = _transform_data(using_flip, using_distort, using_encoding, priors,
            match_thresh, ignore_thresh, variances)(img, labels)

        return img, labels
    return parse_tfrecord


# In[ ]:


def _transform_data(using_flip, using_distort, using_encoding, priors,
                    match_thresh, ignore_thresh, variances):
    def transform_data(img, labels):
        img = tf.cast(img, tf.float32)

        # randomly crop
        #img, labels = _crop(img, labels)

        # padding to square
        #img = _pad_to_square(img)


        # randomly left-right flip
        if using_flip:
            img, labels = _flip(img, labels)

        # distort
        if using_distort:
            img = _distort(img)

        # encode labels to feature targets
        if using_encoding:
            labels = encode_tf(labels=labels, priors=priors,
                               match_thresh=match_thresh,
                               ignore_thresh=ignore_thresh,
                               variances=variances)

        return img, labels
    return transform_data


# In[ ]:


def load_tfrecord_dataset(tfrecord_name, batch_size, 
                          using_bin=True, using_flip=True, using_distort=True,
                          using_encoding=True, priors=None, match_thresh=0.45,
                          ignore_thresh=0.3, variances=[0.1, 0.2],
                          shuffle=True, buffer_size=10240):
    """load dataset from tfrecord"""
    if not using_encoding:
        assert batch_size == 1  # dynamic data len when using_encoding
    else:
        assert priors is not None

    raw_dataset = tf.data.TFRecordDataset(tfrecord_name)
    raw_dataset = raw_dataset.repeat()
    if shuffle:
        raw_dataset = raw_dataset.shuffle(buffer_size=buffer_size)
    dataset = raw_dataset.map(
        _parse_tfrecord(using_bin, using_flip, using_distort,
                        using_encoding, priors, match_thresh, ignore_thresh,
                        variances),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(
        buffer_size=tf.data.experimental.AUTOTUNE)

    return dataset


# In[ ]:


def _flip(img, labels):
    flip_case = tf.random.uniform([], 0, 2, dtype=tf.int32)

    def flip_func():
        flip_img = tf.image.flip_left_right(img)
        flip_labels = tf.stack([1 - labels[:, 2],  labels[:, 1],
                                1 - labels[:, 0],  labels[:, 3],
                                1 - labels[:, 6],  labels[:, 7],
                                1 - labels[:, 4],  labels[:, 5],
                                
                                labels[:, 8]], axis=1)

        return flip_img, flip_labels

    img, labels = tf.case([(tf.equal(flip_case, 0), flip_func)],
                          default=lambda: (img, labels))

    return img, labels


def _crop(img, labels, max_loop=250):
    shape = tf.shape(img)

    def matrix_iof(a, b):
        """
        return iof of a and b, numpy version for data augenmentation
        """
        lt = tf.math.maximum(a[:, tf.newaxis, :2], b[:, :2])
        rb = tf.math.minimum(a[:, tf.newaxis, 2:], b[:, 2:])

        area_i = tf.math.reduce_prod(rb - lt, axis=2) * \
            tf.cast(tf.reduce_all(lt < rb, axis=2), tf.float32)
        area_a = tf.math.reduce_prod(a[:, 2:] - a[:, :2], axis=1)
        return area_i / tf.math.maximum(area_a[:, tf.newaxis], 1)

    def crop_loop_body(i, img, labels):
        valid_crop = tf.constant(1, tf.int32)

        pre_scale = tf.constant([0.3, 0.45, 0.6, 0.8, 1.0], dtype=tf.float32)
        scale = pre_scale[tf.random.uniform([], 0, 5, dtype=tf.int32)]
        short_side = tf.cast(tf.minimum(shape[0], shape[1]), tf.float32)
        h = w = tf.cast(scale * short_side, tf.int32)
        h_offset = tf.random.uniform([], 0, shape[0] - h + 1, dtype=tf.int32)
        w_offset = tf.random.uniform([], 0, shape[1] - w + 1, dtype=tf.int32)
        roi = tf.stack([w_offset, h_offset, w_offset + w, h_offset + h])
        roi = tf.cast(roi, tf.float32)

        value = matrix_iof(labels[:, :4], roi[tf.newaxis])
        valid_crop = tf.cond(tf.math.reduce_any(value >= 1),
                             lambda: valid_crop, lambda: 0)

        centers = (labels[:, :2] + labels[:, 2:4]) / 2
        mask_a = tf.reduce_all(
            tf.math.logical_and(roi[:2] < centers, centers < roi[2:]),
            axis=1)
        labels_t = tf.boolean_mask(labels, mask_a)
        valid_crop = tf.cond(tf.reduce_any(mask_a),
                             lambda: valid_crop, lambda: 0)

        img_t = img[h_offset:h_offset + h, w_offset:w_offset + w, :]
        h_offset = tf.cast(h_offset, tf.float32)
        w_offset = tf.cast(w_offset, tf.float32)
        labels_t = tf.stack(
            [labels_t[:, 0] - w_offset,  labels_t[:, 1] - h_offset,
             labels_t[:, 2] - w_offset,  labels_t[:, 3] - h_offset,
             labels_t[:, 4] - w_offset,  labels_t[:, 5] - h_offset,
             labels_t[:, 6] - w_offset,  labels_t[:, 7] - h_offset,
             labels_t[:, 8] - w_offset,  labels_t[:, 9] - h_offset,
             labels_t[:, 10] - w_offset, labels_t[:, 11] - h_offset,
             labels_t[:, 12] - w_offset, labels_t[:, 13] - h_offset,
             labels_t[:, 14]], axis=1)

        return tf.cond(valid_crop == 1,
                       lambda: (max_loop, img_t, labels_t),
                       lambda: (i + 1, img, labels))

    _, img, labels = tf.while_loop(
        lambda i, img, labels: tf.less(i, max_loop),
        crop_loop_body,
        [tf.constant(-1), img, labels],
        shape_invariants=[tf.TensorShape([]),
                          tf.TensorShape([None, None, 3]),
                          tf.TensorShape([None, 15])])

    return img, labels


def _pad_to_square(img):
    height = tf.shape(img)[0]
    width = tf.shape(img)[1]

    def pad_h():
        img_pad_h = tf.ones([width - height, width, 3]) * \
            tf.reduce_mean(img, axis=[0, 1], keepdims=True)
        return tf.concat([img, img_pad_h], axis=0)

    def pad_w():
        img_pad_w = tf.ones([height, height - width, 3]) * \
            tf.reduce_mean(img, axis=[0, 1], keepdims=True)
        return tf.concat([img, img_pad_w], axis=1)

    img = tf.case([(tf.greater(height, width), pad_w),
                   (tf.less(height, width), pad_h)], default=lambda: img)

    return img


def _resize(img, labels, img_dim):
    w_f = tf.cast(tf.shape(img)[1], tf.float32)
    h_f = tf.cast(tf.shape(img)[0], tf.float32)
    locs = tf.stack([labels[:, 0] / w_f,  labels[:, 1] / h_f,
                     labels[:, 2] / w_f,  labels[:, 3] / h_f,
                     labels[:, 4] / w_f,  labels[:, 5] / h_f,
                     labels[:, 6] / w_f,  labels[:, 7] / h_f,
                     labels[:, 8] / w_f,  labels[:, 9] / h_f,
                     labels[:, 10] / w_f, labels[:, 11] / h_f,
                     labels[:, 12] / w_f, labels[:, 13] / h_f], axis=1)
    locs = tf.clip_by_value(locs, 0, 1)
    labels = tf.concat([locs, labels[:, 14][:, tf.newaxis]], axis=1)

    resize_case = tf.random.uniform([], 0, 5, dtype=tf.int32)

    def resize(method):
        def _resize():
            return tf.image.resize(
                img, [img_dim, img_dim], method=method, antialias=True)
        return _resize

    img = tf.case([(tf.equal(resize_case, 0), resize('bicubic')),
                   (tf.equal(resize_case, 1), resize('area')),
                   (tf.equal(resize_case, 2), resize('nearest')),
                   (tf.equal(resize_case, 3), resize('lanczos3'))],
                  default=resize('bilinear'))

    return img, labels


def _distort(img):
    img = tf.image.random_brightness(img, 0.4)
    img = tf.image.random_contrast(img, 0.5, 1.5)
    img = tf.image.random_saturation(img, 0.5, 1.5)
    img = tf.image.random_hue(img, 0.1)

    return img

In [12]:
train_dataset = load_dataset(cfg, priors, shuffle=True)

## define optimizer

In [13]:
def MultiStepLR(initial_learning_rate, lr_steps, lr_rate, name='MultiStepLR'):
    """Multi-steps learning rate scheduler."""
    lr_steps_value = [initial_learning_rate]
    for _ in range(len(lr_steps)):
        lr_steps_value.append(lr_steps_value[-1] * lr_rate)
    return tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        boundaries=lr_steps, values=lr_steps_value)


def MultiStepWarmUpLR(initial_learning_rate, lr_steps, lr_rate,
                      warmup_steps=0., min_lr=0.,
                      name='MultiStepWarmUpLR'):
    """Multi-steps warm up learning rate scheduler."""
    assert warmup_steps <= lr_steps[0]
    assert min_lr <= initial_learning_rate
    lr_steps_value = [initial_learning_rate]
    for _ in range(len(lr_steps)):
        lr_steps_value.append(lr_steps_value[-1] * lr_rate)
    return PiecewiseConstantWarmUpDecay(
        boundaries=lr_steps, values=lr_steps_value, warmup_steps=warmup_steps,
        min_lr=min_lr)

In [14]:
class PiecewiseConstantWarmUpDecay(
        tf.keras.optimizers.schedules.LearningRateSchedule):
    """A LearningRateSchedule wiht warm up schedule.
    Modified from tf.keras.optimizers.schedules.PiecewiseConstantDecay"""

    def __init__(self, boundaries, values, warmup_steps, min_lr,
                 name=None):
        super(PiecewiseConstantWarmUpDecay, self).__init__()

        if len(boundaries) != len(values) - 1:
            raise ValueError(
                    "The length of boundaries should be 1 less than the"
                    "length of values")

        self.boundaries = boundaries
        self.values = values
        self.name = name
        self.warmup_steps = warmup_steps
        self.min_lr = min_lr

    def __call__(self, step):
        with tf.name_scope(self.name or "PiecewiseConstantWarmUp"):
            step = tf.cast(tf.convert_to_tensor(step), tf.float32)
            pred_fn_pairs = []
            warmup_steps = self.warmup_steps
            boundaries = self.boundaries
            values = self.values
            min_lr = self.min_lr

            pred_fn_pairs.append(
                (step <= warmup_steps,
                 lambda: min_lr + step * (values[0] - min_lr) / warmup_steps))
            pred_fn_pairs.append(
                (tf.logical_and(step <= boundaries[0],
                                step > warmup_steps),
                 lambda: tf.constant(values[0])))
            pred_fn_pairs.append(
                (step > boundaries[-1], lambda: tf.constant(values[-1])))

            for low, high, v in zip(boundaries[:-1], boundaries[1:],
                                    values[1:-1]):
                pred = (step > low) & (step <= high)
                pred_fn_pairs.append((pred, lambda v=v: tf.constant(v)))

            # The default isn't needed here because our conditions are mutually
            # exclusive and exhaustive, but tf.case requires it.
            return tf.case(pred_fn_pairs, lambda: tf.constant(values[0]),
                           exclusive=True)

    def get_config(self):
        return {
                "boundaries": self.boundaries,
                "values": self.values,
                "warmup_steps": self.warmup_steps,
                "min_lr": self.min_lr,
                "name": self.name
        }


In [15]:
from tensorflow.keras.optimizers.legacy import SGD
from tensorflow.keras.optimizers.legacy import Adam

steps_per_epoch = cfg['dataset_len'] // cfg['batch_size']
learning_rate = MultiStepWarmUpLR(
    initial_learning_rate=cfg['init_lr'],
    lr_steps=[e * steps_per_epoch for e in cfg['lr_decay_epoch']],
    lr_rate=cfg['lr_rate'],
    warmup_steps=cfg['warmup_epoch'] * steps_per_epoch,
    min_lr=cfg['min_lr'])
optimizer = SGD(learning_rate=learning_rate, momentum=0.9, nesterov=True)

## define losses function

In [16]:
def _smooth_l1_loss(y_true, y_pred):
    t = tf.abs(y_pred - y_true)
    return tf.where(t < 1, 0.5 * t ** 2, t - 0.5)
"""The final output tensor will have the same shape as t. 
    If an element of t is less than 1, the corresponding element in the output tensor will be 0.5 * t ** 2. 
    If an element of t is greater than or equal to 1, the corresponding element will be t - 0.5."""


def MultiBoxLoss(num_class=2, neg_pos_ratio=3):
    """multi-box loss"""
    def multi_box_loss(y_true, y_pred):
        num_batch = tf.shape(y_true)[0]
        num_prior = tf.shape(y_true)[1]
        
        print(num_prior)

        loc_pred = tf.reshape(y_pred[0], [num_batch * num_prior, 4])
        landm_pred = tf.reshape(y_pred[1], [num_batch * num_prior, 4])
        class_pred = tf.reshape(y_pred[2], [num_batch * num_prior, num_class])
        loc_true = tf.reshape(y_true[..., :4], [num_batch * num_prior, 4])
        landm_true = tf.reshape(y_true[..., 4:8], [num_batch * num_prior, 4])
        landm_valid = tf.reshape(y_true[..., 8], [num_batch * num_prior, 1])
        class_true = tf.reshape(y_true[..., 9], [num_batch * num_prior, 1])

        # define filter mask: class_true = 1 (pos), 0 (neg), -1 (ignore)
        #                     landm_valid = 1 (w landm), 0 (w/o landm)
        mask_pos = tf.equal(class_true, 1)
        mask_neg = tf.equal(class_true, 0)
        mask_landm = tf.logical_and(tf.equal(landm_valid, 1), mask_pos)

        # landm loss (smooth L1)
        mask_landm_b = tf.broadcast_to(mask_landm, tf.shape(landm_true))
        loss_landm = _smooth_l1_loss(tf.boolean_mask(landm_true, mask_landm_b),
                                     tf.boolean_mask(landm_pred, mask_landm_b))
        loss_landm = tf.reduce_mean(loss_landm)

        # localization loss (smooth L1)
        mask_pos_b = tf.broadcast_to(mask_pos, tf.shape(loc_true))
        loss_loc = _smooth_l1_loss(tf.boolean_mask(loc_true, mask_pos_b),
                                   tf.boolean_mask(loc_pred, mask_pos_b))
        loss_loc = tf.reduce_mean(loss_loc)

        # classification loss (crossentropy)
        # 1. compute max conf across batch for hard negative mining
        loss_class = tf.where(mask_neg,
                              1 - class_pred[:, 0][..., tf.newaxis], 0)

        # 2. hard negative mining
        loss_class = tf.reshape(loss_class, [num_batch, num_prior])
        loss_class_idx = tf.argsort(loss_class, axis=1, direction='DESCENDING')
        loss_class_idx_rank = tf.argsort(loss_class_idx, axis=1)
        mask_pos_per_batch = tf.reshape(mask_pos, [num_batch, num_prior])
        num_pos_per_batch = tf.reduce_sum(
                tf.cast(mask_pos_per_batch, tf.float32), 1, keepdims=True)
        num_pos_per_batch = tf.maximum(num_pos_per_batch, 1)
        num_neg_per_batch = tf.minimum(neg_pos_ratio * num_pos_per_batch,
                                       tf.cast(num_prior, tf.float32) - 1)
        mask_hard_neg = tf.reshape(
            tf.cast(loss_class_idx_rank, tf.float32) < num_neg_per_batch,
            [num_batch * num_prior, 1])

        # 3. classification loss including positive and negative examples
        loss_class_mask = tf.logical_or(mask_pos, mask_hard_neg)
        loss_class_mask_b = tf.broadcast_to(loss_class_mask,
                                            tf.shape(class_pred))
        filter_class_true = tf.boolean_mask(tf.cast(mask_pos, tf.float32),
                                            loss_class_mask)
        filter_class_pred = tf.boolean_mask(class_pred, loss_class_mask_b)
        filter_class_pred = tf.reshape(filter_class_pred, [-1, num_class])
        loss_class = tf.keras.losses.sparse_categorical_crossentropy(
            y_true=filter_class_true, y_pred=filter_class_pred)
        loss_class = tf.reduce_mean(loss_class)

        return loss_loc, loss_landm, loss_class

    return multi_box_loss

In [17]:
multi_box_loss = MultiBoxLoss()

## load checkpoint

In [18]:
checkpoint_dir = './checkpoints_new/' + cfg['sub_name']

In [19]:
checkpoint = tf.train.Checkpoint(step=tf.Variable(0, name='step'),
                                 optimizer=optimizer,
                                 model=model)

In [20]:
manager = tf.train.CheckpointManager(checkpoint=checkpoint,
                                         directory=checkpoint_dir,
                                         max_to_keep=3)

In [21]:
if manager.latest_checkpoint:
    checkpoint.restore(manager.latest_checkpoint)
    print('[*] load ckpt from {} at step {}.'.format(
        manager.latest_checkpoint, checkpoint.step.numpy()))
else:
    print("[*] training from scratch.")

[*] load ckpt from ./checkpoints_new/retinaface_mbv2/ckpt-3 at step 3000.


## define training step function

In [22]:
@tf.function
def train_step(inputs, labels):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)

        losses = {}
        losses['reg'] = tf.reduce_sum(model.losses)
        losses['loc'], losses['landm'], losses['class'] = \
            multi_box_loss(labels, predictions)
        total_loss = tf.add_n([l for l in losses.values()])

    grads = tape.gradient(total_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

    return total_loss, losses

## training loop

In [23]:
summary_writer = tf.summary.create_file_writer('./logs/' + cfg['sub_name'])
remain_steps = max(
    steps_per_epoch * cfg['epoch'] - checkpoint.step.numpy(), 0)
prog_bar = ProgressBar(steps_per_epoch,checkpoint.step.numpy() % steps_per_epoch)

for inputs, labels in train_dataset.take(remain_steps):
    checkpoint.step.assign_add(1)
    steps = checkpoint.step.numpy()

    total_loss, losses = train_step(inputs, labels)
    
    prog_bar.update("epoch={}/{}, loss={:.4f}, lr={:.1e}".format(
        ((steps - 1) // steps_per_epoch) + 1, cfg['epoch'],
        total_loss.numpy(), optimizer.lr(steps).numpy()))
    
    if steps % 10 == 0:
        with summary_writer.as_default():
            tf.summary.scalar(
                'loss/total_loss', total_loss, step=steps)
            for k, l in losses.items():
                tf.summary.scalar('loss/{}'.format(k), l, step=steps)
            tf.summary.scalar(
                'learning_rate', optimizer.lr(steps), step=steps)

    if steps % cfg['save_steps'] == 0:
        manager.save()
        print("\n[*] save ckpt file at {}".format(
            manager.latest_checkpoint))


    

Tensor("strided_slice_1:0", shape=(), dtype=int32)
Tensor("strided_slice_1:0", shape=(), dtype=int32)
Training [>>>>>>>>>>>>>>>>>>>>>    ] 689/805, epoch=4/100, loss=nan, lr=1.0e-03  5.0 step/secc

KeyboardInterrupt: 