-
Notifications
You must be signed in to change notification settings - Fork 503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactors SpectralConv for simpler FNO #244
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, in the new API users should just have to set n_modes
(instead of incremental_n_modes
) to dynamically change the number of Fourier modes , right? Then what's max_n_modes
, and how did we fulfill that functionality before this refactoring?
if isinstance(n_modes, int): # Should happen for 1D FNO only | ||
n_modes = [n_modes] | ||
else: | ||
n_modes = list(n_modes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we have more explicit type checking here? Something like:
if isinstance(n_modes, int): # Should happen for 1D FNO only | |
n_modes = [n_modes] | |
else: | |
n_modes = list(n_modes) | |
if isinstance(n_modes, int): # Should happen for 1D FNO only | |
n_modes = [n_modes] | |
elif all([isinstance(m, int) for m in n_modes]): # Assumes n_modes is a populated iterable | |
n_modes = list(n_modes) | |
else: | |
raise ValueError(msg) |
Otherwise a model choking on n_modes
further in execution could be pretty confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assumed it was clear modes should be int, it's also in the docstring: I don't check if it's an int for type check - just to allow users to call FNO(mode, ...) for the 1D FNO rather than FNO( (mode,), ...)
Number = Union[int, float] | ||
|
||
|
||
class SpectralConv(BaseSpectralConv): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we want to keep this legacy code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case we want to revert - for specific use cases. That version is harder to read but actually faster in some cases. Some users might want it for that reason
else: | ||
n_modes = list(n_modes) | ||
# The last mode has a redundacy as we use real FFT | ||
# As a design choice we do the operation here to avoid users dealing with the +1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is documented in the class docstring, though that's where users would readily see it to know about this design. Consider adding this in the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I want to polish the PR a little more
slice(None), # Equivalent to: [:, | ||
slice(None), # ............... :, | ||
slice(self.half_n_modes[0]), # :half_n_modes[0]] | ||
slice(None, self.n_modes[0]), # :half_n_modes[0]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: alignment
slice(None), # Equivalent to: [:, | |
slice(None), # ............... :, | |
slice(self.half_n_modes[0]), # :half_n_modes[0]] | |
slice(None, self.n_modes[0]), # :half_n_modes[0]] | |
slice(None), # Equivalent to: [:, | |
slice(None), # ................ :, | |
slice(None, self.n_modes[0]), # :n_modes[0]] |
Co-authored-by: Mogab <76572666+m4e7@users.noreply.github.com>
The goal was to make the role of the variable clear from their names: max_n_modes is the max possible number of mode, needs to be specified before we create the weights. n_modes is variable and can be changed during training. |
No description provided.