-
Notifications
You must be signed in to change notification settings - Fork 351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Brick-based batch normalization solution #920
Comments
@dwf , can you explain this part:
What is an alternative? I would think that Another point, that we already have a mechanism to store the updates in annotations. Why isn't it possible to store the batch norm updates there? |
He just means that, by default, the graph generated is the inference graph, not the training graph! |
The alternative would be for BatchNorm to do, well, batch normalization by default. Recall that there are two "modes" to batch-normalization:
My proposal is essentially that by default, graphs that include BatchNorm use the second option, and we provide a function that performs the transformation to get to the first option. A hypothetical BatchNorm brick would contain the |
Having the inference graph be the default has another useful side effect: let's say you deserialize a model and want to use it for something. Once you get a hold of the brick objects you want to use, calling |
Right, I see now. @Thrandis promised to publish his brick-based implementation. |
Re: updates on annotations, I was under the impression from this comment that #514 might stand in the way. Seeing the implementation from @Thrandis would definitely make a useful point in the discussion, I agree. |
@dwf @dmitriy-serdyuk |
Great! No great rush, and also no real need for it to be super-duper clean. |
@dwf here is my code: https://github.com/Thrandis/batch_norm It seems to be roughly the solution that you suggest (with some minor changes). I don't use the running average to estimate the population statistics (I don't like to have this I could definitively finish this code as a ccw ticket, after I'm done with the new serialization! |
+1. In my own code (pre #851) I've been using the Activation approach, mixing in Feedforward to make sure the sizes are available, and passing a
|
I wrote an implementation of how I see this working before NIPS, I can write some tests and throw it up for comment once I'm done writing my reviews for ICLR. Re: convolution, I've been gradually decreasing entropy over there, and in truth I think it's easy enough to just make BatchNormalization bricks support the API to let them be wedged into a ConvolutionalSequence (I'd like to do this for activations eventually anyway). |
Great job! @vdumoulin, can you please elaborate on your issues with On 22 January 2016 at 15:37, David Warde-Farley notifications@github.com
|
@rizar it actually wasn't the population statistics. 90cc8bb introduces |
Thanks for the explanation! On 22 January 2016 at 15:57, David Warde-Farley notifications@github.com
|
#513 and #851 have now been in progress for 8 months. I think this is a testament to the fact that the current approach is proving hard to get right.
Below is my plan to accomplish this by explicitly introducing batch normalization bricks into the graph at construction. I think this somewhat simpler, and this solves the issue of how BN population parameters get serialized. Indeed, the abstract of Ioffe & Szegedy (2015) refers to their method as making normalization "part of the model". It seems only appropriate that this be reflected in the graph that is constructed.
New brick
BatchNormalization
will house the learned shift and scale parameters (betas and gammas), as well as population mean and standard deviation statistics shared variables. The latter can't be tagged asPARAMETER
because code that takes allPARAMETER
s in a graph and blindly adapts them according to a cost will do entirely the wrong thing. I thus suggest a superclass roleADAPTABLE
(I also likeLEARNED
orLEARNED_QUANTITY
), with the tree looking like this.ADAPTABLE
PARAMETER
BATCH_NORMALIZATION_PARAMETER
BATCH_NORMALIZATION_POPULATION_MEAN
BATCH_NORMALIZATION_POPULATION_STDEV
Serialization code that looks for
PARAMETER
s to serialize should instead look forADAPTABLE
s.BatchNormalization will have an apply method that uses the Theano
batch_normalization
function. By default, the graph generated will use the population mean and standard deviations in BatchNormalization's apply method, for two reasons:Modifications to existing Bricks
As batch normalization becomes more an more ubiquitous, we want to make it easy to add it to models anywhere an affine transformation takes place. The best way to accomplish this is something I'm not entirely sure about but I think I have a pretty good idea.
I think the cleanest way is to introduce
BatchNormalizedLinear
andBatchNormalizedConvolutional
classes. These would be drop-in replacements that include batch normalization as a child brick. In order to use these within anMLP
(or aConvolutional*
), these could grow keyword arguments not unlike the "prototype" keyword in the parallel bricks; you could either pass a lazily-allocatedBatchNormalized*
brick or simply the class object itself, e.g.ConvolutionalActivation(..., convolution_prototype=BatchNormalizedConvolutional())
vs.ConvolutionalActivation(..., convolutional_class=BatchNormalizedConvolutional)
.Note that you could use this scheme with
Parallel
bricks by passingBatchNormalizedLinear(use_bias=False)
as the prototype.Support functions
To actually train with batch normalization, two support functions are provided:
batch_normalization_training_graph(graph)
replacesBATCH_NORMALIZATION_POPULATION_MEAN
s with the minibatch-wise means from the application method of the corresponding bricks, ditto for standard deviations.batch_normalization_population_updates(graph, alpha)
returns anOrderedDict
of updates which perform moving average updates on the population statistics for all theBATCH_NORMALIZATION_PARAMETERS
in the graph, with lag parameter alpha (i.e.(1 - alpha) * old_value + alpha * value_estimated_from_minibatch
). You pass this toGradientDescent.add_updates
and you're on your way. Users are free not to use this, or to do something fancier if they please, but this is the easy thing that Christian Szegedy told me works just fine. We can also provide a lower-level function that simply returns pairs of(population, minibatch_estimate)
so that users can apply a LearningRule of their choice or do whatever manner of sophisticated nonsense they like, but we should have one obvious way to adapt these things during training.Conclusion
I think that all of these things are doable with the framework we have today, and don't depend on the resolution of fundamental design issues like like #514. I'd appreciate folks' thoughts.
The text was updated successfully, but these errors were encountered: