New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change BN layer to use moving mean/var if frozen #9965

Closed
wants to merge 10 commits into
base: master
from

Conversation

Projects
None yet
@datumbox
Contributor

datumbox commented Apr 17, 2018

During fine-tuning, if a Batch Normalization layer is frozen it uses the mini-batch statistics. I believe this is incorrect and it can lead to reduced accuracy especially when we use Transfer learning. A better approach in this case would be to use the values of the moving mean and variance.

Changes on this PR:
In this PR I update the Batch Normalization layer to use the learned statistics if frozen during training. This is achieved by making the trainable flag part of the computational graph and by depending the behavior of the BN not only on the learning_phase but also on the value of the trainable property.

Brief explanation:
Assume we use one of the pre-trained CNNs of Keras and we want to fine-tune it. Unfortunately, we get no guarantees that the mean and variance of our new dataset inside the BN layers will be similar to the ones of the original dataset. As a result, if we fine-tune the top layers, their weights will be adjusted to the mean/variance of the new dataset. Nevertheless, during inference the top layers will receive data which are scaled using the mean/variance of the original dataset. This discrepancy can lead to reduced accuracy.

I understand that this is a significant change that requires thorough review. To faciliate the situation I've documented why making such a change is important and provided detailed comparisons before and after applying the patch on my blog.

EDIT: Since the fix was not merged on master, I maintain unofficial patches available for Keras 2.1.6, Keras 2.2.2 and Keras 2.2.4.

@fchollet

This comment has been minimized.

Collaborator

fchollet commented Apr 19, 2018

Thanks for the effort.

You are misunderstanding the meaning of the "trainable" property of layers. Historically, it has initially meant "this layer should not be trainable, i.e. the weights of this layer should not be updated during backprop (specifically layer.trainable_weights should be empty)". Then it has been extend to mean "the state of the layer should be frozen during training" (which means that, in addition to the previous definition, layer updates are not run).

What you want is a BN layer in inference mode. There is an argument to control training/inference mode in BN (and other layers): it's the training argument in call (boolean).

What you want is:

x = BatchNormalization()(y, training=False)

For fine-tuning, you could do something like:

# Set up inference-mode base
K.set_learning_phase(0)
inputs = Input(...)
x = layer1(...)(inputs)
x = layer2(...)(x)
...
x = layerN(...)(x)

# Add training-mode layers
K.set_learning_phase(1)
x = layerNp1(...)(x)
x = layerNp2(...)(x)

@fchollet fchollet closed this Apr 19, 2018

@datumbox

This comment has been minimized.

Contributor

datumbox commented Apr 19, 2018

Hi @fchollet,

First of all thanks for taking the time to review and respond. I was aware that this is a significant change in the default behaviour and that there would be debate. :)

I understand that your main concern is around the semantic meaning of the trainable property and how it is being used in this PR. I agree that semantically the training parameter that you proposed is closer to what I do, nevertheless this parameter can't change after the network definition. For instance when you use one of the pre-trained models of keras or when you load a persisted model you have no control over this variable. Would you be open to discuss a solution that would make the training variable changeable after the network definition (or perhaps another property)? If you are open to this, I could update my PR to reflect the agreed behaviour.

Concerning your second recommendation of updating the learning_phase as the network is defined, I see two limitations:

  1. Again this will work only if the network is defined based on code. It will not work for the pretrained models of Keras or when a model is loaded from disk. The latter is quite important; models are trained in multiple rounds usually after restoring them from checkpoints.
  2. After setting the learning_phase(1) in your example, the learning_phase will be static for the remaining of the session. This will overwrite all the nice mechanisms that keras has for switching between phases depending on whether it trains or predicts. Thus if we call fit() with validation data, the model will predict while being in training mode.

I'm not sure if you had a look on the blog post (it is understandably a bit long), but you can see how significant perfomance boost you get by making it possible to set the BN in inference mode. Without this the trainable layers after the BNs adjust their weights based on input that has different scale (comparing to inference). I hope, we can re-open this PR; I'm happy to update it until it satisfies the semantic definitions.

Cheers!

@fchollet

This comment has been minimized.

Collaborator

fchollet commented Apr 19, 2018

Again, there is an existing API that does exactly what you want: the training argument in call. There is no point in having two differently name APIs that do the exact same thing. layer.trainable = False is not what you need, therefore don't use it.

Additionally, your proposed PR adds a computational overhead (which might amount to a ~5% slowdown for a BN-heavy model like InceptionV3) to every single convnet that uses BN, fine-tuning or not. This is a heavy price to pay for supporting an incrementally simpler UX (disputable) for a very specific use case.

For instance when you use one of the pre-trained models of keras or when you load a persisted model you have no control over this variable.

Typically if you want to heavily modify an existing model, rather than merely use it in inference mode, you should have access to the code for the model.

But even if you don't, you can still do your style of fine-tuning in this case:

  • set learning phase to 0
  • load model
  • retrieve features you want to train on
  • set learning phase to 1
  • add new layers on top
  • optionally load weights from initial model layers to corresponding new layers
  • train
@datumbox

This comment has been minimized.

Contributor

datumbox commented Apr 19, 2018

@fchollet My main point is that the training argument can't be changed after model definition, so the existing API does not cover this valid case. I don't argue that there are workarounds, but they are hacky/non-elegant and the default behaviour leads to much confusion to users. Interesting what you mention about the 5% slow down, I would love to see the benchmarks; perhaps it can be resolved. Finally something you don't address here is whether this discrepancy in the scaling makes sense (theoretically or otherwise) and whether the accuracy decrease is worth it.

At any case, let's agree we disagree. I do hope though that you will revise your decision on the future, as it happened with the update of the mini-batch statistics on the BN.

@datumbox datumbox referenced this pull request Apr 19, 2018

Merged

Chenta/cntk bn #9952

@fchollet

This comment has been minimized.

Collaborator

fchollet commented Apr 19, 2018

I would love to see the benchmarks

This is based on something I've observed in the past for InceptionV3 with static learning phase vs. with dynamic learning phase. Only difference between the two settings is cond ops. Control flow seems pretty expensive, especially on GPU. Your PR adds the exact same number cond ops, so I would expect the same overhead.

@datumbox

This comment has been minimized.

Contributor

datumbox commented Apr 20, 2018

Thanks for the clarifying that you are referring to a different benchmark and not to something you ran on this PR. I can't comment on the results without seeing them but when I ran comparisons on CIFAR10 the time difference was negligible (current branch: 4216 secs vs patched: 4251 secs); both ran on GPUs on the same server. Note that the snippet that I used (and listed on my article) comes from Keras' documentation on how to fine-tune a network.

Admittedly the above measurements are single point estimates but especially the 5 point accuracy increase I report is consistent with what I've been observing for almost a year while applying workarounds (first time I reported this is on #7177). I don't know if the speed is currently your main concern for reopening this but I would say that this is unlikely to affect the majority of the users of Keras. This is because by default the Learning Phase is dynamic and the training argument of call is None. This will force the in_train_phase method on backend to use a switch statement that depends on learning phase, so in a sense the "if" statement is already there.

At any case I don't insist that it should me who changes this or that my current solution is the one we should use. I'm just raising a valid use case that is taken directly from Keras' documentation on how fine-tuning is performed. Currently there is no straightforward way to do what I describe (the current API doesn't cover it), nevertheless if you provide specific guidelines on what tickboxes the update should check it would be useful. Or perhaps some other longtime contributor of the BatchNormalization layer has an opinion or can offer a more elegant solution on this? @ozabluda @taehoonlee @farizrahman4u @Dref360

@ozabluda

This comment has been minimized.

Contributor

ozabluda commented Apr 23, 2018

Sorry for late reply, still trying to understand the issues. For example, I am trying to understand if this is related at all to #9214

@ahundt

This comment has been minimized.

Contributor

ahundt commented Apr 23, 2018

What sort of batch sizes were you using in your linked experiments?

Some datasets are only viable with very small batch sizes of 1-4, like with image segmentation on a GPU with 8GB of memory. After briefly skimming this diff, I think the documentation would need to be updated to clearly delineate the different modes and when/why each should typically be chosen. In my case the current frozen behavior improved performance quite a lot over the previous behavior in which mean/var could shift when trainable=False, so I'm a bit hesitant about this though I'll reiterate I haven't reviewed what's happening in full detail.

Here is a PR with some past discussion on BN #8616

@datumbox

This comment has been minimized.

Contributor

datumbox commented Apr 23, 2018

@ozabluda First of all thank you for spending time on this. I wish I had provided on my PR the example that you posted on the issue #9214; perhaps this would have built a stronger case for this patch. What you showed on your post is exactly what I've been observing on real-world non-opensource datasets for the last year (close to 100% accuracy on training mode and 50% during inference on the same dataset and on similar validation sets). As @fchollet said the are lots of hacks that can help you avoid it but none of them should have been necessary.

Based on the code you provided, I'm 100% certain you are being bitten by the behaviour of the BN layer that I'm trying to fix in this PR. In a nutshell, during training mode the frozen BN layers are scaled with different statistics than in inference mode. There is absolutely no theoretical foundation to support this behaviour. As a result, this can have devastating effects when you try to deploy the model or when you try to validate its accuracy. I am certain that the majority of people who face this believe they have overfitted the model while in reality this is just a side-effect of how Keras implements the Batch Normalization layer.

So let's test your example on my branch of Keras where the BN layer is patched:

pip install -U --force-reinstall --no-dependencies git+https://github.com/datumbox/keras@bugfix/trainable_bn

Below I run your code for ResNet50. As you can see the problem that you report is fixed once the BN behaviour is changed:

Epoch 1/100
50/50 [==============================] - 19s 387ms/step - loss: 0.8738 - acc: 0.5000 - val_loss: 1.3021 - val_acc: 0.5000
Epoch 2/100
50/50 [==============================] - 18s 367ms/step - loss: 1.3021 - acc: 0.5000 - val_loss: 0.9412 - val_acc: 0.5000
Epoch 3/100
50/50 [==============================] - 18s 363ms/step - loss: 0.9412 - acc: 0.5000 - val_loss: 0.6904 - val_acc: 0.5000
Epoch 4/100
50/50 [==============================] - 18s 364ms/step - loss: 0.6904 - acc: 0.5000 - val_loss: 0.9428 - val_acc: 0.5000
Epoch 5/100
50/50 [==============================] - 18s 361ms/step - loss: 0.9428 - acc: 0.5000 - val_loss: 0.9180 - val_acc: 0.5000
Epoch 6/100
50/50 [==============================] - 20s 401ms/step - loss: 0.9180 - acc: 0.5000 - val_loss: 0.7111 - val_acc: 0.5000
Epoch 7/100
50/50 [==============================] - 21s 415ms/step - loss: 0.7111 - acc: 0.5000 - val_loss: 0.6802 - val_acc: 0.5200
Epoch 8/100
50/50 [==============================] - 20s 406ms/step - loss: 0.6802 - acc: 0.5200 - val_loss: 0.8039 - val_acc: 0.5000
Epoch 9/100
50/50 [==============================] - 20s 391ms/step - loss: 0.8039 - acc: 0.5000 - val_loss: 0.8075 - val_acc: 0.5000
Epoch 10/100
50/50 [==============================] - 21s 425ms/step - loss: 0.8075 - acc: 0.5000 - val_loss: 0.6963 - val_acc: 0.5000
Epoch 11/100
50/50 [==============================] - 21s 417ms/step - loss: 0.6963 - acc: 0.5000 - val_loss: 0.6406 - val_acc: 0.7000
Epoch 12/100
50/50 [==============================] - 21s 419ms/step - loss: 0.6406 - acc: 0.7000 - val_loss: 0.7017 - val_acc: 0.5000
Epoch 13/100
50/50 [==============================] - 21s 425ms/step - loss: 0.7017 - acc: 0.5000 - val_loss: 0.7408 - val_acc: 0.5000
Epoch 14/100
50/50 [==============================] - 22s 441ms/step - loss: 0.7408 - acc: 0.5000 - val_loss: 0.6895 - val_acc: 0.5000
Epoch 15/100
50/50 [==============================] - 22s 432ms/step - loss: 0.6895 - acc: 0.5000 - val_loss: 0.6267 - val_acc: 0.7200
Epoch 16/100
50/50 [==============================] - 23s 460ms/step - loss: 0.6267 - acc: 0.7200 - val_loss: 0.6376 - val_acc: 0.5600
Epoch 17/100
50/50 [==============================] - 22s 439ms/step - loss: 0.6376 - acc: 0.5600 - val_loss: 0.6775 - val_acc: 0.5400
Epoch 18/100
50/50 [==============================] - 23s 456ms/step - loss: 0.6775 - acc: 0.5400 - val_loss: 0.6675 - val_acc: 0.5400
Epoch 19/100
50/50 [==============================] - 21s 414ms/step - loss: 0.6675 - acc: 0.5400 - val_loss: 0.6209 - val_acc: 0.6000
Epoch 20/100
50/50 [==============================] - 19s 375ms/step - loss: 0.6209 - acc: 0.6000 - val_loss: 0.6055 - val_acc: 0.7400
Epoch 21/100
50/50 [==============================] - 18s 367ms/step - loss: 0.6055 - acc: 0.7400 - val_loss: 0.6309 - val_acc: 0.5800
Epoch 22/100
50/50 [==============================] - 18s 370ms/step - loss: 0.6309 - acc: 0.5800 - val_loss: 0.6392 - val_acc: 0.5600
Epoch 23/100
50/50 [==============================] - 18s 369ms/step - loss: 0.6392 - acc: 0.5600 - val_loss: 0.6111 - val_acc: 0.6400
Epoch 24/100
50/50 [==============================] - 19s 390ms/step - loss: 0.6111 - acc: 0.6400 - val_loss: 0.5890 - val_acc: 0.7800
Epoch 25/100
50/50 [==============================] - 20s 394ms/step - loss: 0.5890 - acc: 0.7800 - val_loss: 0.5990 - val_acc: 0.6200
Epoch 26/100
50/50 [==============================] - 22s 445ms/step - loss: 0.5990 - acc: 0.6200 - val_loss: 0.6105 - val_acc: 0.5800
Epoch 27/100
50/50 [==============================] - 21s 413ms/step - loss: 0.6105 - acc: 0.5800 - val_loss: 0.5961 - val_acc: 0.6000
Epoch 28/100
50/50 [==============================] - 19s 388ms/step - loss: 0.5961 - acc: 0.6000 - val_loss: 0.5759 - val_acc: 0.8000
Epoch 29/100
50/50 [==============================] - 20s 391ms/step - loss: 0.5759 - acc: 0.8000 - val_loss: 0.5767 - val_acc: 0.7400
Epoch 30/100
50/50 [==============================] - 19s 372ms/step - loss: 0.5767 - acc: 0.7400 - val_loss: 0.5857 - val_acc: 0.7400
Epoch 31/100
50/50 [==============================] - 22s 433ms/step - loss: 0.5857 - acc: 0.7400 - val_loss: 0.5785 - val_acc: 0.7600
Epoch 32/100
50/50 [==============================] - 19s 373ms/step - loss: 0.5785 - acc: 0.7600 - val_loss: 0.5627 - val_acc: 0.7800
Epoch 33/100
50/50 [==============================] - 21s 417ms/step - loss: 0.5627 - acc: 0.7800 - val_loss: 0.5597 - val_acc: 0.7800
Epoch 34/100
50/50 [==============================] - 21s 422ms/step - loss: 0.5597 - acc: 0.7800 - val_loss: 0.5651 - val_acc: 0.7000
Epoch 35/100
50/50 [==============================] - 18s 365ms/step - loss: 0.5651 - acc: 0.7000 - val_loss: 0.5606 - val_acc: 0.7200
Epoch 36/100
50/50 [==============================] - 18s 362ms/step - loss: 0.5606 - acc: 0.7200 - val_loss: 0.5488 - val_acc: 0.8000
Epoch 37/100
50/50 [==============================] - 18s 361ms/step - loss: 0.5488 - acc: 0.8000 - val_loss: 0.5449 - val_acc: 0.7800
Epoch 38/100
50/50 [==============================] - 18s 361ms/step - loss: 0.5449 - acc: 0.7800 - val_loss: 0.5473 - val_acc: 0.8000
Epoch 39/100
50/50 [==============================] - 18s 361ms/step - loss: 0.5473 - acc: 0.8000 - val_loss: 0.5433 - val_acc: 0.8000
Epoch 40/100
50/50 [==============================] - 18s 368ms/step - loss: 0.5433 - acc: 0.8000 - val_loss: 0.5344 - val_acc: 0.8000
Epoch 41/100
50/50 [==============================] - 19s 372ms/step - loss: 0.5344 - acc: 0.8000 - val_loss: 0.5311 - val_acc: 0.8600
Epoch 42/100
50/50 [==============================] - 18s 367ms/step - loss: 0.5311 - acc: 0.8600 - val_loss: 0.5318 - val_acc: 0.7800
Epoch 43/100
50/50 [==============================] - 18s 366ms/step - loss: 0.5318 - acc: 0.7800 - val_loss: 0.5278 - val_acc: 0.7800
Epoch 44/100
50/50 [==============================] - 18s 367ms/step - loss: 0.5278 - acc: 0.7800 - val_loss: 0.5208 - val_acc: 0.8800
Epoch 45/100
50/50 [==============================] - 18s 363ms/step - loss: 0.5208 - acc: 0.8800 - val_loss: 0.5181 - val_acc: 0.8200
Epoch 46/100
50/50 [==============================] - 18s 367ms/step - loss: 0.5181 - acc: 0.8200 - val_loss: 0.5175 - val_acc: 0.8200
Epoch 47/100
50/50 [==============================] - 19s 372ms/step - loss: 0.5175 - acc: 0.8200 - val_loss: 0.5131 - val_acc: 0.8400
Epoch 48/100
50/50 [==============================] - 19s 372ms/step - loss: 0.5131 - acc: 0.8400 - val_loss: 0.5075 - val_acc: 0.8600
Epoch 49/100
50/50 [==============================] - 19s 384ms/step - loss: 0.5075 - acc: 0.8600 - val_loss: 0.5053 - val_acc: 0.9000
Epoch 50/100
50/50 [==============================] - 19s 382ms/step - loss: 0.5053 - acc: 0.9000 - val_loss: 0.5035 - val_acc: 0.8400
Epoch 51/100
50/50 [==============================] - 18s 369ms/step - loss: 0.5035 - acc: 0.8400 - val_loss: 0.4989 - val_acc: 0.9000
Epoch 52/100
50/50 [==============================] - 20s 394ms/step - loss: 0.4989 - acc: 0.9000 - val_loss: 0.4944 - val_acc: 0.8800
Epoch 53/100
50/50 [==============================] - 19s 372ms/step - loss: 0.4944 - acc: 0.8800 - val_loss: 0.4920 - val_acc: 0.8800
Epoch 54/100
50/50 [==============================] - 18s 367ms/step - loss: 0.4920 - acc: 0.8800 - val_loss: 0.4890 - val_acc: 0.8800
Epoch 55/100
50/50 [==============================] - 19s 371ms/step - loss: 0.4890 - acc: 0.8800 - val_loss: 0.4845 - val_acc: 0.9000
Epoch 56/100
50/50 [==============================] - 18s 369ms/step - loss: 0.4845 - acc: 0.9000 - val_loss: 0.4811 - val_acc: 0.8800
Epoch 57/100
50/50 [==============================] - 18s 362ms/step - loss: 0.4811 - acc: 0.8800 - val_loss: 0.4792 - val_acc: 0.9000
Epoch 58/100
50/50 [==============================] - 18s 367ms/step - loss: 0.4792 - acc: 0.9000 - val_loss: 0.4759 - val_acc: 0.9000
Epoch 59/100
50/50 [==============================] - 18s 368ms/step - loss: 0.4759 - acc: 0.9000 - val_loss: 0.4721 - val_acc: 0.8800
Epoch 60/100
50/50 [==============================] - 18s 366ms/step - loss: 0.4721 - acc: 0.8800 - val_loss: 0.4695 - val_acc: 0.9200
Epoch 61/100
50/50 [==============================] - 18s 370ms/step - loss: 0.4695 - acc: 0.9200 - val_loss: 0.4670 - val_acc: 0.9000
Epoch 62/100
50/50 [==============================] - 18s 368ms/step - loss: 0.4670 - acc: 0.9000 - val_loss: 0.4634 - val_acc: 0.9200
Epoch 63/100
50/50 [==============================] - 22s 433ms/step - loss: 0.4634 - acc: 0.9200 - val_loss: 0.4602 - val_acc: 0.9200
Epoch 64/100
50/50 [==============================] - 19s 370ms/step - loss: 0.4602 - acc: 0.9200 - val_loss: 0.4578 - val_acc: 0.9200
Epoch 65/100
50/50 [==============================] - 19s 374ms/step - loss: 0.4578 - acc: 0.9200 - val_loss: 0.4548 - val_acc: 0.9200
Epoch 66/100
50/50 [==============================] - 19s 383ms/step - loss: 0.4548 - acc: 0.9200 - val_loss: 0.4515 - val_acc: 0.9400
Epoch 67/100
50/50 [==============================] - 20s 393ms/step - loss: 0.4515 - acc: 0.9400 - val_loss: 0.4488 - val_acc: 0.9200
Epoch 68/100
50/50 [==============================] - 19s 373ms/step - loss: 0.4488 - acc: 0.9200 - val_loss: 0.4462 - val_acc: 0.9200
Epoch 69/100
50/50 [==============================] - 19s 373ms/step - loss: 0.4462 - acc: 0.9200 - val_loss: 0.4431 - val_acc: 0.9400
Epoch 70/100
50/50 [==============================] - 18s 364ms/step - loss: 0.4431 - acc: 0.9400 - val_loss: 0.4402 - val_acc: 0.9400
Epoch 71/100
50/50 [==============================] - 18s 366ms/step - loss: 0.4402 - acc: 0.9400 - val_loss: 0.4376 - val_acc: 0.9800
Epoch 72/100
50/50 [==============================] - 19s 370ms/step - loss: 0.4376 - acc: 0.9800 - val_loss: 0.4347 - val_acc: 0.9800
Epoch 73/100
50/50 [==============================] - 18s 369ms/step - loss: 0.4347 - acc: 0.9800 - val_loss: 0.4317 - val_acc: 0.9400
Epoch 74/100
50/50 [==============================] - 18s 369ms/step - loss: 0.4317 - acc: 0.9400 - val_loss: 0.4291 - val_acc: 0.9400
Epoch 75/100
50/50 [==============================] - 19s 372ms/step - loss: 0.4291 - acc: 0.9400 - val_loss: 0.4264 - val_acc: 0.9400
Epoch 76/100
50/50 [==============================] - 18s 369ms/step - loss: 0.4264 - acc: 0.9400 - val_loss: 0.4235 - val_acc: 0.9400
Epoch 77/100
50/50 [==============================] - 19s 376ms/step - loss: 0.4235 - acc: 0.9400 - val_loss: 0.4208 - val_acc: 0.9600
Epoch 78/100
50/50 [==============================] - 19s 377ms/step - loss: 0.4208 - acc: 0.9600 - val_loss: 0.4182 - val_acc: 0.9800
Epoch 79/100
50/50 [==============================] - 19s 381ms/step - loss: 0.4182 - acc: 0.9800 - val_loss: 0.4154 - val_acc: 0.9600
Epoch 80/100
50/50 [==============================] - 19s 370ms/step - loss: 0.4154 - acc: 0.9600 - val_loss: 0.4127 - val_acc: 0.9400
Epoch 81/100
50/50 [==============================] - 18s 369ms/step - loss: 0.4127 - acc: 0.9400 - val_loss: 0.4101 - val_acc: 0.9400
Epoch 82/100
50/50 [==============================] - 19s 371ms/step - loss: 0.4101 - acc: 0.9400 - val_loss: 0.4075 - val_acc: 0.9400
Epoch 83/100
50/50 [==============================] - 18s 364ms/step - loss: 0.4075 - acc: 0.9400 - val_loss: 0.4048 - val_acc: 0.9600
Epoch 84/100
50/50 [==============================] - 18s 365ms/step - loss: 0.4048 - acc: 0.9600 - val_loss: 0.4022 - val_acc: 0.9800
Epoch 85/100
50/50 [==============================] - 18s 367ms/step - loss: 0.4022 - acc: 0.9800 - val_loss: 0.3996 - val_acc: 0.9800
Epoch 86/100
50/50 [==============================] - 18s 368ms/step - loss: 0.3996 - acc: 0.9800 - val_loss: 0.3970 - val_acc: 0.9600
Epoch 87/100
50/50 [==============================] - 18s 370ms/step - loss: 0.3970 - acc: 0.9600 - val_loss: 0.3945 - val_acc: 0.9600
Epoch 88/100
50/50 [==============================] - 18s 367ms/step - loss: 0.3945 - acc: 0.9600 - val_loss: 0.3919 - val_acc: 0.9600
Epoch 89/100
50/50 [==============================] - 18s 366ms/step - loss: 0.3919 - acc: 0.9600 - val_loss: 0.3894 - val_acc: 0.9600
Epoch 90/100
50/50 [==============================] - 18s 366ms/step - loss: 0.3894 - acc: 0.9600 - val_loss: 0.3869 - val_acc: 0.9800
Epoch 91/100
50/50 [==============================] - 19s 371ms/step - loss: 0.3869 - acc: 0.9800 - val_loss: 0.3844 - val_acc: 0.9800
Epoch 92/100
50/50 [==============================] - 18s 368ms/step - loss: 0.3844 - acc: 0.9800 - val_loss: 0.3819 - val_acc: 0.9800
Epoch 93/100
50/50 [==============================] - 18s 368ms/step - loss: 0.3819 - acc: 0.9800 - val_loss: 0.3795 - val_acc: 0.9800
Epoch 94/100
50/50 [==============================] - 18s 366ms/step - loss: 0.3795 - acc: 1.0000 - val_loss: 0.3770 - val_acc: 1.0000
Epoch 95/100
50/50 [==============================] - 18s 369ms/step - loss: 0.3770 - acc: 1.0000 - val_loss: 0.3746 - val_acc: 1.0000
Epoch 96/100
50/50 [==============================] - 18s 366ms/step - loss: 0.3746 - acc: 1.0000 - val_loss: 0.3722 - val_acc: 1.0000
Epoch 97/100
50/50 [==============================] - 18s 368ms/step - loss: 0.3722 - acc: 1.0000 - val_loss: 0.3698 - val_acc: 1.0000
Epoch 98/100
50/50 [==============================] - 18s 366ms/step - loss: 0.3698 - acc: 1.0000 - val_loss: 0.3674 - val_acc: 1.0000
Epoch 99/100
50/50 [==============================] - 18s 367ms/step - loss: 0.3674 - acc: 1.0000 - val_loss: 0.3651 - val_acc: 1.0000
Epoch 100/100
50/50 [==============================] - 18s 368ms/step - loss: 0.3651 - acc: 1.0000 - val_loss: 0.3627 - val_acc: 1.0000

I would love to know if you can reproduce my results and whether you can observe any speed degradation that @fchollet suspects.

@datumbox

This comment has been minimized.

Contributor

datumbox commented Apr 23, 2018

@ahundt Thanks for your comment!

In this very specific experiment I used a fixed batch size of 32. Nevertheless in this dummy example I try to reproduce a behaviour we've been facing for over a year now on real-world datasets and problems. In those cases a large number of different batch sizes were tested and the results were comparable.

Please note that his PR DOES NOT undo the recent change where the mean/var no longer shifts when trainable=False. I 100% agree with you that this change is very beneficial. This PR actually takes it a step further and makes sure that the moving mean/var are used instead of the mini-batch statistics when trainable=False. This ensures that the non-frozen layers will be trained on data scaled the same way as in inference mode.

BTW thanks for sending me the discussion on #8616. Give me sometime to read all the details to see how this is related.

@datumbox

This comment has been minimized.

Contributor

datumbox commented Apr 23, 2018

@ahundt I've read the discussion on #8616. I understand it focuses on the previous change on BN that correctly stopped the update of the moving mean/var when trainable=False. I totally agree with this change. As I said on my previous comment, this PR takes this a step further to ensure that the data after a frozen BN are scaled in the same way during training as during inference.

What I find interesting is that during the original discussion on #8616, @fchollet raises similar concerns about the semantic meaning of trainable as in this PR. Nevertheless in that discussion, he proposed the introduction of another property to extend the API. I also see he tried to implement another property called "updatable" which was reverted due to the increased complexity (and at the end we settled with extending the semantics of trainable). I wonder if in this case it makes sense to extend the API to cover this valid case OR update the semantics of trainable (preferred solution) OR update the documentation/examples.

I would love to have an opinion from @lukedeo on this since he reviewed the code on the other PR.

@ahundt

This comment has been minimized.

Contributor

ahundt commented May 1, 2018

@datumbox Ok think I see what you are saying, I might try this out on my dataset. Do I need to change any settings like trainable in my training code or can I just pull this in? In my example I use frozen vgg16 imagenet pretrained weights as a feature extractor with additional trainable layers afterwards.

One thing that might help with getting this through is a few improvements to the PR, variable names, and description. If you better separate the concepts and clarify the conditions under which different data is fixed vs changing the reasons this improves performance may be more obvious.

@ahundt

This comment has been minimized.

Contributor

ahundt commented May 1, 2018

Ok so based on the test_batchnorm_trainable() changes this should be active by default in all cases except when both learning phase=1 and trainable=True.

# In all other cases we should use the moving mean and variance from BN.

Correct?

@datumbox

This comment has been minimized.

Contributor

datumbox commented May 1, 2018

@ahundt Thanks for looking into this. My PR affects only networks that use Batch Normalization layers, so VGG will not be affected. No additional configuration is required other than setting trainable=False on the BN layers. Pulling this in should work fine just note that my fork is not synced with the latest stable version of Keras; I plan to do this soon as other people are interested. I synced the patch on Keras 2.1.6, Keras 2.2.2 and Keras 2.2.4.

One thing that might help with getting this through is a few improvements to the PR, variable names, and description.

Sure thing, send me your comments and I'll make the changes. :-)

@ahundt

This comment has been minimized.

Contributor

ahundt commented May 3, 2018

Oh, yeah sorry I did a first run definitely wasn't configured correctly since vgg makes no sense for this case, and the BN layers I had were trained from scratch. I did have other models including resnet and densenet that didn't perform as well as vgg that use pretrained weights, and the fix in this PR might be why. I will try them out but can you confirm the following steps will make use of your changes?

  1. load pretrained weights for densenet (trainable = false)
  2. add some additional layers on the end
  3. set all layers from (1) including bn to trainable = False, layers from (2) to trainable = True
  4. run training script

Should I expect the above sequence to change the performance when running with this PR applied?

edit: fixed typo mentioned in the next post

@datumbox

This comment has been minimized.

Contributor

datumbox commented Jun 18, 2018

@joeyearsley thanks for your post. I checked the suggested solution but I still feel this does not cover all cases. Here are the two main points that in my opinion remain not addressed:

  1. Setting the BNs to trainable and resetting their statistics is not a good solution because the next convolution weights (which remain with trainable=False) were estimated using the initial BN statistics. Resetting their numbers will adapt them to the distribution of the new data but as I describe on the blog post the next convolutions and non-linearities will never get to adapt in this change because they are frozen. This can have very bad effects especially if the new BN stats cause the output to become positive (from negative) as it will affect the next RELUs. Also since the default momentum on BNs can't be changed after their definition, you will require a large number of iterations to adapt the moving mean/var of the BNs.
  2. The problem does not only appear on the pretrained models included in the applications, but also when you load any other pretrained model stored on disk. Imagine that you might have trained a classifier from scratch on one domain and now you want to use it as initial weights on a different problem. Just updating the applications API will not fix the problem as most likely you need to apply the same patch on load_model().

Recently there was an interesting discussion on twitter and other people have pointed out this problem in the past and not all Deep Learning Frameworks are affected. At any case I think it is clear that the documentation is very misleading and as you said needs to be updated. Personally I think that there is no reason to scale differently the data when a BN layer is frozen. There is no theoretical or mathematical justification on why such a different rescaling policy takes place when the layer is frozen. At any case I hope someone will patch this on the future as the effects of the problem are likely to affect lots of people and they might not be aware as it has similar symptoms as overfitting.

@ppwwyyxx

This comment has been minimized.

ppwwyyxx commented Jun 18, 2018

Let me put an answer here: the main issue is that there is NOT a "correct" thing to do for BN in fine-tuning. There are many valid things to do for BN in fine-tuning. You can:

  1. Freeze gamma, beta; Keep mean/variance unchanged. So it becomes a constant affine
  2. Keep mean/variance unchanged. Tune gamma/beta.
  3. Tune gamma/beta, update mean/variance with the new dataset
  4. Recompute mean/variance on the new dataset first; tune gamma/beta; update mean/variance with the new dataset.

They are all valid options. Original Faster-RCNN/MaskRCNN are based on (1). Using (2) is similar to (1). Some papers report (3) to be useful and large-batch faster-rcnn/maskrcnn can also benefit from (3). Some papers argues (4) to be useful for cross-domain knowledge transfer. And I won't be surprised if there is a (5) or (6).

So there is nothing broken and nothing is "wrong". What people should do is to design a good API such that all these different options can be supported, not to argue on which one is better and ask a library to switch from one option to another.

Or, you can try our paper Group Normalization if you don't want to deal with these crazy issues any more.

@ahundt

This comment has been minimized.

Contributor

ahundt commented Jun 18, 2018

I think this all still indicates there is an issue: the keras API could be more clear and easy to use when fine tuning, with an explicit and easy to use choice between the options mentioned by @ppwwyyxx (plus perhaps Group Normalization). Anyone have time to design one?

@ppwwyyxx are there any publicly available pretrained weights with group normalization?

I believe typical users won't have the resources to train ImageNet weights from scratch. GroupNorm is currently only a pull request in keras-contrib, and not in keras.

It sounds like you're fairly certain that GroupNorm will work better for most use cases, you made a pretty strong statement. :-)

@ppwwyyxx

This comment has been minimized.

ppwwyyxx commented Jun 19, 2018

A bit off-topic, but yes there are pretrained imagenet weights in detectron. And I'm not saying GroupNorm will work better, it's usually similar to BatchNorm in applications we've tested, but one certain thing is that everyone can agree on what to do with it in fine-tuning.

@titu1994

This comment has been minimized.

Contributor

titu1994 commented Jun 19, 2018

I believe Keras used to have a parameter called mode for Batch Normalization. At that time, mode acted as a flag to dictate whether BN was used in inference more or training mode (and this specific functionality is now performed by using the training argument in BN's call method. During the migration from Keras 1 to 2, this argument was dropped.

@fchollet If mode can be re-implemented, this time taking into account the above cases, with clear and precise documentation of how each mode behaves, then it can alleviate this fine tuning issue fairly easily. In that case, we can either drop the training argument altogether or for backward compatibility, use it to dictate a mode of operation (the default mode when True, and one of the other modes when set to False).

As to performance, the current BN operation can be the default mode, and the documentation explicitly state that the other modes will incur more cost at training/evaluation time as necessary. Using the mode argument, the computation graph can be built such that it will cost less in terms of switch statements that have to be made by the backend, however some will be unavoidable.

@ahundt

This comment has been minimized.

Contributor

ahundt commented Jun 19, 2018

This discussion is at a good decision point, @titu1994 that sounds reasonable. I only wish there were a better word than mode, which is used in so many programming situations that it can lose all meaning.

François may not notice this again unless again unless he is pointed to the discussion through some other channel, he probably gets 100s or 1000s of @ mentions each day at this point.

@titu1994

This comment has been minimized.

Contributor

titu1994 commented Jun 19, 2018

@ahundt I agree with the semantics of mode. If the general idea about that functionality is pushed forward, then some better keyword can be discussed.

Whatever it is called, I feel it should not be an integer number like before. It caused confusion about whether mode=1 was equal to layer norm or not. This time, in my opinion, semantically meaningful word (s) which attempt to describe what that mode will do should be used.

@datumbox

This comment has been minimized.

Contributor

datumbox commented Jun 19, 2018

Good ideas. If there is concensus about the API I am happy to update the PR. It would be good if someone can ping @fchollet to give his input; I want to avoid wasting time again if he has no plans to merge this.

@ahundt

This comment has been minimized.

Contributor

ahundt commented Jun 20, 2018

@titu1994 A few well named string modes would be more than good enough for most users, I think that's the best way to go.

@datumbox The easiest option may be creating a new branch and a new PR, with a link pointing back to this.

Considering the variability in approaches, an easy way to "code your mode" may have value but I think that can be left as future work.

@shunjiangxu

This comment has been minimized.

shunjiangxu commented Jul 3, 2018

I'd agree with datumbox's original patch of changing the current behavior of Keras BN layer during training when the BN layer is frozen, e.g., using the pretrained moving average instead of mini-batch statistics. By doing this as a default when freezing this layer, at least in transfer learning you are not seeing that huge difference between training and predicting using the same training dataset. I've seen a lot of confusions from different discussions forums when people are trying to using pre-trained ImageNet models and I have encountered it myself. The current behavior is really not intuitive at all, at least to me. If people want different mode switch, that's fine to add from this patch. Just my two cents.

@bhack

This comment has been minimized.

bhack commented Jul 17, 2018

We are still having new and old issues related to BN/finetuning:
#6977
#9214
#10554
#10214
etc.. etc..

Probably there is something not so user friendly or under documented for Keras expected average usability:

@ppwwyyxx made a good analysis in #9965 (comment)

Also there is some "downstream" effect in tf.keras etc...

@jlussi

This comment has been minimized.

jlussi commented Jul 23, 2018

Sorry, beginner Keras user here. I read through the thread and couldn't find a clear answer.
I am encountering this exact problem. I want to fine-tune a pre-existing model, works very well in the training phase but fails completely when I predict on the training data (where I had almost 100 % accuracy). What is in your opinion the best and quickest fix right now and how would I apply it (maybe with example)? Thanks a lot

@datumbox

This comment has been minimized.

Contributor

datumbox commented Aug 7, 2018

For those of you who use the fork, there is a new sync with Keras 2.2.2 here.

@30yavash

This comment has been minimized.

30yavash commented Aug 12, 2018

The Batch Normalization layer of Keras is NOT broken !

please just change this lines: (in your source code written in blog.datumbox)

base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
l = Flatten()(base_model.output)
predictions = Dense(2, activation='softmax')(l)
model = Model(inputs=base_model.input, outputs=predictions)

to this:

base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
input = layers.Input(shape=(224, 224, 3))
x = base_model(input)
l = Flatten()(x)
predictions = Dense(2, activation='softmax')(l)
model = Model(inputs=input, outputs=predictions)

you should define input as a separate layer

@JeremyBYU

This comment has been minimized.

JeremyBYU commented Aug 14, 2018

Just wanted to chime in and share my experience. I ran some code last year training a Resnet50 model (Keras 2.0.2, Tensorflow 1.2) and got pretty good results with a validation set accuracy of 98%. Then this year I wanted to make sure I could still reproduce the results (same data). I reran the code on my old python virtual environment (Keras 2.02, TF 1.2) and I got 97% accuracy. Good enough for me. I then wanted to check if I could reproduce the results with and updated Keras and TF.

I created a new python virtual environment with Keras and Tensorflow updated (Keras 2.2.2, TF 1.9) and updated some of the dependencies and resolved any breaking API changes (Resnet50 model last layer changed). I trained but noticed that my model was over fitting very quickly (after about 2 epochs). I tried multiple times and tweaking the model but it never reached a validation accuracy greater than 92%.

I searched around and found 8430 which lead me to look into batch normalization which lead me to this thread. I applied the patch in this pull request, and reran. I am getting validation set accuracy of about 96.6% with the exact same model.

There is still a discrepancy that I can try to hunt out, but for now I am happy. Thanks for this work!

@datumbox

This comment has been minimized.

Contributor

datumbox commented Aug 14, 2018

@30yavash I think you misunderstood on what I write on the blog. How you pass the input has nothing to do with this problem.

@ozabluda

This comment has been minimized.

Contributor

ozabluda commented Aug 14, 2018

@JeremyBYU,

  1. The difference between 97% and 98% may not be statistically significant. What is the size of your validation set?

  2. Try reducing your learning rate. There is a general trend in TF/CuDNN to make kernels faster at a cost of reduced numerical accuracy.

@JeremyBYU

This comment has been minimized.

JeremyBYU commented Aug 14, 2018

Thanks @ozabluda. Yeah the 97 vs 98 difference doesn't bother me (nor even the 96.6). IIt was the drop to 92% that bothered me on the updated Keras 2.2.2. Validation set is kind of small (about 400 images, 6 classes). I will try your suggestions, thanks!

@ozabluda

This comment has been minimized.

Contributor

ozabluda commented Aug 14, 2018

even with with only 400 images, the difference between 92% and 97% is probably statistically significant (depends on the confusion matrix). 97-vs-98 is not.

@shacharm2

This comment has been minimized.

shacharm2 commented Sep 22, 2018

@30yavash , even if what you suggest had been a valid solution (which is not related to the issue at all, according to @datumbox datumbox ), your suggestion yields the following error:

AttributeError: Layer resnet50 has multiple inbound nodes, hence the notion of "layer input" is ill-defined. Use get_input_at(node_index) instead.
(Keras 2.2.2)

@davidkarl

This comment has been minimized.

davidkarl commented Sep 22, 2018

@datumbox thank you for your contribution i will try it out. i have a question though: why not simply reinitialize the layer completely when importing a model and making it trainable?.

you say that one should use the statistics the BN layer learned whilst training the original model but that resetting the statistics and having BN learn the new statistics is not good "because the next convolution weights (which remain with trainable=False) were estimated using the initial BN statistics.".

but if i understand correctly - the next layers are "Expecting" "whitened" data, it's the data distribution the matters, not the rescaling operation itself.
so why wouldn't it be more proper to make the BN layer trainable and reinitialize statistics and momentums to make the BN layer adjust to the new data distribution quickly?.

however, no one can argue with the results of you approach and i would like to know if there's a specific small change to the BN layer definition to make to apply your patch to avoid the being dependent on you publishing anew for each keras version?.

Thank you very much!

@shacharm2

This comment has been minimized.

shacharm2 commented Sep 22, 2018

@30yavash - update - your code works as-is when I copy-pasted,
However, validation accuracy is extremely low. Much lower than VGG16

129/129 [==============================] - 203s 2s/step - loss: 0.1954 - acc: 0.9651 - val_loss: 9.8208 - val_acc: 0.0334
129/129 [==============================] - 192s 1s/step - loss: 0.0266 - acc: 0.9929 - val_loss: 10.1189 - val_acc: 0.0276

Final validation accuracy is 0.0275 (with loss 10.120)
Final train accuracy is 0.997 (with loss 0.00973)

All using your code.
Are you quite sure about your statement?

@jmhessel

This comment has been minimized.

Contributor

jmhessel commented Oct 9, 2018

Am I understanding/summarizing the issue correctly?

  1. Assume one has loaded in a pretrained model with batchnorm, e.g., ResNet50.

  2. If a BatchNormalization's .trainable property is set to False (and training=True, as per the default settings), none of its weights will be updated, i.e., the rescaling variables and the mean/variance variables will be static. At training time, a second set of mean/variance variables will be tracked, updated with each new batch, and used at training time.

  3. At test time, the old mean/variance variance from the original pretrained model will be used instead of the newly computed mean/variance from training time. This results in the accuracy at test time being lower than at training because all of the scales of the variables are messed up.

  4. The solution to this is to use K.set_learning_phase(0) and/or set training=False when loading/constructing a BatchNormalization layer. If training=False, then the moving average/std from the preloaded model are used at training time.

If this is correct, then that works, but does invite the question -- is there a use case for the current setting of setting BatchNormalization's .trainable property to False? Semantics aside, I can't think of any reason why someone would want to use the old mean/variance instead of the one for the training set they are trying to use.

@datumbox

This comment has been minimized.

Contributor

datumbox commented Oct 9, 2018

@jmhessel

On point 1: During training time, the mean/variance is estimated based on mini-batch statistics. The rest of your comments are correct.

On point 2: Spot on.

On point 3: Setting the learning phase to 0 is not exactly solution as you won't be able to train the network at all. Nevertheless you are right to say that the BN will use the mean/variance of the preloaded model.

The reason what you want to use the old mean/variance of the network for those BNs that are frozen is because the layers after them were trained with specific mean/variance in mind. Changing them means that the weights of the next convolutions won't adjust to the changes of the statistics. Even worse, the RELU non linearities will produce completly different ouputs. This is the reason why many people report so low accuracies.

@jmhessel

This comment has been minimized.

Contributor

jmhessel commented Oct 11, 2018

@datumbox Thanks for the info!!! I think that the main thing to avoid, then, is using training=True and trainable=False, which, unfortunately, seems to be the default API usage if you don't think about it. I guess for now I will freeze everything at training and test time in BN layers from pretrained conv nexts and hope that the optimization process can fix this weirdness.

@RichardSieg

This comment has been minimized.

RichardSieg commented Nov 14, 2018

I have been following this discussion for quite a while now and it still gives me headaches.

My current situation is the following: We would like to use DenseNet as a classifier and freeze the first blocks and leave the remaining blocks trainable. When we first tried this approach we were completely surprised by the poor results (actually it did not detect anything) and it took us several working days to find out, that the BatchNormalization layer is causing this issue. In conclusion, this behavior is not user friendly and especially since it is not even documented somewhere. So thank you @datumbox for clarifying things!

Now I wonder how can freeze the blocks in the DenseNet. I do not really want to copy&paste the code from Keras github but it seems like this is the only solution. Should I rather user the set_learning_phase method or set the training parameter to False for those BN layers which I want to freeze?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment