In [None]:
import os
import sys
import tensorflow as tf
import tensorflow_addons as tfa
import logging
import contextlib
import cv2
import numpy as np

ROOT_DIR = os.path.abspath('.')
sys.path.append(os.path.join(ROOT_DIR, 'data'))
sys.path.append(os.path.join(ROOT_DIR, 'layers'))
sys.path.append(os.path.join(ROOT_DIR, 'protos'))
import yolact
import yolactloss
from data import coco_dataset

# Config

In [None]:
image_height, image_width = 550, 550
number_of_classes = 100
aspect_ratios = [1, 0.5, 2]
scales = [24, 48, 96, 192, 384]
batch_size = 1
training_iterations = 1200000
learning_rate = 1e-3
momentum = 0.9
weight_decay = 5 * 1e-4
pretrained_checkpoints = ''
print_interval = 0
save_interval = 0
validation_iterations = 5000

training_data_path = ''
validation_data_path = ''
log_directory = ''
checkpoint_directory = ''
saved_model_directory = ''

# Training

In [None]:
logging.info('Creating the Yolact model instance')
model = yolact.Yolact(
    image_height=image_height,
    image_width=image_width,
    fpn_channels=256,
    number_of_classes=number_of_classes + 1,
    number_of_masks=32,
    aspect_ratios=aspect_ratios,
    scales=scales,
    base_model_trainable=False
)

In [None]:
logging.info(f'Creating the training dataloader from {training_data_path}...')
training_dataset = coco_dataset.prepare_dataset(
    image_height=image_height,
    image_width=image_width,
    feature_map_sizes=model.feature_map_size,
    protonet_out_sizes=model.protonet_out_size,
    aspect_ratios=aspect_ratios,
    scales=scales,
    tfrecord_directory=training_data_path,
    batch_size=batch_size)


In [None]:
logging.info(f'Creating the validation dataloader from: {validation_data_path}...')
validation_dataset = coco_dataset.prepare_dataset(
    image_height=image_height,
    image_width=image_width,
    feature_map_sizes=model.feature_map_size,
    protonet_out_sizes=model.protonet_out_size,
    aspect_ratios=[float(i) for i in aspect_ratios],
    scales=[int(i) for i in scales],
    tfrecord_directory=validation_data_path,
    batch_size=1)


In [None]:
learning_rate_schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
    [int(0.8 * training_iterations), int(0.9 * training_iterations), int(0.95 * training_iterations)],
    [learning_rate, 0.1 * learning_rate, 0.01 * learning_rate, 0.001 * learning_rate])

logging.info('Initiate the optimizer and loss function...')

optimizer = tfa.optimizers.SGDW(
    learning_rate=learning_rate_schedule,
    momentum=momentum,
    weight_decay=weight_decay)

criterion = yolactloss.YOLACTLoss(
    img_h=image_height,
    img_w=image_width)

In [None]:
training_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
validation_loss = tf.keras.metrics.Mean('valid_loss', dtype=tf.float32)
location = tf.keras.metrics.Mean('loc_loss', dtype=tf.float32)
config = tf.keras.metrics.Mean('conf_loss', dtype=tf.float32)
mask = tf.keras.metrics.Mean('mask_loss', dtype=tf.float32)
mask_iou = tf.keras.metrics.Mean('mask_iou_loss', dtype=tf.float32)
segmentation = tf.keras.metrics.Mean('seg_loss', dtype=tf.float32)
v_loc = tf.keras.metrics.Mean('vloc_loss', dtype=tf.float32)
v_conf = tf.keras.metrics.Mean('vconf_loss', dtype=tf.float32)
v_mask = tf.keras.metrics.Mean('vmask_loss', dtype=tf.float32)
v_mask_iou = tf.keras.metrics.Mean('vmask_iou_loss', dtype=tf.float32)
v_seg = tf.keras.metrics.Mean('vseg_loss', dtype=tf.float32)

In [None]:
logging.info('Setup tensorboard...')
train_log_dir = os.path.join(log_directory, 'train')
test_log_dir = os.path.join(log_directory, 'test')
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
test_summary_writer = tf.summary.create_file_writer(test_log_dir)

In [None]:
logging.info('Start the training process')
checkpoint = tf.train.Checkpoint(
    step=tf.Variable(1), optimizer=optimizer, model=model)
manager = tf.train.CheckpointManager(
    checkpoint, directory=checkpoint_directory, max_to_keep=5)

if manager.latest_checkpoint:
    logging.info(f'Restored from {manager.latest_checkpoint}')
else:
    if pretrained_checkpoints != '':
        feature_extractor_model = tf.train.Checkpoint(
            backbone_resnet=model.backbone_resnet,
            backbone_fpn=model.backbone_fpn,
            protonet=model.protonet)
        ckpt = tf.train.Checkpoint(pretrained_checkpoints).expect_partial()\
            .assert_existing_objects_matched()
        logging.info(f'Backbone restored from {pretrained_checkpoints}')
    else:
        logging.info('Initializing without checkpoints.')
        

In [None]:
@contextlib.contextmanager
def options(options):
    old_opts = tf.config.optimizer.get_experimental_options()
    tf.config.optimizer.set_experimental_options(options)
    try:
        yield
    finally:
        tf.config.optimizer.set_experimental_options(old_opts)

In [None]:
best_val = 1e10
iterations = checkpoint.step.numpy()

for image, labels in training_dataset:
    if iterations > training_iterations:
        break

    checkpoint.step.assign_add(1)
    iterations += 1

    with options({'constant_folding': True,
                  'layout_optimize': True,
                  'loop_optimization': True,
                  'arithmetic_optimization': True,
                  'remapping': True}):
        with tf.GradientTape() as tape:
            output = model(image, training=True)

            loc_loss, conf_loss, mask_loss, mask_iou_loss, seg_loss, total_loss = \
                criterion(model, output, labels, number_of_classes + 1)
        grads = tape.gradient(total_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        training_loss.update_state(total_loss)

    location.update_state(loc_loss)
    config.update_state(conf_loss)
    mask.update_state(mask_loss)
    mask_iou.update_state(mask_iou_loss)
    segmentation.update_state(seg_loss)

    with train_summary_writer.as_default():
        tf.summary.scalar('Total loss',
                          training_loss.result(), step=iterations)

        tf.summary.scalar('Loc loss',
                          location.result(), step=iterations)

        tf.summary.scalar('Conf loss',
                          config.result(), step=iterations)

        tf.summary.scalar('Mask loss',
                          mask.result(), step=iterations)

        tf.summary.scalar('Mask IOU loss',
                          mask_iou.result(), step=iterations)

        tf.summary.scalar('Seg loss',
                          segmentation.result(), step=iterations)

    if iterations and iterations % print_interval == 0:
        logging.info(f'Iteration {iterations}, LR: {optimizer._decayed_lr(var_dtype=tf.float32)}, ' +
                     f'Total loss: {training_loss.result()}, B: {location.result()}, ' +
                     f'C: {config.result()}, M: {mask.result()}, I: {mask_iou.result()}, ' +
                     f'S: {segmentation.result()}')

    if iterations and iterations % save_interval == 0:
        save_path = manager.save()

        logging.info(
            f'Saved checkpoint for step {int(checkpoint.step)} to {save_path}')

        validation_iterator = 0
        for validation_image, validation_labels in validation_dataset:
            if validation_iterator > validation_iterations:
                break

            with options({'constant_folding': True,
                          'layout_optimize': True,
                          'loop_optimization': True,
                          'arithmetic_optimization': True,
                          'remapping': True}):
                output = model(validation_image, training=False)

                validation_location_loss, validation_config_loss, validation_mask_loss, \
                    validation_mask_iou_loss, validation_segmentation_loss, validation_total_loss = \
                    criterion(model, output, validation_labels,
                              number_of_classes + 1)

                validation_loss.update(validation_total_loss)

                _h = validation_image.shape[1]
                _w = validation_image.shape[2]

                number_of_ground_truths = validation_labels['num_obj'][0].numpy(
                )
                ground_truth_boxes = validation_labels['boxes_norm'][0][:number_of_ground_truths]
                ground_truth_boxes = ground_truth_boxes.numpy() * \
                    np.array([_h, _w, _h, _w])
                ground_truth_classes = validation_labels['classes'][0][:number_of_ground_truths].numpy(
                )
                ground_truth_masks = validation_labels['mask_target'][0][:number_of_ground_truths].numpy(
                )

                ground_truth_masked_image = np.zeros(
                    (number_of_ground_truths, _h, _w))
                for _b in range(number_of_ground_truths):
                    _mask = ground_truth_masks[_b].astype("uint8")
                    _mask = cv2.resize(_mask, (_w, _h))
                    ground_truth_masked_image[_b] = _mask

                number_of_detections = np.count_nonzero(
                    output['detection_scores'][0].numpy() > 0.05)

                detection_boxes = output['detection_boxes'][0][:number_of_detections]
                detection_boxes = detection_boxes.numpy() * \
                    np.array([_h, _w, _h, _w])
                detection_masks = output['detection_masks'][0][:number_of_detections].numpy(
                )
                detection_masks = (detection_masks > 0.5)

                detection_scores = output['detection_scores'][0][:number_of_detections].numpy(
                )
                detection_classes = output['detection_classes'][0][:number_of_detections].numpy(
                )

                masked_detection_image = np.zeros(
                    (number_of_detections, _h, _w))
                for _b in range(number_of_detections):
                    _mask = detection_masks[_b].astype("uint8")
                    _mask = cv2.resize(_mask, (_w, _h))
                    masked_detection_image[_b] = _mask

            v_loc.update_state(validation_location_loss)
            v_conf.update_state(validation_config_loss)
            v_mask.update_state(validation_mask_loss)
            v_mask_iou.update_state(validation_mask_iou_loss)
            v_seg.update_state(validation_segmentation_loss)
            validation_iterator += 1

        with test_summary_writer.as_default():
            tf.summary.scalar('V Total loss',
                              validation_loss.result(), step=iterations)

            tf.summary.scalar('V Loc loss',
                              v_loc.result(), step=iterations)

            tf.summary.scalar('V Conf loss',
                              v_conf.result(), step=iterations)

            tf.summary.scalar('V Mask loss',
                              v_mask.result(), step=iterations)

            tf.summary.scalar('V Mask IOU loss',
                              v_mask_iou.result(), step=iterations)

            tf.summary.scalar('V Seg loss',
                              v_seg.result(), step=iterations)

        train_template = ("Iteration {}, Train Loss: {}, Loc Loss: {},  "
                          "Conf Loss: {}, Mask Loss: {}, Mask IOU Loss: {}, Seg Loss: {}")

        valid_template = ("Iteration {}, Valid Loss: {}, V Loc Loss: {},  "
                          "V Conf Loss: {}, V Mask Loss: {}, V Mask IOU Loss: {}, "
                          "Seg Loss: {}")

        logging.info(train_template.format(iterations + 1,
                                           training_loss.result(),
                                           location.result(),
                                           config.result(),
                                           mask.result(),
                                           mask_iou.result(),
                                           segmentation.result()))
        logging.info(valid_template.format(iterations + 1,
                                           validation_loss.result(),
                                           v_loc.result(),
                                           v_conf.result(),
                                           v_mask.result(),
                                           v_mask_iou.result(),
                                           v_seg.result()))

        if validation_loss.result() < best_val:
            best_val = validation_loss.result()

            save_options = tf.saved_model.SaveOptions(
                namespace_whitelist=['Addons'])
            model.save(os.path.join(saved_model_directory, 'saved_model_'
                                    + str(validation_loss.result().numpy())), options=save_options)
    
    training_loss.reset_states()
    location.reset_states()
    config.reset_states()
    mask.reset_states()
    mask_iou.reset_states()
    segmentation.reset_states()

    validation_loss.reset_states()
    v_loc.reset_states()
    v_conf.reset_states()
    v_mask.reset_states()
    v_mask_iou.reset_states()
    v_seg.reset_states()


# Inference

In [None]:
model = tf.saved_model.load('./saved_models/saved_model_0.17916931')
infer = model.signatures["serving_default"]

img = cv2.imread('test.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (550, 550)).astype(np.float32)
output = infer(tf.constant(img[None, ...]))

_h = img.shape[0]
_w = img.shape[1]

det_num = output['num_detections'][0].numpy()
det_boxes = output['detection_boxes'][0][:det_num]
det_boxes = det_boxes.numpy()*np.array([_h,_w,_h,_w])
det_masks = output['detection_masks'][0][:det_num].numpy()

det_scores = output['detection_scores'][0][:det_num].numpy()
det_classes = output['detection_classes'][0][:det_num].numpy()

for i in range(det_num):
    score = det_scores[i]
    if score > 0.5:
        box = det_boxes[i].astype(int)
        _class = det_classes[i]
        cv2.rectangle(img, (box[1], box[0]), (box[3], box[2]), (0, 255, 0), 2)
        cv2.putText(img, str(_class)+'; '+str(round(score,2)), (box[1], box[0]), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), lineType=cv2.LINE_AA)
        mask = det_masks[i]
        mask = cv2.resize(mask, (_w, _h))
        mask = (mask > 0.5)
        roi = img[mask]
        blended = roi.astype("uint8")
        img[mask] = blended*[0,0,1]

cv2.imwrite("out.jpg", cv2.cvtColor(img, cv2.COLOR_RGB2BGR))