diff --git a/azure-pipelines.yml b/azure-pipelines.yml new file mode 100644 index 00000000..60404a33 --- /dev/null +++ b/azure-pipelines.yml @@ -0,0 +1,51 @@ +trigger: + - master + +pr: + - master + +jobs: + - job: "Test" + pool: + vmImage: "Ubuntu-16.04" + strategy: + matrix: + Python37: + python.version: "3.7" + tensorflow.version: "1.13.1" + Python37TF2: + python.version: "3.7" + tensorflow.version: "2.0.0-alpha0" + coverage: "true" + + steps: + - task: UsePythonVersion@0 + inputs: + versionSpec: "$(python.version)" + architecture: "x64" + + - script: | + pip install tensorflow==$(tensorflow.version) + pip install -e .[test] + displayName: "Install dependencies" + + - script: pytest . --junitxml=junit/test-results.xml + displayName: "pytest" + condition: ne(variables['coverage'], 'true') + + - script: pytest . --junitxml=junit/test-results.xml --cov=larq_flock --cov-report=xml --cov-report=html --cov-config=.coveragerc + displayName: "pytest coverage" + condition: eq(variables['coverage'], 'true') + + - task: PublishTestResults@2 + condition: succeededOrFailed() + inputs: + testResultsFiles: "**/test-*.xml" + testRunTitle: "Publish test results for Python $(python.version) and TF $(tensorflow.version)" + + - task: PublishCodeCoverageResults@1 + condition: eq(variables['coverage'], 'true') + inputs: + codeCoverageTool: Cobertura + summaryFileLocation: "$(System.DefaultWorkingDirectory)/**/coverage.xml" + reportDirectory: "$(System.DefaultWorkingDirectory)/**/htmlcov" diff --git a/larq_zoo/__init__.py b/larq_zoo/__init__.py index d9966a63..a6705232 100644 --- a/larq_zoo/__init__.py +++ b/larq_zoo/__init__.py @@ -1,4 +1,4 @@ -from larq_zoo.binarynet import BinaryNet +from larq_zoo.binarynet import BinaryAlexNet from larq_zoo.data import default as preprocess_imagenet -__all__ = ["BinaryNet", "preprocess_imagenet"] +__all__ = ["BinaryAlexNet", "preprocess_imagenet"] diff --git a/larq_zoo/binarynet.py b/larq_zoo/binarynet.py index c14d84f2..c6d79969 100644 --- a/larq_zoo/binarynet.py +++ b/larq_zoo/binarynet.py @@ -1,15 +1,145 @@ from larq_flock import registry, HParams +import larq as lq +import tensorflow as tf +from larq_zoo import utils @registry.register_model -def binarynet(hparams, dataset): - pass +def binary_alex_net(hparams, dataset, input_tensor=None, include_top=True): + kwargs = dict( + input_quantizer="ste_sign", + kernel_quantizer="ste_sign", + kernel_constraint="weight_clip", + use_bias=False, + ) + img_input = utils.get_input_layer(dataset.input_shape, input_tensor) + x = lq.layers.QuantConv2D( + hparams.filters, + 11, + strides=4, + padding="same", + kernel_quantizer="ste_sign", + kernel_constraint="weight_clip", + use_bias=False, + )(img_input) + x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2)(x) + x = tf.keras.layers.BatchNormalization(scale=False)(x) -@registry.register_hparams(binarynet) + x = lq.layers.QuantConv2D(hparams.filters * 3, 5, padding="same", **kwargs)(x) + x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2)(x) + x = tf.keras.layers.BatchNormalization(scale=False)(x) + + x = lq.layers.QuantConv2D(6 * hparams.filters, 3, padding="same", **kwargs)(x) + x = tf.keras.layers.BatchNormalization(scale=False)(x) + + x = lq.layers.QuantConv2D(4 * hparams.filters, 3, padding="same", **kwargs)(x) + x = tf.keras.layers.BatchNormalization(scale=False)(x) + + x = lq.layers.QuantConv2D(4 * hparams.filters, 3, padding="same", **kwargs)(x) + x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2)(x) + x = tf.keras.layers.BatchNormalization(scale=False)(x) + + if include_top: + x = tf.keras.layers.Flatten()(x) + x = lq.layers.QuantDense(hparams.dense_units, **kwargs)(x) + x = tf.keras.layers.BatchNormalization(scale=False)(x) + x = lq.layers.QuantDense(hparams.dense_units, **kwargs)(x) + x = tf.keras.layers.BatchNormalization(scale=False)(x) + x = lq.layers.QuantDense(dataset.num_classes, **kwargs)(x) + x = tf.keras.layers.BatchNormalization(scale=False)(x) + x = tf.keras.layers.Activation("softmax", name="predictions")(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = tf.keras.utils.get_source_inputs(input_tensor) + else: + inputs = img_input + + return tf.keras.models.Model(inputs, x, name="binary_alex_net") + + +@registry.register_hparams(binary_alex_net) def default(): - return HParams() + def lr_schedule(epoch): + if epoch < 20: + return 5e-3 + elif epoch < 30: + return 1e-3 + elif epoch < 35: + return 5e-4 + elif epoch < 40: + return 1e-4 + else: + return 1e-5 + + return HParams( + optimizer=tf.keras.optimizers.Adam(5e-3), + learning_rate_schedule=lr_schedule, + batch_size=256, + filters=64, + dense_units=4096, + ) + + +def BinaryAlexNet( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + classes=1000, +): + """Instantiates the BinaryAlexNet architecture. + + Optionally loads weights pre-trained on ImageNet. + + # Arguments + include_top: whether to include the fully-connected layer at the top of the network. + weights: one of `None` (random initialization), "imagenet" (pre-training on + ImageNet), or the path to the weights file to be loaded. + input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to use as + image input for the model. + input_shape: optional shape tuple, only to be specified if `include_top` is False + (otherwise the input shape has to be `(224, 224, 3)` (with `channels_last` data + format) or `(3, 224, 224)` (with `channels_first` data format). + It should have exactly 3 inputs channels. + classes: optional number of classes to classify images into, only to be specified + if `include_top` is True, and if no `weights` argument is specified. + + # Returns + A Keras model instance. + + # Raises + ValueError: in case of invalid argument for `weights`, or invalid input shape. + """ + input_shape = utils.validate_input(input_shape, weights, include_top, classes) + model = binary_alex_net( + default(), + utils.ImagenetDataset(input_shape, classes), + input_tensor=input_tensor, + include_top=include_top, + ) -def BinaryNet(): - raise NotImplementedError() + # Load weights. + if weights == "imagenet": + raise NotImplementedError() + # if include_top: + # weights_path = tf.keras.utils.get_file( + # "vgg16_weights_tf_dim_ordering_tf_kernels.h5", + # WEIGHTS_PATH, + # cache_subdir="models", + # file_hash="64373286793e3c8b2b4e3219cbf3544b", + # ) + # else: + # weights_path = tf.keras.utils.get_file( + # "vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5", + # WEIGHTS_PATH_NO_TOP, + # cache_subdir="models", + # file_hash="6d6bbae143d832006294945121d1f1fc", + # ) + # model.load_weights(weights_path) + elif weights is not None: + model.load_weights(weights) + return model diff --git a/larq_zoo/data.py b/larq_zoo/data.py index 72018e5c..3e6c5087 100644 --- a/larq_zoo/data.py +++ b/larq_zoo/data.py @@ -1,6 +1,236 @@ +"""Provides utilities to preprocess images. + +The preprocessing steps for VGG were introduced in the following technical +report: + + Very Deep Convolutional Networks For Large-Scale Image Recognition + Karen Simonyan and Andrew Zisserman + arXiv technical report, 2015 + PDF: http://arxiv.org/pdf/1409.1556.pdf + ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf + CC-BY-4.0 + +More information can be obtained from the VGG website: +www.robots.ox.ac.uk/~vgg/research/very_deep/ +""" + +import tensorflow as tf from larq_flock import registry +_R_MEAN = 123.68 +_G_MEAN = 116.78 +_B_MEAN = 103.94 + +_R_STD = 0.229 * 255 +_G_STD = 0.224 * 255 +_B_STD = 0.225 * 255 + +_RESIZE_SIDE_MIN = 256 +_RESIZE_SIDE_MAX = 512 + + +@registry.register_preprocess("imagenet2012", (224, 224, 3)) +# @registry.register_preprocess("oxford_iiit_pet", (224, 224, 3)) +def default(image, training): + return preprocess_image( + image=image, output_height=224, output_width=224, is_training=training + ) + + +def _get_h_w(image): + """Convenience for grabbing the height and width of an image. + """ + shape = tf.shape(image) + return shape[0], shape[1] + + +def _random_crop_and_flip(image, crop_height, crop_width): + """Crops the given image to a random part of the image, and randomly flips. + + Args: + image: a 3-D image tensor + crop_height: the new height. + crop_width: the new width. + + Returns: + 3-D tensor with cropped image. + + """ + height, width = _get_h_w(image) + + # Create a random bounding box. + # + # Use tf.random.uniform and not numpy.random.rand as doing the former would + # generate random numbers at graph eval time, unlike the latter which + # generates random numbers at graph definition time. + total_crop_height = height - crop_height + crop_top = tf.random.uniform([], maxval=total_crop_height + 1, dtype=tf.int32) + total_crop_width = width - crop_width + crop_left = tf.random.uniform([], maxval=total_crop_width + 1, dtype=tf.int32) + + cropped = tf.slice(image, [crop_top, crop_left, 0], [crop_height, crop_width, -1]) + + cropped = tf.image.random_flip_left_right(cropped) + return cropped + + +def _central_crop(image, crop_height, crop_width): + """Performs central crops of the given image list. + + Args: + image: a 3-D image tensor + crop_height: the height of the image following the crop. + crop_width: the width of the image following the crop. + + Returns: + 3-D tensor with cropped image. + """ + height, width = _get_h_w(image) + + total_crop_height = height - crop_height + crop_top = total_crop_height // 2 + total_crop_width = width - crop_width + crop_left = total_crop_width // 2 + return tf.slice(image, [crop_top, crop_left, 0], [crop_height, crop_width, -1]) + + +def _mean_image_subtraction(image, means): + """Subtracts the given means from each image channel. + + For example: + means = [123.68, 116.779, 103.939] + image = _mean_image_subtraction(image, means) + + Note that the rank of `image` must be known. + + Args: + image: a tensor of size [height, width, C]. + means: a C-vector of values to subtract from each channel. + + Returns: + the centered image. + + Raises: + ValueError: If the rank of `image` is unknown, if `image` has a rank other + than three or if the number of channels in `image` doesn't match the + number of values in `means`. + """ + if image.get_shape().ndims != 3: + raise ValueError("Input must be of size [height, width, C>0]") + num_channels = image.get_shape().as_list()[-1] + if len(means) != num_channels: + raise ValueError("len(means) must match the number of channels") + + # We have a 1-D tensor of means; convert to 3-D. + means = tf.expand_dims(tf.expand_dims(means, 0), 0) + + return image - means + + +def _scale_normalization(image, stds): + # We have a 1-D tensor of means; convert to 3-D. + stds = tf.expand_dims(tf.expand_dims(stds, 0), 0) + + return image / stds + + +def _smallest_size_at_least(height, width, smallest_side): + """Computes new shape with the smallest side equal to `smallest_side`. + + Computes new shape with the smallest side equal to `smallest_side` while + preserving the original aspect ratio. + + Args: + height: an int32 scalar tensor indicating the current height. + width: an int32 scalar tensor indicating the current width. + smallest_side: A python integer or scalar `Tensor` indicating the size of + the smallest side after resize. + + Returns: + new_height: an int32 scalar tensor indicating the new height. + new_width: and int32 scalar tensor indicating the new width. + """ + smallest_side = tf.cast(smallest_side, tf.float32) + + height = tf.cast(height, tf.float32) + width = tf.cast(width, tf.float32) + + smaller_dim = tf.minimum(height, width) + scale_ratio = smallest_side / smaller_dim + new_height = tf.cast(height * scale_ratio, tf.int32) + new_width = tf.cast(width * scale_ratio, tf.int32) + + return new_height, new_width + + +def _aspect_preserving_resize(image, smallest_side): + """Resize images preserving the original aspect ratio. + + Args: + image: A 3-D image `Tensor`. + smallest_side: A python integer or scalar `Tensor` indicating the size of + the smallest side after resize. + + Returns: + resized_image: A 3-D tensor containing the resized image. + """ + smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32) + + height, width = _get_h_w(image) + new_height, new_width = _smallest_size_at_least(height, width, smallest_side) + + resized_image = tf.compat.v1.image.resize( + image, + [new_height, new_width], + method=tf.image.ResizeMethod.BILINEAR, + align_corners=False, + ) + return resized_image + + +def preprocess_image( + image, + output_height, + output_width, + is_training=False, + resize_side_min=_RESIZE_SIDE_MIN, + resize_side_max=_RESIZE_SIDE_MAX, +): + """Preprocesses the given image. + + Args: + image: A `Tensor` representing an image of arbitrary size. + output_height: The height of the image after preprocessing. + output_width: The width of the image after preprocessing. + is_training: `True` if we're preprocessing the image for training and + `False` otherwise. + resize_side_min: The lower bound for the smallest side of the image for + aspect-preserving resizing. If `is_training` is `False`, then this value + is used for rescaling. + resize_side_max: The upper bound for the smallest side of the image for + aspect-preserving resizing. If `is_training` is `False`, this value is + ignored. Otherwise, the resize side is sampled from + [resize_size_min, resize_size_max]. + + Returns: + A preprocessed image. + """ + if is_training: + # For training, we want to randomize some of the distortions. + resize_side = tf.random.uniform( + [], minval=resize_side_min, maxval=resize_side_max + 1, dtype=tf.int32 + ) + crop_fn = _random_crop_and_flip + else: + resize_side = resize_side_min + crop_fn = _central_crop + + num_channels = image.get_shape().as_list()[-1] + image = _aspect_preserving_resize(image, resize_side) + image = crop_fn(image, output_height, output_width) + + image.set_shape([output_height, output_width, num_channels]) -@registry.register_preprocess("imagenet2012") -def default(image): - return image + image = tf.cast(image, tf.float32) + image = _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN]) + return _scale_normalization(image, [_R_STD, _G_STD, _B_STD]) diff --git a/larq_zoo/train.py b/larq_zoo/train.py index 33684bfa..3e4c8535 100644 --- a/larq_zoo/train.py +++ b/larq_zoo/train.py @@ -1,10 +1,52 @@ from larq_flock import cli, build_train +from os import path +import click @cli.command() +@click.option("--tensorboard/--no-tensorboard", default=True) @build_train -def train(build_model, dataset, hparams, output_dir, epochs): - pass +def train(build_model, dataset, hparams, output_dir, epochs, tensorboard): + import larq as lq + from larq_zoo.utils import get_distribution_scope + import tensorflow as tf + + callbacks = [ + tf.keras.callbacks.ModelCheckpoint( + filepath=path.join(output_dir, "model.ckpt"), save_weights_only=True + ) + ] + if hasattr(hparams, "learning_rate_schedule"): + callbacks.append( + tf.keras.callbacks.LearningRateScheduler(hparams.learning_rate_schedule) + ) + if tensorboard: + callbacks.append( + tf.keras.callbacks.TensorBoard(log_dir=output_dir, profile_batch=0, write_graph=False) + ) + + with get_distribution_scope(hparams.batch_size): + model = build_model(hparams, dataset) + model.compile( + optimizer=hparams.optimizer, + loss="categorical_crossentropy", + metrics=["categorical_accuracy", "top_k_categorical_accuracy"], + ) + + lq.models.summary(model) + + model.fit( + dataset.train_data(hparams.batch_size), + epochs=epochs, + steps_per_epoch=dataset.train_examples // hparams.batch_size, + validation_data=dataset.validation_data(hparams.batch_size), + validation_steps=dataset.validation_examples // hparams.batch_size, + verbose=2 if tensorboard else 1, + callbacks=callbacks, + ) + + model.save(path.join(output_dir, f"{build_model.__name__}.h5")) + model.save_weights(path.join(output_dir, f"{build_model.__name__}_weights.h5")) if __name__ == "__main__": diff --git a/larq_zoo/utils.py b/larq_zoo/utils.py new file mode 100644 index 00000000..947b7e9f --- /dev/null +++ b/larq_zoo/utils.py @@ -0,0 +1,59 @@ +import os +import sys +import tensorflow as tf +import contextlib +from tensorflow.python.eager.context import num_gpus +from keras_applications.imagenet_utils import _obtain_input_shape +from collections import namedtuple + +ImagenetDataset = namedtuple("ImagenetDataset", ["input_shape", "num_classes"]) + + +def get_distribution_scope(batch_size): + if num_gpus() > 1: + strategy = tf.distribute.MirroredStrategy() + assert ( + batch_size % strategy.num_replicas_in_sync == 0 + ), f"Batch size {batch_size} cannot be divided onto {num_gpus()} GPUs" + distribution_scope = strategy.scope + else: + if sys.version_info >= (3, 7): + distribution_scope = contextlib.nullcontext + else: + distribution_scope = contextlib.suppress + + return distribution_scope() + + +def validate_input(input_shape, weights, include_top, classes): + if not (weights in {"imagenet", None} or os.path.exists(weights)): + raise ValueError( + "The `weights` argument should be either `None` (random initialization), " + "`imagenet` (pre-training on ImageNet), or the path to the weights file " + "to be loaded." + ) + + if weights == "imagenet" and include_top and classes != 1000: + raise ValueError( + "If using `weights` as `imagenet` with `include_top` as true, " + "`classes` should be 1000" + ) + + # Determine proper input shape + return _obtain_input_shape( + input_shape, + default_size=224, + min_size=64, + data_format=tf.keras.backend.image_data_format(), + require_flatten=include_top, + weights=weights, + ) + + +def get_input_layer(input_shape, input_tensor): + if input_tensor is None: + return tf.keras.layers.Input(shape=input_shape) + if not tf.keras.backend.is_keras_tensor(input_tensor): + return tf.keras.layers.Input(tensor=input_tensor, shape=input_shape) + return input_tensor + diff --git a/setup.py b/setup.py index 64368513..9f56c97b 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ def readme(): extras_require={ "tensorflow": ["tensorflow>=1.13.1"], "tensorflow_gpu": ["tensorflow-gpu>=1.13.1"], + "test": ["pytest>=4.3.1", "pytest-cov>=2.6.1"], }, entry_points=""" [console_scripts] diff --git a/tests/models_test.py b/tests/models_test.py new file mode 100644 index 00000000..533e88db --- /dev/null +++ b/tests/models_test.py @@ -0,0 +1,52 @@ +import pytest +import functools +import larq_zoo as lqz +from tensorflow.keras.backend import clear_session + + +def keras_test(func): + """Function wrapper to clean up after TensorFlow tests. + # Arguments + func: test function to clean up after. + # Returns + A function wrapping the input function. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + output = func(*args, **kwargs) + clear_session() + return output + + return wrapper + + +def parametrize(func): + func = keras_test(func) + return pytest.mark.parametrize("app,last_dim", [(lqz.BinaryAlexNet, 256)])(func) + + +@parametrize +def test_basic(app, last_dim): + model = app(weights=None) + assert model.output_shape == (None, 1000) + + +@parametrize +def test_no_top(app, last_dim): + model = app(weights=None, include_top=False) + assert model.output_shape == (None, None, None, last_dim) + + +@parametrize +def test_no_top_variable_shape_1(app, last_dim): + input_shape = (None, None, 1) + model = app(weights=None, include_top=False, input_shape=input_shape) + assert model.output_shape == (None, None, None, last_dim) + + +@parametrize +def test_no_top_variable_shape_4(app, last_dim): + input_shape = (None, None, 4) + model = app(weights=None, include_top=False, input_shape=input_shape) + assert model.output_shape == (None, None, None, last_dim)