Join GitHub today
GitHub is home to over 40 million developers working together to host and review code, manage projects, and build software together.Sign up
Padding Affecting Batch Norm #2
The batch norm LSTM paper published has pretty stellar results. I've followed the keras issue thread where you stated that gamma and beta are shared throughout all timesteps, yet the actual statistics should be kept for each timestep separately. Thanks for clarifying this.
Unfortunately, I implemented the batch norm LSTM (gamma = 0.1) in tensorflow, and it seems to not perform as well as a regular LSTM. I'm applying this to sequences that have padding in them.
My question is this: Is it possible that the padding is throwing everything off? Towards the end of the sequence, padding probably disrupts the batch mean and variance Laurents suggested https://arxiv.org/pdf/1510.01378v1.pdf
Did you try padding for any of your experiments? If so, how did you bridge that gap? I'm thinking towards the end of sequence, the mean and variance should be kept from timestep 10, when clearly no padding is occurring.
Yes, this is something we've noticed as well. We're experimenting with two possibilities:
We started out doing variant 1 but I now believe variant 2 is the better choice.
Also, the validation curves we report all use batch statistics, and we don't compute the population statistics until the very end of training. We make a final pass over the training set to estimate the population statistics exactly (as opposed to by moving average) and use that to perform our test evaluation.
Thanks for asking, I hope that helps!
Thanks Tim for your feedback.
Padding with repetitions of data is an interesting idea. However, I do feel that would throw off the network's performance as it would have to learn to "throw out" the padded material and not consider it.
For option 2, do you mean normalizing just the input and NOT the hidden state into the LSTM?
Because we are using an LSTM, we can not know what the future hidden states until we have calculated the entire previous timestep. Normally you want to apply batch norm after you have matrix multiplied the input by the weights.
For simplicity's sake, I feel that it would be easier to normalize just the input before it is multiplied by the weight matrix if you're going to normalize both by batch and time.
The third option would to be take timestep 10's mean and variance and use that to compute the rest of the timesteps. In this way, you're using mean and variances where there was no padding exposure AND you maintain a different set of statistics at the beginning timesteps. As your paper shows, its important to keep separate statistics at the beginning timesteps. Thoughts on this?
Edit: I tried the third option described above and it slightly helped things but its very clear that the network is not learning as it should.
I recognize that this may be a stupid question. But when you say normalize sequence-wise input
You would still be using a mask to make sure you don't run on the padding, right? The padding would be used only for the purpose of estimating the statistics.
I mean normalizing only the input
You can do this and it would help to some extent, but as you say it's better to normalize after the weight matrix. The idea is that the distribution of
That would work as well, but whether it's time step 10 or some other number might depend on your task. If you don't mind tuning it then this is a good solution.
You'll definitely want to normalize after the weight matrix. I'm not sure what difficulties you're thinking of; if you're working with Theano, you'd do something like this:
embedding = T.dot(x, W) mean = (embedding * mask[:, :, None]).sum(axis=[0, 1], keepdims=True) / mask.sum(axis=[0,1], keepdims=True) variance = ((embedding * mask[:, :, None] - mean)**2).sum(axis=[0,1], keepdims=True) / mask.sum(axis=[0,1], keepdims=True) embedding = beta + gamma * (embedding - mean) / T.sqrt(variance + epsilon)
With axes 0 and 1 being batch and time in any order.
Tim, Thanks for your extensive reply. I really appreciate your time and feedback.
Yes, I totally forgot about masking, and that would solve the issue raised earlier. You said that you started it with this idea, but you don't think its most optimal. Why do you feel this way? It seems like it should work.
Your combo of normalizing frame-wise the hidden state, and time-wise + frame-wise for the input seems the most logical to me. When you normalize the input sequence-wise, would you include or exclude padded inputs when you compute the mean and variance? From the code you provided, it looks like you would keep the padded frames?
Normalizing the input from the embedding layer is very easy as you described above. However, for layers 2 and 3, the process would need to be repeated which would take some work to implement. Definitely possible.
You're right -- you would have to estimate the max time step where padding isn't there.
I will test this option out and report back here with the results in case others come across this thread. I will only apply this continued average only to
Will post back later with results.
EDIT: I have found that if I just batch normalize the hidden state input, it improves the network! If I apply BN to the tanh(new_c) term it seems to hurt it. Apply BN to W_x input term also seems to hurt.
Will try normalizing input by time-wise and see if this makes a difference and report back.
The problem I see is that the timestep-wise normalization destroys the dynamics of the input data. E.g. if your input is a one-dimensional signal such as an audio waveform, the normalization will amplify the quiet parts and attenuate the loud parts. The model won't know which is which anymore, and will easily confuse noise for signal.
More generally, timestepwise estimation works for stationary input signals. If the distribution is not stationary (e.g. there's loud parts and quiet parts), a mean/variance estimate based on such a narrow temporal window is a bad estimate of the global mean/variance.
I notice a bug in the code I provided, the variance computation should go
variance = ((embedding - mean)**2 * mask[:, :, None]).sum(axis=[0,1], keepdims=True) / mask.sum(axis=[0,1], keepdims=True)
That is, the multiplication by the mask should be moved outside the squared difference. I multiply by the mask and divide by the number of ones in the mask, so the padded elements do not contribute to the estimate.
I'm very curious to here more about your findings. What kind of data are you working with?
Would be happy to help. I run 5 separate Titan X's/980TI's -- so I try to rapidly test as much as possible. Models usually have two or three layers of 512 units. I am primarily working with English text data tokenized into words (usually 120 timesteps).
If we are indeed going to normalize sequence wise and batch wise, then the words that are "really loud" will be somewhat dampened. This would be skewed even more if you used a small batch size.
Do you think one potential solution is to simply always apply a running mean and variance to each timestep? Instead of normalizing by the specific batch's stats, it would be better to normalize by a running mean and variance. Of course, you could keep a separate running mean and variance for each timestep. Once you pass timestep 10, you apply the same running mean and var throughout.
Also noted is that you can raise the learning rate even if you just normalize hidden state input. Using Adam with 0.005 LR (which is pretty high) with a learning rate schedule.
I have found that when I apply batch norm to attention, it does hurt it which makes me think that there's something crucially wrong with my implementation. Will comment later when I have more findings.
I think if your words are one-hot encoded you should be fine (though if your vocabulary is large you may need a larger batch size or do this sequence-wise normalization which effectively increases your sample size). In audio data on the other hand, the input is typically real-valued and its absolute value varies a lot and this variability is highly informative so you don't want to lose it. I suspect this is a part of why batch normalization doesn't seem to help on speech (recurrently or otherwise).
The problem with that is that you can't backprop through the mean and variance, which is crucial. Batch normalization just doesn't seem to work if you don't do this. The accepted explanation for this is that the gradient should take into account the effect of the parameter update on the statistics, or optimization may go around in circles. There may be more to it.
I will be busy preparing our NIPS submission and moving so I may be less reponsive in the next week, but I do appreciate the discussion! :-)
No Problem, just message back whenever convenient. Appreciate your thoughts.
I usually train an embedding layer with the model as it performs better than using word2vec or glove. Usually have a vocab size of 40k.
Did not know this -- so thank you. It would explain so many findings I have had lately. I tried applying the running mean and variance and things did not improve at all.
I think I'm going to try to apply batch norming to associative lstm http://arxiv.org/abs/1602.03032 and see what happens.
I tried sequence-wise AND batch-wise norming for input and it did help some, but not significantly. Was sure to exclude padded frames. Really, the main benefit I've seen is from batch-norming the hidden state. I tried stacking 4 or 5 layers with batch norming the input and unfortunately it did not help.
@LeavesBreathe Hello I am implementing a simple batch normalized LSTM in Tensorflow as well. could you explain when you batch normalized the hidden state, are you normalizing just the 'h', not the 'Wh'? and are they normalized timestep-wise, meaning for each layer and timestep combo, there is one mean and var? do you share the gamma and beta across layers?
Bn doesn't really slow down training for me - -maybe 10% slowdown for step times.
Didn't mean to be confusing but I do normally do bn(Wh). I normalize it at each timestep for each layer separately. I share one gamma per layer.
Let us know if you get any improvements with BN(Wx)!