Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Records in CUDA Const Memory #3186

Merged
merged 7 commits into from Sep 10, 2018
Merged

Conversation

njwhite
Copy link
Contributor

@njwhite njwhite commented Jul 30, 2018

(really any type). Don't try to interpret data in the lowering step - just serialize as a stream of bytes.

@njwhite njwhite changed the title Support Records in CUDA Const Memory [WIP] Support Records in CUDA Const Memory Jul 31, 2018
@codecov-io
Copy link

codecov-io commented Jul 31, 2018

Codecov Report

Merging #3186 into master will decrease coverage by 0.02%.
The diff coverage is n/a.

@@            Coverage Diff            @@
##           master   #3186      +/-   ##
=========================================
- Coverage   81.13%   81.1%   -0.03%     
=========================================
  Files         384     386       +2     
  Lines       75110   76398    +1288     
  Branches     8434    8590     +156     
=========================================
+ Hits        60937   61964    +1027     
- Misses      12884   13114     +230     
- Partials     1289    1320      +31

@stuartarchibald stuartarchibald added the CUDA CUDA related issue/PR label Jul 31, 2018
@njwhite njwhite changed the title [WIP] Support Records in CUDA Const Memory Support Records in CUDA Const Memory Aug 1, 2018
@stuartarchibald stuartarchibald added 3 - Ready for Review Pending BuildFarm For PRs that have been reviewed but pending a push through our buildfarm labels Aug 1, 2018
@sklam sklam self-requested a review August 2, 2018 16:20
@sklam
Copy link
Member

sklam commented Aug 2, 2018

Don't try to interpret data in the lowering step - just serialize as a stream of bytes.

This approach may have performance issues. Since GPU is very sensitive to memory load latency/stall, we need to verify what the emitted instruction sequence is. i.e. don't want to load a float32 as 4 byte load instructions.

@njwhite
Copy link
Contributor Author

njwhite commented Aug 3, 2018

@sklam will add that to the tests. Is a vector loads of 4 bytes operationally equivalent to a int 32 load? Or do vector instructions have different performance characteristics?

@sklam
Copy link
Member

sklam commented Aug 3, 2018

Is a vector loads of 4 bytes operationally equivalent to a int 32 load? Or do vector instructions have different performance characteristics?

I don't know for certain. I don't think that is documented for the public. But, I will guess that they are equivalent.

@seibert
Copy link
Contributor

seibert commented Aug 3, 2018

I don't remember there ever being a documented difference, except in really early CUDA C, where casting a pointer to the vector types before loading was used as a hack to work around memory controller behavior. (Loading float2 to get better bandwidth, etc...) I strongly suspect it doesn't matter any more.

@njwhite
Copy link
Contributor Author

njwhite commented Aug 4, 2018

@sklam see tests on the commit I pushed - nvcc guesses the correct instructions to use anyway!

constary = lc.Constant.array(constvals[0].type, constvals)
constvals = [
context.get_constant(types.byte, i)
for i in arr.flatten(order='A').data.tobytes()
Copy link
Member

@sklam sklam Aug 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is causing error on windows and python2.7:

test_const_array (numba.cuda.tests.cudapy.test_constmem.TestCudaConstantMemory)
test_const_array_2d (numba.cuda.tests.cudapy.test_constmem.TestCudaConstantMemory)
test_const_array_3d (numba.cuda.tests.cudapy.test_constmem.TestCudaConstantMemory)
test_const_record (numba.cuda.tests.cudapy.test_constmem.TestCudaConstantMemory)
test_const_record_align (numba.cuda.tests.cudapy.test_constmem.TestCudaConstantMemory)



File "..\_test_env\lib\site-packages\numba\cuda\tests\cudapy\test_constmem.py", line 60:
def cuconstRecAlign(A, B, C, D, E):
    Z = cuda.const.array_like(CONST_RECORD_ALIGN)
    ^
[1] During: lowering "$0.5 = call ptx.cmem.arylike($0.4, kws=[], args=[Var($0.4, c:\conda64\conda-bld\numba_1533830313780\_test_env\lib\site-packages\numba\cuda\tests\cudapy\test_constmem.py (60))], func=ptx.cmem.arylike, vararg=None)" at c:\conda64\conda-bld\numba_1533830313780\_test_env\lib\site-packages\numba\cuda\tests\cudapy\test_constmem.py (60)

'buffer' object has no attribute 'tobytes'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likely a python2.7 error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've pushed a fix for this - it uses this instead

@sklam
Copy link
Member

sklam commented Aug 9, 2018

There's also an error on python3.7:

======================================================================
FAIL: test_const_record (numba.cuda.tests.cudapy.test_constmem.TestCudaConstantMemory)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "c:\conda64\conda-bld\numba_1533826688819\_test_env\lib\site-packages\numba\cuda\tests\cudapy\test_constmem.py", line 116, in test_const_record
    "the compiler realises it doesn't even need to " \
AssertionError: 'ld.const.v2.u64' not found in '//\n// Generated by NVIDIA NVVM Compiler\n//\n// Compiler Build ID: CL-21373419\n// Cuda compilation tools, release 8.0, V8.0.55\n// Based on LLVM 3.4svn\n//\n\n.version 5.0\n.target 

@njwhite
Copy link
Contributor Author

njwhite commented Aug 9, 2018

@sklam re: the python3.7 error - why is it using the v8 toolkit, not a v9 / v9.1 one? Can you include the full PTX so I can see what nvcc generated? Maybe the tests only pass with a recent nvcc version...

@seibert
Copy link
Contributor

seibert commented Aug 13, 2018

We test with CUDA 8, 9, and 9.1, so it is possible that NVVM from one of those versions behaves differently than the others? I'll see if I can dig out the full PTX result

@seibert seibert added this to In Progress in Minor Features Aug 27, 2018
@stuartarchibald
Copy link
Contributor

Build farm is failing on:

  • windows64 + Python 3, all CUDA toolkits (8, 9.0, 9.1)
  • windows and linux + Python 2, all CUDA toolkits (8, 9.0, 9.1)

Python 2 tests are failing because of this:

======================================================================
ERROR: test_cuda (__main__.TestCase)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "<path>/conda-bld/numba_1536016014099/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_/lib/python2.7/site-packages/numba/tests/test_runtests.py", line 77, in test_cuda
    self.check_testsuite_size(['numba.cuda.tests'], 1, 470)
  File "<path>/conda-bld/numba_1536016014099/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_/lib/python2.7/site-packages/numba/tests/test_runtests.py", line 57, in check_testsuite_size
    e.output.decode('UTF-8').splitlines()])
  File "<path>/conda-bld/numba_1536016014099/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_/lib/python2.7/encodings/utf_8.py", line 16, in decode
    return codecs.utf_8_decode(input, errors, True)
UnicodeDecodeError: 'utf8' codec can't decode byte 0xef in position 43116: invalid continuation byte

Without looking I'd guess that the test suite is larger than it thinks it ought to be due to new tests OR that there's a really long error message or similar that's appearing in test discovery and that's being decoded.

There's also an error message like this on windows 64 + python 2:

	======================================================================
	ERROR: test_const_array (numba.cuda.tests.cudapy.test_constmem.TestCudaConstantMemory)
	----------------------------------------------------------------------
	Traceback (most recent call last):
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\numba\cuda\tests\cudapy\test_constmem.py", line 88, in test_const_array
	    jcuconst = cuda.jit('void(float64[:])')(cuconst)
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\numba\cuda\decorators.py", line 92, in kernel_jit
	    inline=inline, fastmath=fastmath)
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\numba\cuda\compiler.py", line 39, in core
	    return fn(*args, **kwargs)
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\numba\cuda\compiler.py", line 78, in compile_kernel
	    cres = compile_cuda(pyfunc, types.void, args, debug=debug, inline=inline)
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\numba\cuda\compiler.py", line 39, in core
	    return fn(*args, **kwargs)
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\numba\cuda\compiler.py", line 67, in compile_cuda
	    locals={})
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\numba\compiler.py", line 903, in compile_extra
	    return pipeline.compile_extra(func)
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\numba\compiler.py", line 369, in compile_extra
	    return self._compile_bytecode()
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\numba\compiler.py", line 834, in _compile_bytecode
	    return self._compile_core()
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\numba\compiler.py", line 821, in _compile_core
	    res = pm.run(self.status)
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\numba\compiler.py", line 255, in run
	    raise patched_exception
	RuntimeError: Caused By:
	Traceback (most recent call last):
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\numba\compiler.py", line 246, in run
	    stage()
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\numba\compiler.py", line 699, in stage_nopython_backend
	    self._backend(lowerfn, objectmode=False)
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\numba\compiler.py", line 649, in _backend
	    lowered = lowerfn()
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\numba\compiler.py", line 636, in backend_nopython_mode
	    self.flags)
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\numba\compiler.py", line 1022, in native_lowering_stage
	    lower.lower()
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\numba\lowering.py", line 198, in lower
	    self.library.add_ir_module(self.module)
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\numba\targets\codegen.py", line 184, in add_ir_module
	    ll_module = ll.parse_assembly(ir)
	  File "<path>\conda-bld\numba_1536018978198\_test_env\lib\site-packages\llvmlite\binding\module.py", line 22, in parse_assembly
	    raise RuntimeError("LLVM IR parsing error\n{0}".format(errmsg))
	RuntimeError: LLVM IR parsing error
	<string>:235:64: error: expected value token
	@"_cudapy_cmem" = internal addrspace(4) constant [80 x i8] [i8 
	                                                               ^
	
	
	LLVM IR parsing error

Output excerpt for failure on Python 3 + Windows 64:

======================================================================

FAIL: test_const_record (numba.cuda.tests.cudapy.test_constmem.TestCudaConstantMemory)

----------------------------------------------------------------------

Traceback (most recent call last):

  File �<path>\conda-bld
umba_1536016443225\_test_env\lib\site-packages
umba\cuda	ests\cudapy	est_constmem.py", line 144, in test_const_record

    "the compiler realises it doesn't even need to " 
AssertionError: 'ld.const.v2.u64' not found in '//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-23083092
// Cuda compilation tools, release 9.1, V9.1.85
// Based on LLVM 3.4svn
//

.version 6.1
.target sm_35
.address_size 64

	// .globl	_ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE
.visible .global .align 4 .u32 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE__errcode__;
.visible .global .align 4 .u32 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE__tidx__;
.visible .global .align 4 .u32 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE__ctaidx__;
.visible .global .align 4 .u32 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE__tidy__;
.visible .global .align 4 .u32 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE__ctaidy__;
.visible .global .align 4 .u32 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE__tidz__;
.visible .global .align 4 .u32 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE__ctaidz__;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets7numbers16complex_div_impl12$3clocals$3e17complex_div$24148E9complex649complex64;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba6unsafe7ndarray14to_fixed_tuple12$3clocals$3e7codegen12$3clocals$3e8impl$249E5ArrayIxLi1E1C7mutable7alignedEx8UniTupleIxLi3EE;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba6unsafe7ndarray14to_fixed_tuple12$3clocals$3e7codegen12$3clocals$3e8impl$245E5ArrayIxLi1E1C7mutable7alignedEx8UniTupleIxLi2EE;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba6unsafe7ndarray14to_fixed_tuple12$3clocals$3e7codegen12$3clocals$3e8impl$247E5ArrayIxLi1E1C7mutable7alignedEx8UniTupleIxLi1EE;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl10atanh_impl12$3clocals$3e15atanh_impl$2490E10complex128;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl10acosh_impl12$3clocals$3e15acosh_impl$2471E10complex128;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets8arrayobj13_change_dtype12$3clocals$3e8imp$2421Ex5ArrayIxLi1E1C7mutable7alignedE5ArrayIxLi1E1C7mutable7alignedExxa;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl9sqrt_impl12$3clocals$3e14sqrt_impl$2465E10complex128;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl9acos_impl12$3clocals$3e14acos_impl$2468E9complex64;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl9acos_impl12$3clocals$3e14acos_impl$2464E10complex128;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl10asinh_impl12$3clocals$3e15asinh_impl$2482E9complex64;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl10asinh_impl12$3clocals$3e15asinh_impl$2478E10complex128;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl10acosh_impl12$3clocals$3e15acosh_impl$2474E9complex64;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl9asin_impl12$3clocals$3e14asin_impl$2477E10complex128;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl9asin_impl12$3clocals$3e14asin_impl$2481E9complex64;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl13log_base_impl12$3clocals$3e14log_base$24143E10complex12810complex128;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets7numbers16complex_div_impl12$3clocals$3e17complex_div$24144E10complex12810complex128;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl9cosh_impl12$3clocals$3e15cosh_impl$24106E9complex64;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl10atanh_impl12$3clocals$3e15atanh_impl$2494E9complex64;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl9atan_impl12$3clocals$3e14atan_impl$2489E10complex128;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl9cosh_impl12$3clocals$3e15cosh_impl$24102E10complex128;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl9atan_impl12$3clocals$3e14atan_impl$2493E9complex64;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl8cos_impl12$3clocals$3e14cos_impl$24101E10complex128;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl10log10_impl12$3clocals$3e16log10_impl$24140E9complex64;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl14exp_impl$24116Effbb;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl8cos_impl12$3clocals$3e14cos_impl$24105E9complex64;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl14exp_impl$24113Eddbb;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl10log10_impl12$3clocals$3e16log10_impl$24137E10complex128;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl14log_impl$24131Eddbb;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl14log_impl$24134Effbb;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl9sinh_impl12$3clocals$3e15sinh_impl$24174E9complex64;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl16phase_impl$24151Eddbb;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl13log_base_impl12$3clocals$3e14log_base$24147E9complex649complex64;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl8sin_impl12$3clocals$3e14sin_impl$24169E10complex128;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl16phase_impl$24154Effbb;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl9rect_impl12$3clocals$3e10rect$24163Eddb;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl16polar_impl$24160Effbb;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl16polar_impl$24157Eddbb;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl9rect_impl12$3clocals$3e10rect$24166Effb;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl9sinh_impl12$3clocals$3e15sinh_impl$24170E10complex128;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl8sin_impl12$3clocals$3e14sin_impl$24173E9complex64;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl8tan_impl12$3clocals$3e14tan_impl$24186E10complex128;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl9tanh_impl12$3clocals$3e15tanh_impl$24187E10complex128;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl9sqrt_impl12$3clocals$3e15sqrt_impl$24183E9complex64;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl9tanh_impl12$3clocals$3e15tanh_impl$24191E9complex64;
.common .global .align 8 .u64 _ZN08NumbaEnv5numba7targets9cmathimpl8tan_impl12$3clocals$3e14tan_impl$24190E9complex64;
.const .align 1 .b8 _cudapy_cmem[24] = {0, 0, 0, 0, 0, 0, 240, 63, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 64, 4, 0, 0, 0};

.visible .entry _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE(
	.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE_param_0,
	.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE_param_1,
	.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE_param_2,
	.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE_param_3,
	.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE_param_4,
	.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE_param_5,
	.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE_param_6,
	.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE_param_7,
	.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE_param_8,
	.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE_param_9,
	.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE_param_10,
	.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE_param_11,
	.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE_param_12,
	.param .u64 _ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE_param_13
)
{
	.reg .pred 	%p<2>;
	.reg .b32 	%r<8>;
	.reg .b64 	%rd<25>;


	ld.param.u64 	%rd1, [_ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE_param_4];
	ld.param.u64 	%rd2, [_ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE_param_5];
	ld.param.u64 	%rd3, [_ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE_param_11];
	ld.param.u64 	%rd4, [_ZN6cudapy5numba4cuda5tests6cudapy13test_constmem16cuconstRec$24214E5ArrayIdLi1E1C7mutable7alignedE5ArrayIiLi1E1C7mutable7alignedE_param_12];
	cvta.to.global.u64 	%rd5, %rd3;
	cvta.to.global.u64 	%rd6, %rd1;
	mov.u32 	%r1, %tid.x;
	mov.u32 	%r2, %ntid.x;
	mov.u32 	%r3, %ctaid.x;
	mad.lo.s32 	%r4, %r3, %r2, %r1;
	cvt.s64.s32	%rd7, %r4;
	setp.lt.s32	%p1, %r4, 0;
	shr.u32 	%r5, %r4, 30;
	and.b32  	%r6, %r5, 2;
	cvt.u64.u32	%rd8, %r6;
	add.s64 	%rd9, %rd8, %rd7;
	mul.lo.s64 	%rd10, %rd9, 12;
	mov.u64 	%rd11, _cudapy_cmem;
	add.s64 	%rd12, %rd11, %rd10;
	ld.const.u32 	%rd13, [%rd12];
	ld.const.u32 	%rd14, [%rd12+4];
	shl.b64 	%rd15, %rd14, 32;
	or.b64  	%rd16, %rd15, %rd13;
	selp.b64	%rd17, %rd2, 0, %p1;
	add.s64 	%rd18, %rd17, %rd7;
	shl.b64 	%rd19, %rd18, 3;
	add.s64 	%rd20, %rd6, %rd19;
	st.global.u64 	[%rd20], %rd16;
	ld.const.u32 	%r7, [%rd12+8];
	selp.b64	%rd21, %rd4, 0, %p1;
	add.s64 	%rd22, %rd21, %rd7;
	shl.b64 	%rd23, %rd22, 2;
	add.s64 	%rd24, %rd5, %rd23;
	st.global.u32 	[%rd24], %r7;
	ret;
}


' : the compiler realises it doesn't even need to interpret the bytes as float!



----------------------------------------------------------------------

Ran 475 tests in 84.786s

@njwhite
Copy link
Contributor Author

njwhite commented Sep 4, 2018

@stuartarchibald thanks for that - I've pushed fixes to my branch (ndarray.tobytes returns a string in py2, so it was inserting odd non-unicode byte values into the LLVM IR instead of the byte's integer ordinal value!). I've expanded the ld.const.v2.u64 test to accept the ld.const.u32 that Win64 / Py3 / CUDA 9.1 generated; it's not loading byte-by-byte but it's not fusing the loads of the record's two fields into a single instruction either :/

@seibert seibert mentioned this pull request Sep 10, 2018
17 tasks
@seibert seibert added BuildFarm Passed For PRs that have been through the buildfarm and passed and removed Pending BuildFarm For PRs that have been reviewed but pending a push through our buildfarm labels Sep 10, 2018
@seibert
Copy link
Contributor

seibert commented Sep 10, 2018

This is passing on the build farm and ready to merge.

@seibert seibert merged commit 73fc6fe into numba:master Sep 10, 2018
Minor Features automation moved this from In Progress to Done Sep 10, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3 - Ready for Review BuildFarm Passed For PRs that have been through the buildfarm and passed CUDA CUDA related issue/PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants