Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flax avg_pool also includes padding tokens when computing the mean #2401

Closed
marcvanzee opened this issue Aug 16, 2022 · 2 comments
Closed

Flax avg_pool also includes padding tokens when computing the mean #2401

marcvanzee opened this issue Aug 16, 2022 · 2 comments
Assignees
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.

Comments

@marcvanzee
Copy link
Collaborator

I am not sure if this is an issue but it is definitely a cause of possible confusion: We currently implement pooling.avg_pool as avg_pool(x) = lax.reduce_window(lax.add, x) / prod(window_size). If we use padding, we always divide by the full window size even if this contains padding tokens.

Example:

xs = np.array([1, 2]).reshape((1, 2, 1))
avg_pool(xs, window_shape=(2,), strides=(1,), padding="SAME").reshape((2,))
# Result: [1.5, 1. ]

Is this what we want? the first result (1+2)/2=1.5 makes sense, but the second result 2/2=1. is a bit odd. Shouldn't we do 2/1=2?

Other frameworks do it as follows:

Personally I feel that including padding tokens with value 0 is wrong (it seems like an arbitrary constant). At the very least we should be explicit about our choice and document it.

A possible solution to implementing average pooling and only counting non-padding tokens is to doing an additional sum_pool2 on the same input shape with only 1s, where you pad with 0s. Then you return sum_pool / sum_pool2, which correctly ignores the padding tokens.

@marcvanzee
Copy link
Collaborator Author

Discussed this offline with @jheek and @cgarciae. We agreed that the current behavior is not desirable since we are assuming that padding tokens for avg_pool are 0's and we include them when counting the average, but we are not docuementing this anywhere. Tensorflow has chosen to implement this differently, namely by excluding the padding tokens, and similarly, they are not documenting this in their APis. Pytorch seems to have the best of both worlds: they allow the user to specify it in a flags. This seems something we could do as well.

@cgarciae cgarciae added the Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. label Aug 26, 2022
@levskaya
Copy link
Collaborator

I'm going to close this since the changes in #2448 seem to address the main point, feel free to reopen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.
Projects
None yet
Development

No branches or pull requests

3 participants