New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added function to compute activation statistics for BatchNormalization #82
Conversation
I'm thinking about how do handle this. Ideally, for the sake of good design and simplicity, we want activation stats to be learned automatically by the layer at train time without having to call any 3rd party method. One way to achieve this would be to include the stats in the updates of the model._train() method. This can be done by concatenating the gradient updates computed by the optimizer with layer-specific updates. I will have a shot at implementing something like that maybe tonight, and see if it works (unless you want to do it). The way I'm seeing it right now:
|
The problem with storing the sum of all inputs during training is that since the network parameters will be changing over every iteration, so will the activation statistics... or maybe I didn't understand exactly what you intend to do? I think way I implemented is exactly what is described in the batch normalization paper. It is terribly slow, though. If you want to keep a moving average over batches/iterations, that may be possible, but you will only have approximations to all of the statistics past the first batch normalization layer (because once you start using the mean and std activations on it, all the activations after that layer will be different). Please give it a go and let me know if you find a different solution. Otherwise, we could just include the function I implemented in the end of the |
Network parameters stabilize fast when it comes to mean and std, so I would think doing an exponential average of the inputs during training should show no difference compared to doing exact computations after training (I will attempt to test this hypothesis in practice)...
That would work for model.fit(), but we have to consider that batch training where the user builds their own batches will use model.train(). In general I think layers should be completely self-contained, for the sake of modularity and simplicity. |
That's a good idea. If that turns out to be true in most cases, it would save a lot of time. We can use my implementation to compare the exact values to the approximations, so at least I won't think I wasted my time if it works! :)
I completely agree. Having to remember you have to do a lot of things just because you decided to add or remove a layer is painful. |
As far as I understand, the average should be over the statistics of each epoch rather than iteration. Please do correct me if I'm wrong |
Yes, the average is over all batches in the training data, which corresponds to a full epoch. |
I am doing just that! By the way, I just noticed that you are computing mean and std over axis 0 (samples); I believe it should be the last axis instead (dimensions). At least the latter performs much better empirically (the former performs worse than no batch normalization at all). |
I am computing it over samples because then we get the mean/std for each activation x^k, instead of a single mean/std for all the activations in a given layer. Again, this is what is done in the batch normalization paper, but in practice something else could perform better. |
You are right, my mistake. In any case, it is interesting to notice that samplewise batch normalization has been performing better than featurewise batch normalization on the 3 tasks I tried it on so far. I will keep investigating. |
fwiw, something that works pretty well in practice for us is a running-estimate of mean/std with a momentum term. A momentum of 0.1 works reasonably okay (estimate over the last 1280 samples, for a batchsize of 128 for example), and the momentum is configurable in the constructor: I've trained several imagenet models with this. |
@soumith Great, thank you for the input! I will try this. |
@soumith Just tried implementing this. On the task I tested it on (Kaggle Otto challenge), the performance seems to be essentially the same as what you get by using the statistics of the current batch (as long as batches are fairly large --128 works well). I suppose in this case a sample of 128 is sufficiently representative of the data. Still, conceptually, the rolling estimate seems like a good strategy. I'll test it on a couple more tasks and that's what we'll go with. Thanks for the pointer! |
This adds a function,
set_activation_stats(model, X_train, batch_size, verbose)
that computes the activation mean and standard deviation for each BatchNormalization layer in a model and stores them as layer attributes. It also changes the BatchNormalization layer to use this information during test/inference. The only issue I see with this, which comes from using batch normalization and not from the implementation, is that in theory one needs to compute the activation statistics after each epoch if they want to check the validation or test performance. Otherwise, since they will not have these values, we don't have a way to normalize the data during testing.