Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QuickNet model and flip_ratio metric do not work together #148

Closed
itayalfia opened this issue Apr 11, 2020 · 3 comments · Fixed by #149
Closed

QuickNet model and flip_ratio metric do not work together #148

itayalfia opened this issue Apr 11, 2020 · 3 comments · Fixed by #149
Assignees
Labels
bug Something isn't working

Comments

@itayalfia
Copy link

Describe the bug

When using a model that includes QuickNet with flip_ratio metric,
model creation fails because of mismatched dimensions -
Dimensions must be equal, but are 64 and 128 for 'Equal' (op: 'Equal') with input shapes: [3,3,64,64], [3,3,128,128].

My suspicion is that one quantizer is created and reused for the entire model, and flip_ratio looks at the same quantizer with inputs of different shapes and fails because of this.

To Reproduce

from functools import partial
from typing import Callable, Tuple

import larq
import tensorflow as tf
from tensorflow import keras
from larq_zoo.sota import QuickNet

INPUT_SHAPE = 32
CLASSES_NUM = 10

EPOCHS = 100
BATCH_SIZE = 128
LEARNING_RATE = 5e-3


def quicknet(input_shape: int, num_classes: int) -> keras.Model:
    quicknet_spatial_reduce_factor = 32
    global_pool_shape = input_shape / quicknet_spatial_reduce_factor

    quicknet_pretrained_base = QuickNet(input_shape=(input_shape, input_shape, 3), include_top=False)
    quicknet_pretrained_base.trainable = True

    return tf.keras.models.Sequential([
        quicknet_pretrained_base,

        keras.layers.AveragePooling2D(pool_size=(global_pool_shape, global_pool_shape)),
        keras.layers.Flatten(),
        keras.layers.Dense(num_classes, kernel_initializer="glorot_normal"),
        tf.keras.layers.Activation("softmax", dtype="float32")
    ])


def get_dataset(batch_size: int, preprocessing: Callable) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
    train, test = tf.keras.datasets.cifar10.load_data()

    train_dataset = (
        tf.data.Dataset.from_tensor_slices(train).cache()
            .shuffle(10 * batch_size, reshuffle_each_iteration=True)
            .map(partial(preprocessing, training=True))
            .batch(batch_size)
    )

    test_dataset = (
        tf.data.Dataset.from_tensor_slices(test).cache()
            .map(preprocessing)
            .batch(batch_size)
    )

    return train_dataset, test_dataset


def identity_preprocess(x, y, training=False):
    return x, y


if __name__ == "__main__":
    with larq.context.metrics_scope(['flip_ratio']):
        model = quicknet(INPUT_SHAPE, CLASSES_NUM)

    larq.models.summary(model)

    train_dataset, test_dataset = get_dataset(BATCH_SIZE, identity_preprocess)

    optimizer = keras.optimizers.Adam(LEARNING_RATE)
    loss = keras.losses.SparseCategoricalCrossentropy()

    model.compile(
        optimizer=optimizer, loss=loss, metrics=[keras.metrics.SparseCategoricalAccuracy()]
    )

    model.fit(
        train_dataset, epochs=EPOCHS, validation_data=test_dataset
    )

Expected behavior

I would have expected the example to run without problems, as happens when quicknet is replaced with a series of binary operations, but instead I get the following error:

Traceback (most recent call last):
  File "C:/Users/User/PycharmProjects/BNN-Playground/bug_replicate.py", line 59, in <module>
    model = quicknet(INPUT_SHAPE, CLASSES_NUM)
  File "C:/Users/User/PycharmProjects/BNN-Playground/bug_replicate.py", line 21, in quicknet
    quicknet_pretrained_base = QuickNet(input_shape=(input_shape, input_shape, 3), include_top=False)
  File "C:\Users\User\Anaconda3\lib\site-packages\larq_zoo\sota\quicknet.py", line 327, in QuickNet
    num_classes=num_classes,
  File "C:\Users\User\Anaconda3\lib\site-packages\zookeeper\core\factory.py", line 20, in wrapped_fn
    result = fn(factory_instance)
  File "C:\Users\User\Anaconda3\lib\site-packages\larq_zoo\sota\quicknet.py", line 189, in build
    model = super().build()
  File "C:\Users\User\Anaconda3\lib\site-packages\zookeeper\core\factory.py", line 20, in wrapped_fn
    result = fn(factory_instance)
  File "C:\Users\User\Anaconda3\lib\site-packages\larq_zoo\sota\quicknet.py", line 156, in build
    x = self.residual_block(x, use_squeeze_and_excite)
  File "C:\Users\User\Anaconda3\lib\site-packages\larq_zoo\sota\quicknet.py", line 104, in residual_block
    x = self.conv_block(x, infilters, use_squeeze_and_excite)
  File "C:\Users\User\Anaconda3\lib\site-packages\larq_zoo\sota\quicknet.py", line 92, in conv_block
    )(x)
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 773, in __call__
    outputs = call_fn(cast_inputs, *args, **kwargs)
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow_core\python\autograph\impl\api.py", line 237, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in converted code:

    C:\Users\User\Anaconda3\lib\site-packages\larq\layers_base.py:37 call  *
        return super().call(inputs)
    C:\Users\User\Anaconda3\lib\site-packages\larq\layers_base.py:153 call  *
        return super().call(inputs)
    C:\Users\User\Anaconda3\lib\site-packages\tensorflow_core\python\keras\layers\convolutional.py:209 call
        outputs = self._convolution_op(inputs, self.kernel)
    C:\Users\User\Anaconda3\lib\site-packages\tensorflow_core\python\ops\nn_ops.py:1135 __call__
        return self.conv_op(inp, filter)
    C:\Users\User\Anaconda3\lib\site-packages\tensorflow_core\python\ops\nn_ops.py:640 __call__
        return self.call(inp, filter)
    C:\Users\User\Anaconda3\lib\site-packages\tensorflow_core\python\ops\nn_ops.py:239 __call__
        name=self.name)
    C:\Users\User\Anaconda3\lib\site-packages\tensorflow_core\python\ops\nn_ops.py:2011 conv2d
        name=name)
    C:\Users\User\Anaconda3\lib\site-packages\tensorflow_core\python\ops\gen_nn_ops.py:969 conv2d
        data_format=data_format, dilations=dilations, name=name)
    C:\Users\User\Anaconda3\lib\site-packages\tensorflow_core\python\framework\op_def_library.py:486 _apply_op_helper
        (input_name, err))

    ValueError: Tried to convert 'filter' to a tensor and failed. Error: in converted code:
    
        C:\Users\User\Anaconda3\lib\site-packages\larq\quantizers.py:249 call  *
            return super().call(outputs)
        C:\Users\User\Anaconda3\lib\site-packages\larq\quantizers.py:160 call  *
            self.add_metric(self.flip_ratio(inputs))
        C:\Users\User\Anaconda3\lib\site-packages\larq\metrics.py:43 __call__  *
            return super().__call__(inputs, **kwargs)
        C:\Users\User\Anaconda3\lib\site-packages\tensorflow_core\python\keras\metrics.py:196 __call__
            replica_local_fn, *args, **kwargs)
        C:\Users\User\Anaconda3\lib\site-packages\larq\metrics.py:71 update_state  *
            unchanged_values = tf.math.count_nonzero(
        C:\Users\User\Anaconda3\lib\site-packages\tensorflow_core\python\util\dispatch.py:180 wrapper
            return target(*args, **kwargs)
        C:\Users\User\Anaconda3\lib\site-packages\tensorflow_core\python\ops\math_ops.py:1305 equal
            return gen_math_ops.equal(x, y, name=name)
        C:\Users\User\Anaconda3\lib\site-packages\tensorflow_core\python\ops\gen_math_ops.py:3240 equal
            name=name)
        C:\Users\User\Anaconda3\lib\site-packages\tensorflow_core\python\framework\op_def_library.py:742 _apply_op_helper
            attrs=attr_protos, op_def=op_def)
        C:\Users\User\Anaconda3\lib\site-packages\tensorflow_core\python\framework\func_graph.py:595 _create_op_internal
            compute_device)
        C:\Users\User\Anaconda3\lib\site-packages\tensorflow_core\python\framework\ops.py:3322 _create_op_internal
            op_def=op_def)
        C:\Users\User\Anaconda3\lib\site-packages\tensorflow_core\python\framework\ops.py:1786 __init__
            control_input_ops)
        C:\Users\User\Anaconda3\lib\site-packages\tensorflow_core\python\framework\ops.py:1622 _create_c_op
            raise ValueError(str(e))
    
        ValueError: Dimensions must be equal, but are 64 and 128 for 'Equal' (op: 'Equal') with input shapes: [3,3,64,64], [3,3,128,128].

Environment

TensorFlow version: 2.1.0
Larq version: 0.9.3
Larq-Zoo version: 1.0b4

@lgeiger
Copy link
Member

lgeiger commented Apr 11, 2020

Thanks for the detailed issue.

My suspicion is that one quantizer is created and reused for the entire model, and flip_ratio looks at the same quantizer with inputs of different shapes and fails because of this.

I noticed that in a different case too a few days ago. I think it is because we set the input quantizers like this so they are stored on the class which will fail since the metric has internal variables:

input_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.25))
kernel_quantizer = Field(lambda: lq.quantizers.SteSign(clip_value=1.25))

@lgeiger
Copy link
Member

lgeiger commented Apr 11, 2020

@AdamHillier do you think we should change zookeeper.Field to have the same semantics as @property if initialized with a lambda function or used as a decorator, so that we get a new instance on every access to prevent issues like this? I came accross a similar issue last week as well when trying to use quantizers with learnable parameters.

Another solution would be to change quantizers here to normal @property instead of having them as fields, so that they are not cached on the class instance.

@itayalfia We'll take a closer look at this next week, since this is something we've come accross before and we should fix in a proper way. For now you can use #149, which should provide an intermediate fix for the issue.

@AdamHillier
Copy link
Contributor

@AdamHillier do you think we should change zookeeper.Field to have the same semantics as @property if initialized with a lambda function or used as a decorator, so that we get a new instance on every access to prevent issues like this? I came accross a similar issue last week as well when trying to use quantizers with learnable parameters.

I've made a Zookeeper issue for this here: larq/zookeeper#134.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants