Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Brick-based batch normalization. #941

Merged
merged 36 commits into from Jan 22, 2016
Merged

Conversation

dwf
Copy link
Contributor

@dwf dwf commented Jan 16, 2016

Tries to distill the best ideas from @Thrandis, @cooijmanstim, and my own proposal into something that fits into Blocks.

To avoid circular imports, the bricks and graph packages had to be split up a bit. I think this is long overdue anyway, and will aid in #885.

  • ~~~Utility code for inference network updates: provide at least a simple function to calculate moving averages, and maybe some analog to César's BatchNormExtension.~~~ save for later!
  • Examples of usage, at least via docstrings. Narrative docs can be a CCW ticket.
  • Some code for easily accessing the population <-> minibatch relationships when you create the graph using the context manager.

Fixes #920.

@dwf dwf force-pushed the batch_norm_bricks branch 4 times, most recently from 6d9c7df to 9d7ef6e Compare January 17, 2016 01:32
@dwf
Copy link
Contributor Author

dwf commented Jan 17, 2016

@rizar @dmitriy-serdyuk If this interests either of you, I think it's ready for an initial look. Mainly needs finishing touches now.

@dwf dwf force-pushed the batch_norm_bricks branch 2 times, most recently from 0fb5442 to 9daba11 Compare January 18, 2016 05:46
@dwf
Copy link
Contributor Author

dwf commented Jan 18, 2016

Tests occasionally fail with sqlite backend (#942) but it's unrelated to this PR.

@aam-at
Copy link

aam-at commented Jan 18, 2016

@dwf I forked your implementation before and had one problem: graph replacements in blocks using theano.clone will create multiple copies of the same path.
For multiple bn layers, you get a combinatorial explosion of paths in the graph. During function compilation, they will be optimized away. However, it drastically increases compilation time.

import time

from blocks.bricks import Rectifier, BatchNormalizedMLP, Softmax
from blocks.graph import ComputationGraph
from blocks.graph.bn import batch_normalize
from theano import tensor, function

x = tensor.matrix()
y = BatchNormalizedMLP(activations=[Rectifier(), Rectifier(), Rectifier(),
                                    Rectifier(), Rectifier(), Softmax()],
                       dims=[784, 1000, 1000, 1000, 1000, 1000, 10]).apply(x)
cg = ComputationGraph(y)
new_cg, _ = batch_normalize(cg)
y_hat = new_cg.outputs[0]

start = time.time()
function(inputs=[x], outputs=y)
end = time.time()
print 'inference graph:', end - start
start = time.time()
function(inputs=[x], outputs=y_hat)
end = time.time()
print 'training graph:', end - start

Compilation timings for 6 layer MLP:

inference graph: 0.382107973099
training graph: 7.72867417336

For 8 layer MLP:

inference graph: 0.478291988373
training graph: 370.310412884

As a workaround, I used to create a copy of graph and worked directly with ApplyNode to update graph.

I am also attaching simple pdf to show difference between graph before and after batch_normalize for three layer MLP.

y.pdf
y_hat.pdf

@cooijmanstim
Copy link

+1 for refactoring the __init__.pys.

Re combinatorial cloning, I believe a cheap way to merge clones before compilation would be to put a special identity attribute on each node at construction time (i.e. self.tag.original_id = id(self)) and use that to identify clones. This will be faster than traversing the graph like equal_computations. I guess this is outside the scope of blocks though.

I've been working on an alternative graph substitution feature for Theano with which I've successfully been able to avoid combinatorial explosions due to cloning, but it's not ready for general use.

Alternatively it may be possible to do one big clone/replacement for all BN bricks rather than many small ones.

@dwf
Copy link
Contributor Author

dwf commented Jan 18, 2016

I would think that cg.replace would do that if it were possible, no?

Regardless I think the implementation details of cg.replace can improve
underneath this implementation once we figure out the best way to go about
it. It may involve complementary features in Theano and Blocks.

On Mon, Jan 18, 2016, 12:26 PM Tim Cooijmans notifications@github.com
wrote:

+1 for refactoring the init.pys.

Re combinatorial cloning, I believe a cheap way to merge clones before
compilation would be to put a special identity attribute on each node at
construction time (i.e. self.tag.original_id = id(self)) and use that to
identify clones. This will be faster than traversing the graph like
equal_computations. I guess this is outside the scope of blocks though.

I've been working on an alternative graph substitution feature for Theano
with which I've successfully been able to avoid combinatorial explosions
due to cloning, but it's not ready for general use.

Alternatively it may be possible to do one big clone/replacement for all
BN bricks rather than many small ones.


Reply to this email directly or view it on GitHub
#941 (comment).

@dwf
Copy link
Contributor Author

dwf commented Jan 19, 2016

@aam-at I've made the implementation a bit more efficient (only does one replace per brick, rather than two) and also added another facility for creating the training graph via a context manager.

@dwf dwf force-pushed the batch_norm_bricks branch 2 times, most recently from f0d14db to a4d0728 Compare January 19, 2016 00:50
@aam-at
Copy link

aam-at commented Jan 19, 2016

@dwf After your update, both apply_batch_normalization and batch_normalization produce identical graphs and does not suffer from the compilation problem. Nice!
And I like your context manager idea. With a little bit of extra logic, we can change the context of apply call (like in Lasagne with kwargs) but without the problem who gets what. In current implementation, you change context for training self._training_mode=True. But it's easy to extend if desired to something like: after training finetune population statistic.

@janchorowski
Copy link
Contributor

Very nice!

I once had another idea for a context manager: use the tags of input theano variables (to get the context you would walk the graph from one of the function inputs, with caching results in the tags of inputs this is linear with graph size). The advantage is that you don't have a private field in a brick that you toggle to change its operation. All you do is tag one of the inputs to the top-level apply call.

pseudocode:

   def find_context(*vars):
      ctx = {}
      for var in vars:
          if hasattr(var.tag, 'context'):
               ctxv = var.tag.context
               # maybe do something smart with name clashes in the context?
               ctx.update(ctxv)
          else:
               if var.owner:
                     ctxv = find_context(*var.owner.inputs)
               ctx.update(ctxv)
      for var in vars:
           var.tag.context = ctx
      return ctx

   def apply(in1, in2):
       ctx = find_context(in1, in2)
       if ctx['train']:
            do this
       else
            do that

@aam-at
Copy link

aam-at commented Jan 19, 2016

@janchorowski Having a top-level apply call to pass context is the same as Lasagne kwargs in get_output. And the problems with this approach are the same: mistyping of strings, unused arguments.
The advantage of a private field in brick is that you get autocompletion for setting and getting training_mode. So code is not prone to mistyping strings, e.g. 'trian'. One more advantage, you can change context only for some bricks in the hierarchy.

@janchorowski
Copy link
Contributor

@aam-at Thanks for pointing out the issues.

@rizar
Copy link
Contributor

rizar commented Jan 19, 2016

@dwf, I gave it the first quick look and it looks nice! Do I understand it right that you let the user to choose whether they want to use replacements or context manager to obtain batch normalized graph for training?

@dwf
Copy link
Contributor Author

dwf commented Jan 19, 2016

@rizar Yes, that's the idea. It sounds like doing outputs replacement rather than replacement of quantities used in the internally created graph has solved the problem @aam-at was seeing with graph size explosions, but it might be worth keeping it just in case, I don't know. It'd be worth comparing compile times with, say, a 100 layer MLP.

One thing I'm having trouble with: something you'll want to do in either case is get a list of (population shared variable, corresponding minibatch estimate) pairs, with which you do something -- either a moving average, or something like what @thradis does in his extension (though this is prohibitively expensive for large datasets). But the crucial step is getting those pairs. I return a second value from apply_batch_normalization currently, but there's no easy facility for this if you use the

Writing such a function is not hard, but the question is what should it be called? Also, if we are tending towards #885 and asking that namespaces be flattened, I'd like to give it a descriptive name that isn't prohibitively long. gather_batch_normalization_statistics or something is getting pretty long.

One thought is to make it a static method of BatchNormalization, like BatchNormalization.gather_statistics. Static methods aren't very common but I think.

Another problem is what to do when a given BatchNormalization brick appears twice in a graph, and you may have two different minibatch estimates for the mean/standard deviation. I sort of consider this the user's problem, but I think I should probably have an allow_duplicates keyword that is False by default (with the False behaviour being to raise an exception).

@dwf
Copy link
Contributor Author

dwf commented Jan 20, 2016

So after sys.setrecursionlimit(100000) to take care of the interpreter's silly stack issues, a 100 layer batch-normalized net (with graph replacement) takes 18 seconds to compile, so I think it's officially Not A Big Deal anymore.

@dwf
Copy link
Contributor Author

dwf commented Jan 20, 2016

Buildbot currently failing due to Theano/Theano#3894. Sigh...

@@ -0,0 +1,395 @@
"""Some of the simplest individual bricks."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, I would rename simple.py into basic.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little close to base, don't you think?

@rizar
Copy link
Contributor

rizar commented Jan 20, 2016

Given that the problem seems to be solved, do we still need the context manager?

@dwf
Copy link
Contributor Author

dwf commented Jan 20, 2016

The context manager gives you a bit more flexibility, i.e. you can choose to only activate it for portions of your graph depending on the bricks you pass.

More than that, I know I'll be using find_bricks for other things (it scratches an itch I often have, and it might even be useful in speeding up the Selector as it does a breadth first search without use of the stack), and the additional maintenance burden on top of that to provide the context manager is 7 lines, plus docs and tests.

@vdumoulin
Copy link
Contributor

@dwf by the way, the import issue in Theano has been resolved.

@rizar
Copy link
Contributor

rizar commented Jan 20, 2016

@dwf, can you please add an example to Blocks-examples, as we discussed?

@cooijmanstim , can you please review this PR?

@dwf
Copy link
Contributor Author

dwf commented Jan 22, 2016

Okay, barring any further issues, and assuming the tests pass except for the stupid flaky sqlite one, I'd say I'd personally consider this ready for merge. 😂

@cooijmanstim Thanks for your very thorough reviewing up to this point!

@dwf dwf changed the title WIP: Brick-based batch normalization. Brick-based batch normalization. Jan 22, 2016
@cooijmanstim
Copy link

Thanks, hope the granularity was right. :-P I really like how it turned out, especially having inference be the default makes handling the population statistics much smoother. The context manager is a nice touch and I could see that become the standard way to use this.

There is still a class of bugs in the way the context nesting is carried out though, but it won't happen and it will crash and burn rather than silently do the wrong thing. It's potentially triggered if one of the __enter__ calls fails such that control goes to the finally block before all of the bricks have been entered. All of the bricks will be exited, which is incorrect. However if there are too many __exit__ calls on a brick, one of them will fail to pop the training_mode stack because it is empty, and the code will crash.

So LGTM.

@dwf
Copy link
Contributor Author

dwf commented Jan 22, 2016

Yeah, I don't see how __enter__ could reasonably fail, and as you said, it would die with a pop failure. So I don't think there's a reason to worry too much.

Thanks again for the timely reviewing.

@Thrandis @rizar @vdumoulin If there are no objections I'll merge around noon tomorrow.


Parameters
----------
input_dim : int or tuple
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This says that input_dim can be an int, but it looks like the code enforces it to be a sequence of length >= 2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, should we put this check at the allocation stage? Lazy assignation is broken otherwise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, interestingly lazy initialization works, but I don't know why. But yes, that should be moved to _allocate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I know why. NoneAllocation is not an instance of collections.Sequence (that's also why ints are accepted). But it should still be in _allocate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

d9171ff actually gets it right.

@vdumoulin
Copy link
Contributor

This is not a PR blocker, but in its current state, your PR doesn't have batch-normalized convnet support, right?

@vdumoulin
Copy link
Contributor

Nevermind, I put my foot in my mouth. However, does the current code make it easy to apply batch normalization before the nonlinearity is applied in convnets?

@vdumoulin
Copy link
Contributor

There's something else I realized while trying this PR out: step rules that keep track of gradient statistics (e.g. Momentum, RMSProp) initialize those shared variables without paying attention to the broadcasting pattern of the corresponding parameter, which makes graph compilation fail.

@dwf
Copy link
Contributor Author

dwf commented Jan 22, 2016

The way you'd currently apply it before the nonlinearity is with a ConvolutionalActivation where Activation is Sequence([SpatialBatchNormalization().apply, Rectifier().apply]). But, this brick also can go right into a ConvolutionalSequence, and there's no reason we can't make the trivial change to allow activation functions to be able to do the same.

@vdumoulin
Copy link
Contributor

That does the job for me, thanks!

@dwf
Copy link
Contributor Author

dwf commented Jan 22, 2016

@vdumoulin d129aa7

@dwf
Copy link
Contributor Author

dwf commented Jan 22, 2016

Sorry, just amended as 90cc8bb

@vdumoulin
Copy link
Contributor

That fixes it for me, thanks!

@vdumoulin
Copy link
Contributor

Also, I had a look at the latest commit, LGTM.

@dwf
Copy link
Contributor Author

dwf commented Jan 22, 2016

Okay, good. I'll merge once the tests eventually pass (Travis is having some issues again today according to their status page).

@dwf
Copy link
Contributor Author

dwf commented Jan 22, 2016

Tests pass except for the flaky sqlite build, so I'm merging. Thanks @vdumoulin, @cooijmanstim for the assistance.

dwf added a commit that referenced this pull request Jan 22, 2016
Brick-based batch normalization.
@dwf dwf merged commit 98797a1 into mila-iqia:master Jan 22, 2016
@dwf dwf mentioned this pull request Jan 31, 2016
7 tasks
@dwf dwf deleted the batch_norm_bricks branch June 4, 2016 03:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants