-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
COO GPU spmv kernel #57
Conversation
Does the code run for you? I receive for the tester:
|
works now - looks good to me! |
@gflegar no conflicts here. |
gpu/matrix/coo_kernels.cu
Outdated
__forceinline__ __device__ static cuDoubleComplex atomic_add( | ||
cuDoubleComplex *address, cuDoubleComplex val) | ||
{ | ||
// Seperate to real part and imag part |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: s/seperate/separate : here and a few places below.
gpu/matrix/coo_kernels.cu
Outdated
|
||
|
||
template <typename ValueType, typename IndexType> | ||
__global__ __launch_bounds__(128) void spmv_kernel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As this is the actual kernel, some documentation explaining the working would be nice.
gpu/matrix/coo_kernels.cu
Outdated
const matrix::Coo<ValueType, IndexType> *a, | ||
const matrix::Dense<ValueType> *b, matrix::Dense<ValueType> *c) | ||
{ | ||
int multiple = 8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hard-coding values like this is always difficult to warrant. I guess here it is probably okay. But maybe it should always accompany some metrics/ reasoning as to why this value was chosen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you that it shouldn't look like this.
However, AFAIK, these are tuned parameters obtained from running experiments (and might differ for different architectures). I also have a similar problem in #49, and have them hardcoded.
The problem is that we don't have this problem resolved in Ginkgo - the idea is to be able to say something like: "this is a parameter that should be tuned, valid values are from the set S, and use x by default". Then the user should be able to do an autotuning run of ginkgo, which will set the parameters to proper values for their system.
However we do it right now, there's a good chance we'll have to change it when we implement autotuning, so I'm fine for now with having it like this + adding a "TODO" that this should be changed when we support autotuning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. However, there is some code duplication, see my comments for suggestions on how to resolve it.
gpu/matrix/coo_kernels.cu
Outdated
#include "gpu/components/shuffle.cuh" | ||
#include "gpu/components/synchronization.cuh" | ||
|
||
namespace gko { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: another empty line above this
gpu/matrix/coo_kernels.cu
Outdated
cuDoubleComplex *cuval = reinterpret_cast<cuDoubleComplex *>(&val); | ||
atomic_add(cuaddr, *cuval); | ||
return *address; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do the following for all overloads of atomic_add
:
- All of them could be useful in other algorithms, so let's move them into a separate header, say:
gpu/components/atomic.cuh
and nest them insides thegko::kernels::gpu
namespace. - There is no need to have them in anonymous namespace,
static
, and__forceinline__
. Any one of them will prevent symbol ambiguities. In this case, we can forget about anonymous namespace andstatic
qualifiers and just use__forceinline__
. - There's no need for
cuComplex
andcuDoubleComplex
overloads - we always usethrust::complex
, the other types are only here for cuBLAS / cuSPARSE interoperability. - The complex versions are incorrect, as they return the new value, instead of the old one. I don't think it's possible to implement this correctly without significant performance penalties (basically, we would have to implement a mutex to do it properly), so I suggest we just remove the return value, and have all our
atomic_add
functions returnvoid
.
gpu/matrix/coo_kernels.cu
Outdated
{ | ||
ValueType add_val; | ||
#pragma unroll | ||
for (int i = 1; i < 32; i <<= 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/32/cuda_config::warp_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also do this everywhere in the code - a good rule of thumb is to never rely on "magic values" like 32
or 128
. Instead, you define these values as constants somewhere, and use the constant. This way the code is easier to understand (as you have something like warp_size
, or warps_in_block * warp_size
and not something strange like 32
or 128
), and easier to maintain if you need to change the value with something else in the future.
gpu/matrix/coo_kernels.cu
Outdated
|
||
|
||
template <typename ValueType, typename IndexType> | ||
__device__ __forceinline__ void segment_scan(IndexType *ind, ValueType *val, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: ind
is not modified by this function so it should be just a plain variable
gpu/matrix/coo_kernels.cu
Outdated
#pragma unroll | ||
for (int i = 1; i < 32; i <<= 1) { | ||
const IndexType add_ind = warp::shuffle_up(*ind, i); | ||
add_val = zero<ValueType>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: In contrast to C and Fortran, C++ allows you to declare variables wherever you want, and the guidelines advise to avoid uninitialized variables, and define them as late as possible.
Here, you should remove ValueType add_val;
from the beginning of the function, and in this line use:
auto add_val = zero<ValueType>();
gpu/matrix/coo_kernels.cu
Outdated
IndexType next_row; | ||
for (; ind < ind_end; ind += 32) { | ||
temp_val += (ind < nnz) ? val[ind] * b[col[ind]] : zero<ValueType>(); | ||
next_row = (ind + 32 < nnz) ? row[ind + 32] : row[nnz - 1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: same as before, auto next_row = ...
, and remove declaration of next_row
before the loop
gpu/matrix/coo_kernels.cu
Outdated
next_row = (ind + 32 < nnz) ? row[ind + 32] : row[nnz - 1]; | ||
// segmented scan | ||
const bool is_scan = temp_row != next_row; | ||
if (warp::any(is_scan)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why not just if (warp::any(curr_row != next_row))
it reads just fine: "if any thread has different curr_row
than next_row
". The version with a temporary bool value is more confusing than this
gpu/matrix/coo_kernels.cu
Outdated
if (warp::any(is_scan)) { | ||
atomichead = true; | ||
segment_scan(&temp_row, &temp_val, &atomichead); | ||
if (atomichead) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe s/atomichead/is_first_in_segment would make it clearer
gpu/matrix/coo_kernels.cu
Outdated
int nwarps = config * multiple; | ||
if (nwarps > ceildiv(nnz, 32)) { | ||
nwarps = ceildiv(nnz, 32); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is repeated in both spmv
kernels. We should extract it into a separate function to avoid code duplication
gpu/matrix/coo_kernels.cu
Outdated
atomic_add(&(c[temp_row]), alpha_val * temp_val); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the only difference between this kernel and the other one in how atomic_add
is called (this one with alpha, the other one without it)? If so, I propose we do the following to remove code duplication:
- extract this entire kernel into a device function,
- add a functor template parameter to it:
template </* other template parameters */, typename Closure> __device__ /* specifiers */ spmv(/* parameters */, Closure scale) { /* body */ }
- replace
with
atomic_add(/* destination */, alpha_val * temp_val)
in device function body.atomic_add(/* destination */, scale(temp_val));
- implement the "simple" apply kernel as:
spmv(/* parameters */, [](const ValueType &x) { return x; });
- implement the "advanced" apply kernel as:
ValueType scale_factor = alpha[0]; spmv(/* parameters */, [&scale_factor](const ValueType &x) { return scale_factor * x; });
Thanks for your help. |
gpu/components/atomic.cuh
Outdated
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600)) | ||
|
||
|
||
__forceinline__ __device__ static double atomic_add(double *addr, double val) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should also return void
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and no need for static
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix it
@yhmtsai yes, as it is written now (with functions returning void), the complex versions of atomic add are "atomic enough" for our use-case (atomic enough = if we have multiple calls to atomic add, they will all accumulate their values, even though it won't exactly be done atomically). Everything looks good now. There's just this minor thing with the wrong return value in Kepler implementation of double atomic add. We can merge this as soon as it is fixed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR #57 introduced atomic operations for standard types, but not for custom value types, which causes compilation to fail for other value types (e.g. FloatX numbers). This PR adds a dummy implementation of the general atomic_add (which just causes an assertion failure), so Ginkgo at least compiles with custom types, even though parts which use it do not work properly.
COO GPU spmv kernel