# DETR for Tensorflow

This notebook is a friendly tool for implementing my DETR object detection and multi-instance classification models on the COCO dataset.

My models are coded in Tensorflow from first principles, as presented in the paper [End-to-End Object Detection with Transformers](https://ai.facebook.com/research/publications/end-to-end-object-detection-with-transformers) by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, and Sergey Zagoruyko.

In [None]:
#"""
# automatically reload imports as they change (for debugging cusom imports)
%load_ext autoreload
%autoreload 2
#"""

In [None]:
# Google Drive integration
# for model checkpointing (also for data loading if not using GCS)
from google.colab import drive
drive.mount('/content/drive')
#"""

Mounted at /content/drive


In [None]:
# Tensorflow
import tensorflow as tf
!pip install -q tensorflow-addons
import tensorflow_addons as tfa
tf.config.optimizer.set_jit(enabled=True)
!pip install -U tensorboard-plugin-profile

# computation
import pandas as pd
import numpy as np

# file system
import sys
import os
import glob
import shutil
import json
from zipfile import ZipFile
!pip install -q wget

# Visualization
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
%matplotlib inline

[K     |████████████████████████████████| 1.1 MB 4.0 MB/s 
[?25hCollecting tensorboard-plugin-profile
  Downloading tensorboard_plugin_profile-2.5.0-py3-none-any.whl (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 4.2 MB/s 
Collecting gviz-api>=1.9.0
  Downloading gviz_api-1.10.0-py2.py3-none-any.whl (13 kB)
Installing collected packages: gviz-api, tensorboard-plugin-profile
Successfully installed gviz-api-1.10.0 tensorboard-plugin-profile-2.5.0
  Building wheel for wget (setup.py) ... [?25l[?25hdone


In [None]:
# custom imports
sys.path.insert(0, '/content/drive/MyDrive/GitHub/DETR_for_TF/ModelComponents')  # if using Google Drive
import model
import model_pretrainer
import learning_rate_schedulers
import parameters
import datasets
import pipeline

Load Modules

In [None]:
dataset_name = 'COCO'
model_parameters = parameters.ModelParameters(dataset_name=dataset_name)
params = model_parameters.default_params()


filepaths = parameters.Filepaths(dataset_name=dataset_name)
strategies = parameters.StrategyOptions(mixed_precision=True)
STRATEGY = strategies.strategy()

## Load data

In [None]:
coco = datasets.COCOStandard(local_base_dir='/content',
                             archive_base_dir='/content/drive/MyDrive/datasets/')

In [None]:
coco.get_data(download=False, unzip=True, force_rebuild=False)

extracting: train2017.zip
/content/COCO/images/train found. Using previously extracted data. (Note: set force_rebuild=True to override)
/content/COCO/annotations/train found. Using previously extracted data. (Note: set force_rebuild=True to override)
/content/COCO/annotations/train found. Using previously extracted data. (Note: set force_rebuild=True to override)
/content/COCO/annotations/train found. Using previously extracted data. (Note: set force_rebuild=True to override)
extracting: val2017.zip
/content/COCO/images/test found. Using previously extracted data. (Note: set force_rebuild=True to override)
extracting: test2017.zip
/content/COCO/images/test found. Using previously extracted data. (Note: set force_rebuild=True to override)


Prepare dataframes

In [None]:
all_info_train = coco.prepare_COCO_from_json(subset='train', force_rebuild=False)
all_info_valid = coco.prepare_COCO_from_json(subset='val', force_rebuild=False)

In [None]:
print('train samples:', len(all_info_train['annotations_df']))
print('valid samples:', len(all_info_valid['annotations_df']))

Create TF Datasets

In [None]:
data_pipeline = pipeline.Pipeline(**params)
image_augmentations = pipeline.Augmentations()

In [None]:
ds_train = data_pipeline.data_generator(labels_df=all_info_train['annotations_df'],
                                        decode_images=True,
                                        stream_from_directory=False)

ds_train_augmented = image_augmentations.apply_image_augmentations(ds_train)

ds_valid = data_pipeline.data_generator(labels_df=all_info_valid['annotations_df'], 
                                        decode_images=True,
                                        stream_from_directory=False)

## Prepare Model

Set Checkpoints

In [None]:
CLASS_CHECKPOINT_DIR = os.path.join(filepaths.default_params('checkpoint_save_dir'), 'classification')
class_checkpoint_path = os.path.join(CLASS_CHECKPOINT_DIR, 'coco_class.ckpt')

DETECTION_CHECKPOINT_DIR = os.path.join(filepaths.default_params('checkpoint_save_dir'), 'detection')
detection_checkpoint_path = os.path.join(DETECTION_CHECKPOINT_DIR, 'coco_detect.ckpt')


class_checkpoint = tf.keras.callbacks.ModelCheckpoint(class_checkpoint_path,
                                                      save_weights_only=True)

detection_checkpoint = tf.keras.callbacks.ModelCheckpoint(detection_checkpoint_path,
                                                          save_weights_only=True)

Examine Data

In [None]:
#"""
for val in ds_train_augmented.take(1):
    print(val.keys())

ds_train.batch(1)
#"""

Box Visualization

In [None]:
def show_example(val, verbose=False):

    image_id = val['image_id'].numpy()
    num_obj = val['num_objects'].numpy()
    image = val['image']
    bbox = val['bbox'][:num_obj, ...]
    category = val['category'][:num_obj, ...]
    attribute = val['attribute'][:num_obj, ...]
    orig_width = val['width'].numpy()
    orig_height = val['height'].numpy()


    if verbose:
        print('image_id:', image_id, 'num_obj:', num_obj)
        print('image:', image.shape, 'bbox:', bbox.shape, 
              'category:', category.shape, 'attribute:', attribute.shape)

    
    # display image
    fig = plt.figure()
    currentAxis = plt.gca()
    imgplot = plt.imshow(image, aspect=orig_height/orig_width)

    # get image data
    image_height = image.shape[-3]
    image_width = image.shape[-2]
    boxes = bbox.numpy().tolist()
    categories = category.numpy().tolist()
    attributes = attribute.numpy().tolist()

    # add boxes to image
    for i in range(len(boxes)):
        category = categories[i]
        
        # skip paddings
        if category == [b'<PAD>']:
            continue  
        attribute = attributes[i]
        box = boxes[i]

        # report info
        print(f'box {i}', 'category:', category)
        print('attribute:', attribute)

        # update box values for matplotlib
        xmin, ymin, xheight, yheight = box
        xmin = xmin * image_width
        width = xheight * image_width
        ymin = ymin * image_height
        height = yheight * image_height

        currentAxis.add_patch(Rectangle((xmin, ymin),width, height,
                                alpha=1, fill=False, 
                                label=category,
                                color = np.random.random(3)))
    fig.legend(loc='lower right', ncol=4)
    plt.show()

    return plt

In [None]:
#"""
# training image
for val in ds_train.take(1):
    show_example(val, verbose=True)
#"""

In [None]:
#detection_model = model.DETR(**params, attribute_weight=0.0)  # attributes not provided in COCO
detection_model({'image':tf.zeros([5,350,650,3])})

(<tf.Tensor: shape=(5, 96, 1), dtype=string, numpy=
 array([[[b'<PAD>'],
         [b'<PAD>'],
         [b'<PAD>'],
         [b'<PAD>'],
         [b'sandwich'],
         [b'sandwich'],
         [b'sandwich'],
         [b'vase'],
         [b'<PAD>'],
         [b'parking meter'],
         [b'refrigerator'],
         [b'<PAD>'],
         [b'hot dog'],
         [b'hot dog'],
         [b'hot dog'],
         [b'orange'],
         [b'hot dog'],
         [b'hot dog'],
         [b'orange'],
         [b'orange'],
         [b'hot dog'],
         [b'bottle'],
         [b'hot dog'],
         [b'hot dog'],
         [b'<PAD>'],
         [b'motorcycle'],
         [b'sandwich'],
         [b'hot dog'],
         [b'<PAD>'],
         [b'motorcycle'],
         [b'wine glass'],
         [b'hot dog'],
         [b'spoon'],
         [b'<PAD>'],
         [b'spoon'],
         [b'motorcycle'],
         [b'bird'],
         [b'toothbrush'],
         [b'toothbrush'],
         [b'bird'],
         [b'bird'],
         [

### Load Detection & Classification Models

In [None]:
LOAD_CLASS_WEIGHTS = False
# if False, models uses the most recent detection checkpoint weights

"""  # Distributed computing code is commented out
with STRATEGY.scope():
"""
# DETECTION MODEL
# load base model and build
detection_model = model.DETR(**params, attribute_weight=0.0)  # attributes not provided in COCO

# build
for val in ds_valid.batch(3).take(1):
    out_detect_0 = detection_model(val)
    out_detect_1 = detection_model(val, training=True)

# learning rate
lr = tf.keras.optimizers.schedules.CosineDecayRestarts(
        initial_learning_rate=.001, first_decay_steps=4000, m_mul=.95, alpha=0.1)

#optimizer_detect = tfa.optimizers.AdamW(learning_rate=.0001, weight_decay=.001, clipnorm=0.1)  # NOTE: suspect that using this with mixed precision causes NaNs from underflow/overflow
optimizer_detect = tf.keras.optimizers.SGD(learning_rate=lr, momentum=.9, 
                                           nesterov=True, clipnorm=0.1)


# compile
detection_model.compile(optimizer=optimizer_detect)  # loss functions are built in

# load weights
detect_checkpoint_filename = tf.train.latest_checkpoint(DETECTION_CHECKPOINT_DIR)
detection_model.load_weights(detect_checkpoint_filename)

# CLASSIFICATION
# initialize
classification_model = model_pretrainer.DETR_MultiClassifier(base_model=detection_model, 
                                                    vocab_dict=model_parameters.vocab_dict('COCO'),
                                                    hidden_dim=128, 
                                                    name='COCO_Classifier_DETR')
# build
for val in ds_valid.batch(3).take(1):
    out_class_0 = classification_model(val)

# compile
lr = learning_rate_schedulers.LRScheduleAIAYN(100.0)
optimizer_class = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=.001)

classification_model.compile(optimizer=optimizer_class)

# load weights
if LOAD_CLASS_WEIGHTS:
    class_checkpoint_filename = tf.train.latest_checkpoint(CLASS_CHECKPOINT_DIR)
    classification_model.load_weights(class_checkpoint_filename)

In [None]:
"""
# examine
classification_model.summary()
"""

In [None]:
# examine
detection_model.summary()

# Training

### Train Classifier Model

In [None]:
"""
# train classifier
NUM_EPOCHS = 50
BATCH_SIZE = 16

ds = ds_train_augmented.batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)
validation_data = ds_valid.batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE) 

classification_model.fit(ds,
                         epochs=NUM_EPOCHS, 
                         #validation_data=validation_data,
                         callbacks=[class_checkpoint, tf.keras.callbacks.TerminateOnNaN()],
                         #steps_per_epoch=2
                         )
"""

In [None]:
"""
save_weights_to_base = True

# save updated weights into base model
if save_weights_to_base:
    detection_model.save_weights(detection_checkpoint_path).expect_partial()
"""

### Train Detection Model

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs

In [None]:
## Train Detection Model
NUM_EPOCHS = 300
BATCH_SIZE = 8

ds = ds_train_augmented.batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)
#ds = ds_train.batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)
validation_data = ds_valid.batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE) 

detection_model.fit(ds,
                    epochs=NUM_EPOCHS, 
                    validation_data=validation_data,
                    callbacks=[detection_checkpoint, tf.keras.callbacks.TerminateOnNaN(),
                               tf.keras.callbacks.TensorBoard()],
                    )

In [None]:
for val in ds_train.batch(2).take(1):
    print(detection_model(val, training=False))


In [None]:
def show_prediction(val, model):

    print('True Values:')
    print('image_id:', val['image_id'].shape)
    print('image:', val['image'].shape, 'bbox:', val['bbox'].shape, 
          'category:', val['category'].shape, 'attribute:', val['attribute'].shape) 
    
    category, attribute, box_coord_preds = model(val)

    # update values
    val['image_id'] = val['image_id'][0, ...]
    val['image'] = val['image'][0, ...]
    val['bbox'] = box_coord_preds[0, ...]
    val['category'] = category[0, ...]
    val['attribute'] = attribute[0, ...]
    val['num_objects'] = tf.constant(200)

    outs = show_example(val, verbose=True)
    return outs

In [None]:
#"""
for val in ds_train.batch(1).take(1):
    show_prediction(val, detection_model)
#"""