Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

cudaErrorIllegalAddress error in DeviceSelect::If() #78

Closed
daktfi opened this issue Dec 23, 2016 · 5 comments
Closed

cudaErrorIllegalAddress error in DeviceSelect::If() #78

daktfi opened this issue Dec 23, 2016 · 5 comments

Comments

@daktfi
Copy link

daktfi commented Dec 23, 2016

Being still new to CUDA and CUB, I'm not quite sure it's a CUB error and not mine, but this time I did my very best to exclude all other sources.

The Setup:
the GeForce GTX 950, CUDA 8.0, CUB 1.6.4

I allocated two arrays of uint3 (15 millions each) and requested by DeviceSelect::If() 313087 bytes for temp_storage. The compare functor is modified from snippet accordingly to type:

struct Check {
	int compare;
	CUB_RUNTIME_FUNCTION __forceinline__
	Check( int compare ) : compare( compare ) {}

	CUB_RUNTIME_FUNCTION __forceinline__
	bool operator()( const uint3 &a ) const {
		return ( a.x > compare );
	}
};

This doesn't compile with following error:

/home/daktfi/cub/cub/device/dispatch/../../agent/agent_select_if.cuh(279): error: calling a host function("Check::operator ()") from a device function("cub::AgentSelectIf< ::cub::DispatchSelectIf< ::uint3 *, ::cub::NullType *, ::uint3 *, unsigned int *, ::Check, ::cub::NullType, int, (bool)0> ::PtxSelectIfPolicyT, ::uint3 *, ::cub::NullType *, ::uint3 *, ::Check, ::cub::NullType, int, (bool)0> ::InitializeSelections<(bool)1, (bool)0> ") is not allowed

So, I changed the declaration to

struct Check {
	int compare;
	__host__ __device__ __forceinline__
	Check( int compare ) : compare( compare ) {}

	__host__ __device__ __forceinline__
	bool operator()( const uint3 &a ) const {
		return ( a.x > compare );
	}
};

specifying both host and device for methods. This way the code compiles, but after launch of DeviceSelect::If() and cudaDeviceSynchronise() cudaPeekAtLastError() returns 77 (cudaErrorIllegalAddress). I'm pretty much sure this isn't carried from previous asynchronous calls (after every device-related call I invoke cudaDeviceSynchronise() and check cudaPeekAtLastError()), but quite in doubt where else it can come from... All the device code is executed in non-default cudaStream s, created at the very beginning.
The code snippet is:

Check ch0( 0 );
// Memory is allocated at d_ptr, size equals mem_req, counted ahead of time
uint3 *ptr_a = ( uint3 * )d_ptr, *ptr_b = ptr_a + row_count;
void *d_tmp = ptr_b + row_count;

fprintf( stderr, "DevSelect( %p, %lu, %p, %p, %lu ) %p of %p\n", d_tmp, tmp_select, ptr_a,
	 ptr_b, row_count, ( char * ) d_tmp + tmp_select, ( char * )d_ptr + mem_req );

size_t tmp_sel_ask;
cub::DeviceSelect::If( nullptr, tmp_sel_ask, ptr_a, ptr_b, &length, row_count, ch0, s, true );

cudaDeviceSynchronize();
rc = cudaPeekAtLastError();

if( rc ) {
	std::cerr << "Some cuda error " << cudaGetLastError() << std::endl;
	cudaFree( d_ptr );
	return rc;
} else
	std::cerr << "DevSelect ready, needs " << tmp_sel_ask << " of " << tmp_select << std::endl;

if( tmp_sel_ask > tmp_select ) {
	cudaFree( d_ptr );
	return rc;
} else
	std::cerr << "DevSelect clear to go" << std::endl;

cub::DeviceSelect::If( d_tmp, tmp_select, ptr_a, ptr_b, &length, row_count, ch0, s, true );

cudaDeviceSynchronize();
rc = cudaPeekAtLastError();

if( rc ) {
	std::cerr << "DevSelect error " << cudaGetLastError() << std::endl;
	cudaFree( d_ptr );
	return rc;
} else
	std::cerr << "DevSelect ok" << std::endl;

The output is:

DevSelect( 0xb188d2a00, 313087, 0xb03180000, 0xb0dd29500, 15000000 ) 0xb1891f0ff of 0xb1891f0ff
DevSelect ready, needs 313087 of 313087
DevSelect clear to go
Invoking scan_init_kernel<<<306, 128, 0, 140458720343936>>>()
DevSelect error 77

Can you please look into the issue and send me the right way to fix it?

@daktfi
Copy link
Author

daktfi commented Dec 26, 2016

Launching my app under cuda-memcheck I got following message from it:
Invoking scan_init_kernel<<<306, 128, 0, 139992313554640>>>()

========= Invalid __global__ write of size 4
=========     at 0x00000128 in void cub::DeviceCompactInitKernel<cub::ScanTileState<int, bool=1>, unsigned int*>(int, int, bool=1)
=========     by thread (0,0,0) in block (0,0,0)
=========     Address 0x7f52921b9268 is out of bounds

@daktfi
Copy link
Author

daktfi commented Dec 26, 2016

After I compiled the code with -lineinfo option I got the row number:
at 0x00000128 in /home/daktfi/polymatica/cub/cub/device/dispatch/dispatch_scan.cuh:88:void cub::DeviceCompactInitKernel<cub::ScanTileState<int, bool=1>, unsigned int*>(int, int, bool=1)
Hope this helps.
The -G option did not worked out - my code is part of larger project and at the moment I'm unable to compile it separately, as according to this NVIDIA/thrust#864 it requires extra nvcc linking pass.

@daktfi
Copy link
Author

daktfi commented Dec 27, 2016

I finally got the working standalone project (attached) which reproduces error and get it commented. There's quite complicated logic behind the the issue, so if You have question - I'll gladly explain.

devselect.zip

@dumerrill
Copy link
Contributor

Hi, I think the "an illegal memory access was encountered" issue is probably related to cub::DeviceSelect:If() expecting a pointer to device memory for where to output the number of items selected. (You are simply referencing the auto-var length on the host thread's program stack, which isn't visible to the GPU threads.) You'll want to create a device-allocation for that.

Let me know if that helps!

@daktfi
Copy link
Author

daktfi commented Jan 13, 2017

Only one word: "oops". :-(
The most stupid bugs are hardest to find...
Worked like a charm! Thanks a lot!!!

@daktfi daktfi closed this as completed Jan 13, 2017
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants