Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QuickNet no-top models pretrained weights are not working as expected #157

Closed
itayalfia opened this issue Apr 24, 2020 · 1 comment · Fixed by #158
Closed

QuickNet no-top models pretrained weights are not working as expected #157

itayalfia opened this issue Apr 24, 2020 · 1 comment · Fixed by #158
Labels
bug Something isn't working

Comments

@itayalfia
Copy link

Describe the bug

Tried freezing the no-top quicknet models, and training a linear classifier on top of them, in order to classify images from the Imagenette dataset (10 easy classes from ImageNet).

Because the pretrained zoo models are trained on the superset of this dataset, I expected the pretrained embedders to perform very well, but they did not succeed in reaching above 50% accuracy.

However, when I manually cut the full models, the embedders work as expected and reach 95% easily, hinting the problem is with the no-top pretrained weights.

To Reproduce

Run the code below with the following configurations:
QuickNetBugTest cut_full_model=True
QuickNetBugTest cut_full_model=False
QuickNetLargeBugTest cut_full_model=True
QuickNetLargeBugTest cut_full_model=False
QuickNetXLBugTest cut_full_model=True
QuickNetXLBugTest cut_full_model=False

(cut_full_model param determines if the pretrained no-top model is used, or the pretrained full model is taken and cut before the global pooling. The models are trained for 3 epochs in the example but even if trained more the no-top model does not improve much.)

import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
from larq_zoo.training.data import preprocess_image_bytes
from larq_zoo.sota import QuickNet, QuickNetLarge, QuickNetXL
from zookeeper import cli, task, Field
from typing import Callable, Tuple, Optional


class EmbedderWrapperModel(keras.Model):
    def __init__(self, zoo_class: Callable[..., keras.Model],
                 input_shape: int, num_classes: int, dynamic=False,
                 finetune_basenet=True, pretrained_basenet=True, cut_layer_name: Optional[str] = None):
        super(EmbedderWrapperModel, self).__init__(dynamic=dynamic)

        self.basenet = self._get_basenet(zoo_class, input_shape, finetune_basenet, pretrained_basenet, cut_layer_name)
        global_pool_shape = self.basenet.output_shape[1], self.basenet.output_shape[2]

        self.batch_norm = keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)
        self.global_pool = keras.layers.AveragePooling2D(pool_size=global_pool_shape)
        self.dense_softmax = keras.layers.Dense(num_classes, activation=tf.nn.softmax)

    def _get_basenet(self, zoo_class: Callable[..., keras.Model], input_shape: int,
                     finetune_basenet: bool, pretrained_basenet: bool, cut_layer_name: Optional[str]) -> keras.Model:
        weights = "imagenet" if pretrained_basenet else None

        if not cut_layer_name:
            basenet = zoo_class(input_shape=(input_shape, input_shape, 3), include_top=False, weights=weights)
        else:
            full_zoo_model = zoo_class(input_shape=(input_shape, input_shape, 3), include_top=True, weights=weights)
            inputs, outputs = full_zoo_model.inputs, full_zoo_model.get_layer(cut_layer_name).output
            basenet = keras.Model(inputs=inputs, outputs=outputs)

        basenet.trainable = finetune_basenet

        return basenet

    def call(self, inputs, training=False, mask=None):
        x = self.basenet(inputs, training=training)
        x = self.batch_norm(x, training=training)
        x = self.global_pool(x)
        x = keras.layers.Flatten()(x)
        x = self.dense_softmax(x)

        return x


def wrap_preprocessing(preprocessing: Callable, training=False) -> Callable:
    return lambda x, y: (preprocessing(x, training), y)


def get_imagenette_dataset(batch_size: int, preprocessing: Callable,
                           parallel=True) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
    decoders = {"image": tfds.decode.SkipDecoding()}
    total_dataset = tfds.load('imagenette', split=None, shuffle_files=True, as_supervised=True, decoders=decoders)
    train, test = total_dataset['train'], total_dataset['validation']

    parallelism = tf.data.experimental.AUTOTUNE if parallel else None

    train_dataset = (
        train.cache()
            .shuffle(10 * batch_size, reshuffle_each_iteration=True)
            .map(wrap_preprocessing(preprocessing, training=True), num_parallel_calls=parallelism)
            .batch(batch_size)
    )

    test_dataset = (
        test.cache()
            .map(wrap_preprocessing(preprocessing), num_parallel_calls=parallelism)
            .batch(batch_size)
    )

    return train_dataset, test_dataset


class BugTest:
    INPUT_SHAPE = 224
    CLASSES_NUM = 10

    EPOCHS = 3
    BATCH_SIZE = 256
    LEARNING_RATE = 1e-2

    LR_DECAY = 0.1
    DECAY_EVERY = 30

    FINETUNE_BASENET = False
    PRETRAINED_BASENET = True

    def test_bug(self, cut_full_model: bool, model_class: Callable):
        # `activation` is last relu layer before global pooling
        cut_layer_name = "activation" if cut_full_model else None

        model = EmbedderWrapperModel(model_class, self.INPUT_SHAPE, self.CLASSES_NUM,
                                     finetune_basenet=self.FINETUNE_BASENET, pretrained_basenet=self.PRETRAINED_BASENET,
                                     cut_layer_name=cut_layer_name)

        train_dataset, test_dataset = get_imagenette_dataset(self.BATCH_SIZE, preprocess_image_bytes)

        optimizer = keras.optimizers.Adam(self.LEARNING_RATE)
        loss = keras.losses.SparseCategoricalCrossentropy()
        metrics = [keras.metrics.SparseCategoricalAccuracy()]

        model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

        model.fit(train_dataset, epochs=self.EPOCHS, validation_data=test_dataset)


@task
class QuickNetBugTest(BugTest):
    cut_full_model: bool = Field(False)

    def run(self):
        self.test_bug(self.cut_full_model, QuickNet)


@task
class QuickNetLargeBugTest(BugTest):
    cut_full_model: bool = Field(False)

    def run(self):
        self.test_bug(self.cut_full_model, QuickNetLarge)


@task
class QuickNetXLBugTest(BugTest):
    cut_full_model: bool = Field(False)

    def run(self):
        self.test_bug(self.cut_full_model, QuickNetXL)


if __name__ == "__main__":
    cli()

Expected behavior

Expected the pretrained no-top models and the cut pretrained full models to perform the same, instead got the following discrepancy:

Model no_top accuracy full_model_cut accuracy
QuickNet 47.5% 95.4%
QuickNetLarge 30.2% 96%
QuickNetXL 27.7% 97.7%

Environment

TensorFlow version: 2.2.0rc3
tensorflow-datasets version: 3.0.0
Larq version: 0.9.4
Larq-Zoo version: 1.0.b4

@lgeiger lgeiger added the bug Something isn't working label Apr 24, 2020
@lgeiger
Copy link
Member

lgeiger commented Apr 24, 2020

Thanks for catching this and sorry for the confusing behaviour! It looks like some of our published no-top wheight are broken. I added a test in #158 to verify that some of the pretrained weights have missmatches.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants