New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Brick-based batch normalization. #941
Conversation
6d9c7df
to
9d7ef6e
Compare
9d7ef6e
to
e7dfb3e
Compare
@rizar @dmitriy-serdyuk If this interests either of you, I think it's ready for an initial look. Mainly needs finishing touches now. |
0fb5442
to
9daba11
Compare
Tests occasionally fail with sqlite backend (#942) but it's unrelated to this PR. |
@dwf I forked your implementation before and had one problem: graph replacements in blocks using theano.clone will create multiple copies of the same path. import time
from blocks.bricks import Rectifier, BatchNormalizedMLP, Softmax
from blocks.graph import ComputationGraph
from blocks.graph.bn import batch_normalize
from theano import tensor, function
x = tensor.matrix()
y = BatchNormalizedMLP(activations=[Rectifier(), Rectifier(), Rectifier(),
Rectifier(), Rectifier(), Softmax()],
dims=[784, 1000, 1000, 1000, 1000, 1000, 10]).apply(x)
cg = ComputationGraph(y)
new_cg, _ = batch_normalize(cg)
y_hat = new_cg.outputs[0]
start = time.time()
function(inputs=[x], outputs=y)
end = time.time()
print 'inference graph:', end - start
start = time.time()
function(inputs=[x], outputs=y_hat)
end = time.time()
print 'training graph:', end - start Compilation timings for 6 layer MLP:
For 8 layer MLP:
As a workaround, I used to create a copy of graph and worked directly with ApplyNode to update graph. I am also attaching simple pdf to show difference between graph before and after batch_normalize for three layer MLP. |
+1 for refactoring the Re combinatorial cloning, I believe a cheap way to merge clones before compilation would be to put a special identity attribute on each node at construction time (i.e. I've been working on an alternative graph substitution feature for Theano with which I've successfully been able to avoid combinatorial explosions due to cloning, but it's not ready for general use. Alternatively it may be possible to do one big clone/replacement for all BN bricks rather than many small ones. |
I would think that cg.replace would do that if it were possible, no? Regardless I think the implementation details of cg.replace can improve On Mon, Jan 18, 2016, 12:26 PM Tim Cooijmans notifications@github.com
|
f8ed9d0
to
508f5c4
Compare
@aam-at I've made the implementation a bit more efficient (only does one replace per brick, rather than two) and also added another facility for creating the training graph via a context manager. |
f0d14db
to
a4d0728
Compare
@dwf After your update, both |
Very nice! I once had another idea for a context manager: use the tags of input theano variables (to get the context you would walk the graph from one of the function inputs, with caching results in the tags of inputs this is linear with graph size). The advantage is that you don't have a private field in a brick that you toggle to change its operation. All you do is tag one of the inputs to the top-level apply call. pseudocode: def find_context(*vars):
ctx = {}
for var in vars:
if hasattr(var.tag, 'context'):
ctxv = var.tag.context
# maybe do something smart with name clashes in the context?
ctx.update(ctxv)
else:
if var.owner:
ctxv = find_context(*var.owner.inputs)
ctx.update(ctxv)
for var in vars:
var.tag.context = ctx
return ctx
def apply(in1, in2):
ctx = find_context(in1, in2)
if ctx['train']:
do this
else
do that |
@janchorowski Having a top-level apply call to pass context is the same as Lasagne |
@aam-at Thanks for pointing out the issues. |
@dwf, I gave it the first quick look and it looks nice! Do I understand it right that you let the user to choose whether they want to use replacements or context manager to obtain batch normalized graph for training? |
@rizar Yes, that's the idea. It sounds like doing outputs replacement rather than replacement of quantities used in the internally created graph has solved the problem @aam-at was seeing with graph size explosions, but it might be worth keeping it just in case, I don't know. It'd be worth comparing compile times with, say, a 100 layer MLP. One thing I'm having trouble with: something you'll want to do in either case is get a list of (population shared variable, corresponding minibatch estimate) pairs, with which you do something -- either a moving average, or something like what @thradis does in his extension (though this is prohibitively expensive for large datasets). But the crucial step is getting those pairs. I return a second value from Writing such a function is not hard, but the question is what should it be called? Also, if we are tending towards #885 and asking that namespaces be flattened, I'd like to give it a descriptive name that isn't prohibitively long. One thought is to make it a static method of Another problem is what to do when a given |
So after |
Buildbot currently failing due to Theano/Theano#3894. Sigh... |
@@ -0,0 +1,395 @@ | |||
"""Some of the simplest individual bricks.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, I would rename simple.py
into basic.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a little close to base
, don't you think?
Given that the problem seems to be solved, do we still need the context manager? |
The context manager gives you a bit more flexibility, i.e. you can choose to only activate it for portions of your graph depending on the bricks you pass. More than that, I know I'll be using |
@dwf by the way, the import issue in Theano has been resolved. |
@dwf, can you please add an example to Blocks-examples, as we discussed? @cooijmanstim , can you please review this PR? |
3266254
to
4b39c93
Compare
Okay, barring any further issues, and assuming the tests pass except for the stupid flaky sqlite one, I'd say I'd personally consider this ready for merge. 😂 @cooijmanstim Thanks for your very thorough reviewing up to this point! |
Thanks, hope the granularity was right. :-P I really like how it turned out, especially having inference be the default makes handling the population statistics much smoother. The context manager is a nice touch and I could see that become the standard way to use this. There is still a class of bugs in the way the context nesting is carried out though, but it won't happen and it will crash and burn rather than silently do the wrong thing. It's potentially triggered if one of the So LGTM. |
Yeah, I don't see how Thanks again for the timely reviewing. @Thrandis @rizar @vdumoulin If there are no objections I'll merge around noon tomorrow. |
|
||
Parameters | ||
---------- | ||
input_dim : int or tuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This says that input_dim
can be an int
, but it looks like the code enforces it to be a sequence of length >= 2
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, should we put this check at the allocation stage? Lazy assignation is broken otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, interestingly lazy initialization works, but I don't know why. But yes, that should be moved to _allocate
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I know why. NoneAllocation is not an instance of collections.Sequence (that's also why ints are accepted). But it should still be in _allocate
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
d9171ff actually gets it right.
This is not a PR blocker, but in its current state, your PR doesn't have batch-normalized convnet support, right? |
Nevermind, I put my foot in my mouth. However, does the current code make it easy to apply batch normalization before the nonlinearity is applied in convnets? |
There's something else I realized while trying this PR out: step rules that keep track of gradient statistics (e.g. |
The way you'd currently apply it before the nonlinearity is with a |
2800b3f
to
d9171ff
Compare
That does the job for me, thanks! |
d129aa7
to
90cc8bb
Compare
Sorry, just amended as 90cc8bb |
That fixes it for me, thanks! |
Also, I had a look at the latest commit, LGTM. |
Okay, good. I'll merge once the tests eventually pass (Travis is having some issues again today according to their status page). |
Tests pass except for the flaky sqlite build, so I'm merging. Thanks @vdumoulin, @cooijmanstim for the assistance. |
Tries to distill the best ideas from @Thrandis, @cooijmanstim, and my own proposal into something that fits into Blocks.
To avoid circular imports, the bricks and graph packages had to be split up a bit. I think this is long overdue anyway, and will aid in #885.
BatchNormExtension
.~~~ save for later!Fixes #920.