Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding flag option to only pad the last dim #185

Merged
merged 12 commits into from
Aug 11, 2023

Conversation

btolooshams
Copy link
Contributor

Added a pad_only_last_dim in DomainPadding class (default False). Updated FNO class to have the flag domain_padding_only_last_dim (default False). If set True, then padding is only applied to the last dim of the input.

@JeanKossaifi
Copy link
Member

Thanks @btolooshams. Let's make it more general, how about having two use cases:

  • padding is True, in that case set padding = range(...) such that we pad all dimensions
  • Padding is a list of dimensions (for your use case would be [4], in that case we just pad those dimensions.

The only question is whether we account batch-size and channels in the dimensions (dim 0 is then batch size). Otherwise 0 would be the first spatial dim.

What do you think? Could you make the change for the padding and update the docstrings in the models?

@btolooshams
Copy link
Contributor Author

I have modified the previous pool request. The "pad_only_last_dim" is now changed to "padding_dim". If True, all dimensions (0,1,2) excluding batch and channel with be padded, if "padding_dim" is a list, then those dimensions in the list will only be padded.

An additional fix --> the padding in F is applied in reverse order, so I have added a line to reverse the padding. Padding was being wrongly applied before.

@btolooshams
Copy link
Contributor Author

The current implementation excludes batch and ch dim (0,1) in counting with dimension to pad. 0 is the first spatial dimension.

@JeanKossaifi
Copy link
Member

Thank you @btolooshams, this is great! Could we call it just padding?
Would be great to add a quick test case in the unit-tests for sanity and to make sure it remains correct in the future?

Copy link
Member

@JeanKossaifi JeanKossaifi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @btolooshams! I like the API like this, I left a few comments, mostly just style.

padded = padder.pad(data)

target_shape = list(padded.shape)
target_shape[-1] = target_shape[-2] = out_size[mode]
for ctr in range(1,3):
target_shape[-ctr] = out_size[mode][-ctr]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just write this out to make it easier to read with a small inline comment? ALso I don't know what ctr means, good to use clear names

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"ctr" is changed to "pad_dim", and comment added.

data = torch.randn((2, 3, 10, 10))
padder = DomainPadding(0.2, mode)
padder = DomainPadding(padding_dim, mode)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we call it padding_dim here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

padding_dim is renamed to "padding"

unpad_list = list()
for p in output_pad[::-1]:
if p == 0:
unpad_amount_pos = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not leave 0 in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slice gives an error for p=0. So if padding is 0, p is set to None, hence no slicing is done to unpad list.

else:
unpad_amount_pos = p
unpad_amount_neg = -p
unpad_list.append(slice(unpad_amount_pos, unpad_amount_neg, None))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't mind spelling it out but can we call it padding_start and padding_end or something clearer? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name modified :)

if isinstance(self.domain_padding, (float, int)):
self.domain_padding = [float(self.domain_padding)]*len(resolution)

assert len(self.domain_padding) == len(resolution)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good check to have! Can you add a message if that's not the case to explain the issue to users?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explanation is added. "domain_padding length must match the number of spatial/time dimensions (excluding batch, ch)"

padded = padder.pad(data)

target_shape = list(padded.shape)
for ctr in range(1,4):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above - ideally tests should be self explanatory and easy to get by just reading the code!

@JeanKossaifi
Copy link
Member

Thank you @btolooshams, this is great, merging!

@JeanKossaifi JeanKossaifi merged commit 68d90b6 into neuraloperator:main Aug 11, 2023
1 check passed
ziqi-ma pushed a commit to ziqi-ma/neuraloperator that referenced this pull request Aug 29, 2023
adding flag option to only pad the last dim
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants