Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Fine-tune InceptionV3/ResNet50 on a new set of classes" doesn't work, while VGG16 works (suspect BN) #9214

Closed
ozabluda opened this issue Jan 28, 2018 · 40 comments

Comments

@ozabluda
Copy link
Contributor

@ozabluda ozabluda commented Jan 28, 2018

The following code works as expected with vgg16 (no BN) but not with resnet50 or inception_v3 (BN). My hypothesis is that it's due to BN. The code follows "Fine-tune InceptionV3 on a new set of classes" from https://keras.io/applications/#usage-examples-for-image-classification-models

from keras.preprocessing import image
from keras.applications import resnet50, inception_v3, vgg16
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D, Input
from keras.optimizers import Adam
import numpy as np

batch_size = 50
num_classes = 2

#base_model = resnet50.ResNet50
#base_model = inception_v3.InceptionV3
base_model = vgg16.VGG16

base_model = base_model(weights='imagenet', include_top=False)
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
for layer in base_model.layers:
    layer.trainable = False

model.compile(loss='sparse_categorical_crossentropy',
              optimizer=Adam(lr=0.0001),
              metrics=['acc'])

x_train = np.random.normal(loc=127, scale=127, size=(50, 224,224,3))
y_train = np.array([0,1]*25)
x_train = resnet50.preprocess_input(x_train)

print(model.evaluate(x_train, y_train, batch_size=batch_size, verbose=0))
model.fit(x_train, y_train,
          epochs=100,
          batch_size=batch_size,
          shuffle=False,
          validation_data=(x_train, y_train))
@ozabluda
Copy link
Contributor Author

@ozabluda ozabluda commented Jan 28, 2018

VGG16 output (works as expected):

  1. both validation loss and acc before training (from evaluate()) are exactly equal to those in the first training iteration
  2. validation loss/acc correspond to training loss/acc
  3. training loss->0, acc->1, validation loss->0, acc->1.
[1.1005926132202148, 0.5]
Train on 50 samples, validate on 50 samples
Epoch 1/100
50/50 [==============================] - 1s 11ms/step - loss: 1.1006 - acc: 0.5000 - val_loss: 0.7771 - val_acc: 0.5800
Epoch 2/100
50/50 [==============================] - 0s 8ms/step - loss: 0.7771 - acc: 0.5800 - val_loss: 0.8947 - val_acc: 0.4600
Epoch 3/100
50/50 [==============================] - 0s 9ms/step - loss: 0.8947 - acc: 0.4600 - val_loss: 0.9511 - val_acc: 0.4800
Epoch 4/100
50/50 [==============================] - 0s 9ms/step - loss: 0.9511 - acc: 0.4800 - val_loss: 0.8385 - val_acc: 0.4600
Epoch 5/100
50/50 [==============================] - 0s 9ms/step - loss: 0.8385 - acc: 0.4600 - val_loss: 0.7341 - val_acc: 0.5400
Epoch 6/100
50/50 [==============================] - 0s 9ms/step - loss: 0.7341 - acc: 0.5400 - val_loss: 0.7455 - val_acc: 0.5600
Epoch 7/100
50/50 [==============================] - 0s 8ms/step - loss: 0.7455 - acc: 0.5600 - val_loss: 0.7991 - val_acc: 0.6000
Epoch 8/100
50/50 [==============================] - 0s 9ms/step - loss: 0.7991 - acc: 0.6000 - val_loss: 0.7902 - val_acc: 0.6000
Epoch 9/100
50/50 [==============================] - 0s 9ms/step - loss: 0.7902 - acc: 0.6000 - val_loss: 0.7258 - val_acc: 0.5800
Epoch 10/100
50/50 [==============================] - 0s 9ms/step - loss: 0.7258 - acc: 0.5800 - val_loss: 0.6727 - val_acc: 0.6400
[...]
Epoch 98/100
50/50 [==============================] - 0s 9ms/step - loss: 0.2272 - acc: 1.0000 - val_loss: 0.2246 - val_acc: 1.0000
Epoch 99/100
50/50 [==============================] - 0s 9ms/step - loss: 0.2246 - acc: 1.0000 - val_loss: 0.2221 - val_acc: 1.0000
Epoch 100/100
50/50 [==============================] - 0s 9ms/step - loss: 0.2221 - acc: 1.0000 - val_loss: 0.2196 - val_acc: 1.0000

@ozabluda
Copy link
Contributor Author

@ozabluda ozabluda commented Jan 28, 2018

resnet50 output (does not work as expected):

  1. validation loss before training (from evaluate()) is nowhere near to that in the first training iteration (BN?)
  2. validation loss/acc does not correspond to training loss/acc at all (BN?)
  3. training loss->0, acc->1 very quickly (acc=1.0 starting from epoch 5), validation loss stays huge forever, acc=0.5 (random) forever.
[2.3405368328094482, 0.5]
Train on 50 samples, validate on 50 samples
Epoch 1/100
50/50 [==============================] - 1s 21ms/step - loss: 0.6806 - acc: 0.5400 - val_loss: 1.6767 - val_acc: 0.5000
Epoch 2/100
50/50 [==============================] - 0s 8ms/step - loss: 0.6061 - acc: 0.6400 - val_loss: 1.8632 - val_acc: 0.5000
Epoch 3/100
50/50 [==============================] - 0s 9ms/step - loss: 0.5088 - acc: 0.9000 - val_loss: 2.0533 - val_acc: 0.5000
Epoch 4/100
50/50 [==============================] - 0s 9ms/step - loss: 0.4437 - acc: 0.9200 - val_loss: 1.9083 - val_acc: 0.5000
Epoch 5/100
50/50 [==============================] - 0s 9ms/step - loss: 0.3799 - acc: 1.0000 - val_loss: 1.5847 - val_acc: 0.5000
Epoch 6/100
50/50 [==============================] - 0s 9ms/step - loss: 0.3222 - acc: 1.0000 - val_loss: 1.3209 - val_acc: 0.5000
Epoch 7/100
50/50 [==============================] - 0s 8ms/step - loss: 0.2816 - acc: 1.0000 - val_loss: 1.2207 - val_acc: 0.5000
Epoch 8/100
50/50 [==============================] - 0s 8ms/step - loss: 0.2439 - acc: 1.0000 - val_loss: 1.2348 - val_acc: 0.5000
Epoch 9/100
50/50 [==============================] - 0s 9ms/step - loss: 0.2089 - acc: 1.0000 - val_loss: 1.2679 - val_acc: 0.5000
Epoch 10/100
50/50 [==============================] - 0s 9ms/step - loss: 0.1824 - acc: 1.0000 - val_loss: 1.2359 - val_acc: 0.5000
[...]
Epoch 98/100
50/50 [==============================] - 0s 8ms/step - loss: 0.0032 - acc: 1.0000 - val_loss: 2.2686 - val_acc: 0.5000
Epoch 99/100
50/50 [==============================] - 0s 8ms/step - loss: 0.0032 - acc: 1.0000 - val_loss: 2.2791 - val_acc: 0.5000
Epoch 100/100
50/50 [==============================] - 0s 8ms/step - loss: 0.0031 - acc: 1.0000 - val_loss: 2.2894 - val_acc: 0.5000

@ozabluda
Copy link
Contributor Author

@ozabluda ozabluda commented Jan 28, 2018

The problem is not happening without

for layer in base_model.layers:
    layer.trainable = False

Note that this is repo master i.e with 24246ea already merged, see #8616 (comment)

@ozabluda
Copy link
Contributor Author

@ozabluda ozabluda commented Feb 13, 2018

Just checked with TF 1.5 and Keras master, and the behavior is unchanged. Also identical on CPU (which I didn't check before).

@fchollet, it appears to be a serious problem, because AFAIK I follow (simple) docs to the letter.

@GougeC
Copy link

@GougeC GougeC commented Mar 14, 2018

I am having the exact same issue. Inception does not work but VGG does fine. InceptionV3 picks the same class every time no matter what the test set is

@majiaji
Copy link

@majiaji majiaji commented Apr 2, 2018

same issue when I try ResNet50 in keras

@Zamirquito
Copy link

@Zamirquito Zamirquito commented Apr 6, 2018

I have tried using different way to optimize(Adam, SGD, SGD with momentum and so on) when I trained the ResNet50, finetune and just freeze all the layers except fc layer, My training loss is decreasing, but val accu is increasing and then just as you say, stopping at 50% ~60%......

I have tried to ramdom sampling the data and used some data augmentation tricks, but those didn't work.

@NProdanova
Copy link

@NProdanova NProdanova commented Apr 6, 2018

I am having the same problem with Inception-v3, while VGG19 works. I can also confirm that when I remove

layer.trainable = False

validation accuracy starts imroving.
I posted a question for a workaround on stackoverflow:
https://stackoverflow.com/questions/49689122/keras-inception-v3-fine-tuning-workaraound

Maybe somebody has a suggestion

@ciprianfocsaneanu
Copy link

@ciprianfocsaneanu ciprianfocsaneanu commented Apr 18, 2018

I am having the same problem with ResNet50. I am doing transfer learning and the same dataset/code works for InceptionV3 and DenseNet121, but ResNet seems to always predict one class

@datumbox
Copy link
Contributor

@datumbox datumbox commented Apr 23, 2018

For all of you who are affected by this, please have a look at PR #9965. This probllem is caused by the way that the Batch Normalization layer is implemented in Keras.

To understand why this happens we need to understand how the BN works. When the network is in training mode, the mini-batch statistics of BN are used for training the network; when the network is in inference mode, we use the moving mean/var learned during the training. That's all good. The problem is how the layer behaves when it is frozen. Its side-effects are more profound when we use fine-tuning and Transfer Learning.

You see, when frozen and while in training mode the BN continues to use the mini-batch statistics for scaling the training data. This causes the unfrozen/trainable layers to adapt to the scale of the data. Unfortunately during inference mode (predictions) the network will switch to the moving mean/var. If the moving mean/var is different that the mini-batch statistics the data are scaled differently causing massive discrepancies on the accuracy. If you want more info, have a look at the PR.

@jksmither
Copy link

@jksmither jksmither commented Apr 28, 2018

@ozabluda I'm facing the same problem on Keras 2.1.6 and TensorFlow 1.7. Did you find a solution?

@datumbox I tried installing your branch and it seems to work well. Unfortunately it is not synced with the latest version of Keras. Any plans to merge it with 2.1.6? If not I can do it.

@datumbox
Copy link
Contributor

@datumbox datumbox commented May 1, 2018

@jksmither Sorry for the late response. I just synced my branch with the latest master and provided a patched fork of 2.1.6. Honestly I would like to see this fixed on master as maintaining a separate fork with the patch is not a viable solution on the long term. I'll probably keep syncing it for as long as we use Keras at work but I can't make any promises.

@shazamkash
Copy link

@shazamkash shazamkash commented Jun 5, 2018

Did anyone find a concrete solution to this problem? I am also affected by this problem and I am working on Keras 2.1.6 and TensorFlow 1.7 to train and test my data using InceptionV3 and Resnet50.

I am very new to deep learning and any help will be appreciated.

@izharikov
Copy link

@izharikov izharikov commented Jun 5, 2018

@shazamkash

This is temp fix:

pip install -U --force-reinstall --no-dependencies git+https://github.com/datumbox/keras@bugfix/trainable_bn 

So, this command reinstall keras with fixes, provided by @datumbox.
This works for me (Inception started train normally).

@shazamkash
Copy link

@shazamkash shazamkash commented Jun 6, 2018

@izharikov @datumbox
Thank you so much for providing this information and the patch. Although are we expecting a permanent fix anytime soon?
Also has anyone faced this problem using just the TensorFlow? Any experience with that?

@datumbox
Copy link
Contributor

@datumbox datumbox commented Jun 6, 2018

Thanks for the input. I would advise using the fork 2.1.6 instead of the trainable_bn branch. This is because the fork is synced with the latest stable version of keras, while the trainable_bn even though it's more fresh it's not based on a finalized release.

Unfortunately there are no plans for a permanent fix at the moment. My PR #9965 was rejected (you can read the rational on the link) because it modifies the semantics of trainable. It's not the first time that the BatchNormalization layer forces us to update the semantics of trainable (see version 2.1.3) but it can take a while until such a change gets enough momentum. So maybe on the future if enough people complain about it, it will reopen. Until then I'll do my best to maintain the fork for those of you who are brave enough to mess with custom implementations.

@drsxr
Copy link

@drsxr drsxr commented Jun 13, 2018

Wow. Was spinning my wheels for a while with ResNet50 training trying to fine tune until I found these threads. Same problems. So batchnorm in Keras = no fine tuning? Either paste the FC layer on top of trained weights (imagenet) or train from scratch. I'm working with smaller N's so training from scratch with augmentation is not a defacto solution. A shame too because Transfer Learning is looking more attractive lately - see : "Do Better ImageNet Models Transfer Better?"

Apart from @datumbox's patch, or moving over to another framework, any other workarounds?
@fchollet suggested this in the main discussion thread on Vasils' PR (#9965):

But even if you don't, you can still do your style of fine-tuning in this case:

set learning phase to 0
load model
retrieve features you want to train on
set learning phase to 1
add new layers on top
optionally load weights from initial model layers to corresponding new layers
train

@joeyearsley
Copy link
Contributor

@joeyearsley joeyearsley commented Jun 18, 2018

for layer in base_model.layers:
        if hasattr(layer, 'moving_mean') and hasattr(layer, 'moving_variance'):
            layer.trainable = True
            K.eval(K.update(layer.moving_mean, K.zeros_like(layer.moving_mean)))
            K.eval(K.update(layer.moving_variance, K.zeros_like(layer.moving_variance)))
        else:
            layer.trainable = False

Reset the batch norm moving averages and allow them to update to the new dataset - you'll see it transfer.

I'm writing a longer update on this matter and will open issues (easy PRs) so people can help contribute to fixing the documentation and the like.

@datumbox
Copy link
Contributor

@datumbox datumbox commented Jun 18, 2018

Unfreezing the BNs while keeping the subsequent Convolutions frozen can have negative effects on accuracy. I describe this in more detail here.

@AndreGuerra123
Copy link

@AndreGuerra123 AndreGuerra123 commented Aug 18, 2018

Just a simple question: why the value 1024 as units in the last dense layer?
x = Dense(1024, activation='relu')(x)

@ozabluda
Copy link
Contributor Author

@ozabluda ozabluda commented Aug 19, 2018

x = Dense(1024, activation='relu')(x) is because

The code follows "Fine-tune InceptionV3 on a new set of classes" from https://keras.io/applications/#usage-examples-for-image-classification-models

For the purpose of this Issue, I was following the docs literally for stronger effect to make my point.

@xpngzhng
Copy link

@xpngzhng xpngzhng commented Aug 20, 2018

I fine tune keras pretrained model on my own dataset. I freeze some the layers in the early stage. I got decent validation accuracy on VGG, but bad validation accuracy on ResNet50.

VGG
epoch 1: train_acc 0.546514682730133, val_acc 0.6607583973804312
epoch16: train_acc 0.9279250402631126, val_acc 0.7440402661072892
There is overfitting.

ResNet50
epoch 1: train_acc 0.7301661501087283, val_acc 0.04389513340162522
Then I terminate the training.

I think this may be caused by BatchNormalization.

I once used keras-retinanet https://github.com/fizyr/keras-retinanet to train on my own dataset, which worked very well. So I want to find out the reason. RetinaNet uses ResNet as backbone, and BatchNormalization layers are frozen, see https://github.com/fizyr/keras-retinanet/blob/master/keras_retinanet/models/resnet.py#L98

The ResNet in that project is borrowed from another repo keras-resnet https://github.com/broadinstitute/keras-resnet. In this ResNet implementation, the authors customize the BatchNormalize layer, see
https://github.com/broadinstitute/keras-resnet/blob/master/keras_resnet/layers/_batch_normalization.py

    def call(self, *args, **kwargs):
        # return super.call, but set training
        return super(BatchNormalization, self).call(training=(not self.freeze), *args, **kwargs)

It seems that this operation is what @fchollet recommends in @datumbox 's PR
#9965 (comment)

I think it would be better to use keras-resnet https://github.com/broadinstitute/keras-resnet for fine tuning. I have not tried yet.

Yesterday I tried fine tuning InceptionV3 on the same dataset, with half of the layers set untrainable. But it is somewhat strange that the validation accuracy is quite well.

InceptionV3
epoch 1: train_acc 0.7295633685318796, val_acc 0.7712759384936045
epoch 10: train_acc 0.9254250409199613, val_acc 0.7986968871106656

The code I use is something like this https://gist.github.com/XupingZHENG/1e20d54a70c8e04912c0b37fa7e7b931

@cesarorosco
Copy link

@cesarorosco cesarorosco commented Sep 18, 2018

I have the same problem with Resnet50.

This seems to work.

-Set the learning phase to 1
-In every batch normalization layer set Training=False

After that I get the correct accuracy.

K.set_learning_phase(1)

base_model = ResNet50(weights='imagenet', include_top=False)

x = base_model.output

x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(2, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)

for layer in base_model.layers:
    layer.trainable = False
    
    if layer.name.startswith('bn'):
        layer.call(layer.input, training=False)

@datumbox
Copy link
Contributor

@datumbox datumbox commented Sep 28, 2018

@cesarorosco This means that you network runs always on Training mode. Even when you make predictions you use the mini-batch statistics. This is not great as your predictions will change depending on what images you pass on the batch.

@lpenet
Copy link

@lpenet lpenet commented Dec 2, 2018

If I proceed as @cesarorosco mentionned, things go even worse and I stick around a ~0.5 acc on the learning and validation sets...

It is a bit disappointing to have this kind of problem, especially when this kind of stuff is presented as too simple when using Keras...

It is such a great tool, anyway.

@chabir
Copy link

@chabir chabir commented Jan 8, 2019

All,
I don't know if this is completely related but I have been refering to this issue for several days now. I used the patch mentioned by @datumbox and I observed the following:

Keras version: 2.2.4

from keras.applications.resnet50 import ResNet50
from keras.layers import Dense, GlobalAveragePooling2D

case 1: model converges:
base_model = ResNet50(input_shape=(img_size, img_size, 3), weights=None, classes=5004)

case 2: model never ever converges:
res50 = ResNet50(input_shape=(img_size, img_size, 3), weights=None, include_top=False)
x = GlobalAveragePooling2D()(res50.output)
output = Dense(5004)(x)
base_model = Model(inputs=[res50.input], outputs=[output])

any ideas ?

Edit: I forgot the activation layer softmax in case 2...

@Wazaki-Ou
Copy link

@Wazaki-Ou Wazaki-Ou commented Feb 20, 2019

I think I might be affected by a similar issue, except that I am using VGG16 instead. The training goes really well, to be honest a bit unexpectedly well, and it goes beyond 90%. I checked the model accuracy on a testing set and it gives good results and the same thing happens when I check with a confusion matrix. The problem is that, when I try to use the model to predict_classes() on my testing set (the same one that gave good results on accuracy and confusion matrix), the predictions are awfully bad. One class seems to be preferred over the others and I get 0 accuracy in 2 or the 5 classes. I was asked to check this post and I am wondering if anyone could help. Thanks a lot !!

@rohan19250
Copy link

@rohan19250 rohan19250 commented Feb 21, 2019

I am having the same issue as well.Inception v3 giving low validation accuracy but high training accuracy.What would be the suggestion to fix this?

Epoch 1/30
185/185 [==============================] - 5478s 30s/step - loss: 0.1161 - acc: 0.9493 - val_loss: 2.4898 - val_acc: 0.5678
Epoch 2/30
185/185 [==============================] - 5453s 29s/step - loss: 0.0362 - acc: 0.9861 - val_loss: 1.1530 - val_acc: 0.7678
Epoch 3/30
185/185 [==============================] - 5457s 29s/step - loss: 0.0280 - acc: 0.9902 - val_loss: 5.4614 - val_acc: 0.4506
Epoch 4/30
185/185 [==============================] - 5458s 30s/step - loss: 0.0184 - acc: 0.9934 - val_loss: 5.2297 - val_acc: 0.5117
Epoch 5/30
185/185 [==============================] - 5474s 30s/step - loss: 0.0146 - acc: 0.9954 - val_loss: 4.2587 - val_acc: 0.5586
Epoch 6/30
185/185 [==============================] - 5463s 30s/step - loss: 0.0113 - acc: 0.9965 - val_loss: 4.5049 - val_acc: 0.6019
Epoch 7/30
185/185 [==============================] - 5467s 30s/step - loss: 0.0099 - acc: 0.9972 - val_loss: 6.9422 - val_acc: 0.3551
Epoch 8/30
185/185 [==============================] - 5467s 30s/step - loss: 0.0099 - acc: 0.9969 - val_loss: 5.8211 - val_acc: 0.4901
Epoch 9/30
185/185 [==============================] - 5466s 30s/step - loss: 0.0112 - acc: 0.9965 - val_loss: 5.2108 - val_acc: 0.5518
Epoch 10/30
185/185 [==============================] - 5471s 30s/step - loss: 0.0113 - acc: 0.9964 - val_loss: 6.1660 - val_acc: 0.5092
Epoch 11/30
104/185 [===============>..............] - ETA: 37:44 - loss: 0.0140 - acc: 0.9958

@pchris24
Copy link

@pchris24 pchris24 commented Mar 7, 2019

same issue when I try ResNet50 in keras

Same here too

I use ResNet50 for fine-tuning. I want to predict the results for two classes. In one class the validation and trainning accuracy is 45% and on the other it's 0% but I unfreeze the final set of conv layers.

@AkshayRoy
Copy link

@AkshayRoy AkshayRoy commented Apr 2, 2019

i tried to run resnet50 model for classifying colors of clothes, i used image net weights and added a globalavg pool layer, 2 dense and dropouts and final output layer with sigmoid/softmax. i frooze all the layers except the newly added ones then started training. Training goes well for some time and i managed to get some accuracies but when i tested my model, the predicitions are all wrong. can anybody help me solve this?

@xhm1014
Copy link

@xhm1014 xhm1014 commented Aug 12, 2019

I tried transfer learning on resnet50 in keras for two class classification problem. I only fine-tune the top fully-connected layer, while all other layers are frozen. I encounter the same problem: training accuracy is increasing as expected, but validation accuracy is only between 50-60%.

Have any one had good solutions to overcome this problem in keras, please?

@geometrikal
Copy link

@geometrikal geometrikal commented Aug 25, 2019

@xhm1014 Pre-compute the vectors using resnet50, then train model with only the dense layers on the vectors. After training, join the dense layers to the resnet50 layers if you want to save the whole network.

@BraveDistribution
Copy link

@BraveDistribution BraveDistribution commented Aug 25, 2019

@geometrikal

could you provide any example how to do that? i know that this w/a was suggested by the author of keras, but I couldn't find any way how to do that.

It is really shame that official docs doesn't mention this. Any other model than VGG (without BN) is useless if you want to freeze any of the layers.

@geometrikal
Copy link

@geometrikal geometrikal commented Aug 25, 2019

@BraveDistribution Yea it is a bit strange.

This is how I do it, (taken from my repo here: https://github.com/microfossil/particle-classification/blob/master/miso/training/model_trainer.py )

Functions to make the head and tail:

def resnet50_head(input_shape):
    inputs = Input(shape=input_shape)
    x = Lambda(lambda y: tf.reverse(y, axis=[-1]))(inputs)
    x = Lambda(lambda y: y * tf.constant(255.0)
                         - tf.reshape(tf.constant([103.939, 116.779, 128.68]),
                                      [1, 1, 1, 3]))(x)
    x = resnet50.ResNet50(include_top=False,
                          weights='imagenet',
                          pooling='avg')(x)
    model = Model(inputs=inputs, outputs=x)
    model.get_layer('resnet50').trainable = False
    return model


def tl_tail(nb_classes):
    model = Sequential()
    model.add(Dropout(0.05))
    model.add(Dense(512, activation='relu'))
    model.add(Dropout(0.15))
    model.add(Dense(512, activation='relu'))
    model.add(Dense(nb_classes, activation='softmax'))
    return model

Note the two lambda layers to make the head. The pre-trained ResNet50 uses some pre-processing that takes away the channel averages (of ImageNet). In my datasets I use images scaled to the range [0,1] by divding by 255. So these two layers convert from [0,1] range to correct pre-processing used by the pre-trained network. You can remove them if you are using resnet50.preprocessing() to create the images.

Now create the head and make the vectors:

model_head = resnet50_head(input_shape=(224,224,3))
train_vectors = model_head.predict(train_images)
test_vectors = model_head.predict(test_images)

Make the tail and train with these vectors:

model_tail = tl_tail(nb_classes=N)
model_tail.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
history = model_tail.fit(train_vector,
                                 train_onehots,
                                 validation_data=(test_vector, test_onehots),
                                 epochs=max_epochs,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 verbose=0)

Then if you want to have a network that takes image as input, join them:

outputs = model_tail(model_head.outputs[0])
model = Model(model_head.inputs[0], outputs)

and then you can do

train_preds = model.predict(train_images) etc

Pre-calculating the vectors like this makes training very fast. You can even do it on a CPU.

By the way, my repo is designed for a project where we are enabling non-ML people to train networks. If you are interested see the docs read me here: https://github.com/microfossil/particle-classification-examples and especially the google colab tutorial (and use resnet50_tl as the cnn type)

@digital-thinking
Copy link

@digital-thinking digital-thinking commented Oct 6, 2019

I also noticed this while training the efficientNet model, which includes BatchNormalization. I observed that it seems like freezing the BN layer, leads to bad accuracy (wrong predictions) in validation phase, while in training phase everything looks fine. When un-freezing the BN, the test accuracy recovers.

Here is an example notebook and here is some additional information:

@TimbusCalin
Copy link

@TimbusCalin TimbusCalin commented Nov 20, 2019

@digital-thinking I think it would be interesting to see the results if a very small batch_size(1/2) is used; that is the specific case in which for using pre-trained backbones(especially for segmentation), freezing the BN layer is recommended.

@Nestak2
Copy link

@Nestak2 Nestak2 commented Jul 1, 2020

@digital-thinking Thanks, can you point out which are the precise code parts you refer to? Do I understand you correctly:

  • first we need to train the model a few epochs with the setting
for layer in base_model.layers:
    if isinstance(layer, BatchNormalization):
        layer.trainable = True
    else:
        layer.trainable = False
  • then we need to run the very same trained model more epochs with the settings
for layer in model.layers:
    layer.trainable = True

Is this what you are saying in your post?

@digital-thinking
Copy link

@digital-thinking digital-thinking commented Jul 1, 2020

Hi @Nestak2 yes this is what I had to do to get reasonable validation metrics. If you train only the top-layer and don't make Batch Normalization layers trainable, you won't get correct results. If the hole model is trainable, there is no problem. I guess it's because the BatchNormalization layer is completely fixed and therefore it does not normalize the batch anymore.

@Nestak2
Copy link

@Nestak2 Nestak2 commented Jul 1, 2020

@digital-thinking Thanks for clarifying! Unfortunately this strategy didn't work out for me - the validation loss and accuracy were bad in both the "warm-up" and the proper training phases. I post the heart part of my code and the training metrics. Can you spot what I am doing wrong? Tnx

image_input = Input(shape=data_shape)
basemodel = ResNet50(input_tensor=image_input, include_top=True,weights='imagenet')
x = basemodel.get_layer('avg_pool').output

x = Flatten(name='flatten')(x)
x = Dense(512, activation="relu")(x)
x = Dropout(0.5)(x)

out = Dense(num_classes, activation='softmax', name='output_layer')(x)
custom_resnet_model = Model(inputs=image_input,outputs= out)

custom_resnet_model.compile(Adam(lr=0.001),loss='categorical_crossentropy',metrics=['accuracy'])
model = custom_resnet_model

######################
#### do the warm-up training for a the head of the model for a few epochs.
#### the tail of the model is not being trained 
for layer in model.layers:
    if isinstance(layer, BatchNormalization):
        layer.trainable = True
    else:
        layer.trainable = False

H = model.fit_generator(
    generator = train_generator,
    steps_per_epoch=train_generator.n//train_generator.batch_size,
    validation_data=valid_generator,
    validation_steps=valid_generator.n//valid_generator.batch_size,
    # class_weight=class_weights,
    epochs=15
    )
#######################


#######################
#### do the proper training of all the layers
for layer in model.layers:
    layer.trainable = True

print('check of validation ims 2:', valid_generator.filepaths[:5])

H = model.fit_generator(
    generator = train_generator,
    steps_per_epoch=train_generator.n//train_generator.batch_size,
    validation_data=valid_generator,
    validation_steps=valid_generator.n//valid_generator.batch_size,
    # class_weight=class_weights,
    epochs=num_epochs
    )
########################

This gives me this training metrics output, which shows a decent learning performance, but bad validation:

 9/10 [==========================>...] - ETA: 7s - loss: 0.9657 - acc: 0.7240 Epoch 1/15
74/10 [==============================================================================================================================================================================================================================] - 13s 171ms/step - loss: 26.6340 - acc: 0.1757
10/10 [==============================] - 85s 9s/step - loss: 0.9069 - acc: 0.7391 - val_loss: 29.8742 - val_acc: 0.1757
Epoch 2/15
 9/10 [==========================>...] - ETA: 5s - loss: 0.2085 - acc: 0.9446 Epoch 1/15
74/10 [==============================================================================================================================================================================================================================] - 9s 121ms/step - loss: 40.7067 - acc: 0.1081
10/10 [==============================] - 67s 7s/step - loss: 0.2051 - acc: 0.9487 - val_loss: 51.1770 - val_acc: 0.1081
Epoch 3/15
 9/10 [==========================>...] - ETA: 5s - loss: 0.2235 - acc: 0.9464 Epoch 1/15
74/10 [==============================================================================================================================================================================================================================] - 9s 121ms/step - loss: 222.0811 - acc: 0.1757
10/10 [==============================] - 62s 6s/step - loss: 0.2124 - acc: 0.9490 - val_loss: 284.9632 - val_acc: 0.1757
Epoch 4/15
 9/10 [==========================>...] - ETA: 5s - loss: 0.1671 - acc: 0.9554 Epoch 1/15
74/10 [==============================================================================================================================================================================================================================] - 9s 120ms/step - loss: 27.6185 - acc: 0.1892
10/10 [==============================] - 64s 6s/step - loss: 0.1536 - acc: 0.9599 - val_loss: 50.7521 - val_acc: 0.1892
Epoch 5/15
 9/10 [==========================>...] - ETA: 5s - loss: 0.0500 - acc: 0.9844 Epoch 1/15
74/10 [==============================================================================================================================================================================================================================] - 9s 120ms/step - loss: 30.1733 - acc: 0.1757
10/10 [==============================] - 65s 7s/step - loss: 0.0565 - acc: 0.9828 - val_loss: 40.2771 - val_acc: 0.1757
Epoch 6/15
 9/10 [==========================>...] - ETA: 5s - loss: 0.0466 - acc: 0.9853 Epoch 1/15
74/10 [==============================================================================================================================================================================================================================] - 9s 120ms/step - loss: 13.8834 - acc: 0.1757
10/10 [==============================] - 62s 6s/step - loss: 0.0429 - acc: 0.9868 - val_loss: 18.6932 - val_acc: 0.1757
Epoch 7/15
 9/10 [==========================>...] - ETA: 5s - loss: 0.0214 - acc: 0.9946 Epoch 1/15
74/10 [==============================================================================================================================================================================================================================] - 9s 122ms/step - loss: 20.4458 - acc: 0.1757
10/10 [==============================] - 64s 6s/step - loss: 0.0388 - acc: 0.9904 - val_loss: 25.5468 - val_acc: 0.1757
Epoch 8/15
 9/10 [==========================>...] - ETA: 5s - loss: 0.0701 - acc: 0.9821 Epoch 1/15
74/10 [==============================================================================================================================================================================================================================] - 9s 121ms/step - loss: 14.2723 - acc: 0.1892
10/10 [==============================] - 65s 6s/step - loss: 0.0732 - acc: 0.9824 - val_loss: 24.9597 - val_acc: 0.1892
Epoch 9/15
 9/10 [==========================>...] - ETA: 5s - loss: 0.0761 - acc: 0.9768 Epoch 1/15
74/10 [==============================================================================================================================================================================================================================] - 9s 121ms/step - loss: 2.8525 - acc: 0.1622
10/10 [==============================] - 64s 6s/step - loss: 0.0689 - acc: 0.9792 - val_loss: 3.2069 - val_acc: 0.1622
Epoch 10/15
 9/10 [==========================>...] - ETA: 5s - loss: 0.0368 - acc: 0.9861 Epoch 1/15
74/10 [==============================================================================================================================================================================================================================] - 9s 123ms/step - loss: 6.5042 - acc: 0.1622
10/10 [==============================] - 64s 6s/step - loss: 0.0332 - acc: 0.9872 - val_loss: 6.5715 - val_acc: 0.1622
Epoch 11/15
 9/10 [==========================>...] - ETA: 5s - loss: 0.0753 - acc: 0.9826 Epoch 1/15
74/10 [==============================================================================================================================================================================================================================] - 9s 122ms/step - loss: 2.4013 - acc: 0.1622
10/10 [==============================] - 66s 7s/step - loss: 0.1494 - acc: 0.9750 - val_loss: 2.6977 - val_acc: 0.1622
Epoch 12/15
 9/10 [==========================>...] - ETA: 5s - loss: 0.1294 - acc: 0.9653 Epoch 1/15
74/10 [==============================================================================================================================================================================================================================] - 9s 115ms/step - loss: 17.0063 - acc: 0.1622
10/10 [==============================] - 65s 6s/step - loss: 0.1289 - acc: 0.9656 - val_loss: 18.5074 - val_acc: 0.1622
Epoch 13/15
 9/10 [==========================>...] - ETA: 5s - loss: 0.0855 - acc: 0.9816 Epoch 1/15
74/10 [==============================================================================================================================================================================================================================] - 9s 120ms/step - loss: 180.1986 - acc: 0.1622
10/10 [==============================] - 66s 7s/step - loss: 0.0788 - acc: 0.9836 - val_loss: 174.2110 - val_acc: 0.1622
Epoch 14/15
 9/10 [==========================>...] - ETA: 5s - loss: 0.0823 - acc: 0.9792 Epoch 1/15
74/10 [==============================================================================================================================================================================================================================] - 9s 120ms/step - loss: 211.2853 - acc: 0.1622
10/10 [==============================] - 66s 7s/step - loss: 0.0802 - acc: 0.9797 - val_loss: 199.9975 - val_acc: 0.1622
Epoch 15/15
 9/10 [==========================>...] - ETA: 5s - loss: 0.1303 - acc: 0.9821 Epoch 1/15
74/10 [==============================================================================================================================================================================================================================] - 9s 120ms/step - loss: 17.8702 - acc: 0.1622
10/10 [==============================] - 63s 6s/step - loss: 0.1164 - acc: 0.9840 - val_loss: 18.2201 - val_acc: 0.1622
check of validation ims 2: ['./ims_train2/bottle/20200623_202617.jpg', './ims_train2/bottle/20200623_202619.jpg', './ims_train2/bottle/20200623_202624.jpg', './ims_train2/bottle/20200623_202629.jpg', './ims_train2/bottle/20200623_202636.jpg']
Epoch 1/60
 9/10 [==========================>...] - ETA: 6s - loss: 0.1154 - acc: 0.9653 Epoch 1/60
74/10 [==============================================================================================================================================================================================================================] - 9s 116ms/step - loss: 4.5718 - acc: 0.1622
10/10 [==============================] - 74s 7s/step - loss: 0.1057 - acc: 0.9688 - val_loss: 7.2388 - val_acc: 0.1622
Epoch 2/60
 9/10 [==========================>...] - ETA: 5s - loss: 0.1150 - acc: 0.9724 Epoch 1/60
74/10 [==============================================================================================================================================================================================================================] - 9s 121ms/step - loss: 5.0241 - acc: 0.1892
10/10 [==============================] - 66s 7s/step - loss: 0.1047 - acc: 0.9753 - val_loss: 7.6464 - val_acc: 0.1892
Epoch 3/60
 9/10 [==========================>...] - ETA: 5s - loss: 0.4038 - acc: 0.9714 Epoch 1/60
74/10 [==============================================================================================================================================================================================================================] - 9s 122ms/step - loss: 9.9850 - acc: 0.1757
10/10 [==============================] - 64s 6s/step - loss: 0.3883 - acc: 0.9679 - val_loss: 11.0730 - val_acc: 0.1757
Epoch 4/60
 9/10 [==========================>...] - ETA: 5s - loss: 0.3822 - acc: 0.8982 Epoch 1/60
74/10 [==============================================================================================================================================================================================================================] - 9s 125ms/step - loss: 2.7105 - acc: 0.1892
10/10 [==============================] - 64s 6s/step - loss: 0.4216 - acc: 0.8926 - val_loss: 5.3383 - val_acc: 0.1892
Epoch 5/60
 9/10 [==========================>...] - ETA: 5s - loss: 0.2122 - acc: 0.9427 Epoch 1/60
74/10 [==============================================================================================================================================================================================================================] - 9s 125ms/step - loss: 2599176.3548 - acc: 0.1892
10/10 [==============================] - 66s 7s/step - loss: 0.2137 - acc: 0.9438 - val_loss: 4664818.5203 - val_acc: 0.1892
Epoch 6/60
 9/10 [==========================>...] - ETA: 5s - loss: 0.2319 - acc: 0.9761 Epoch 1/60
74/10 [==============================================================================================================================================================================================================================] - 9s 124ms/step - loss: 821284.6876 - acc: 0.1892
10/10 [==============================] - 63s 6s/step - loss: 0.2161 - acc: 0.9770 - val_loss: 1461603.6436 - val_acc: 0.1892

@therealansh
Copy link

@therealansh therealansh commented May 21, 2021

Did anyone find a workaround on the code side rather than having patch in the backend? I also tried @digital-thinking answer but that somehow doesn't work for me and the loss comes out to be -inf/nan for every epoch while the accuracy is increasing very slowly.

@fchollet fchollet closed this Jun 24, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

None yet