Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added function to compute activation statistics for BatchNormalization #82

Closed
wants to merge 1 commit into from

Conversation

jfsantos
Copy link
Contributor

This adds a function, set_activation_stats(model, X_train, batch_size, verbose) that computes the activation mean and standard deviation for each BatchNormalization layer in a model and stores them as layer attributes. It also changes the BatchNormalization layer to use this information during test/inference. The only issue I see with this, which comes from using batch normalization and not from the implementation, is that in theory one needs to compute the activation statistics after each epoch if they want to check the validation or test performance. Otherwise, since they will not have these values, we don't have a way to normalize the data during testing.

@fchollet
Copy link
Member

I'm thinking about how do handle this. Ideally, for the sake of good design and simplicity, we want activation stats to be learned automatically by the layer at train time without having to call any 3rd party method.

One way to achieve this would be to include the stats in the updates of the model._train() method. This can be done by concatenating the gradient updates computed by the optimizer with layer-specific updates.

I will have a shot at implementing something like that maybe tonight, and see if it works (unless you want to do it). The way I'm seeing it right now:

  • store sum of all inputs (with a mechanism to scale it down to avoid overflow) and number of samples inside the layer
  • expose a layer.updates list that to compute the above, that will be included in the updates in model._train()
  • at test time, compute mean and std once using the local stats, and cache the values inside the layer. We need some mechanism to make sure we keep updating these values as new training data is fed (I assume that calling the layer with train=False will trigger computing/caching of the values, and calling it with train=True will reset the values).

@jfsantos
Copy link
Contributor Author

The problem with storing the sum of all inputs during training is that since the network parameters will be changing over every iteration, so will the activation statistics... or maybe I didn't understand exactly what you intend to do?

I think way I implemented is exactly what is described in the batch normalization paper. It is terribly slow, though. If you want to keep a moving average over batches/iterations, that may be possible, but you will only have approximations to all of the statistics past the first batch normalization layer (because once you start using the mean and std activations on it, all the activations after that layer will be different).

Please give it a go and let me know if you find a different solution. Otherwise, we could just include the function I implemented in the end of the fit method, and only call it if there's a BatchNormalization layer in model.layers.

@fchollet
Copy link
Member

The problem with storing the sum of all inputs during training is that since the network parameters will be changing over every iteration, so will the activation statistics... or maybe I didn't understand exactly what you intend to do?

Network parameters stabilize fast when it comes to mean and std, so I would think doing an exponential average of the inputs during training should show no difference compared to doing exact computations after training (I will attempt to test this hypothesis in practice)...

Otherwise, we could just include the function I implemented in the end of the fit method, and only call it if there's a BatchNormalization layer in model.layers.

That would work for model.fit(), but we have to consider that batch training where the user builds their own batches will use model.train().

In general I think layers should be completely self-contained, for the sake of modularity and simplicity.

@jfsantos
Copy link
Contributor Author

Network parameters stabilize fast when it comes to mean and std, so I would think doing an exponential average of the inputs during training should show no difference compared to doing exact computations after training (I will attempt to test this hypothesis in practice)...

That's a good idea. If that turns out to be true in most cases, it would save a lot of time. We can use my implementation to compare the exact values to the approximations, so at least I won't think I wasted my time if it works! :)

That would work for model.fit(), but we have to consider that batch training where the user builds their own batches will use model.train(). In general I think layers should be completely self-contained, for the sake of modularity and simplicity.

I completely agree. Having to remember you have to do a lot of things just because you decided to add or remove a layer is painful.

@pranv
Copy link
Contributor

pranv commented Apr 24, 2015

As far as I understand, the average should be over the statistics of each epoch rather than iteration. Please do correct me if I'm wrong

@jfsantos
Copy link
Contributor Author

Yes, the average is over all batches in the training data, which corresponds to a full epoch.

@fchollet
Copy link
Member

We can use my implementation to compare the exact values to the approximations, so at least I won't think I wasted my time if it works! :)

I am doing just that!

By the way, I just noticed that you are computing mean and std over axis 0 (samples); I believe it should be the last axis instead (dimensions). At least the latter performs much better empirically (the former performs worse than no batch normalization at all).

@jfsantos
Copy link
Contributor Author

I am computing it over samples because then we get the mean/std for each activation x^k, instead of a single mean/std for all the activations in a given layer. Again, this is what is done in the batch normalization paper, but in practice something else could perform better.

@fchollet
Copy link
Member

You are right, my mistake. In any case, it is interesting to notice that samplewise batch normalization has been performing better than featurewise batch normalization on the 3 tasks I tried it on so far. I will keep investigating.

@soumith
Copy link

soumith commented May 12, 2015

fwiw, something that works pretty well in practice for us is a running-estimate of mean/std with a momentum term. A momentum of 0.1 works reasonably okay (estimate over the last 1280 samples, for a batchsize of 128 for example), and the momentum is configurable in the constructor:

I've trained several imagenet models with this.
https://github.com/torch/nn/blob/master/BatchNormalization.lua#L26

@fchollet
Copy link
Member

@soumith Great, thank you for the input! I will try this.

@fchollet
Copy link
Member

@soumith Just tried implementing this. On the task I tested it on (Kaggle Otto challenge), the performance seems to be essentially the same as what you get by using the statistics of the current batch (as long as batches are fairly large --128 works well). I suppose in this case a sample of 128 is sufficiently representative of the data.

Still, conceptually, the rolling estimate seems like a good strategy. I'll test it on a couple more tasks and that's what we'll go with. Thanks for the pointer!

@fchollet fchollet closed this May 12, 2015
hubingallin pushed a commit to hubingallin/keras that referenced this pull request Sep 22, 2023
kernel-loophole pushed a commit to kernel-loophole/keras that referenced this pull request Sep 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants