Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add low-precision to TFNO #172

Merged
merged 12 commits into from
Jul 11, 2023
Merged

Conversation

crwhite14
Copy link
Collaborator

This pull request adds low precision options for TFNO. Here is a brief summary.

  1. opt.amp_autocast was previously in the yaml file but was a no-op. This PR gets opt.amp_autocast working
  • False (default): run the model in full precision
  • True: turn on torch.amp.autocast, torch's built-in half-precision method. This turns many operations to half precision, with a few notable exceptions: reduction operations, weight updates (due to instability) and complex-valued operations (due to lack of implementation).
  1. Add new parameter to the yaml file, fno_block_precision
  • 'full' (default): standard full precision
  • 'half': the FFT, contraction, and inverse FFT run in half precision
  • 'mixed': the FFT runs in full-precision, and the contraction and inverse FFT run in half precision
  1. Add new parameter to the yaml file, stabilizer
  • None (default)
  • 'tanh': adds a tanh just before the FFT in the fno-block. Typically this needs to be set to stabilize the FFT, when fno_block_precision='half'.

Running

python train_navier_stokes.py --opt.amp_autocast=True --tfno2d.fno_block_precision='half' --tfno2d.stabilizer='tanh'
python train_navier_darcy.py --opt.amp_autocast=True --tfno2d.fno_block_precision='half' --tfno2d.stabilizer='tanh'

Improves runtime and memory by up to 30%, depending on the GPU used, the resolution of the data (greater speedups for 64x64 resolution or higher), and other hyperparameters such as factorization and rank.

Copy link
Member

@JeanKossaifi JeanKossaifi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great, thanks!

if self.fno_block_precision == 'half':
x = x.half()
else:
x = x.float()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove this one or do you think we need to always explicitly cast here @crwhite14 @rtu715 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove the "else: x = x.float()"
I just did that in this commit: bcc52e7#diff-5ae1e49af12ed16c75135c0043a08575110fd03d4c722a837b60aa0950b31e32L342-L343

@JeanKossaifi
Copy link
Member

Thanks, great PR @crwhite14 @rtu715, merging!

@JeanKossaifi JeanKossaifi merged commit 1051112 into neuraloperator:main Jul 11, 2023
1 check passed
ziqi-ma pushed a commit to ziqi-ma/neuraloperator that referenced this pull request Aug 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants