In [1]:
import argparse
from datetime import date
import os
import sys
import tensorflow as tf

In [2]:
from tensorflow import keras
import tensorflow.keras.backend as K
from tensorflow.keras.optimizers import Adam, SGD

from augmentor.color import VisualEffect
from augmentor.misc import MiscEffect
from model import efficientdet
from losses import smooth_l1, focal, smooth_l1_quad
from efficientnet import BASE_WEIGHTS_PATH, WEIGHTS_HASHES

In [3]:
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[1], 'GPU')
tf.config.experimental.set_memory_growth(device=gpus[1], enable=True)

In [4]:
def makedirs(path):
    # Intended behavior: try to create the directory,
    # pass if the directory exists already, fails otherwise.
    # Meant for Python 2.7/3.n compatibility.
    try:
        os.makedirs(path)
    except OSError:
        if not os.path.isdir(path):
            raise

In [5]:
def get_session():
    """
    Construct a modified tf session.
    """
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    return tf.Session(config=config)

In [6]:
def create_callbacks(training_model, prediction_model, validation_generator, args):
    """
    Creates the callbacks to use during training.

    Args
        training_model: The model that is used for training.
        prediction_model: The model that should be used for validation.
        validation_generator: The generator for creating validation data.
        args: parseargs args object.

    Returns:
        A list of callbacks used for training.
    """
    callbacks = []

    tensorboard_callback = None

    if args.tensorboard_dir:
        if tf.version.VERSION > '2.0.0':
            file_writer = tf.summary.create_file_writer(args.tensorboard_dir)
            file_writer.set_as_default()
        tensorboard_callback = keras.callbacks.TensorBoard(
            log_dir=args.tensorboard_dir,
            histogram_freq=0,
            batch_size=args.batch_size,
            write_graph=True,
            write_grads=False,
            write_images=False,
            embeddings_freq=0,
            embeddings_layer_names=None,
            embeddings_metadata=None
        )
        callbacks.append(tensorboard_callback)

    if args.evaluation and validation_generator:
        if args.dataset_type == 'coco':
            from eval.coco import Evaluate
            # use prediction model for evaluation
            evaluation = Evaluate(validation_generator, prediction_model, tensorboard=tensorboard_callback)
        else:
            from eval.pascal import Evaluate
            evaluation = Evaluate(validation_generator, prediction_model, tensorboard=tensorboard_callback)
        callbacks.append(evaluation)

    # save the model
    if args.snapshots:
        # ensure directory created first; otherwise h5py will error after epoch.
        makedirs(args.snapshot_path)
        checkpoint = keras.callbacks.ModelCheckpoint(
            os.path.join(
                args.snapshot_path,
                f'{args.dataset_type}_{{epoch:02d}}_{{loss:.4f}}_{{val_loss:.4f}}.h5' if args.compute_val_loss
                else f'{args.dataset_type}_{{epoch:02d}}_{{loss:.4f}}.h5'
            ),
            verbose=1,
            save_weights_only=True,
            # save_best_only=True,
            # monitor="mAP",
            # mode='max'
        )
        callbacks.append(checkpoint)

    # callbacks.append(keras.callbacks.ReduceLROnPlateau(
    #     monitor='loss',
    #     factor=0.1,
    #     patience=2,
    #     verbose=1,
    #     mode='auto',
    #     min_delta=0.0001,
    #     cooldown=0,
    #     min_lr=0
    # ))

    return callbacks


In [7]:
def create_generators(args):
    """
    Create generators for training and validation.

    Args
        args: parseargs object containing configuration for generators.
        preprocess_image: Function that preprocesses an image for the network.
    """
    common_args = {
        'batch_size': args.batch_size,
        'phi': args.phi,
        'detect_text': args.detect_text,
        'detect_quadrangle': args.detect_quadrangle
    }

    # create random transform generator for augmenting training data
    if args.random_transform:
        misc_effect = MiscEffect()
        visual_effect = VisualEffect()
    else:
        misc_effect = None
        visual_effect = None

    if args.dataset_type == 'pascal':
        from generators.pascal import PascalVocGenerator
        train_generator = PascalVocGenerator(
            args.pascal_path,
            'train',
            classes = {'pack':0},
            skip_difficult=True,
            misc_effect=misc_effect,
            visual_effect=visual_effect,
            **common_args
        )

        validation_generator = PascalVocGenerator(
            args.pascal_path,
            'val',
            classes = {'pack':0},
            skip_difficult=True,
            shuffle_groups=False,
            **common_args
        )
    elif args.dataset_type == 'csv':
        from generators.csv_ import CSVGenerator
        train_generator = CSVGenerator(
            args.annotations_path,
            args.classes_path,
            misc_effect=misc_effect,
            visual_effect=visual_effect,
            **common_args
        )

        if args.val_annotations_path:
            validation_generator = CSVGenerator(
                args.val_annotations_path,
                args.classes_path,
                shuffle_groups=False,
                **common_args
            )
        else:
            validation_generator = None
    elif args.dataset_type == 'coco':
        # import here to prevent unnecessary dependency on cocoapi
        from generators.coco import CocoGenerator
        train_generator = CocoGenerator(
            args.coco_path,
            'train2017',
            misc_effect=misc_effect,
            visual_effect=visual_effect,
            group_method='random',
            **common_args
        )

        validation_generator = CocoGenerator(
            args.coco_path,
            'val2017',
            shuffle_groups=False,
            **common_args
        )
    else:
        raise ValueError('Invalid data type received: {}'.format(args.dataset_type))

    return train_generator, validation_generator

In [8]:
def check_args(parsed_args):
    """
    Function to check for inherent contradictions within parsed arguments.
    For example, batch_size < num_gpus
    Intended to raise errors prior to backend initialisation.

    Args
        parsed_args: parser.parse_args()

    Returns
        parsed_args
    """

    if parsed_args.gpu and parsed_args.batch_size < len(parsed_args.gpu.split(',')):
        raise ValueError(
            "Batch size ({}) must be equal to or higher than the number of GPUs ({})".format(parsed_args.batch_size,
                                                                                             len(parsed_args.gpu.split(
                                                                                                 ','))))

    return parsed_args


In [9]:
def parse_args(args):
    """
    Parse the arguments.
    """
    today = str(date.today())
    parser = argparse.ArgumentParser(description='Simple training script for training a RetinaNet network.')
    subparsers = parser.add_subparsers(help='Arguments for specific dataset types.', dest='dataset_type')
    subparsers.required = True

    coco_parser = subparsers.add_parser('coco')
    coco_parser.add_argument('coco_path', help='Path to dataset directory (ie. /tmp/COCO).')

    pascal_parser = subparsers.add_parser('pascal')
    pascal_parser.add_argument('pascal_path', help='Path to dataset directory (ie. /tmp/VOCdevkit).')

    csv_parser = subparsers.add_parser('csv')
    csv_parser.add_argument('annotations_path', help='Path to CSV file containing annotations for training.')
    csv_parser.add_argument('classes_path', help='Path to a CSV file containing class label mapping.')
    csv_parser.add_argument('--val-annotations-path',
                            help='Path to CSV file containing annotations for validation (optional).')
    parser.add_argument('--detect-quadrangle', help='If to detect quadrangle.', action='store_true', default=False)
    parser.add_argument('--detect-text', help='If is text detection task.', action='store_true', default=False)

    parser.add_argument('--snapshot', help='Resume training from a snapshot.')
    parser.add_argument('--freeze-backbone', help='Freeze training of backbone layers.', action='store_true')
    parser.add_argument('--freeze-bn', help='Freeze training of BatchNormalization layers.', action='store_true')
    parser.add_argument('--weighted-bifpn', help='Use weighted BiFPN', action='store_true')

    parser.add_argument('--batch-size', help='Size of the batches.', default=1, type=int)
    parser.add_argument('--phi', help='Hyper parameter phi', default=0, type=int, choices=(0, 1, 2, 3, 4, 5, 6))
    parser.add_argument('--gpu', help='Id of the GPU to use (as reported by nvidia-smi).')
    parser.add_argument('--epochs', help='Number of epochs to train.', type=int, default=50)
    parser.add_argument('--steps', help='Number of steps per epoch.', type=int, default=10000)
    parser.add_argument('--snapshot-path',
                        help='Path to store snapshots of models during training',
                        default='checkpoints/{}'.format(today))
    parser.add_argument('--tensorboard-dir', help='Log directory for Tensorboard output',
                        default='logs/{}'.format(today))
    parser.add_argument('--no-snapshots', help='Disable saving snapshots.', dest='snapshots', action='store_false')
    parser.add_argument('--no-evaluation', help='Disable per epoch evaluation.', dest='evaluation',
                        action='store_false')
    parser.add_argument('--random-transform', help='Randomly transform image and annotations.', action='store_true')
    parser.add_argument('--compute-val-loss', help='Compute validation loss during training', dest='compute_val_loss',
                        action='store_true')

    # Fit generator arguments
    parser.add_argument('--multiprocessing', help='Use multiprocessing in fit_generator.', action='store_true')
    parser.add_argument('--workers', help='Number of generator workers.', type=int, default=1)
    parser.add_argument('--max-queue-size', help='Queue length for multiprocessing workers in fit_generator.', type=int,
                        default=10)
    print(vars(parser.parse_args(args)))
    return check_args(parser.parse_args(args))


In [10]:
def main(args=None):
    # parse arguments
    if args is None:
        args = sys.argv[1:]
    args = parse_args(args)

    # create the generators
    train_generator, validation_generator = create_generators(args)

    num_classes = train_generator.num_classes()
    num_anchors = train_generator.num_anchors

    # optionally choose specific GPU
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    # K.set_session(get_session())

    model, prediction_model = efficientdet(args.phi,
                                           num_classes=num_classes,
                                           num_anchors=num_anchors,
                                           weighted_bifpn=args.weighted_bifpn,
                                           freeze_bn=args.freeze_bn,
                                           detect_quadrangle=args.detect_quadrangle
                                           )
    # load pretrained weights
    if args.snapshot:
        if args.snapshot == 'imagenet':
            model_name = 'efficientnet-b{}'.format(args.phi)
            file_name = '{}_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5'.format(model_name)
            file_hash = WEIGHTS_HASHES[model_name][1]
            weights_path = keras.utils.get_file(file_name,
                                                BASE_WEIGHTS_PATH + file_name,
                                                cache_subdir='models',
                                                file_hash=file_hash)
            model.load_weights(weights_path, by_name=True)
        else:
            print('Loading model, this may take a second...')
            model.load_weights(args.snapshot, by_name=True)

    # freeze backbone layers
    if args.freeze_backbone:
        # 227, 329, 329, 374, 464, 566, 656
        for i in range(1, [227, 329, 329, 374, 464, 566, 656][args.phi]):
            model.layers[i].trainable = False

    if args.gpu and len(args.gpu.split(',')) > 1:
        model = keras.utils.multi_gpu_model(model, gpus=list(map(int, args.gpu.split(','))))

    # compile model
    model.compile(optimizer=Adam(lr=1e-3), loss={
        'regression': smooth_l1_quad() if args.detect_quadrangle else smooth_l1(),
        'classification': focal()
    }, )

    # print(model.summary())

    # create the callbacks
    callbacks = create_callbacks(
        model,
        prediction_model,
        validation_generator,
        args,
    )

    if not args.compute_val_loss:
        validation_generator = None
    elif args.compute_val_loss and validation_generator is None:
        raise ValueError('When you have no validation data, you should not specify --compute-val-loss.')

    # start training
    return model.fit(
        x=train_generator,
        steps_per_epoch=args.steps,
        initial_epoch=0,
        epochs=args.epochs,
        verbose=1,
        callbacks=callbacks,
        workers=args.workers,
        use_multiprocessing=args.multiprocessing,
        max_queue_size=args.max_queue_size,
        validation_data=validation_generator
    )


In [11]:
args = ['--snapshot','imagenet', '--phi','5','--gpu','1','--random-transform','--compute-val-loss','--freeze-backbone','--epochs','20','--steps','500','pascal','/ai/data/VOC2012']
main(args)

{'dataset_type': 'pascal', 'detect_quadrangle': False, 'detect_text': False, 'snapshot': 'imagenet', 'freeze_backbone': True, 'freeze_bn': False, 'weighted_bifpn': False, 'batch_size': 1, 'phi': 5, 'gpu': '1', 'epochs': 20, 'steps': 500, 'snapshot_path': 'checkpoints/2020-08-04', 'tensorboard_dir': 'logs/2020-08-04', 'snapshots': True, 'evaluation': True, 'random_transform': True, 'compute_val_loss': True, 'multiprocessing': False, 'workers': 1, 'max_queue_size': 10, 'pascal_path': '/ai/data/VOC2012'}
Epoch 1/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:46 Time:  0:00:46
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=9791.0, num_tp=109.0
7612 instances of class pack with average precision: 0.0003
mAP: 0.0003

Epoch 00001: saving model to checkpoints/2020-08-04/pascal_01_0.5363_1.5162.h5
Epoch 2/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:38 Time:  0:00:38
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=9789.0, num_tp=111.0
7612 instances of class pack with average precision: 0.0003
mAP: 0.0003

Epoch 00002: saving model to checkpoints/2020-08-04/pascal_02_0.2618_1.5086.h5
Epoch 3/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:38 Time:  0:00:38
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=3380.0, num_tp=6520.0
7612 instances of class pack with average precision: 0.8452
mAP: 0.8452

Epoch 00003: saving model to checkpoints/2020-08-04/pascal_03_0.2225_0.3804.h5
Epoch 4/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:38 Time:  0:00:38
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=3081.0, num_tp=6819.0
7612 instances of class pack with average precision: 0.8840
mAP: 0.8840

Epoch 00004: saving model to checkpoints/2020-08-04/pascal_04_0.1997_0.2577.h5
Epoch 5/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:38 Time:  0:00:38
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=3063.0, num_tp=6837.0
7612 instances of class pack with average precision: 0.8896
mAP: 0.8896

Epoch 00005: saving model to checkpoints/2020-08-04/pascal_05_0.1677_0.2143.h5
Epoch 6/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:37 Time:  0:00:37
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=3046.0, num_tp=6854.0
7612 instances of class pack with average precision: 0.8896
mAP: 0.8896

Epoch 00006: saving model to checkpoints/2020-08-04/pascal_06_0.1651_0.2407.h5
Epoch 7/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:37 Time:  0:00:37
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=3068.0, num_tp=6832.0
7612 instances of class pack with average precision: 0.8874
mAP: 0.8874

Epoch 00007: saving model to checkpoints/2020-08-04/pascal_07_0.1552_0.1959.h5
Epoch 8/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:37 Time:  0:00:37
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=3062.0, num_tp=6838.0
7612 instances of class pack with average precision: 0.8904
mAP: 0.8904

Epoch 00008: saving model to checkpoints/2020-08-04/pascal_08_0.1445_0.2214.h5
Epoch 9/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:37 Time:  0:00:37
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=3103.0, num_tp=6797.0
7612 instances of class pack with average precision: 0.8843
mAP: 0.8843

Epoch 00009: saving model to checkpoints/2020-08-04/pascal_09_0.1414_0.2113.h5
Epoch 10/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:38 Time:  0:00:38
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=3056.0, num_tp=6844.0
7612 instances of class pack with average precision: 0.8908
mAP: 0.8908

Epoch 00010: saving model to checkpoints/2020-08-04/pascal_10_0.1335_0.1892.h5
Epoch 11/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:37 Time:  0:00:37
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=3042.0, num_tp=6858.0
7612 instances of class pack with average precision: 0.8933
mAP: 0.8933

Epoch 00011: saving model to checkpoints/2020-08-04/pascal_11_0.1312_0.1705.h5
Epoch 12/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:37 Time:  0:00:37
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=3035.0, num_tp=6865.0
7612 instances of class pack with average precision: 0.8944
mAP: 0.8944

Epoch 00012: saving model to checkpoints/2020-08-04/pascal_12_0.1253_0.1971.h5
Epoch 13/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:37 Time:  0:00:37
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=3006.0, num_tp=6894.0
7612 instances of class pack with average precision: 0.8990
mAP: 0.8990

Epoch 00013: saving model to checkpoints/2020-08-04/pascal_13_0.1190_0.1902.h5
Epoch 14/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:38 Time:  0:00:38
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=3010.0, num_tp=6890.0
7612 instances of class pack with average precision: 0.8980
mAP: 0.8980

Epoch 00014: saving model to checkpoints/2020-08-04/pascal_14_0.1141_0.1893.h5
Epoch 15/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:38 Time:  0:00:38
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=3009.0, num_tp=6891.0
7612 instances of class pack with average precision: 0.8977
mAP: 0.8977

Epoch 00015: saving model to checkpoints/2020-08-04/pascal_15_0.1152_0.1799.h5
Epoch 16/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:38 Time:  0:00:38
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=3011.0, num_tp=6889.0
7612 instances of class pack with average precision: 0.8983
mAP: 0.8983

Epoch 00016: saving model to checkpoints/2020-08-04/pascal_16_0.1084_0.1674.h5
Epoch 17/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:37 Time:  0:00:37
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=3022.0, num_tp=6878.0
7612 instances of class pack with average precision: 0.8961
mAP: 0.8961

Epoch 00017: saving model to checkpoints/2020-08-04/pascal_17_0.1099_0.1729.h5
Epoch 18/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:37 Time:  0:00:37
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=2999.0, num_tp=6901.0
7612 instances of class pack with average precision: 0.9003
mAP: 0.9003

Epoch 00018: saving model to checkpoints/2020-08-04/pascal_18_0.1062_0.1540.h5
Epoch 19/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:37 Time:  0:00:37
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=3059.0, num_tp=6841.0
7612 instances of class pack with average precision: 0.8893
mAP: 0.8893

Epoch 00019: saving model to checkpoints/2020-08-04/pascal_19_0.1007_0.1659.h5
Epoch 20/20

Running network: 100% (99 of 99) |#######| Elapsed Time: 0:00:37 Time:  0:00:37
Parsing annotations: 100% (99 of 99) |###| Elapsed Time: 0:00:00 Time:  0:00:00


num_fp=2998.0, num_tp=6902.0
7612 instances of class pack with average precision: 0.9003
mAP: 0.9003

Epoch 00020: saving model to checkpoints/2020-08-04/pascal_20_0.0928_0.1898.h5


<tensorflow.python.keras.callbacks.History at 0x7fdf8c25eba8>