Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Brick-based batch normalization solution #920

Closed
dwf opened this issue Nov 24, 2015 · 14 comments
Closed

Brick-based batch normalization solution #920

dwf opened this issue Nov 24, 2015 · 14 comments

Comments

@dwf
Copy link
Contributor

dwf commented Nov 24, 2015

#513 and #851 have now been in progress for 8 months. I think this is a testament to the fact that the current approach is proving hard to get right.

Below is my plan to accomplish this by explicitly introducing batch normalization bricks into the graph at construction. I think this somewhat simpler, and this solves the issue of how BN population parameters get serialized. Indeed, the abstract of Ioffe & Szegedy (2015) refers to their method as making normalization "part of the model". It seems only appropriate that this be reflected in the graph that is constructed.

New brick

BatchNormalization will house the learned shift and scale parameters (betas and gammas), as well as population mean and standard deviation statistics shared variables. The latter can't be tagged as PARAMETER because code that takes all PARAMETERs in a graph and blindly adapts them according to a cost will do entirely the wrong thing. I thus suggest a superclass role ADAPTABLE (I also like LEARNED or LEARNED_QUANTITY), with the tree looking like this.

  • ADAPTABLE
    • PARAMETER
    • BATCH_NORMALIZATION_PARAMETER
      • BATCH_NORMALIZATION_POPULATION_MEAN
      • BATCH_NORMALIZATION_POPULATION_STDEV

Serialization code that looks for PARAMETERs to serialize should instead look for ADAPTABLEs.

BatchNormalization will have an apply method that uses the Theano batch_normalization function. By default, the graph generated will use the population mean and standard deviations in BatchNormalization's apply method, for two reasons:

  • It's kind of icky to produce a graph which, by default upon construction, has the output for each example depending on the entire minibatch it arrived in.
  • If the population means and variances are initialized to zeros and ones respectively, this makes it a no-op relative to what the graph would do without these additional bricks.

Modifications to existing Bricks

As batch normalization becomes more an more ubiquitous, we want to make it easy to add it to models anywhere an affine transformation takes place. The best way to accomplish this is something I'm not entirely sure about but I think I have a pretty good idea.

I think the cleanest way is to introduce BatchNormalizedLinear and BatchNormalizedConvolutional classes. These would be drop-in replacements that include batch normalization as a child brick. In order to use these within an MLP (or a Convolutional*), these could grow keyword arguments not unlike the "prototype" keyword in the parallel bricks; you could either pass a lazily-allocated BatchNormalized* brick or simply the class object itself, e.g. ConvolutionalActivation(..., convolution_prototype=BatchNormalizedConvolutional()) vs. ConvolutionalActivation(..., convolutional_class=BatchNormalizedConvolutional).

Note that you could use this scheme with Parallel bricks by passing BatchNormalizedLinear(use_bias=False) as the prototype.

Support functions

To actually train with batch normalization, two support functions are provided:

  • batch_normalization_training_graph(graph) replaces BATCH_NORMALIZATION_POPULATION_MEANs with the minibatch-wise means from the application method of the corresponding bricks, ditto for standard deviations.
  • batch_normalization_population_updates(graph, alpha) returns an OrderedDict of updates which perform moving average updates on the population statistics for all the BATCH_NORMALIZATION_PARAMETERS in the graph, with lag parameter alpha (i.e. (1 - alpha) * old_value + alpha * value_estimated_from_minibatch). You pass this to GradientDescent.add_updates and you're on your way. Users are free not to use this, or to do something fancier if they please, but this is the easy thing that Christian Szegedy told me works just fine. We can also provide a lower-level function that simply returns pairs of (population, minibatch_estimate) so that users can apply a LearningRule of their choice or do whatever manner of sophisticated nonsense they like, but we should have one obvious way to adapt these things during training.

Conclusion

I think that all of these things are doable with the framework we have today, and don't depend on the resolution of fundamental design issues like like #514. I'd appreciate folks' thoughts.

@dmitriy-serdyuk
Copy link
Contributor

@dwf , can you explain this part:

By default, the graph generated will use the population mean and standard deviations in BatchNormalization's apply method

What is an alternative?

I would think that BatchNorm should be an activation, it doesn't depend on the layer before it. And it should be an easy way to combine several activations e.g. Composition([BatchNorm(), ReLU()]). I remember, @Thrandis implemented batch norm like this.

Another point, that we already have a mechanism to store the updates in annotations. Why isn't it possible to store the batch norm updates there?

@Thrandis
Copy link
Contributor

He just means that, by default, the graph generated is the inference graph, not the training graph!

@dwf
Copy link
Contributor Author

dwf commented Nov 24, 2015

@dmitriy-serdyuk

What is an alternative?

The alternative would be for BatchNorm to do, well, batch normalization by default.

Recall that there are two "modes" to batch-normalization:

  • out = gamma * (in - in.mean(axis=batch_axis, keepdims=True)) / (in.std(axis=batch_axis, keepdims=True) + eps) + beta
  • out = gamma * (in - mu) / (sigma + eps) + beta where mu and sigma are population statistics estimated from the data.

My proposal is essentially that by default, graphs that include BatchNorm use the second option, and we provide a function that performs the transformation to get to the first option.

A hypothetical BatchNorm brick would contain the gamma and beta parameters, as well as the mu and sigma not-quite-model-parameters. All of these depend on the size of the layer before's output, so mixing with the "activation" is tricky. You're almost never going to have an activation function that does not follow a linear transformation, and so I think it makes more sense to have the linear transformations incorporate batch norm, as these bricks already know about the size of output they produce and can pass that to their BatchNorm child appropriately.

@dwf
Copy link
Contributor Author

dwf commented Nov 24, 2015

Having the inference graph be the default has another useful side effect: let's say you deserialize a model and want to use it for something. Once you get a hold of the brick objects you want to use, calling apply() should automatically do the right thing.

@dmitriy-serdyuk
Copy link
Contributor

Right, I see now.

@Thrandis promised to publish his brick-based implementation.

@dwf
Copy link
Contributor Author

dwf commented Nov 24, 2015

Re: updates on annotations, I was under the impression from this comment that #514 might stand in the way.

Seeing the implementation from @Thrandis would definitely make a useful point in the discussion, I agree.

@Thrandis
Copy link
Contributor

@dwf @dmitriy-serdyuk
Yeah, I actually have a brick implementation of batch norm that is really similar to what you are proposing already! I'll push it somewhere as soon as I have time (I need to clean it a bit and I also have a homework for Aaron's class :-p)

@dwf
Copy link
Contributor Author

dwf commented Nov 24, 2015

Great! No great rush, and also no real need for it to be super-duper clean.

@dwf dwf added the discussion label Nov 24, 2015
@Thrandis
Copy link
Contributor

@dwf here is my code: https://github.com/Thrandis/batch_norm

It seems to be roughly the solution that you suggest (with some minor changes).

I don't use the running average to estimate the population statistics (I don't like to have this alpha parameter to tune: I wasn't able to set it up properly for my experiments, so I always had a validation curve that was oscillating a lot!). Instead, I implemented an extension that will take some training minibatches and forward them through the network to update the statistics. This method is really effective and give nice an smooth validation curves, and is not that much computationally intensive. So I let my extension there in case you find it useful!

I could definitively finish this code as a ccw ticket, after I'm done with the new serialization!

@cooijmanstim
Copy link

+1. In my own code (pre #851) I've been using the Activation approach, mixing in Feedforward to make sure the sizes are available, and passing a broadcastable tuple to the constructor to indicate which axes to reduce over (to deal with e.g. convolution in 2d or 3d). I would prefer this approach over further complicating the blocks.bricks.conv mess, but I'm also very much in favor of having batch normalization included in blocks at all.

A minor thing: for the role I'd prefer the name inferred over either adaptable or learned. Nevermind, if it's a superclass of parameter then adaptable or learned make more sense.

@dwf
Copy link
Contributor Author

dwf commented Dec 14, 2015

I wrote an implementation of how I see this working before NIPS, I can write some tests and throw it up for comment once I'm done writing my reviews for ICLR.

Re: convolution, I've been gradually decreasing entropy over there, and in truth I think it's easy enough to just make BatchNormalization bricks support the API to let them be wedged into a ConvolutionalSequence (I'd like to do this for activations eventually anyway).

@rizar
Copy link
Contributor

rizar commented Jan 22, 2016

Great job!

@vdumoulin, can you please elaborate on your issues with Momentum? I
don't understand why the step rule is using population statistics...

On 22 January 2016 at 15:37, David Warde-Farley notifications@github.com
wrote:

Closed #920 #920 via #941
#941.


Reply to this email directly or view it on GitHub
#920 (comment).

@dwf
Copy link
Contributor Author

dwf commented Jan 22, 2016

@rizar it actually wasn't the population statistics. BatchNormalization contains scale and shift parameters that preserve the expressiveness of the function, such that scale * (x-mu)/sigma + shift. scale and shift have the same broadcastable flags as the population/minibatch statistics, and they are PARAMETERs that get adapted by GradientDescent. Only problem is that the algorithm buffers were not being created with the same broadcastable flags and so their steps had the wrong broadcastable too.

90cc8bb introduces shared_floatx_zeros_matching which takes a shared variable as an argument and matches its shape and broadcastable flags, and adds some tests to ensure that StepRules are preserving broadcastable.

@rizar
Copy link
Contributor

rizar commented Jan 22, 2016

Thanks for the explanation!

On 22 January 2016 at 15:57, David Warde-Farley notifications@github.com
wrote:

@rizar https://github.com/rizar it actually wasn't the population
statistics. BatchNormalization contains scale and shift parameters that
preserve the expressiveness of the function, such that scale *
(x-mu)/sigma + shift. scale and shift have the same broadcastable flags
as the population/minibatch statistics, and they are PARAMETERs that get
adapted by GradientDescent. Only problem is that the algorithm buffers
were not being created with the same broadcastable flags and so their
steps had the wrong broadcastable too.

90cc8bb
90cc8bb
introduces shared_floatx_zeros_matching which takes a shared variable as
an argument and matches its shape and broadcastable flags, and adds some
tests to ensure that StepRules are preserving broadcastable.


Reply to this email directly or view it on GitHub
#920 (comment).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants