-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update moving_mean and moving_variance of BatchNormalization Layer use sess.run() #6752
Comments
I mean how to train moving_mean and moving_variance directly using sess.run(). |
@zzd1992, you might wanna take a look at the section "Collecting trainable weights and state updates" of https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html.
|
Actually, I think that there's a small mistake in that tutorial, because "layer" there is just a tf tensor. You need to change it into something like this: from keras.layers import BatchNormalization
layer = BatchNormalization()
blah = layer(x)
update_ops = []
for old_value, new_value in layer.updates:
update_ops.append(tf.assign(old_value, new_value))
also, it seems that layer.updates already contains the assign ops - so further change is needed into this: from keras.layers import BatchNormalization
layer = BatchNormalization()
blah = layer(x)
update_ops = []
for assign_op in layer.updates:
update_ops.append(assign_op))
please correct me if I'm wrong :) |
Is it possible to get a complete example of the above solution in order to clarify when the update_ops should be called? For instance, given the code of zzd1992 in the first post and the proposed solution, a training step would be run using
or do the update_ops need to be called separately? |
I usually just use the layers of keras. I write the training code by myself in tensorflow.
I find if I don't use model.fit function to train a model, moving_mean and moving_variance of BatchNormalization Layer will not update. That is, moving_mean is always equal to 0 and moving_variance is always equal to 1.
Here is a example of my model:
When I use model.fit to train it, moving_mean and moving_variance are updated.
But when I train it use original tensorflow code like the following:
In this way, moving_mean and moving_variance are not updated.
I know we can see moving_mean and moving_variance in model.updates. But I don't know how to update them during training if I don't want to use model.fit.
Is there a simple solution?
The text was updated successfully, but these errors were encountered: