Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: Don't optimize IR before sending it to NVVM #6030

Merged
merged 10 commits into from Jul 30, 2020

Conversation

gmarkall
Copy link
Member

There are various issues created by optimizing the IR before sending it to NVVM, in particular:

Optimizing IR prior to sending it to NVVM was originally added to work around an NVVM bug documented as Issue #1341. The cause of issue #1341 is now resolved in NVVM, and has been since CUDA 7.5 - the test script at https://gist.github.com/sklam/f62f1f48bb0be78f9ceb executes in a fraction of a second and consumes around 100MB of RAM - the bug manifested itself as a long compilation time with gigabytes of RAM consumed.

There are some side-effects to not optimizing before sending to NVVM, which are reflected in the tests:

  • TestCudaConstantMemory.test_const_array_3d: Instead of reading the complex value as a vector of u32s, there are two individual loads of the real and imaginary parts as f32s. This is not expected to have a negative impact on performance, since coalesced reading of f32s will use memory bandwidth effectively.
  • TestCudaConstantMemory.test_const_record: The generated code here appears less optimal than it previously was - instead of loading three 8-bit members in a single 32-bit read, the reads are now individual 8-bit reads.
  • TestCudaConstantMemory.test_const_record_align: A similar issue to test_const_record.
  • TestNvvmWithoutCuda.test_nvvm_memset_fixup: now that IR is not optimized prior to sending to NVVM, memsets no longer appear in the IR. It is desirable to keep the memset fixup functionality and test so that Numba and its extensions will be able to use memsets (e.g. from cgutils.memset), so this test is modified to use strings of IR rather than relying on memsets naturally being produced as part of the compilation process. It is possible to generate a memset without an alignment attribute on the destination with cgutils.memset, so a check for this error and a test of the check are also added.

For the issues with less optimal code being generated in TestCudaConstantMemory, the tests for the optimal code have been moved into new test cases that are marked as expected fails - it would be better to keep these tests and note the sub-optimal behaviour, with the expectation that a future version of NVVM may improve on the code it currently generates.

There are various issues created by optimizing the IR before sending it
to NVVM, in particular:

- Issue numba#5576: Optimizing the IR before sending it to NVVM results in code
  that makes out-of-bounds accesses.
- Issue numba#6022: Optimization generates memset intrinsics incompatible
  with LLVM 3.4 (NVVM).
- IR auto-upgrade breaks the IR for LLVM 3.4 due to changes in atomic
  instructions. This requires a patch to the LLVM version used for
  llvmlite, which is problematic because many distributions won't carry
  this patch. (ref: numba/llvmlite#593)

Optimizing IR prior to sending it to NVVM was originally added to work
around an NVVM bug documented as Issue numba#1341. The cause of issue numba#1341
is now resolved in NVVM, and has been since CUDA 7.5 - the test script
at https://gist.github.com/sklam/f62f1f48bb0be78f9ceb executes in a
fraction of a second and consumes around 100MB of RAM - the bug
manifested itself as a long compilation time with gigabytes of RAM
consumed.

There are some side-effects to not optimizing before sending to NVVM,
which are reflected in the tests:

- `TestCudaConstantMemory.test_const_array_3d`: Instead of reading the
  complex value as a vector of u32s, there are two individual loads of
  the real and imaginary parts as f32s. This is not expected to have a
  negative impact on performance, since coalesced reading of f32s will
  use memory bandwidth effectively.
- `TestCudaConstantMemory.test_const_record`: The generated code here
  appears less optimal than it previously was - instead of loading
  three 8-bit members in a single 32-bit read, the reads are now
  individual 8-bit reads.
- `TestCudaConstantMemory.test_const_record_align`: A similar issue to
  test_const_record.
- `TestNvvmWithoutCuda.test_nvvm_memset_fixup`: now that IR is not
  optimized prior to sending to NVVM, memsets no longer appear in the
  IR. It is desirable to keep the memset fixup functionality and test so
  that Numba and its extensions will be able to use memsets
  (e.g. from `cgutils.memset`), so this test is modified to use strings
  of IR rather than relying on memsets naturally being produced as part
  of the compilation process. It is possible to generate a memset
  without an alignment attribute on the destination with
  `cgutils.memset`, so a check for this error and a test of the check
  are also added.

For the issues with less optimal code being generated in
`TestCudaConstantMemory`, the tests for the optimal code have been moved
into new test cases that are marked as expected fails - it would be
better to keep these tests and note the sub-optimal behaviour, with the
expectation that a future version of NVVM may improve on the code it
currently generates.
@esc esc added this to the PR Backlog milestone Jul 24, 2020
gmarkall added a commit to gmarkall/llvmlite that referenced this pull request Jul 24, 2020
This reverts commit 3f66129.

Once numba/numba#6030 is merged, it will no
longer be necessary to disable the autoupgrade of atomic intrinsics for
NVPTX, because LLVM from llvmlite will not be used to optimize the IR
before sending it to NVVM.
@gmarkall gmarkall modified the milestones: PR Backlog, Numba 0.51 RC Jul 24, 2020
@gmarkall gmarkall mentioned this pull request Jul 24, 2020
@gmarkall gmarkall added the CUDA CUDA related issue/PR label Jul 24, 2020
Copy link
Contributor

@stuartarchibald stuartarchibald left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the patch, couple of minor things to resolve else looks good.

numba/cuda/codegen.py Show resolved Hide resolved
numba/cuda/cudadrv/nvvm.py Outdated Show resolved Hide resolved
@stuartarchibald stuartarchibald added 4 - Waiting on author Waiting for author to respond to review Pending BuildFarm For PRs that have been reviewed but pending a push through our buildfarm and removed 3 - Ready for Review labels Jul 24, 2020
@stuartarchibald stuartarchibald self-assigned this Jul 24, 2020
@esc
Copy link
Member

esc commented Jul 27, 2020

BFID: numba_smoketest_cuda_69.

- Add comments about why we don't optimize LLVM IR in the CUDA target.
- Change wording of error about requiring alignment on memset
  destination and update test accordingly.
@gmarkall gmarkall added 4 - Waiting on reviewer Waiting for reviewer to respond to author and removed 4 - Waiting on author Waiting for author to respond to review labels Jul 27, 2020
@sklam
Copy link
Member

sklam commented Jul 27, 2020

Re-build BFID: numba_smoketest_cuda_73

Copy link
Member

@sklam sklam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Waiting on buildfarm to rerun

@esc esc added 4 - Waiting on author Waiting for author to respond to review and removed 4 - Waiting on reviewer Waiting for reviewer to respond to author labels Jul 28, 2020
@esc
Copy link
Member

esc commented Jul 28, 2020

BF Report for BFID: numba_smoketest_cuda_73: All builds failed. Windows and Linux builds fail equal.

Cuda 8 and 9 tests fail with:

[2020-07-27 21:39:43,881] {bash_operator.py:137} INFO - ======================================================================
[2020-07-27 21:39:43,881] {bash_operator.py:137} INFO - FAIL: test_const_array_3d (numba.cuda.tests.cudapy.test_constmem.TestCudaConstantMemory)
[2020-07-27 21:39:43,881] {bash_operator.py:137} INFO - ----------------------------------------------------------------------
[2020-07-27 21:39:43,881] {bash_operator.py:137} INFO - Traceback (most recent call last):
[2020-07-27 21:39:43,881] {bash_operator.py:137} INFO -   File "F:\ci_envs\64\Miniconda3\envs\testenv_bd4fe4a3-c31c-4d1d-82a0-7535ec9cdee6\lib\site-packages\numba\cuda\tests\cudapy\test_constmem.py", line 139, in test_const_array_3d
[2020-07-27 21:39:43,881] {bash_operator.py:137} INFO -     "load each half of the complex as f32")
[2020-07-27 21:39:43,881] {bash_operator.py:137} INFO - AssertionError: 'ld.const.f32' not found in '//\n// Generated by NVIDIA NVVM Compiler\n//\n// Compiler Build ID: CL-21373419\n// Cuda compilation tools, release 8.0, V8.0.55\n// Based on LLVM 3.4svn\n//\n\n.version 5.0\n.target sm_35\n.address_size 64\n\n\t// .globl\t_ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE\n.visible .global .align 4 .u32 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE__errcode__;\n.visible .global .align 4 .u32 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE__tidx__;\n.visible .global .align 4 .u32 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE__ctaidx__;\n.visible .global .align 4 .u32 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE__tidy__;\n.visible .global .align 4 .u32 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE__ctaidy__;\n.visible .global .align 4 .u32 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE__tidz__;\n.visible .global .align 4 .u32 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE__ctaidz__;\n.common .global .align 8 .u64 _ZN08NumbaEnv5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE;\n.const .align 8 .b8 _cudapy_cmem[1000] = {0, 0, 0, 63, 0, 0, 0, 0, 0, 0, 0, 63, 0, 0, 0, 191, 0, 0, 0, 63, 0, 0, 128, 191, 0, 0, 0, 63, 0, 0, 192, 191, 0, 0, 0, 63, 0, 0, 0, 192, 0, 0, 0, 63, 0, 0, 32, 192, 0, 0, 0, 63, 0, 0, 64, 192, 0, 0, 0, 63, 0, 0, 96, 192, 0, 0, 0, 63, 0, 0, 128, 192, 0, 0, 0, 63, 0, 0, 144, 192, 0, 0, 0, 63, 0, 0, 160, 192, 0, 0, 0, 63, 0, 0, 176, 192, 0, 0, 0, 63, 0, 0, 192, 192, 0, 0, 0, 63, 0, 0, 208, 192, 0, 0, 0, 63, 0, 0, 224, 192, 0, 0, 0, 63, 0, 0, 240, 192, 0, 0, 0, 63, 0, 0, 0, 193, 0, 0, 0, 63, 0, 0, 8, 193, 0, 0, 0, 63, 0, 0, 16, 193, 0, 0, 0, 63, 0, 0, 24, 193, 0, 0, 0, 63, 0, 0, 32, 193, 0, 0, 0, 63, 0, 0, 40, 193, 0, 0, 0, 63, 0, 0, 48, 193, 0, 0, 0, 63, 0, 0, 56, 193, 0, 0, 0, 63, 0, 0, 64, 193, 0, 0, 0, 63, 0, 0, 72, 193, 0, 0, 0, 63, 0, 0, 80, 193, 0, 0, 0, 63, 0, 0, 88, 193, 0, 0, 0, 63, 0, 0, 96, 193, 0, 0, 0, 63, 0, 0, 104, 193, 0, 0, 0, 63, 0, 0, 112, 193, 0, 0, 0, 63, 0, 0, 120, 193, 0, 0, 0, 63, 0, 0, 128, 193, 0, 0, 0, 63, 0, 0, 132, 193, 0, 0, 0, 63, 0, 0, 136, 193, 0, 0, 0, 63, 0, 0, 140, 193, 0, 0, 0, 63, 0, 0, 144, 193, 0, 0, 0, 63, 0, 0, 148, 193, 0, 0, 0, 63, 0, 0, 152, 193, 0, 0, 0, 63, 0, 0, 156, 193, 0, 0, 0, 63, 0, 0, 160, 193, 0, 0, 0, 63, 0, 0, 164, 193, 0, 0, 0, 63, 0, 0, 168, 193, 0, 0, 0, 63, 0, 0, 172, 193, 0, 0, 0, 63, 0, 0, 176, 193, 0, 0, 0, 63, 0, 0, 180, 193, 0, 0, 0, 63, 0, 0, 184, 193, 0, 0, 0, 63, 0, 0, 188, 193, 0, 0, 0, 63, 0, 0, 192, 193, 0, 0, 0, 63, 0, 0, 196, 193, 0, 0, 0, 63, 0, 0, 200, 193, 0, 0, 0, 63, 0, 0, 204, 193, 0, 0, 0, 63, 0, 0, 208, 193, 0, 0, 0, 63, 0, 0, 212, 193, 0, 0, 0, 63, 0, 0, 216, 193, 0, 0, 0, 63, 0, 0, 220, 193, 0, 0, 0, 63, 0, 0, 224, 193, 0, 0, 0, 63, 0, 0, 228, 193, 0, 0, 0, 63, 0, 0, 232, 193, 0, 0, 0, 63, 0, 0, 236, 193, 0, 0, 0, 63, 0, 0, 240, 193, 0, 0, 0, 63, 0, 0, 244, 193, 0, 0, 0, 63, 0, 0, 248, 193, 0, 0, 0, 63, 0, 0, 252, 193, 0, 0, 0, 63, 0, 0, 0, 194, 0, 0, 0, 63, 0, 0, 2, 194, 0, 0, 0, 63, 0, 0, 4, 194, 0, 0, 0, 63, 0, 0, 6, 194, 0, 0, 0, 63, 0, 0, 8, 194, 0, 0, 0, 63, 0, 0, 10, 194, 0, 0, 0, 63, 0, 0, 12, 194, 0, 0, 0, 63, 0, 0, 14, 194, 0, 0, 0, 63, 0, 0, 16, 194, 0, 0, 0, 63, 0, 0, 18, 194, 0, 0, 0, 63, 0, 0, 20, 194, 0, 0, 0, 63, 0, 0, 22, 194, 0, 0, 0, 63, 0, 0, 24, 194, 0, 0, 0, 63, 0, 0, 26, 194, 0, 0, 0, 63, 0, 0, 28, 194, 0, 0, 0, 63, 0, 0, 30, 194, 0, 0, 0, 63, 0, 0, 32, 194, 0, 0, 0, 63, 0, 0, 34, 194, 0, 0, 0, 63, 0, 0, 36, 194, 0, 0, 0, 63, 0, 0, 38, 194, 0, 0, 0, 63, 0, 0, 40, 194, 0, 0, 0, 63, 0, 0, 42, 194, 0, 0, 0, 63, 0, 0, 44, 194, 0, 0, 0, 63, 0, 0, 46, 194, 0, 0, 0, 63, 0, 0, 48, 194, 0, 0, 0, 63, 0, 0, 50, 194, 0, 0, 0, 63, 0, 0, 52, 194, 0, 0, 0, 63, 0, 0, 54, 194, 0, 0, 0, 63, 0, 0, 56, 194, 0, 0, 0, 63, 0, 0, 58, 194, 0, 0, 0, 63, 0, 0, 60, 194, 0, 0, 0, 63, 0, 0, 62, 194, 0, 0, 0, 63, 0, 0, 64, 194, 0, 0, 0, 63, 0, 0, 66, 194, 0, 0, 0, 63, 0, 0, 68, 194, 0, 0, 0, 63, 0, 0, 70, 194, 0, 0, 0, 63, 0, 0, 72, 194, 0, 0, 0, 63, 0, 0, 74, 194, 0, 0, 0, 63, 0, 0, 76, 194, 0, 0, 0, 63, 0, 0, 78, 194, 0, 0, 0, 63, 0, 0, 80, 194, 0, 0, 0, 63, 0, 0, 82, 194, 0, 0, 0, 63, 0, 0, 84, 194, 0, 0, 0, 63, 0, 0, 86, 194, 0, 0, 0, 63, 0, 0, 88, 194, 0, 0, 0, 63, 0, 0, 90, 194, 0, 0, 0, 63, 0, 0, 92, 194, 0, 0, 0, 63, 0, 0, 94, 194, 0, 0, 0, 63, 0, 0, 96, 194, 0, 0, 0, 63, 0, 0, 98, 194, 0, 0, 0, 63, 0, 0, 100, 194, 0, 0, 0, 63, 0, 0, 102, 194, 0, 0, 0, 63, 0, 0, 104, 194, 0, 0, 0, 63, 0, 0, 106, 194, 0, 0, 0, 63, 0, 0, 108, 194, 0, 0, 0, 63, 0, 0, 110, 194, 0, 0, 0, 63, 0, 0, 112, 194, 0, 0, 0, 63, 0, 0, 114, 194, 0, 0, 0, 63, 0, 0, 116, 194, 0, 0, 0, 63, 0, 0, 118, 194, 0, 0, 0, 63, 0, 0, 120, 194};\n\n.visible .entry _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE(\n\t.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE_param_0,\n\t.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE_param_1,\n\t.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE_param_2,\n\t.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE_param_3,\n\t.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE_param_4,\n\t.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE_param_5,\n\t.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE_param_6,\n\t.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE_param_7,\n\t.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE_param_8,\n\t.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE_param_9,\n\t.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE_param_10\n)\n{\n\t.reg .pred \t%p<4>;\n\t.reg .f32 \t%f<5>;\n\t.reg .b32 \t%r<4>;\n\t.reg .b64 \t%rd<36>;\n\n\n\tld.param.u64 \t%rd1, [_ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE_param_4];\n\tld.param.u64 \t%rd2, [_ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE_param_5];\n\tld.param.u64 \t%rd3, [_ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE_param_6];\n\tld.param.u64 \t%rd4, [_ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE_param_7];\n\tld.param.u64 \t%rd5, [_ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE_param_8];\n\tld.param.u64 \t%rd6, [_ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE_param_9];\n\tld.param.u64 \t%rd7, [_ZN6cudapy5numba4cuda5tests6cudapy13test_constmem15cuconst3d$24232E5ArrayI9complex64Li3E1A7mutable7alignedE_param_10];\n\tmov.u32 \t%r1, %tid.x;\n\tmov.u32 \t%r2, %tid.y;\n\tmov.u32 \t%r3, %tid.z;\n\tcvt.s64.s32\t%rd8, %r1;\n\tcvt.s64.s32\t%rd9, %r2;\n\tcvt.s64.s32\t%rd10, %r3;\n\tsetp.lt.s32\t%p1, %r1, 0;\n\tselp.b64\t%rd11, 5, 0, %p1;\n\tadd.s64 \t%rd12, %rd11, %rd8;\n\tsetp.lt.s32\t%p2, %r2, 0;\n\tselp.b64\t%rd13, 5, 0, %p2;\n\tadd.s64 \t%rd14, %rd13, %rd9;\n\tsetp.lt.s32\t%p3, %r3, 0;\n\tselp.b64\t%rd15, 5, 0, %p3;\n\tmul.lo.s64 \t%rd16, %rd12, 25;\n\tmul.lo.s64 \t%rd17, %rd14, 5;\n\tadd.s64 \t%rd18, %rd16, %rd10;\n\tadd.s64 \t%rd19, %rd18, %rd15;\n\tadd.s64 \t%rd20, %rd19, %rd17;\n\tshl.b64 \t%rd21, %rd20, 3;\n\tmov.u64 \t%rd22, _cudapy_cmem;\n\tadd.s64 \t%rd23, %rd22, %rd21;\n\tld.const.v2.f32 \t{%f1, %f2}, [%rd23];\n\tselp.b64\t%rd24, %rd2, 0, %p1;\n\tadd.s64 \t%rd25, %rd24, %rd8;\n\tselp.b64\t%rd26, %rd3, 0, %p2;\n\tadd.s64 \t%rd27, %rd26, %rd9;\n\tselp.b64\t%rd28, %rd4, 0, %p3;\n\tadd.s64 \t%rd29, %rd28, %rd10;\n\tmul.lo.s64 \t%rd30, %rd25, %rd5;\n\tmul.lo.s64 \t%rd31, %rd27, %rd6;\n\tmul.lo.s64 \t%rd32, %rd29, %rd7;\n\tadd.s64 \t%rd33, %rd30, %rd1;\n\tadd.s64 \t%rd34, %rd33, %rd31;\n\tadd.s64 \t%rd35, %rd34, %rd32;\n\tst.f32 \t[%rd35], %f1;\n\tst.f32 \t[%rd35+4], %f2;\n\tret;\n}\n\n\n\x00' : load each half of the complex as f32
[2020-07-27 21:39:43,928] {bash_operator.py:137} INFO - 

And one unexpected success which is unfortunately not displayed in the log.

Cuda 10 tests fail with two unexpected successes, which are also unfortunately not displayed in the log.

@gmarkall
Copy link
Member Author

Many thanks @esc

This passes on some toolkits (e.g. 9.0), so it will be modified so the
expectation is conditional on the toolkit version in subsequent commits
@gmarkall
Copy link
Member Author

@esc With the added commits I get no fails on Linux with CUDA 8.0, 9.0, 9.1, 9.2, 10.0, 10.1, 10.2, or 11.0 (as long as I've made no mistakes in testing) - could you run through the buildfarm again please?

@esc
Copy link
Member

esc commented Jul 29, 2020

BFID: numba_smoketest_cuda_81 sure, test is in progress.

@esc
Copy link
Member

esc commented Jul 29, 2020

The windows tests failed on CUDA 8 & 9 with unexpected successes, but the testing script doesn't seem to output these.

Switching from an expected failure to checking the behaviour for various
toolkits and platforms.
@gmarkall
Copy link
Member Author

@esc Many thanks - this is ready for another buildfarm run.

@gmarkall gmarkall removed the 4 - Waiting on author Waiting for author to respond to review label Jul 29, 2020
@gmarkall gmarkall mentioned this pull request Jul 29, 2020
2 tasks
@esc
Copy link
Member

esc commented Jul 29, 2020

numba_smoketest_cuda_85

@esc
Copy link
Member

esc commented Jul 30, 2020

BF reports that CUDA 10 tests on Windows fail with:

[2020-07-29 18:31:16,285] {bash_operator.py:137} INFO - ======================================================================
[2020-07-29 18:31:16,285] {bash_operator.py:137} INFO - FAIL: test_const_record_optimization (numba.cuda.tests.cudapy.test_constmem.TestCudaConstantMemory)
[2020-07-29 18:31:16,285] {bash_operator.py:137} INFO - ----------------------------------------------------------------------
[2020-07-29 18:31:16,285] {bash_operator.py:137} INFO - Traceback (most recent call last):
[2020-07-29 18:31:16,285] {bash_operator.py:137} INFO -   File "F:\ci_envs\64\Miniconda3\envs\testenv_3f30b38e-ab5c-4bbf-8317-b8102356130f\lib\site-packages\numba\cuda\tests\cudapy\test_constmem.py", line 189, in test_const_record_optimization
[2020-07-29 18:31:16,285] {bash_operator.py:137} INFO -     self.assertGreaterEqual(u8_load_count, 16,
[2020-07-29 18:31:16,285] {bash_operator.py:137} INFO - AssertionError: 12 not greater than or equal to 16 : load record values as individual bytes
[2020-07-29 18:31:16,285] {bash_operator.py:137} INFO - 
[2020-07-29 18:31:16,285] {bash_operator.py:137} INFO - ----------------------------------------------------------------------
[2020-07-29 18:31:16,285] {bash_operator.py:137} INFO - Ran 642 tests in 147.462s
[2020-07-29 18:31:16,285] {bash_operator.py:137} INFO - 
[2020-07-29 18:31:16,285] {bash_operator.py:137} INFO - FAILED (failures=1, skipped=16)
[2020-07-29 18:31:16,769] {bash_operator.py:141} INFO - Command exited with return code 1

@gmarkall
Copy link
Member Author

Conda-forge Numba feedstock issue that will also be resolved by this PR: conda-forge/numba-feedstock#55

@gmarkall
Copy link
Member Author

@esc Could this have another buildfarm run please?

@esc
Copy link
Member

esc commented Jul 30, 2020

numba_smoketest_cuda_88

@esc esc added BuildFarm Passed For PRs that have been through the buildfarm and passed 5 - Ready to merge Review and testing done, is ready to merge and removed Pending BuildFarm For PRs that have been reviewed but pending a push through our buildfarm labels Jul 30, 2020
@esc
Copy link
Member

esc commented Jul 30, 2020

Reckon this is good to go now.

@sklam sklam merged commit fa1d922 into numba:master Jul 30, 2020
esc pushed a commit to esc/llvmlite that referenced this pull request Aug 4, 2020
This reverts commit 3f66129.

Once numba/numba#6030 is merged, it will no
longer be necessary to disable the autoupgrade of atomic intrinsics for
NVPTX, because LLVM from llvmlite will not be used to optimize the IR
before sending it to NVVM.
sklam pushed a commit to numba/llvmlite that referenced this pull request Aug 5, 2020
* Use std::make_unique on LLVM 10

LLVM 10 removes llvm::make_unique in favor of std::make_unique.
However, this requires C++14 and is therefore unsuitable for LLVM 9
that forces -std=c++11.  Update the code to use both conditionally.
This fixes all issues with LLVM 10.

* grab updated SVML patch and new fastmath patch

https://github.com/conda-forge/llvmdev-feedstock/blob/c706309/recipe/patches/intel-D47188-svml-VF.patch

https://github.com/conda-forge/llvmdev-feedstock/blob/c706309/recipe/patches/expect-fastmath-entrypoints-in-add-TLI-mappings.ll.patch

* updte to LLVM 10 in recipe and tests

As title

* update documented version numbers in README.rst

As title.

* update appveyor.yml comment to include LLVM 10

As title.

* update to llvmdev 10.0* on Azure pipelines

As title.

* update to llvmdev 10.0 in buildscripts

As title

* update llvmdev_manylinux1 conda-recipe

As title

* update Sphinx documentation for correct LLVM version

As title.

* temporarily set the VERSION_SUFFIX for CI

This is needed as the llvmlite package we are building needs to
differentiate itself somehow from the 'other' dev builds.

* specify the correct compiler in the build requirements

* workaround potentially broken LLVM CMake setup

With LLVM 10 the following CMake line no longer appears to work.

```
llvm_map_components_to_libnames(llvm_libs all)
```

... and this commit is a suitable workaround.

Current working hypothesis is that this might have been caused by the
change in:

llvm/llvm-project@ab41180#diff-cebfde545aa260d20ca6b0cdc3da08daR270

However the docs at https://llvm.org/docs/CMake.html are not very clear
about this. And in addition the output of `llvm-config --components`
does list `all` as an option.

Many thanks to @angloyna for debugging this!!

* update the comment and links for the svml patch

* add a note about broken LLVM cmake setup

As title

* update compatability matrix

As title.

* install cmake on armv7l too

* remove TODO

* bump to LLVM 10.0.1

As title.

* Make sure llvmlite build is unique.

As title.

* remove outdated comment

As title.

* implement LLVM version split

The next llvmlite release will have a split LLVM dependency. The
`aarch64` target will use LLVM 9 and all other targets will use LLVM 10.

* Revert "Fix CUDA with LLVM9"

This reverts commit 3f66129.

Once numba/numba#6030 is merged, it will no
longer be necessary to disable the autoupgrade of atomic intrinsics for
NVPTX, because LLVM from llvmlite will not be used to optimize the IR
before sending it to NVVM.

* revert LLVM documentation link to 10.0.0

The page at:

https://releases.llvm.org/10.0.1/docs/

Currently (Tue Aug  4 10:48:42 2020 GMT+2) returns a 404 - page not
found and this breaks our build as Sphinx can't setup the intersphinx
links.

* accept LLVM 9.0 once again

As we are doing a split release for 0.34.0 and `aarch64` will need LLVM
9.0.* -- we allow to build with this version again.

* Apply suggestions from code review

Fix missing word and missing punctuation.

Co-authored-by: stuartarchibald <stuartarchibald@users.noreply.github.com>

* adding link to LLVM bug report

* adding link to potential LLVM Cmake bug

As title.

* stop testing older LLVMs

We only support 10.0.* and 9.0.* from now on.

* getting ready to release 0.34.0 remove VERSION_SUFFIX

Final builds should not have a version suffix.

Co-authored-by: Michał Górny <mgorny@gentoo.org>
Co-authored-by: Graham Markall <graham@big-grey.co.uk>
Co-authored-by: stuartarchibald <stuartarchibald@users.noreply.github.com>
gmarkall added a commit to gmarkall/numba that referenced this pull request Jan 25, 2021
Starting with CUDA 11.2, a new version of NVVM is provided that is based
on LLVM 7.0. This requires a number of changes to support, which must be
maintained in parallel with the existing support for NVVM based on LLVM
3.4. This PR adds these changes, which consist of:

- Addition of a function to query the NVVM IR version, and a property
  indicating whether the NVVM in use is based on LLVM 3.4 or 7.0
  (`is_nvvm70`).
- The CAS hack (inserting a text-based implementation of `cmpxchg` with
  pre-LLVM 3.5 semantics in a function) is only needed with NVVM 3.4 -
  on NVVM 7.0, llvmlite is used to build `cmpxchg` instructions directly
  instead.
- Templates for other atomics (inc, dec, min, max) have the right form
  of the `cmpxchg` instruction inserted depending on the NVVM version.
- The datalayout shorthand is now only replaced for NVVM 3.4.
- There are now two variants of the functions to rewrite the IR -
  `llvm100_to_70_ir` and `llvm100_to_34_ir`. `llvm100_to_34_ir` is the
  old `llvm_39_to_34_ir` with a name reflecting what it currently does.
- `llvm100_to_70_ir` removes the `willreturn` attribute from functions,
  as it is not supported by LLVM 7.0. It also converts DISPFlags to main
  subprogram DIFlags.  For example, `spflags: DISPFlagDefinition |
  DISPFlagOptimized` is rewritten as `isDefinition: true, isOptimized:
  true`.
- For NVVM 7.0, the `DIBuilder` also used for the CPU target can be used,
  instead of the `NvvmDIBuilder` that was needed to support NVVM 3.4.
- Some tests are updated to support modified function names, and also to
  expect a CUDA version of 11.2.
- `test_nvvm_driver` is updated to include appropriate IR for both NVVM
  3.4 and 7.0. Some refactoring also makes its code clearer (e.g.
  renaming `get_ptx()` to `get_nvvimir()`, because it returns NVVM IR
  and not PTX).
- Some optimizations in LLVM 7.0 result in different code generation in
  `test_constmem`, so alternative expected results are added for when
  NVVM 7.0 is used. Note that this recovers some optimizations that were
  lost when IR optimization using llvmlite was switched off (PR numba#6030,
  "Don't optimize IR before sending it to NVVM").
- `test_debuginfo` is updated to match the format of the debuginfo
  section produced by both NVVM 3.4 and 7.0 (there is some variation in
  whitespace between these versions).
gmarkall added a commit to gmarkall/numba that referenced this pull request Jan 25, 2021
Starting with CUDA 11.2, a new version of NVVM is provided that is based
on LLVM 7.0. This requires a number of changes to support, which must be
maintained in parallel with the existing support for NVVM based on LLVM
3.4. This PR adds these changes, which consist of:

- Addition of a function to query the NVVM IR version, and a property
  indicating whether the NVVM in use is based on LLVM 3.4 or 7.0
  (`is_nvvm70`).
- The CAS hack (inserting a text-based implementation of `cmpxchg` with
  pre-LLVM 3.5 semantics in a function) is only needed with NVVM 3.4 -
  on NVVM 7.0, llvmlite is used to build `cmpxchg` instructions directly
  instead.
- Templates for other atomics (inc, dec, min, max) have the right form
  of the `cmpxchg` instruction inserted depending on the NVVM version.
- The datalayout shorthand is now only replaced for NVVM 3.4.
- There are now two variants of the functions to rewrite the IR -
  `llvm100_to_70_ir` and `llvm100_to_34_ir`. `llvm100_to_34_ir` is the
  old `llvm_39_to_34_ir` with a name reflecting what it currently does.
- `llvm100_to_70_ir` removes the `willreturn` attribute from functions,
  as it is not supported by LLVM 7.0. It also converts DISPFlags to main
  subprogram DIFlags.  For example, `spflags: DISPFlagDefinition |
  DISPFlagOptimized` is rewritten as `isDefinition: true, isOptimized:
  true`.
- For NVVM 7.0, the `DIBuilder` also used for the CPU target can be used,
  instead of the `NvvmDIBuilder` that was needed to support NVVM 3.4.
- Some tests are updated to support modified function names, and also to
  expect a CUDA version of 11.2.
- `test_nvvm_driver` is updated to include appropriate IR for both NVVM
  3.4 and 7.0. Some refactoring also makes its code clearer (e.g.
  renaming `get_ptx()` to `get_nvvimir()`, because it returns NVVM IR
  and not PTX).
- Some optimizations in LLVM 7.0 result in different code generation in
  `test_constmem`, so alternative expected results are added for when
  NVVM 7.0 is used. Note that this recovers some optimizations that were
  lost when IR optimization using llvmlite was switched off (PR numba#6030,
  "Don't optimize IR before sending it to NVVM").
- `test_debuginfo` is updated to match the format of the debuginfo
  section produced by both NVVM 3.4 and 7.0 (there is some variation in
  whitespace between these versions).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
5 - Ready to merge Review and testing done, is ready to merge BuildFarm Passed For PRs that have been through the buildfarm and passed CUDA CUDA related issue/PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants