Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change BN layer to use moving mean/var if frozen #9965

Closed
wants to merge 10 commits into from

Conversation

@datumbox
Copy link
Contributor

datumbox commented Apr 17, 2018

During fine-tuning, if a Batch Normalization layer is frozen it uses the mini-batch statistics. I believe this is incorrect and it can lead to reduced accuracy especially when we use Transfer learning. A better approach in this case would be to use the values of the moving mean and variance.

Changes on this PR:
In this PR I update the Batch Normalization layer to use the learned statistics if frozen during training. This is achieved by making the trainable flag part of the computational graph and by depending the behavior of the BN not only on the learning_phase but also on the value of the trainable property.

Brief explanation:
Assume we use one of the pre-trained CNNs of Keras and we want to fine-tune it. Unfortunately, we get no guarantees that the mean and variance of our new dataset inside the BN layers will be similar to the ones of the original dataset. As a result, if we fine-tune the top layers, their weights will be adjusted to the mean/variance of the new dataset. Nevertheless, during inference the top layers will receive data which are scaled using the mean/variance of the original dataset. This discrepancy can lead to reduced accuracy.

I understand that this is a significant change that requires thorough review. To faciliate the situation I've documented why making such a change is important and provided detailed comparisons before and after applying the patch on my blog.

EDIT: Since the fix was not merged on master, I maintain unofficial patches available for Keras 2.1.6, Keras 2.2.2 and Keras 2.2.4.

@fchollet

This comment has been minimized.

Copy link
Collaborator

fchollet commented Apr 19, 2018

Thanks for the effort.

You are misunderstanding the meaning of the "trainable" property of layers. Historically, it has initially meant "this layer should not be trainable, i.e. the weights of this layer should not be updated during backprop (specifically layer.trainable_weights should be empty)". Then it has been extend to mean "the state of the layer should be frozen during training" (which means that, in addition to the previous definition, layer updates are not run).

What you want is a BN layer in inference mode. There is an argument to control training/inference mode in BN (and other layers): it's the training argument in call (boolean).

What you want is:

x = BatchNormalization()(y, training=False)

For fine-tuning, you could do something like:

# Set up inference-mode base
K.set_learning_phase(0)
inputs = Input(...)
x = layer1(...)(inputs)
x = layer2(...)(x)
...
x = layerN(...)(x)

# Add training-mode layers
K.set_learning_phase(1)
x = layerNp1(...)(x)
x = layerNp2(...)(x)
@fchollet fchollet closed this Apr 19, 2018
@datumbox

This comment has been minimized.

Copy link
Contributor Author

datumbox commented Apr 19, 2018

Hi @fchollet,

First of all thanks for taking the time to review and respond. I was aware that this is a significant change in the default behaviour and that there would be debate. :)

I understand that your main concern is around the semantic meaning of the trainable property and how it is being used in this PR. I agree that semantically the training parameter that you proposed is closer to what I do, nevertheless this parameter can't change after the network definition. For instance when you use one of the pre-trained models of keras or when you load a persisted model you have no control over this variable. Would you be open to discuss a solution that would make the training variable changeable after the network definition (or perhaps another property)? If you are open to this, I could update my PR to reflect the agreed behaviour.

Concerning your second recommendation of updating the learning_phase as the network is defined, I see two limitations:

  1. Again this will work only if the network is defined based on code. It will not work for the pretrained models of Keras or when a model is loaded from disk. The latter is quite important; models are trained in multiple rounds usually after restoring them from checkpoints.
  2. After setting the learning_phase(1) in your example, the learning_phase will be static for the remaining of the session. This will overwrite all the nice mechanisms that keras has for switching between phases depending on whether it trains or predicts. Thus if we call fit() with validation data, the model will predict while being in training mode.

I'm not sure if you had a look on the blog post (it is understandably a bit long), but you can see how significant perfomance boost you get by making it possible to set the BN in inference mode. Without this the trainable layers after the BNs adjust their weights based on input that has different scale (comparing to inference). I hope, we can re-open this PR; I'm happy to update it until it satisfies the semantic definitions.

Cheers!

@fchollet

This comment has been minimized.

Copy link
Collaborator

fchollet commented Apr 19, 2018

Again, there is an existing API that does exactly what you want: the training argument in call. There is no point in having two differently name APIs that do the exact same thing. layer.trainable = False is not what you need, therefore don't use it.

Additionally, your proposed PR adds a computational overhead (which might amount to a ~5% slowdown for a BN-heavy model like InceptionV3) to every single convnet that uses BN, fine-tuning or not. This is a heavy price to pay for supporting an incrementally simpler UX (disputable) for a very specific use case.

For instance when you use one of the pre-trained models of keras or when you load a persisted model you have no control over this variable.

Typically if you want to heavily modify an existing model, rather than merely use it in inference mode, you should have access to the code for the model.

But even if you don't, you can still do your style of fine-tuning in this case:

  • set learning phase to 0
  • load model
  • retrieve features you want to train on
  • set learning phase to 1
  • add new layers on top
  • optionally load weights from initial model layers to corresponding new layers
  • train
@datumbox

This comment has been minimized.

Copy link
Contributor Author

datumbox commented Apr 19, 2018

@fchollet My main point is that the training argument can't be changed after model definition, so the existing API does not cover this valid case. I don't argue that there are workarounds, but they are hacky/non-elegant and the default behaviour leads to much confusion to users. Interesting what you mention about the 5% slow down, I would love to see the benchmarks; perhaps it can be resolved. Finally something you don't address here is whether this discrepancy in the scaling makes sense (theoretically or otherwise) and whether the accuracy decrease is worth it.

At any case, let's agree we disagree. I do hope though that you will revise your decision on the future, as it happened with the update of the mini-batch statistics on the BN.

@datumbox datumbox mentioned this pull request Apr 19, 2018
@fchollet

This comment has been minimized.

Copy link
Collaborator

fchollet commented Apr 19, 2018

I would love to see the benchmarks

This is based on something I've observed in the past for InceptionV3 with static learning phase vs. with dynamic learning phase. Only difference between the two settings is cond ops. Control flow seems pretty expensive, especially on GPU. Your PR adds the exact same number cond ops, so I would expect the same overhead.

@datumbox

This comment has been minimized.

Copy link
Contributor Author

datumbox commented Apr 20, 2018

Thanks for the clarifying that you are referring to a different benchmark and not to something you ran on this PR. I can't comment on the results without seeing them but when I ran comparisons on CIFAR10 the time difference was negligible (current branch: 4216 secs vs patched: 4251 secs); both ran on GPUs on the same server. Note that the snippet that I used (and listed on my article) comes from Keras' documentation on how to fine-tune a network.

Admittedly the above measurements are single point estimates but especially the 5 point accuracy increase I report is consistent with what I've been observing for almost a year while applying workarounds (first time I reported this is on #7177). I don't know if the speed is currently your main concern for reopening this but I would say that this is unlikely to affect the majority of the users of Keras. This is because by default the Learning Phase is dynamic and the training argument of call is None. This will force the in_train_phase method on backend to use a switch statement that depends on learning phase, so in a sense the "if" statement is already there.

At any case I don't insist that it should me who changes this or that my current solution is the one we should use. I'm just raising a valid use case that is taken directly from Keras' documentation on how fine-tuning is performed. Currently there is no straightforward way to do what I describe (the current API doesn't cover it), nevertheless if you provide specific guidelines on what tickboxes the update should check it would be useful. Or perhaps some other longtime contributor of the BatchNormalization layer has an opinion or can offer a more elegant solution on this? @ozabluda @taehoonlee @farizrahman4u @Dref360

@ozabluda

This comment has been minimized.

Copy link
Contributor

ozabluda commented Apr 23, 2018

Sorry for late reply, still trying to understand the issues. For example, I am trying to understand if this is related at all to #9214

@ahundt

This comment has been minimized.

Copy link
Contributor

ahundt commented Apr 23, 2018

What sort of batch sizes were you using in your linked experiments?

Some datasets are only viable with very small batch sizes of 1-4, like with image segmentation on a GPU with 8GB of memory. After briefly skimming this diff, I think the documentation would need to be updated to clearly delineate the different modes and when/why each should typically be chosen. In my case the current frozen behavior improved performance quite a lot over the previous behavior in which mean/var could shift when trainable=False, so I'm a bit hesitant about this though I'll reiterate I haven't reviewed what's happening in full detail.

Here is a PR with some past discussion on BN #8616

@datumbox

This comment has been minimized.

Copy link
Contributor Author

datumbox commented Apr 23, 2018

@ozabluda First of all thank you for spending time on this. I wish I had provided on my PR the example that you posted on the issue #9214; perhaps this would have built a stronger case for this patch. What you showed on your post is exactly what I've been observing on real-world non-opensource datasets for the last year (close to 100% accuracy on training mode and 50% during inference on the same dataset and on similar validation sets). As @fchollet said the are lots of hacks that can help you avoid it but none of them should have been necessary.

Based on the code you provided, I'm 100% certain you are being bitten by the behaviour of the BN layer that I'm trying to fix in this PR. In a nutshell, during training mode the frozen BN layers are scaled with different statistics than in inference mode. There is absolutely no theoretical foundation to support this behaviour. As a result, this can have devastating effects when you try to deploy the model or when you try to validate its accuracy. I am certain that the majority of people who face this believe they have overfitted the model while in reality this is just a side-effect of how Keras implements the Batch Normalization layer.

So let's test your example on my branch of Keras where the BN layer is patched:

pip install -U --force-reinstall --no-dependencies git+https://github.com/datumbox/keras@bugfix/trainable_bn

Below I run your code for ResNet50. As you can see the problem that you report is fixed once the BN behaviour is changed:

Epoch 1/100
50/50 [==============================] - 19s 387ms/step - loss: 0.8738 - acc: 0.5000 - val_loss: 1.3021 - val_acc: 0.5000
Epoch 2/100
50/50 [==============================] - 18s 367ms/step - loss: 1.3021 - acc: 0.5000 - val_loss: 0.9412 - val_acc: 0.5000
Epoch 3/100
50/50 [==============================] - 18s 363ms/step - loss: 0.9412 - acc: 0.5000 - val_loss: 0.6904 - val_acc: 0.5000
Epoch 4/100
50/50 [==============================] - 18s 364ms/step - loss: 0.6904 - acc: 0.5000 - val_loss: 0.9428 - val_acc: 0.5000
Epoch 5/100
50/50 [==============================] - 18s 361ms/step - loss: 0.9428 - acc: 0.5000 - val_loss: 0.9180 - val_acc: 0.5000
Epoch 6/100
50/50 [==============================] - 20s 401ms/step - loss: 0.9180 - acc: 0.5000 - val_loss: 0.7111 - val_acc: 0.5000
Epoch 7/100
50/50 [==============================] - 21s 415ms/step - loss: 0.7111 - acc: 0.5000 - val_loss: 0.6802 - val_acc: 0.5200
Epoch 8/100
50/50 [==============================] - 20s 406ms/step - loss: 0.6802 - acc: 0.5200 - val_loss: 0.8039 - val_acc: 0.5000
Epoch 9/100
50/50 [==============================] - 20s 391ms/step - loss: 0.8039 - acc: 0.5000 - val_loss: 0.8075 - val_acc: 0.5000
Epoch 10/100
50/50 [==============================] - 21s 425ms/step - loss: 0.8075 - acc: 0.5000 - val_loss: 0.6963 - val_acc: 0.5000
Epoch 11/100
50/50 [==============================] - 21s 417ms/step - loss: 0.6963 - acc: 0.5000 - val_loss: 0.6406 - val_acc: 0.7000
Epoch 12/100
50/50 [==============================] - 21s 419ms/step - loss: 0.6406 - acc: 0.7000 - val_loss: 0.7017 - val_acc: 0.5000
Epoch 13/100
50/50 [==============================] - 21s 425ms/step - loss: 0.7017 - acc: 0.5000 - val_loss: 0.7408 - val_acc: 0.5000
Epoch 14/100
50/50 [==============================] - 22s 441ms/step - loss: 0.7408 - acc: 0.5000 - val_loss: 0.6895 - val_acc: 0.5000
Epoch 15/100
50/50 [==============================] - 22s 432ms/step - loss: 0.6895 - acc: 0.5000 - val_loss: 0.6267 - val_acc: 0.7200
Epoch 16/100
50/50 [==============================] - 23s 460ms/step - loss: 0.6267 - acc: 0.7200 - val_loss: 0.6376 - val_acc: 0.5600
Epoch 17/100
50/50 [==============================] - 22s 439ms/step - loss: 0.6376 - acc: 0.5600 - val_loss: 0.6775 - val_acc: 0.5400
Epoch 18/100
50/50 [==============================] - 23s 456ms/step - loss: 0.6775 - acc: 0.5400 - val_loss: 0.6675 - val_acc: 0.5400
Epoch 19/100
50/50 [==============================] - 21s 414ms/step - loss: 0.6675 - acc: 0.5400 - val_loss: 0.6209 - val_acc: 0.6000
Epoch 20/100
50/50 [==============================] - 19s 375ms/step - loss: 0.6209 - acc: 0.6000 - val_loss: 0.6055 - val_acc: 0.7400
Epoch 21/100
50/50 [==============================] - 18s 367ms/step - loss: 0.6055 - acc: 0.7400 - val_loss: 0.6309 - val_acc: 0.5800
Epoch 22/100
50/50 [==============================] - 18s 370ms/step - loss: 0.6309 - acc: 0.5800 - val_loss: 0.6392 - val_acc: 0.5600
Epoch 23/100
50/50 [==============================] - 18s 369ms/step - loss: 0.6392 - acc: 0.5600 - val_loss: 0.6111 - val_acc: 0.6400
Epoch 24/100
50/50 [==============================] - 19s 390ms/step - loss: 0.6111 - acc: 0.6400 - val_loss: 0.5890 - val_acc: 0.7800
Epoch 25/100
50/50 [==============================] - 20s 394ms/step - loss: 0.5890 - acc: 0.7800 - val_loss: 0.5990 - val_acc: 0.6200
Epoch 26/100
50/50 [==============================] - 22s 445ms/step - loss: 0.5990 - acc: 0.6200 - val_loss: 0.6105 - val_acc: 0.5800
Epoch 27/100
50/50 [==============================] - 21s 413ms/step - loss: 0.6105 - acc: 0.5800 - val_loss: 0.5961 - val_acc: 0.6000
Epoch 28/100
50/50 [==============================] - 19s 388ms/step - loss: 0.5961 - acc: 0.6000 - val_loss: 0.5759 - val_acc: 0.8000
Epoch 29/100
50/50 [==============================] - 20s 391ms/step - loss: 0.5759 - acc: 0.8000 - val_loss: 0.5767 - val_acc: 0.7400
Epoch 30/100
50/50 [==============================] - 19s 372ms/step - loss: 0.5767 - acc: 0.7400 - val_loss: 0.5857 - val_acc: 0.7400
Epoch 31/100
50/50 [==============================] - 22s 433ms/step - loss: 0.5857 - acc: 0.7400 - val_loss: 0.5785 - val_acc: 0.7600
Epoch 32/100
50/50 [==============================] - 19s 373ms/step - loss: 0.5785 - acc: 0.7600 - val_loss: 0.5627 - val_acc: 0.7800
Epoch 33/100
50/50 [==============================] - 21s 417ms/step - loss: 0.5627 - acc: 0.7800 - val_loss: 0.5597 - val_acc: 0.7800
Epoch 34/100
50/50 [==============================] - 21s 422ms/step - loss: 0.5597 - acc: 0.7800 - val_loss: 0.5651 - val_acc: 0.7000
Epoch 35/100
50/50 [==============================] - 18s 365ms/step - loss: 0.5651 - acc: 0.7000 - val_loss: 0.5606 - val_acc: 0.7200
Epoch 36/100
50/50 [==============================] - 18s 362ms/step - loss: 0.5606 - acc: 0.7200 - val_loss: 0.5488 - val_acc: 0.8000
Epoch 37/100
50/50 [==============================] - 18s 361ms/step - loss: 0.5488 - acc: 0.8000 - val_loss: 0.5449 - val_acc: 0.7800
Epoch 38/100
50/50 [==============================] - 18s 361ms/step - loss: 0.5449 - acc: 0.7800 - val_loss: 0.5473 - val_acc: 0.8000
Epoch 39/100
50/50 [==============================] - 18s 361ms/step - loss: 0.5473 - acc: 0.8000 - val_loss: 0.5433 - val_acc: 0.8000
Epoch 40/100
50/50 [==============================] - 18s 368ms/step - loss: 0.5433 - acc: 0.8000 - val_loss: 0.5344 - val_acc: 0.8000
Epoch 41/100
50/50 [==============================] - 19s 372ms/step - loss: 0.5344 - acc: 0.8000 - val_loss: 0.5311 - val_acc: 0.8600
Epoch 42/100
50/50 [==============================] - 18s 367ms/step - loss: 0.5311 - acc: 0.8600 - val_loss: 0.5318 - val_acc: 0.7800
Epoch 43/100
50/50 [==============================] - 18s 366ms/step - loss: 0.5318 - acc: 0.7800 - val_loss: 0.5278 - val_acc: 0.7800
Epoch 44/100
50/50 [==============================] - 18s 367ms/step - loss: 0.5278 - acc: 0.7800 - val_loss: 0.5208 - val_acc: 0.8800
Epoch 45/100
50/50 [==============================] - 18s 363ms/step - loss: 0.5208 - acc: 0.8800 - val_loss: 0.5181 - val_acc: 0.8200
Epoch 46/100
50/50 [==============================] - 18s 367ms/step - loss: 0.5181 - acc: 0.8200 - val_loss: 0.5175 - val_acc: 0.8200
Epoch 47/100
50/50 [==============================] - 19s 372ms/step - loss: 0.5175 - acc: 0.8200 - val_loss: 0.5131 - val_acc: 0.8400
Epoch 48/100
50/50 [==============================] - 19s 372ms/step - loss: 0.5131 - acc: 0.8400 - val_loss: 0.5075 - val_acc: 0.8600
Epoch 49/100
50/50 [==============================] - 19s 384ms/step - loss: 0.5075 - acc: 0.8600 - val_loss: 0.5053 - val_acc: 0.9000
Epoch 50/100
50/50 [==============================] - 19s 382ms/step - loss: 0.5053 - acc: 0.9000 - val_loss: 0.5035 - val_acc: 0.8400
Epoch 51/100
50/50 [==============================] - 18s 369ms/step - loss: 0.5035 - acc: 0.8400 - val_loss: 0.4989 - val_acc: 0.9000
Epoch 52/100
50/50 [==============================] - 20s 394ms/step - loss: 0.4989 - acc: 0.9000 - val_loss: 0.4944 - val_acc: 0.8800
Epoch 53/100
50/50 [==============================] - 19s 372ms/step - loss: 0.4944 - acc: 0.8800 - val_loss: 0.4920 - val_acc: 0.8800
Epoch 54/100
50/50 [==============================] - 18s 367ms/step - loss: 0.4920 - acc: 0.8800 - val_loss: 0.4890 - val_acc: 0.8800
Epoch 55/100
50/50 [==============================] - 19s 371ms/step - loss: 0.4890 - acc: 0.8800 - val_loss: 0.4845 - val_acc: 0.9000
Epoch 56/100
50/50 [==============================] - 18s 369ms/step - loss: 0.4845 - acc: 0.9000 - val_loss: 0.4811 - val_acc: 0.8800
Epoch 57/100
50/50 [==============================] - 18s 362ms/step - loss: 0.4811 - acc: 0.8800 - val_loss: 0.4792 - val_acc: 0.9000
Epoch 58/100
50/50 [==============================] - 18s 367ms/step - loss: 0.4792 - acc: 0.9000 - val_loss: 0.4759 - val_acc: 0.9000
Epoch 59/100
50/50 [==============================] - 18s 368ms/step - loss: 0.4759 - acc: 0.9000 - val_loss: 0.4721 - val_acc: 0.8800
Epoch 60/100
50/50 [==============================] - 18s 366ms/step - loss: 0.4721 - acc: 0.8800 - val_loss: 0.4695 - val_acc: 0.9200
Epoch 61/100
50/50 [==============================] - 18s 370ms/step - loss: 0.4695 - acc: 0.9200 - val_loss: 0.4670 - val_acc: 0.9000
Epoch 62/100
50/50 [==============================] - 18s 368ms/step - loss: 0.4670 - acc: 0.9000 - val_loss: 0.4634 - val_acc: 0.9200
Epoch 63/100
50/50 [==============================] - 22s 433ms/step - loss: 0.4634 - acc: 0.9200 - val_loss: 0.4602 - val_acc: 0.9200
Epoch 64/100
50/50 [==============================] - 19s 370ms/step - loss: 0.4602 - acc: 0.9200 - val_loss: 0.4578 - val_acc: 0.9200
Epoch 65/100
50/50 [==============================] - 19s 374ms/step - loss: 0.4578 - acc: 0.9200 - val_loss: 0.4548 - val_acc: 0.9200
Epoch 66/100
50/50 [==============================] - 19s 383ms/step - loss: 0.4548 - acc: 0.9200 - val_loss: 0.4515 - val_acc: 0.9400
Epoch 67/100
50/50 [==============================] - 20s 393ms/step - loss: 0.4515 - acc: 0.9400 - val_loss: 0.4488 - val_acc: 0.9200
Epoch 68/100
50/50 [==============================] - 19s 373ms/step - loss: 0.4488 - acc: 0.9200 - val_loss: 0.4462 - val_acc: 0.9200
Epoch 69/100
50/50 [==============================] - 19s 373ms/step - loss: 0.4462 - acc: 0.9200 - val_loss: 0.4431 - val_acc: 0.9400
Epoch 70/100
50/50 [==============================] - 18s 364ms/step - loss: 0.4431 - acc: 0.9400 - val_loss: 0.4402 - val_acc: 0.9400
Epoch 71/100
50/50 [==============================] - 18s 366ms/step - loss: 0.4402 - acc: 0.9400 - val_loss: 0.4376 - val_acc: 0.9800
Epoch 72/100
50/50 [==============================] - 19s 370ms/step - loss: 0.4376 - acc: 0.9800 - val_loss: 0.4347 - val_acc: 0.9800
Epoch 73/100
50/50 [==============================] - 18s 369ms/step - loss: 0.4347 - acc: 0.9800 - val_loss: 0.4317 - val_acc: 0.9400
Epoch 74/100
50/50 [==============================] - 18s 369ms/step - loss: 0.4317 - acc: 0.9400 - val_loss: 0.4291 - val_acc: 0.9400
Epoch 75/100
50/50 [==============================] - 19s 372ms/step - loss: 0.4291 - acc: 0.9400 - val_loss: 0.4264 - val_acc: 0.9400
Epoch 76/100
50/50 [==============================] - 18s 369ms/step - loss: 0.4264 - acc: 0.9400 - val_loss: 0.4235 - val_acc: 0.9400
Epoch 77/100
50/50 [==============================] - 19s 376ms/step - loss: 0.4235 - acc: 0.9400 - val_loss: 0.4208 - val_acc: 0.9600
Epoch 78/100
50/50 [==============================] - 19s 377ms/step - loss: 0.4208 - acc: 0.9600 - val_loss: 0.4182 - val_acc: 0.9800
Epoch 79/100
50/50 [==============================] - 19s 381ms/step - loss: 0.4182 - acc: 0.9800 - val_loss: 0.4154 - val_acc: 0.9600
Epoch 80/100
50/50 [==============================] - 19s 370ms/step - loss: 0.4154 - acc: 0.9600 - val_loss: 0.4127 - val_acc: 0.9400
Epoch 81/100
50/50 [==============================] - 18s 369ms/step - loss: 0.4127 - acc: 0.9400 - val_loss: 0.4101 - val_acc: 0.9400
Epoch 82/100
50/50 [==============================] - 19s 371ms/step - loss: 0.4101 - acc: 0.9400 - val_loss: 0.4075 - val_acc: 0.9400
Epoch 83/100
50/50 [==============================] - 18s 364ms/step - loss: 0.4075 - acc: 0.9400 - val_loss: 0.4048 - val_acc: 0.9600
Epoch 84/100
50/50 [==============================] - 18s 365ms/step - loss: 0.4048 - acc: 0.9600 - val_loss: 0.4022 - val_acc: 0.9800
Epoch 85/100
50/50 [==============================] - 18s 367ms/step - loss: 0.4022 - acc: 0.9800 - val_loss: 0.3996 - val_acc: 0.9800
Epoch 86/100
50/50 [==============================] - 18s 368ms/step - loss: 0.3996 - acc: 0.9800 - val_loss: 0.3970 - val_acc: 0.9600
Epoch 87/100
50/50 [==============================] - 18s 370ms/step - loss: 0.3970 - acc: 0.9600 - val_loss: 0.3945 - val_acc: 0.9600
Epoch 88/100
50/50 [==============================] - 18s 367ms/step - loss: 0.3945 - acc: 0.9600 - val_loss: 0.3919 - val_acc: 0.9600
Epoch 89/100
50/50 [==============================] - 18s 366ms/step - loss: 0.3919 - acc: 0.9600 - val_loss: 0.3894 - val_acc: 0.9600
Epoch 90/100
50/50 [==============================] - 18s 366ms/step - loss: 0.3894 - acc: 0.9600 - val_loss: 0.3869 - val_acc: 0.9800
Epoch 91/100
50/50 [==============================] - 19s 371ms/step - loss: 0.3869 - acc: 0.9800 - val_loss: 0.3844 - val_acc: 0.9800
Epoch 92/100
50/50 [==============================] - 18s 368ms/step - loss: 0.3844 - acc: 0.9800 - val_loss: 0.3819 - val_acc: 0.9800
Epoch 93/100
50/50 [==============================] - 18s 368ms/step - loss: 0.3819 - acc: 0.9800 - val_loss: 0.3795 - val_acc: 0.9800
Epoch 94/100
50/50 [==============================] - 18s 366ms/step - loss: 0.3795 - acc: 1.0000 - val_loss: 0.3770 - val_acc: 1.0000
Epoch 95/100
50/50 [==============================] - 18s 369ms/step - loss: 0.3770 - acc: 1.0000 - val_loss: 0.3746 - val_acc: 1.0000
Epoch 96/100
50/50 [==============================] - 18s 366ms/step - loss: 0.3746 - acc: 1.0000 - val_loss: 0.3722 - val_acc: 1.0000
Epoch 97/100
50/50 [==============================] - 18s 368ms/step - loss: 0.3722 - acc: 1.0000 - val_loss: 0.3698 - val_acc: 1.0000
Epoch 98/100
50/50 [==============================] - 18s 366ms/step - loss: 0.3698 - acc: 1.0000 - val_loss: 0.3674 - val_acc: 1.0000
Epoch 99/100
50/50 [==============================] - 18s 367ms/step - loss: 0.3674 - acc: 1.0000 - val_loss: 0.3651 - val_acc: 1.0000
Epoch 100/100
50/50 [==============================] - 18s 368ms/step - loss: 0.3651 - acc: 1.0000 - val_loss: 0.3627 - val_acc: 1.0000

I would love to know if you can reproduce my results and whether you can observe any speed degradation that @fchollet suspects.

@datumbox

This comment has been minimized.

Copy link
Contributor Author

datumbox commented Apr 23, 2018

@ahundt Thanks for your comment!

In this very specific experiment I used a fixed batch size of 32. Nevertheless in this dummy example I try to reproduce a behaviour we've been facing for over a year now on real-world datasets and problems. In those cases a large number of different batch sizes were tested and the results were comparable.

Please note that his PR DOES NOT undo the recent change where the mean/var no longer shifts when trainable=False. I 100% agree with you that this change is very beneficial. This PR actually takes it a step further and makes sure that the moving mean/var are used instead of the mini-batch statistics when trainable=False. This ensures that the non-frozen layers will be trained on data scaled the same way as in inference mode.

BTW thanks for sending me the discussion on #8616. Give me sometime to read all the details to see how this is related.

@datumbox

This comment has been minimized.

Copy link
Contributor Author

datumbox commented Apr 23, 2018

@ahundt I've read the discussion on #8616. I understand it focuses on the previous change on BN that correctly stopped the update of the moving mean/var when trainable=False. I totally agree with this change. As I said on my previous comment, this PR takes this a step further to ensure that the data after a frozen BN are scaled in the same way during training as during inference.

What I find interesting is that during the original discussion on #8616, @fchollet raises similar concerns about the semantic meaning of trainable as in this PR. Nevertheless in that discussion, he proposed the introduction of another property to extend the API. I also see he tried to implement another property called "updatable" which was reverted due to the increased complexity (and at the end we settled with extending the semantics of trainable). I wonder if in this case it makes sense to extend the API to cover this valid case OR update the semantics of trainable (preferred solution) OR update the documentation/examples.

I would love to have an opinion from @lukedeo on this since he reviewed the code on the other PR.

@ahundt

This comment has been minimized.

Copy link
Contributor

ahundt commented May 1, 2018

@datumbox Ok think I see what you are saying, I might try this out on my dataset. Do I need to change any settings like trainable in my training code or can I just pull this in? In my example I use frozen vgg16 imagenet pretrained weights as a feature extractor with additional trainable layers afterwards.

One thing that might help with getting this through is a few improvements to the PR, variable names, and description. If you better separate the concepts and clarify the conditions under which different data is fixed vs changing the reasons this improves performance may be more obvious.

@ahundt

This comment has been minimized.

Copy link
Contributor

ahundt commented May 1, 2018

Ok so based on the test_batchnorm_trainable() changes this should be active by default in all cases except when both learning phase=1 and trainable=True.

# In all other cases we should use the moving mean and variance from BN.

Correct?

@datumbox

This comment has been minimized.

Copy link
Contributor Author

datumbox commented May 1, 2018

@ahundt Thanks for looking into this. My PR affects only networks that use Batch Normalization layers, so VGG will not be affected. No additional configuration is required other than setting trainable=False on the BN layers. Pulling this in should work fine just note that my fork is not synced with the latest stable version of Keras; I plan to do this soon as other people are interested. I synced the patch on Keras 2.1.6, Keras 2.2.2 and Keras 2.2.4.

One thing that might help with getting this through is a few improvements to the PR, variable names, and description.

Sure thing, send me your comments and I'll make the changes. :-)

@ahundt

This comment has been minimized.

Copy link
Contributor

ahundt commented May 3, 2018

Oh, yeah sorry I did a first run definitely wasn't configured correctly since vgg makes no sense for this case, and the BN layers I had were trained from scratch. I did have other models including resnet and densenet that didn't perform as well as vgg that use pretrained weights, and the fix in this PR might be why. I will try them out but can you confirm the following steps will make use of your changes?

  1. load pretrained weights for densenet (trainable = false)
  2. add some additional layers on the end
  3. set all layers from (1) including bn to trainable = False, layers from (2) to trainable = True
  4. run training script

Should I expect the above sequence to change the performance when running with this PR applied?

edit: fixed typo mentioned in the next post

@ali-masoudi

This comment has been minimized.

Copy link

ali-masoudi commented Jul 28, 2019

@ozabluda @taehoonlee @farizrahman4u @Dref360 @fchollet
when the patch will be merged?

@aethersis

This comment has been minimized.

Copy link

aethersis commented Aug 5, 2019

@ozabluda @taehoonlee @farizrahman4u @Dref360 @fchollet Are there any plans to merge this path? It would make my work on semantic segmentation so much easier!

@Toukenize

This comment has been minimized.

Copy link

Toukenize commented Aug 8, 2019

Hi, may I know what is the correct way of finetuning pre-trained models with BN layers on a new dataset?

From what I read so far, it seems like the Inception V3 finetuning example in Keras documentation is not the correct way, due to the way BN layers behave when frozen.

@geometrikal

This comment has been minimized.

Copy link

geometrikal commented Aug 8, 2019

@Toukenize I saw there is a note in https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/layers/normalization_v2.py that suggest in tensorflow 2.0, setting trainable=False should cause the bn layer to use the saved mean and variance rather than calculating a new one for the batch. So maybe give that a try.

In my case, I just want to do some simple transfer learning (output of ResNet50 plus a few trainable dense layers). My workaround was this:

  1. Create the base model (e.g. ResNet50 + ImageNet) and use it to generate the vectors
  2. Create a separate model for the Dense layers with vectors as input and train it. (Let's call it model_dense)
  3. Join them together using functional interface.

E.g.

model_base = ResNet50(.....)
train_vectors = model_base.predict(...)
test_vectors = model_base.predict(...)
model_dense = ...
model_dense.fit(train_vectors,...)
joined = model_base(model_dense.outputs[0])
model_joined = Model(model_base.inputs[0], joined)

Note that if you want to freeze the resulting graph for production etc there is a bug in tf.keras (1.13.1, 1.14.0 and 1.15.0 nightly) where the conversion to frozen graph either doesn't work (1.14.0) or works but creates a frozen graph that gives errors on import (1.13.1 and 1.15.0.dev20190807 nightly). See my workaround here: tensorflow/tensorflow#31331 (comment)

@Toukenize

This comment has been minimized.

Copy link

Toukenize commented Aug 8, 2019

@Toukenize I saw there is a note in https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/layers/normalization_v2.py that suggest in tensorflow 2.0, setting trainable=False should cause the bn layer to use the saved mean and variance rather than calculating a new one for the batch. So maybe give that a try.

In my case, I just want to do some simple transfer learning (output of ResNet50 plus a few trainable dense layers). My workaround was this:

  1. Create the base model (e.g. ResNet50 + ImageNet) and use it to generate the vectors
  2. Create a separate model for the Dense layers with vectors as input and train it. (Let's call it model_dense)
  3. Join them together using functional interface.

E.g.

model_base = ResNet50(.....)
train_vectors = model_base.predict(...)
test_vectors = model_base.predict(...)
model_dense = ...
model_dense.fit(train_vectors,...)
joined = model_base(model_dense.outputs[0])
model_joined = Model(model_base.inputs[0], joined)

Note that if you want to freeze the resulting graph for production etc there is a bug in tf.keras (1.13.1, 1.14.0 and 1.15.0 nightly) where the conversion to frozen graph either doesn't work (1.14.0) or works but creates a frozen graph that gives errors on import (1.13.1 and 1.15.0.dev20190807 nightly). See my workaround here: tensorflow/tensorflow#31331 (comment)

@geometrikal Thanks for your prompt response. Does that mean I got to extract the train and test vectors every epoch, if I'm doing data augmentation?

My use case is slightly more complicated, meant for a sequence of images (video classification), the actual model is as such:

pretrained = ResNet50(include_top=False, weights='imagenet', pooling='max')

for layer in pretrained.layers:
    layer.trainable = False

nn_input = Input(shape=(SEQ, 224, 224, 3))
x = TimeDistributed(pretrained)(nn_input)
x = Reshape(target_shape=(SEQ, pretrained.output_shape[-1]))(x)
x = CuDNNLSTM(128, return_sequences=True)(x)
x = GlobalAveragePooling1D()(x)
x = Dense(128, activation='relu')(x)
nn_output = Dense(1, activation='sigmoid')(x)   # To extend to multi-label, change the number of hidden units

model = Model(inputs=nn_input, outputs=nn_output)

Do you have any suggestion on other work arounds for this?

@geometrikal

This comment has been minimized.

Copy link

geometrikal commented Aug 8, 2019

@Toukenize I was just talking about this with a colleague, who said that they use the normal keras (not tensorflow.keras) ResNet50 for transfer learning and have no problem. I have not tried that myself however, but this issue suggests setting layer.trainable = False should work for plain keras: #7085

Edit: Apologies. I thought I was commenting on the tensorflow.keras github, it seems this issue exists in normal keras too, and still problems (#9214)

As for suggestions, maybe you could pre-compute a bunch of augmented vectors as well? Or it might work to create a generator (to train with fit_generator) which does both augmentation and creating the vectors.

I just looked at a friends code and they just leave the base model trainable. Not sure what impact that makes.

@sandeepjana

This comment has been minimized.

Copy link

sandeepjana commented Aug 19, 2019

Does the issue affect the models that are saved and loaded using model.save() and load_model? If so, saving the model with modified BN layers as suggested in @faustomorales comment above would still be helpful?

@ybsave

This comment has been minimized.

Copy link

ybsave commented Aug 22, 2019

Cannot believe that this patch hasn't been merged after 1.5 years! Keras people seem to be very arrogant to accept others' delicate solutions; ignoring 50% performance drop, but saying 5% extra cost without benchmark evidence. Really angry. This stupid Keras issue costs me two weeks.

Does anyone know that how should I apply @datumbox 's patch to the Tensorflow's Keras? Thank you.

@TheGuywithTheHat

This comment has been minimized.

Copy link

TheGuywithTheHat commented Aug 22, 2019

@ybsave, datumbox has previously said to use the following command to install his fork:

pip install -U --force-reinstall --no-dependencies git+https://github.com/datumbox/keras@fork/keras2.2.4
@ybsave

This comment has been minimized.

Copy link

ybsave commented Aug 22, 2019

@TheGuywithTheHat Thank you for your help. After install his fork, I test the codes in @datumbox 's blog. But the results are still not fixed...

I tested the Tensorflow's batch normalization. As long as set is_training to False during testing, the performance is the same as during training.

Keras sucks. I would switch to pure Tensorflow instead. Thanks @datumbox 's excellent solution and @TheGuywithTheHat 's help.

Shame on Keras people damaging Tensorflow. I feel sad Tensorflow use the stupid Keras as default mode.

@datumbox

This comment has been minimized.

Copy link
Contributor Author

datumbox commented Aug 22, 2019

@ybsave check if you define the learning phase in your implementation. You should not define it else the behaviour of the patch won't be what you expect.

Also I understand the frustration and your point of view but if you think about it, Keras is a good solution and at the end of the day it is an open source project that anyone can fork and adjust to their needs. That's what we do here. It's up to the owner to merge or not PRs if they meet their criteria. Maybe one day someone will provide a more elegant solution and it will be part of Keras. :)

@ybsave

This comment has been minimized.

Copy link

ybsave commented Aug 22, 2019

@datumbox Thank you for your kind comments. My previous codes used tensorflow.keras, but not the patched keras. After fixing it, I got the same results as shown in your blog.

Thank you for your kind and warm words. I just lose confidence on the Keras people. "After training, testing on the same training data should produce the same results in the last training round". This is a common sense to me. But I think Keras people lack common sense. You give very detailed explanations, experiments, fixes, and discussions. But they do not listen. I cannot understand their ignoring the huge performance drop (>30% in my own work) but attacking your solution for tiny (<5% extra cost, and no experimental evidence).

If you will produce new Keras version, I will definitely use it. But without your patch, I will never use Keras anymore. Thank you so much for saving my world!

@ybsave

This comment has been minimized.

Copy link

ybsave commented Aug 22, 2019

@datumbox Do you have a plan to also produce a fix to the tensorflow.keras? Thank you.

@geometrikal

This comment has been minimized.

Copy link

geometrikal commented Aug 22, 2019

@ybsave Try tensorflow 2.0, I think the behavious has been changed there

@ybsave

This comment has been minimized.

Copy link

ybsave commented Aug 24, 2019

@geometrikal Thanks for the suggestion. Have you tried @datumbox 's example codes on his blog by changing keras to tf.keras? In my test, the problem still exists on Tensorflow 2.0. Does this problem disappear in your test?

@geometrikal

This comment has been minimized.

Copy link

geometrikal commented Aug 25, 2019

@ybsave I haven't done any more tests so I'm not sure of the status of this. I just pre-calculate the vectors as per my workaround.

@ofgurcan

This comment has been minimized.

Copy link

ofgurcan commented Aug 27, 2019

@ybsave can you share keras codes please? I haven't seen complete code, still dont know where I will put this pack in code? "install pip install -U --force-reinstall --no-dependencies git+https://github.com/datumbox/keras@fork/keras2.2.4" thank you

@ybsave

This comment has been minimized.

Copy link

ybsave commented Aug 27, 2019

@ofgurcan you can just run the pip installation command above in the terminal. Just make sure your keras version; if not 2.2.4, datumbox also has patches for other versions; just choose the correct one. You can see the revisions in the "Files Change" tab.

@ybsave

This comment has been minimized.

Copy link

ybsave commented Aug 27, 2019

Test the newest Tensorflow 2.0.0-rc0, @datumbox 's testing codes are correct now. It seems we could safely use tf.keras from the newest Tensorflow, instead of waiting for arrogant keras people fixing this.

@lovejing0306

This comment has been minimized.

Copy link

lovejing0306 commented Sep 19, 2019

It seems we could safely use tf.keras from the newest Tensorflow, instead of waiting for arrogant keras people fixing this.

I used Tensorflow 2.0.0-rc0, when I fine-tuning resnet have same problem.

@off99555

This comment has been minimized.

Copy link

off99555 commented Oct 6, 2019

In tensorflow2 GPU there is still this problem occurring, I have to use @faustomorales code in order to fix the issue.
But this sounds like a hack to me, my prediction time increases. Is there some article teaching fine-tuning with CNN and dealing with BatchNorm stuff? I don't want to master BatchNorm hacking!

I check and followed this tensorflow guide on transfer learning and it didn't even once mentioned issue with batch norm: https://www.tensorflow.org/tutorials/images/transfer_learning And it also doesn't have the problem like I have with my coordinate regression task. It seems like classification dogs and cats is too easy that batch norm issue isn't a big deal.

It's already 2019, and almost 2020. Does there exist an elegant solution yet?

Update: It seems the increase in prediction time is caused by weird issue that happens after you call model.compile(), not because of hacked code, please check my question on stackoverflow about that here: https://stackoverflow.com/questions/58378374/why-does-keras-model-predict-slower-after-compile
It means that you don't have to worry about prediction time increasing if you measure prediction time after compiling.
All complaining aside, the solution from @faustomorales works fine.

@captainst

This comment has been minimized.

Copy link

captainst commented Oct 8, 2019

After some try out, I think that a fesible yet simple solution is this (pseudo code, taking inceptionV3 as example):

1. K.set_learning_phase(0) # test mode, freezing everything
2. myModel = InceptionV3(weights='imagenet', include_top=False, input_shape=(299,299,3)) # load pre-trained model and weight
3. nn_inputs = myModel.input # save input for layer use
4. for layer in myModel.layers:
        layer.trainable = False # freeze the weights in each layer. Notice that BNs are also freezed
5. K.set_learning_phase(1) # switch to training mode
6. # build the top layers 
    myModelOut = myModel.output
    myModelOut = GlobalAveragePooling2D()(myModelOut)
    myModelOut = Dense(1024, activation="relu")(myModelOut)
    myModelOut = Dense(10, activation="softmax")(myModelOut)
7. # build the whole model
    finalModel = Model(inputs=nn_inputs, outputs=myModelOut)
8. # verify the model structure and parameters
    print(model.summary())
@gjy1992

This comment has been minimized.

Copy link

gjy1992 commented Oct 21, 2019

@captainst Hello, this way can work. But after I save the finalModel, and want to continue a training through load the saved_model, I cannot set part of the finalModel be at learning_phase=0. (╥╯^╰╥)

@ec1841

This comment has been minimized.

Copy link

ec1841 commented Nov 2, 2019

@captainst your approach doesn't work, when you want to fine-tune top-k layers part of your base-model that may have BN layers.

Yet another work-around :), using @faustomorales suggestion as the base soup to solve the top-k layer fine-tuning.


from keras import layers

class FrozenBatchNormalization(layers.BatchNormalization):
    def call(self, inputs, training=None):
        return super().call(inputs=inputs, training=False)

model = InceptionResNetV2(....)
if mode == 'training':
    _bottom_layers = model.layers[:-top_k_layers]
    _top_layers = model.layers[-top_k_layers:]
elif mode == 'inference':
    _bottom_layers = model.layers
    _top_layers = []

for _layer in _bottom_layers:
    _layer.trainable = False
    if (_is_batch_normalization(_layer)):
        print('Freezing BN layers ... {}'.format(_layer.name))
        _layer = FrozenBatchNormalization

for _layer in _top_layers:
    _layer.trainable = True
    if (_is_batch_normalization(_layer)):
        print('Unfreezing BN layers ... {}'.format(_layer.name))
        _layer = layers.BatchNormalization

Will this work?

@rpeloff

This comment has been minimized.

Copy link

rpeloff commented Nov 3, 2019

@lovejing0306

I used Tensorflow 2.0.0-rc0, when I fine-tuning resnet have same problem.

@off99555

In tensorflow2 GPU there is still this problem occurring, I have to use @faustomorales code in order to fix the issue.

For those using TensorFlow 2.0 and trying to fine-tune resnet, inceptionv3, etc., the problem seems to persist due to the injection of tensorflow.python.keras.layers. This references the TF 1.0 behaviour batch normalisation in keras_applications when calling layers.BatchNormalization (for example, in inceptionv3).

Similar to what @faustomorales suggested, I found that simply injecting tf.keras.layers references the TF 2.0 behaviour batch normalisation and fixes this issue (see here for the change in behaviour). When loading models from tf.keras.applications simply add the argument layers=tf.keras.layers. For example:

import tensorflow as tf

pretrained = tf.keras.applications.inception_v3.InceptionV3(
    layers=tf.keras.layers, weights='imagenet')
@Tauranis

This comment has been minimized.

Copy link

Tauranis commented Nov 25, 2019

@datumbox,
If one day I meet you, I promise I'll pay you a beer.
You have not idea how this thread saved me. I've spent two weeks struggling with transfer learning without having any clue of why it was going completely wrong. A simple transfer learning, it was non-sense.
Unfortunately, none of the suggested workarounds worked for me. I'm currently using TF 2.0.0.
The only network that you won't have any headaches is VGG once it has not batch norm layers.

For all others, what works in my case is to do transfer-learning on a 2-step process: Extract embeddings first (into tfrecords shards or not, it is up to you) for further classification.

@rpeloff , your workaround worked for me on TF 1.15.0 but not at TF 2.0.0, but thanks anyway.

@sedghi

This comment has been minimized.

Copy link

sedghi commented Dec 3, 2019

I can't believe this is not fixed yet
I have spent 1 week to finally find what's going wrong with my code
Thanks @datumbox , I'll buy you a beer too

@sameervk

This comment has been minimized.

Copy link

sameervk commented Jan 17, 2020

@sedghi, same here. Thanks @datumbox, this was not a problem before when working with custom layers, but when working with Transfer Learning, I just didn't know what on earth was going on until I came across your blog and then this.

@sameervk

This comment has been minimized.

Copy link

sameervk commented Jan 17, 2020

@Tauranis would you mind elaborating on your 2-step process please? Thanks.

@RomainSabathe

This comment has been minimized.

Copy link

RomainSabathe commented Jan 25, 2020

@lovejing0306

I used Tensorflow 2.0.0-rc0, when I fine-tuning resnet have same problem.

@off99555

In tensorflow2 GPU there is still this problem occurring, I have to use @faustomorales code in order to fix the issue.

For those using TensorFlow 2.0 and trying to fine-tune resnet, inceptionv3, etc., the problem seems to persist due to the injection of tensorflow.python.keras.layers. This references the TF 1.0 behaviour batch normalisation in keras_applications when calling layers.BatchNormalization (for example, in inceptionv3).

Similar to what @faustomorales suggested, I found that simply injecting tf.keras.layers references the TF 2.0 behaviour batch normalisation and fixes this issue (see here for the change in behaviour). When loading models from tf.keras.applications simply add the argument layers=tf.keras.layers. For example:

import tensorflow as tf

pretrained = tf.keras.applications.inception_v3.InceptionV3(
    layers=tf.keras.layers, weights='imagenet')

Thank you so much!!! Looks like it solved it for me. That is certainly a strange behaviour though. One would think that they're using TF 2.x components when using the official TF 2.x release. Anyways, thanks for your reply and the explanation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
You can’t perform that action at this time.