New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SyncBatchNormalization layer for TensorFlow. #2075
Conversation
@tgaddair As discussed, would we be open to removing support for the very old TF release in our tests (v1.6) to unblock this PR? |
Sure, I can take a stab at that today. |
Signed-off-by: Josh Romero <joshr@nvidia.com>
Signed-off-by: Josh Romero <joshr@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I will create an issue around the comment I suggested. Feel free to merge in if you're ready.
worker_mean, worker_variance = super(SyncBatchNormalization, self)._moments( | ||
inputs, reduction_axes, keep_dims=keep_dims) | ||
|
||
if size() > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may want to make this work with dynamic worker count in a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sounds good to me. Thanks for pointing that out and opening the issue.
I test batch norm in tf 1.14 with graph mode , it does not work. Why |
Checklist before submitting
Description
This PR adds a
SyncBatchNormalization
layer implementation using Horovod for TensorFlow.Fixes #2066.
Review process to land