Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved fastmath code generation for trig, log, and exp/pow. #6619

Merged
merged 10 commits into from Feb 23, 2021

Conversation

testhound
Copy link
Contributor

This pull request adds fastmath code generation support for the following functions:

sin, cos, tan, log, log2, log10, exp, and pow when the fastmath option is used with 32-bit types.

This pull request address poor code generation identified in: #6183

@testhound testhound changed the title Testhound/cuda fast math Improved fastmath code generation for trig, log, and exp/pow. Jan 5, 2021
@esc esc added 3 - Ready for Review CUDA CUDA related issue/PR labels Jan 5, 2021
@esc
Copy link
Member

esc commented Jan 5, 2021

@testhound thank you for submitting this! I have added it to the queue for review.

@stuartarchibald stuartarchibald added the Effort - medium Medium size effort needed label Jan 5, 2021
Copy link
Contributor

@stuartarchibald stuartarchibald left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the patch, great to see these fast math functions added. I've made a couple of suggestions, in general looks good. I think this probably needs documenting, both in the CUDA docs and also the cuda.jit doc strings. Thanks again!

@@ -76,8 +88,20 @@ def lower_boolean_impl(context, builder, sig, args):

def impl_unary(key, ty, libfunc):
def lower_unary_impl(context, builder, sig, args):
libfunc_impl = context.get_function(libfunc, typing.signature(ty, ty))
return libfunc_impl(builder, args)
if ty == float32 and context.fastmath is True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if ty == float32 and context.fastmath is True:
if ty == float32 and context.fastmath:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stuartarchibald thanks for the review, I have made the code changes locally. Can you point me to the documentation files that need to be changed?

Comment on lines 92 to 103
fast_replacement = unarys_fastmath.get(libfunc.__name__)
if fast_replacement is None:
libfunc_impl = context.get_function(libfunc,
typing.signature(ty, ty))
else:
new_libfunc = getattr(libdevice, fast_replacement)
libfunc_impl = context.get_function(new_libfunc,
typing.signature(ty, ty))
return libfunc_impl(builder, args)
else:
libfunc_impl = context.get_function(libfunc,
typing.signature(ty, ty))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think the 3x libfunc_impl = context.get_function... can be pulled out as all paths use it, then the branch is just on 32bit float with fastmath and switching the library libfunc = getattr(libdevice, fast_replacement), something like the (entirely untested!) code below perhaps, what do you think?

if ty == float32 and context.fastmath:
    fast_replacement = unarys_fastmath.get(libfunc.__name__)
    if fast_replacement is not None:
        libfunc = getattr(libdevice, fast_replacement)

libfunc_impl = context.get_function(libfunc,
                                    typing.signature(ty, ty))

Comment on lines 129 to 139
if ty == float32 and context.fastmath is True:
fast_replacement = binarys_fastmath.get(libfunc.__name__)
if fast_replacement is None:
libfunc_impl = context.get_function(libfunc,
typing.signature(ty,
ty, ty))
else:
new_libfunc = getattr(libdevice, fast_replacement)
libfunc_impl = context.get_function(new_libfunc,
typing.signature(ty,
ty, ty))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment/refactor suggestion as above.

@stuartarchibald stuartarchibald added 4 - Waiting on author Waiting for author to respond to review and removed 3 - Ready for Review labels Jan 8, 2021
@stuartarchibald
Copy link
Contributor

RE #6619 (comment): I've taken a look through the current CUDA docs and I can't seem to find a section talking about the various options available to the cuda.jit decorator and the impact they have (might be a nice in future to have an FAQ entry https://numba.readthedocs.io/en/latest/cuda/faq.html or something to add as kernel declaration options here https://numba.readthedocs.io/en/latest/cuda/kernels.html#kernel-declaration). For now, I think it's just the fastmath part of this section needs an update: https://numba.readthedocs.io/en/latest/cuda-reference/kernel.html#kernel-declaration which I think comes from here:

:param fastmath: If true, enables flush-to-zero and fused-multiply-add,
disables precise division and square root. This parameter has no effect
on device function, whose fastmath setting depends on the kernel function
from which they are called.

Having reviewed the Numba documentation, I've also looked at the nvcc docs for "fast math" and think that what this PR is providing is a deviation from this behaviour. As a result, one concern I now have with this PR is over whether it would cause regressions in existing code bases due to a behavioural change in an existing configuration i.e. fastmath=True. It makes me wonder if low precision/fast math intrinsic functions ought to be activated via a different option, or that fastmath should be a dictionary like:

{
'ftz': <bool>,
'prec-div': <bool>,
'prec-sqrt': <bool>,
'fmad': <bool>,
'fast_libm': <bool>, # this being the thing that switches on the fast math intrinsic functions etc as wired in in this PR
}

On the CPU target, fastmath is a similar to this, it can be True/False but can also be specifically configured.

class FastMathOptions(object):
"""
Options for controlling fast math optimization.
"""
def __init__(self, value):
# https://releases.llvm.org/7.0.0/docs/LangRef.html#fast-math-flags
valid_flags = {
'fast',
'nnan', 'ninf', 'nsz', 'arcp',
'contract', 'afn', 'reassoc',
}
if value is True:
self.flags = {'fast'}
elif value is False:
self.flags = set()
elif isinstance(value, set):
invalid = value - valid_flags
if invalid:
raise ValueError("Unrecognized fastmath flags: %s" % invalid)
self.flags = value
elif isinstance(value, dict):
invalid = set(value.keys()) - valid_flags
if invalid:
raise ValueError("Unrecognized fastmath flags: %s" % invalid)
self.flags = {v for v, enable in value.items() if enable}
else:
msg = "Expected fastmath option(s) to be either a bool, dict or set"
raise ValueError(msg)
def __bool__(self):
return bool(self.flags)
__nonzero__ = __bool__

What do you think?

@testhound
Copy link
Contributor Author

RE #6619 (comment): I've taken a look through the current CUDA docs and I can't seem to find a section talking about the various options available to the cuda.jit decorator and the impact they have (might be a nice in future to have an FAQ entry https://numba.readthedocs.io/en/latest/cuda/faq.html or something to add as kernel declaration options here https://numba.readthedocs.io/en/latest/cuda/kernels.html#kernel-declaration). For now, I think it's just the fastmath part of this section needs an update: https://numba.readthedocs.io/en/latest/cuda-reference/kernel.html#kernel-declaration which I think comes from here:

:param fastmath: If true, enables flush-to-zero and fused-multiply-add,
disables precise division and square root. This parameter has no effect
on device function, whose fastmath setting depends on the kernel function
from which they are called.

Having reviewed the Numba documentation, I've also looked at the nvcc docs for "fast math" and think that what this PR is providing is a deviation from this behaviour. As a result, one concern I now have with this PR is over whether it would cause regressions in existing code bases due to a behavioural change in an existing configuration i.e. fastmath=True. It makes me wonder if low precision/fast math intrinsic functions ought to be activated via a different option, or that fastmath should be a dictionary like:

{
'ftz': <bool>,
'prec-div': <bool>,
'prec-sqrt': <bool>,
'fmad': <bool>,
'fast_libm': <bool>, # this being the thing that switches on the fast math intrinsic functions etc as wired in in this PR
}

On the CPU target, fastmath is a similar to this, it can be True/False but can also be specifically configured.

class FastMathOptions(object):
"""
Options for controlling fast math optimization.
"""
def __init__(self, value):
# https://releases.llvm.org/7.0.0/docs/LangRef.html#fast-math-flags
valid_flags = {
'fast',
'nnan', 'ninf', 'nsz', 'arcp',
'contract', 'afn', 'reassoc',
}
if value is True:
self.flags = {'fast'}
elif value is False:
self.flags = set()
elif isinstance(value, set):
invalid = value - valid_flags
if invalid:
raise ValueError("Unrecognized fastmath flags: %s" % invalid)
self.flags = value
elif isinstance(value, dict):
invalid = set(value.keys()) - valid_flags
if invalid:
raise ValueError("Unrecognized fastmath flags: %s" % invalid)
self.flags = {v for v, enable in value.items() if enable}
else:
msg = "Expected fastmath option(s) to be either a bool, dict or set"
raise ValueError(msg)
def __bool__(self):
return bool(self.flags)
__nonzero__ = __bool__

What do you think?

@stuartarchibald interesting suggestion. Let me digest and respond later.

@gmarkall
Copy link
Member

Having reviewed the Numba documentation, I've also looked at the nvcc docs for "fast math" and think that what this PR is providing is a deviation from this behaviour.

As discussed out-of-band earlier, I think this is bringing Numba's behaviour into line with NVCC. An example of NVCC using a less precise cos implementation with the fast math flag can be seen in: https://github.com/gmarkall/nvcc-fastmath

As a result, one concern I now have with this PR is over whether it would cause regressions in existing code bases due to a behavioural change in an existing configuration i.e. fastmath=True.

I think any code that was passing fastmath=True should not be relying on the usual precision guarantees - so if it does break some existing code, then that code was probably not safe to use with the fastmath flag.

It makes me wonder if low precision/fast math intrinsic functions ought to be activated via a different option, or that fastmath should be a dictionary like:

In light of the above, would you agree that the fastmath dictionary suggestion wouldn't be required now?

Copy link
Member

@gmarkall gmarkall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great so far!


fastver = cuda.jit("void(float32[::1], float32)", fastmath=True)(f4)
slowver = cuda.jit("void(float32[::1], float32)")(f4)
self.assertNotIn('fma.rn.f32 ', fastver.ptx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noting here that this checks that the fast version doesn't have any FMA instructions in it. I noticed that the fast version makes use of ex2.approx.ftz.f32, which could be checked for instead, but I also think the test as-is is sufficient, because there shouldn't be any FMA instructions in the fast version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gmarkall I have made and verified the first of the two changes you requested. I am unclear how to make the change to numba/cuda/decorators.py that will update the documentation; specifically I updated 'decorators.py' but after generating the documentation with 'make html', I do not get a updated local page for 'CUDA Kernel API. Is there another command to update this portion of the documentation or did I update the wrong file?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think sphinx isn't very good at detecting when docstrings change. Every time I change something documentation related, I always run make clean html, which is a bit inconvenient because it takes a while to build from scratch, but I haven't found a better way.

If you do this, does your change now show up?


fastver = cuda.jit("void(float32[::1], float32)", fastmath=True)(f5)
slowver = cuda.jit("void(float32[::1], float32)")(f5)
self.assertIn('lg2.approx.ftz.f32 ', fastver.ptx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I guess this test is taking the approach that could have been used for exp)

Copy link
Member

@gmarkall gmarkall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks for the documentation additions - I have some further suggestions for the docs, which I've pushed in the commit gmarkall@f6e735f. To summarise the suggestions, they are:

  • Fixing the docs build by moving the cuda-fast-math label ahead of the section title,
  • Add links to documentation for the underlying libdevice functions and NVVM optimizations,
  • Explicitly state the math module functions affected by the transformation,
  • Refer to the docs in the docstring of the jit decorator for fastmath - it's got to the point where it does a bit much to easily summarise in the docstring.

Feel free to pull in any / all of the changes in the linked commit, or do let me know what you think of the suggestions.

Following resolution of the documentation suggestions, I think this will be all looking good!

@gmarkall gmarkall removed the 4 - Waiting on author Waiting for author to respond to review label Feb 17, 2021
gmarkall
gmarkall previously approved these changes Feb 17, 2021
Copy link
Member

@gmarkall gmarkall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks for the update! This looks good to me!

@esc could this have a buildfarm run please? (the typeguard fail is a general issue not related to this branch)

@gmarkall gmarkall added 4 - Waiting on CI Review etc done, waiting for CI to finish Pending BuildFarm For PRs that have been reviewed but pending a push through our buildfarm labels Feb 17, 2021
@esc
Copy link
Member

esc commented Feb 18, 2021

Running on Farm as: numba_smoketest_cuda_yaml_13

@esc esc added BuildFarm Passed For PRs that have been through the buildfarm and passed and removed Pending BuildFarm For PRs that have been reviewed but pending a push through our buildfarm labels Feb 18, 2021
@esc
Copy link
Member

esc commented Feb 18, 2021

Build farm was fine: numba_smoketest_cuda_yaml_13.

@gmarkall gmarkall removed the 4 - Waiting on CI Review etc done, waiting for CI to finish label Feb 18, 2021
@gmarkall
Copy link
Member

@stuartarchibald @sklam Are you happy with this going RTM?

Comment on lines +29 to +36
- :func:`math.cos`: Implemented using `__nv_fast_cosf <https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_fast_cosf.html>`_.
- :func:`math.sin`: Implemented using `__nv_fast_sinf <https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_fast_sinf.html>`_.
- :func:`math.tan`: Implemented using `__nv_fast_tanf <https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_fast_tanf.html>`_.
- :func:`math.exp`: Implemented using `__nv_fast_expf <https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_fast_expf.html>`_.
- :func:`math.log2`: Implemented using `__nv_fast_log2f <https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_fast_log2f.html>`_.
- :func:`math.log10`: Implemented using `__nv_fast_log10f <https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_fast_log10f.html>`_.
- :func:`math.log`: Implemented using `__nv_fast_logf <https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_fast_logf.html>`_.
- :func:`math.pow`: Implemented using `__nv_fast_powf <https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_fast_powf.html>`_.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most of the :func:``math.FOO`` link to the developer/autogen_math_listing references, and a couple to the intended python doc. Not sure what can be done apart for explicit linking. See rendered doc: https://numba--6619.org.readthedocs.build/en/6619/cuda/fastmath.html

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a Numba doc problem, not for fixing in this PR. Deferring to ticket #6737.

@stuartarchibald
Copy link
Contributor

@stuartarchibald @sklam Are you happy with this going RTM?

Once #6619 (comment) is resolved, patch looks good. Thanks for working on this @testhound, thanks for reviewing @gmarkall.

@stuartarchibald stuartarchibald added the 4 - Waiting on author Waiting for author to respond to review label Feb 18, 2021
@gmarkall gmarkall added 4 - Waiting on author Waiting for author to respond to review and removed 4 - Waiting on author Waiting for author to respond to review labels Feb 18, 2021
@stuartarchibald
Copy link
Contributor

/AzurePipelines run

@azure-pipelines
Copy link
Contributor

Azure Pipelines successfully started running 1 pipeline(s).

@stuartarchibald stuartarchibald added 5 - Ready to merge Review and testing done, is ready to merge and removed 4 - Waiting on author Waiting for author to respond to review labels Feb 19, 2021
@stuartarchibald
Copy link
Contributor

@testhound please could you resolve the conflicts when you have a moment? Many Thanks.

@testhound
Copy link
Contributor Author

@testhound please could you resolve the conflicts when you have a moment? Many Thanks.

@stuartarchibald I just resolved the conflicts.

Copy link
Contributor

@stuartarchibald stuartarchibald left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for resolving conflicts, looks good.

@sklam sklam merged commit d825860 into numba:master Feb 23, 2021
@testhound testhound deleted the testhound/cuda_fast_math branch June 1, 2021 23:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
5 - Ready to merge Review and testing done, is ready to merge BuildFarm Passed For PRs that have been through the buildfarm and passed CUDA CUDA related issue/PR Effort - medium Medium size effort needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants