Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN gradient when using BatchNormalization with sequential input #5377

Closed
MaigoAkisame opened this issue Feb 13, 2017 · 5 comments
Closed

NaN gradient when using BatchNormalization with sequential input #5377

MaigoAkisame opened this issue Feb 13, 2017 · 5 comments

Comments

@MaigoAkisame
Copy link

MaigoAkisame commented Feb 13, 2017

I was training a neural network that works with sequential data. It consisted of convolutional and recurrent layers, as well as BatchNormalization layers in between. When I trained the model, I got NaN loss on the first minibatch. A closer inspection revealed that the gradient started becoming NaN for the gamma parameter of the topmost BatchNormalization layer.

I have stripped all the convolutional and recurrent layers off the model, and ended up with the following minimal code that reproduces the problem:

import numpy
import keras.backend as K
from keras.layers import *
from keras.models import *

# Build a model with a single BatchNormalization layer.
# Each sequence in the input has 10 time steps, and each step is a 1-dimensional vector
model = Sequential([BatchNormalization(name = 'bn', input_shape = (10, 1))])
model.compile(optimizer = 'sgd', loss = 'mse')

# Build computation graph for the gradient, see https://github.com/fchollet/keras/issues/2226
model = model.model
    # Get the internal model
weights = model.trainable_weights
    # Trainable weights of the model, including gamma and beta of the BatchNormalization layer
gradients = model.optimizer.get_gradients(model.total_loss, weights)
    # Symbolic variable for the gradients
inputs = model.inputs + model.targets + model.sample_weights + [K.learning_phase()]
    # All symbolic variables that the gradients depend on
grad_func = K.function(inputs, gradients)
    # Build function to compute the gradients

# Generate dummy data
x = numpy.random.uniform(size = (40, 10, 1))  # Input: 40 sequences of 10 time steps
y = numpy.random.uniform(size = (40, 10, 1))  # Output: 40 sequences of 10 time steps
w = numpy.ones(40)  # Sequence weights: equal weights for all sequences

# Compute and print gradient
grad = grad_func([x, y, w, 0])
print list(zip(weights, grad))

The output is:

[(bn_gamma, array([ nan], dtype=float32)), (bn_beta, array([-0.03533024], dtype=float32))]

The gradient for the gamma parameter is NaN.

By the way, if I do not use sequential data -- i.e. replace the input_shape of the BatchNormalization layer with (1,), and change the shape of x and y to (40, 1) -- the NaN gradient problem disappears.

My questions:

  1. Is the first line of my code the correct way to build a BatchNormalization layer, if I want the normalization to happen across all the 40 * 10 time steps in a minibatch?
  2. If it is, how do I get rid of the NaN gradient problem?

I am using Keras 1.2.1, Theano 0.9.0b1 (0.9.0beta1.dev-e5d51daf5ac03bfd8bd076075ee587311dff6f48), and CUDA 8.0.

@MaigoAkisame
Copy link
Author

MaigoAkisame commented Feb 14, 2017

I have found that the problem lies with Theano 0.9.0b1.

In the Theano backend, Keras calls the following batch normalization functions in Theano, if they are available:

  • T.nnet.bn.batch_normalization_train, T.nnet.bn.batch_normalization_test

or the following functions, implemented in Keras (backend/theano_backend.py), if Theano does not provide the functions above:

  • _old_normalize_batch_in_training, _old_batch_normalization

Theano 0.9.0b1 provides these functions, but it seems that they produce NaN gradients unless the "axis" argument is 1.

Switching back to Theano 0.8.2, which does not provide these functions, makes Keras use its own implementation, and solves the problem.

@nouiz
Copy link
Contributor

nouiz commented Feb 22, 2017

The issue in Theano was closed. So I think this issue can also be closed.

@MaigoAkisame
Copy link
Author

Yes. I resolved it by disabling fastmath.

@fengwang
Copy link

fengwang commented Sep 4, 2018

@MaigoAkisame I have exactly the same Nan problem when employing batch normalization layers with Tensorflow backend. Would you please open this issue again?

@MaigoAkisame
Copy link
Author

@fengwang OK, I'm reopening it.

@MaigoAkisame MaigoAkisame reopened this Sep 4, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants