Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TimeDistributed BatchNormalization #7467

Merged
merged 5 commits into from
Aug 4, 2017

Conversation

waleedka
Copy link
Contributor

@waleedka waleedka commented Jul 29, 2017

This PR fixes issue #7466. The root cause of the problem is that TimeDistributed reshapes the inputs before passing them to BatchNormalization, and batch normalization has update ops that are conditioned on the inputs. Since the inputs to TimeDistributed are different from those passed to BatchNormalization, the conditional updates don't get executed and the mean and variance don't get updated.

Wrappers could modify the inputs in many different ways, so I implemented a generic solution on the Wrapper class to map wrapper inputs to inner layer inputs. This solution should fix the issue with BatchNormaliztion and also with any other wrapped layer that might have update ops conditioned on its inputs.

@QuantumLiu
Copy link

Thanks for your job.But it got 1 failing check in TEST_MODE=PEP8?

@waleedka
Copy link
Contributor Author

Thanks for the heads up! I had a line over indented. It's fixed, and I also fixed another indentation issue that was already in the original code.

@fchollet
Copy link
Member

fchollet commented Jul 31, 2017

  • Please add unit tests (a test that would have been failing with the previous implementation but would be passing now, as well as enough tests to have full coverage of the new code)
  • Don't break lines with \

@milani
Copy link
Contributor

milani commented Jul 31, 2017

Is this the same as BatchNormalization's mode=2 in keras 1.x?

@ahundt
Copy link
Contributor

ahundt commented Aug 1, 2017

is there any chance #7033 might be fixed by this or could be updated to do so?

@waleedka
Copy link
Contributor Author

waleedka commented Aug 1, 2017

@fchollet Done! Added the unit tests and removed the \

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Three more things:

  • please keep lines under 80 chars (approx), there are long lines in the test
  • make input_map a private attribute (_input_map)
  • in the test, you use " as quote character, but the rest of the file uses '

Thanks!

The CI test fails with W503 error, but actually
the testing tool is enforcing PEP8 wrong.
PyCQA/pycodestyle#498
@waleedka
Copy link
Contributor Author

waleedka commented Aug 4, 2017

@fchollet Done! Please review.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants