Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add circular padding to flax.linen.Conv and flax.linen.ConvTranspose #1661

Merged
merged 4 commits into from
Nov 22, 2021

Conversation

sgrigory
Copy link
Contributor

@sgrigory sgrigory commented Nov 5, 2021

What does this PR do?

This PR implements a CIRCULAR padding option for flax.linen.Conv and flax.linen.ConvTranspose, in addition to existing VALID and SAME. This allows creating convolutional and transposed convolutional layers with periodic boundary conditions. Code added to flax.linen.Conv is based on the snippet from #903 (comment). Also tests are added for the new padding option

Fixes #971, #903

Checklist

Additional note

Please note that this implementation differs slightly from suggested in #971 (comment) in the sense that the circular padding is applied symmetrically from both sides of the input, as in #903 (comment).

For example, on 1D data [1, 2, 3, 4, 5] with kernel size 3 and stride 3 there will be 2 filter operations: one at [5, 1, 2] and the other at [3, 4, 5] - instead of [1, 2, 3] and [4, 5, 1], as suggested in #971 (comment). I feel that this is better aligned with how other padding options - VALID and SAME - work.

@VolodyaCO, please let me know if that makes sense for your use case

@google-cla google-cla bot added the cla: yes label Nov 5, 2021
@jheek
Copy link
Member

jheek commented Nov 5, 2021

Should ConvTranspose have a circular padding option as well?

@VolodyaCO
Copy link

VolodyaCO commented Nov 5, 2021 via email

@grigory-sizov
Copy link

Should ConvTranspose have a circular padding option as well?

@jheek That would make sense, at least that seems to be implemented in neural_tangents: https://github.com/google/neural-tangents/blob/9f2ebc88905c46d60b7c4a9da25636924acc9d45/neural_tangents/stax.py#L1449

Would you recommend adding it to this PR or first have this one reviewed and then create a follow-up?

@jheek
Copy link
Member

jheek commented Nov 5, 2021

Would you recommend adding it to this PR or first have this one reviewed and then create a follow-up?

I think it makes sense to add it to both conv options in this PR. This way we make sure that both conv layers are consistent at that they are indeed transposable

@grigory-sizov
Copy link

It does make sense to me. I was just wondering, in your mini example, what if the filter size is 4 (or any even number)?

@VolodyaCO Indeed, if filter size is even, one can't perfectly centre the filter around each element of the input, so there needs to be convention - either shift it by one element forwards or backwards. Current implementation for SAME padding seems to be choosing to shift it forwards and I follow the same logic

So, for example, for input (1, 2, 3, 4, 5), kernel (1, 1, 1, 1), stride 3, and SAME padding I get (6, 12) as a result, which makes sense: 6 == 0 + 1 + 2 + 3, 12 == 3 + 4 + 5 + 0
If instead we use CIRCULAR padding, my implementation gives (11, 13), which again makes sense: 11 == 5 + 1 + 2 + 3, 13 == 3 + 4 + 5 + 1.

Again, as I said, choosing to shift the kernel one element forwards or backwards for even-sized kernels is purely a convention, but I just made sure it's consistent between SAME and CIRCULAR paddings

@sgrigory
Copy link
Contributor Author

sgrigory commented Nov 5, 2021

Would you recommend adding it to this PR or first have this one reviewed and then create a follow-up?

I think it makes sense to add it to both conv options in this PR. This way we make sure that both conv layers are consistent at that they are indeed transposable

@jheek Ok, perfect, let me add it then

@VolodyaCO
Copy link

VolodyaCO commented Nov 5, 2021 via email

@sgrigory sgrigory changed the title Add circular padding to flax.linen.Conv Add circular padding to flax.linen.Conv and flax.linen.ConvTranspose Nov 9, 2021
@sgrigory
Copy link
Contributor Author

sgrigory commented Nov 9, 2021

@jheek I've added circular padding to ConvTranspose as well - feel free to have a look

@@ -204,6 +204,174 @@ def test_group_conv(self):
self.assertEqual(initial_params['params']['kernel'].shape, (3, 2, 4))
np.testing.assert_allclose(y, np.full((1, 6, 4), 7.))

@parameterized.product(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really extensive tests! Thanks!

rhs_dilation=self.kernel_dilation,
precision=self.precision)

if self.padding == "CIRCULAR":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is pretty complicated. Is there a reference for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jheek I didn't use any reference (except for verifying the results against neural_tangents.stax.ConvTranspose), but let me explain the logic here:
After ConvTranspose has been calculated with "VALID" padding into variable y in this line, we need to

  1. identify the size of the final output along each spatial dimension (let's call it "period")
  2. pad each dimension to a certain integer number of periods
  3. wrap the array periodically around each spatial dimension

Step by step:

  1. Size of the final output is stride * input size

  2. For padding, we need to understand how much to pad from left/right for each spatial dimension. Padding should satisfy three criteria:

  • total size after padding should be an integer number of periods
  • padding should be the symmetric (same from the right and from the left)
  • the element corresponding to beginning of the original input data inside the padded array should be located at integer number of periods (otherwise we'll get correct answer, but circularly shifted)

The above is satisfied if we compute the padding like I did in the code:

  • Compute the difference between the size of y and the desired size of the final output
  • Compute complement of this difference to even number of periods
  • Divide the complement into two for left and right padding
  1. After the padding is done, we can wrap the array around each spatial dimension. I did this separately for each dimension, reshaping the array into (..., -1, period, ...) and summing over the corresponding axis. Let me know if there is a simpler way to do it - I haven't found one (see e.g. https://stackoverflow.com/questions/42297115/numpy-split-cube-into-cubes)

There is a subtlety with even-sized kernels: the choice to write (size_diff + 1) // 2, size_diff // 2) (and not (size_diff // 2, (size_diff + 1) // 2)) here reflects the alignment convention I talk about in a comment above

What's the best way to proceed, should I add some of this description to inline comments?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think it makes sense to add (a summary of) in inline comments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jheek Sure, I've expanded the comments in ConvTranspose - please have a look

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks looks great!

@codecov-commenter
Copy link

codecov-commenter commented Nov 10, 2021

Codecov Report

Merging #1661 (c4e1250) into main (6da4a00) will increase coverage by 0.06%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1661      +/-   ##
==========================================
+ Coverage   83.10%   83.17%   +0.06%     
==========================================
  Files          69       69              
  Lines        5836     5853      +17     
==========================================
+ Hits         4850     4868      +18     
+ Misses        986      985       -1     
Impacted Files Coverage Δ
flax/linen/linear.py 99.01% <100.00%> (+0.62%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 6da4a00...c4e1250. Read the comment docs.

@sgrigory sgrigory requested a review from jheek November 16, 2021 20:56
@copybara-service copybara-service bot merged commit df26680 into google:main Nov 22, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Circular padding in convolutional neural networks
5 participants