Skip to content

Commit

Permalink
Fix scatter_nd_add and gather bug (PaddlePaddle#35544)
Browse files Browse the repository at this point in the history
* fix scatter_add_nd and gather bug

* fix gather compile error
  • Loading branch information
sneaxiy authored and niuliling123 committed Sep 29, 2021
1 parent 6d3a29d commit 5ba9e21
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 117 deletions.
104 changes: 53 additions & 51 deletions paddle/fluid/operators/gather.cu.h
Expand Up @@ -32,23 +32,23 @@ template <typename T, typename IndexT = int>
__global__ void GatherCUDAKernel(const T* params, const IndexT* indices,
T* output, size_t index_size,
size_t slice_size) {
CUDA_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT gather_i = indices[indices_i];
IndexT params_i = gather_i * slice_size + slice_i;
*(output + i) = *(params + params_i);
}
}

template <typename T, typename IndexT = int>
__global__ void GatherNdCUDAKernel(const T* input, const int* input_dims,
__global__ void GatherNdCUDAKernel(const T* input, const int64_t* input_dims,
const IndexT* indices, T* output,
size_t remain_size, size_t slice_size,
size_t end_size) {
CUDA_KERNEL_LOOP(i, remain_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
CUDA_KERNEL_LOOP_TYPE(i, remain_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT gather_i = 0;
int64_t temp = slice_size;
for (int64_t j = end_size - 1; j >= 0; --j) {
Expand Down Expand Up @@ -92,23 +92,23 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
}

// index size
int index_size = index.dims()[0];
int64_t index_size = index.dims()[0];

auto src_dims = src.dims();
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;

// slice size
int slice_size = 1;
int64_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];

const T* p_src = src.data<T>();
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();

int block = 512;
int n = slice_size * index_size;
int grid = (n + block - 1) / block;
int64_t n = slice_size * index_size;
int64_t grid = (n + block - 1) / block;

GatherCUDAKernel<T, IndexT><<<
grid, block, 0,
Expand Down Expand Up @@ -143,21 +143,21 @@ void GPUGatherNd(const framework::ExecutionContext& context,
slice_size *= input_dims[i];
}
// source dim
std::vector<int> v_input_dims(input_dims_size);
std::vector<int64_t> v_input_dims(input_dims_size);
for (int i = 0; i < input_dims_size; ++i) {
v_input_dims[i] = static_cast<int>(input_dims[i]);
v_input_dims[i] = input_dims[i];
}

auto& dev_ctx = context.cuda_device_context();
int bytes = input_dims_size * sizeof(int);
int64_t bytes = input_dims_size * sizeof(int64_t);
auto p_input_dims = memory::Alloc(dev_ctx, bytes);
int* g_input_dims = reinterpret_cast<int*>(p_input_dims->ptr());
int64_t* g_input_dims = reinterpret_cast<int64_t*>(p_input_dims->ptr());
memory::Copy(gplace, g_input_dims, cplace, v_input_dims.data(), bytes,
ctx.stream());

int block = 512;
int n = slice_size * remain_numel;
int grid = (n + block - 1) / block;
int64_t n = slice_size * remain_numel;
int64_t grid = (n + block - 1) / block;

GatherNdCUDAKernel<T, IndexT><<<
grid, block, 0,
Expand All @@ -168,16 +168,16 @@ void GPUGatherNd(const framework::ExecutionContext& context,

template <typename T, typename U>
__global__ void GatherGPUKernel(const T* input, const U* index, T* out,
int outer_dim_size, int inner_dim_size,
int out_index_dim_size,
int input_index_dim_size, int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int outer_size = outer_dim_size * out_index_dim_size;
int64_t outer_dim_size, int64_t inner_dim_size,
int64_t out_index_dim_size,
int64_t input_index_dim_size, int64_t size) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
int64_t outer_size = outer_dim_size * out_index_dim_size;
for (; idx < size; idx += blockDim.x * gridDim.x) {
int inner_dim_index = idx / outer_size;
int next_idx = idx - outer_size * inner_dim_index;
int index_dim_index = next_idx / outer_dim_size;
int index_val = index[index_dim_index];
int64_t inner_dim_index = idx / outer_size;
int64_t next_idx = idx - outer_size * inner_dim_index;
int64_t index_dim_index = next_idx / outer_dim_size;
U index_val = index[index_dim_index];

PADDLE_ENFORCE(
index_val >= 0 && index_val < input_index_dim_size,
Expand All @@ -187,8 +187,8 @@ __global__ void GatherGPUKernel(const T* input, const U* index, T* out,
"be less than [%d] and greater than or equal to 0, but received [%d]",
input_index_dim_size, index_val);

int out_dim_index = next_idx - outer_dim_size * index_dim_index;
int input_index =
int64_t out_dim_index = next_idx - outer_dim_size * index_dim_index;
int64_t input_index =
inner_dim_index * (outer_dim_size * input_index_dim_size) +
index_val * outer_dim_size + out_dim_index;
out[idx] = input[input_index];
Expand All @@ -197,17 +197,19 @@ __global__ void GatherGPUKernel(const T* input, const U* index, T* out,

template <typename T, typename U>
__global__ void GatherGradGPUKernel(const T* input, const U* index, T* out,
int outer_dim_size, int inner_dim_size,
int input_index_dim_size,
int out_index_dim_size, int size) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int64_t outer_dim_size,
int64_t inner_dim_size,
int64_t input_index_dim_size,
int64_t out_index_dim_size, int64_t size) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < size; idx += blockDim.x * gridDim.x) {
int inner_dim_index = idx / (outer_dim_size * input_index_dim_size);
int next_idx = idx % (outer_dim_size * input_index_dim_size);
int index_dim_index = next_idx / (outer_dim_size);
int out_dim_index = next_idx % outer_dim_size;
int out_index = inner_dim_index * (outer_dim_size * out_index_dim_size) +
index[index_dim_index] * outer_dim_size + out_dim_index;
int64_t inner_dim_index = idx / (outer_dim_size * input_index_dim_size);
int64_t next_idx = idx % (outer_dim_size * input_index_dim_size);
int64_t index_dim_index = next_idx / (outer_dim_size);
int64_t out_dim_index = next_idx % outer_dim_size;
int64_t out_index =
inner_dim_index * (outer_dim_size * out_index_dim_size) +
index[index_dim_index] * outer_dim_size + out_dim_index;
paddle::platform::CudaAtomicAdd(out + out_index, *(input + idx));
}
}
Expand All @@ -217,20 +219,20 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
const int axis, Tensor* out,
const paddle::platform::Place& place,
const framework::ExecutionContext& ctx) {
int index_size = index->numel();
int input_size = input->numel();
int64_t index_size = index->numel();
int64_t input_size = input->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();
auto* index_data = index->data<U>();

if (input->numel() == 0) return;

int axis_index = axis;
int index_dim_size = input_dim[axis_index];
int64_t index_dim_size = input_dim[axis_index];

int inner_dim_size = 1;
int outer_dim_size = 1;
std::vector<int> out_dim_vec;
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
std::vector<int64_t> out_dim_vec;

for (int i = 0; i < axis_index; i++) {
inner_dim_size *= input_dim[i];
Expand All @@ -245,7 +247,7 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,

out->Resize(out_dim);
auto* out_data = out->mutable_data<T>(place);
int out_size = out->numel();
int64_t out_size = out->numel();

platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), out_size);
Expand All @@ -262,17 +264,17 @@ void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index,
const paddle::platform::Place& place,
const framework::ExecutionContext& ctx) {
auto* index_data = index->data<U>();
int index_size = index->numel();
int input_size = input->numel();
int64_t index_size = index->numel();
int64_t input_size = input->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();

if (input->numel() == 0) return;
int axis_index = axis;
int input_index_dim_size = input_dim[axis_index];
int64_t input_index_dim_size = input_dim[axis_index];

int inner_dim_size = 1;
int outer_dim_size = 1;
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;

for (int i = 0; i < axis_index; i++) {
inner_dim_size *= input_dim[i];
Expand All @@ -284,7 +286,7 @@ void GatherV2GradCUDAFunction(const Tensor* input, const Tensor* index,
auto* out_data = out->mutable_data<T>(place);
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto out_dim = out->dims();
int out_index_dim_size = out_dim[axis_index];
int64_t out_index_dim_size = out_dim[axis_index];
operators::math::set_constant(*dev_ctx, out, 0.0);

platform::GpuLaunchConfig config =
Expand Down
46 changes: 23 additions & 23 deletions paddle/fluid/operators/gather.h
Expand Up @@ -65,10 +65,10 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
T* p_output = output->data<T>();

// slice size
int slice_size = 1;
int64_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
// input size
int input_size = src_dims[0] * slice_size;
int64_t input_size = src_dims[0] * slice_size;

const size_t slice_bytes = slice_size * sizeof(T);

Expand Down Expand Up @@ -144,16 +144,16 @@ template <typename T, typename U>
void GatherV2Function(const Tensor* input, const Tensor* index, int axis,
Tensor* out, const paddle::platform::Place& place) {
auto* index_data = index->data<U>();
int index_size = index->numel();
int input_size = input->numel();
int64_t index_size = index->numel();
int64_t input_size = input->numel();
auto input_dim = input->dims();
auto* input_data = input->data<T>();

if (input->numel() == 0) return;
int axis_index = axis;

int input_index_dim_size = input_dim[axis_index];
for (int i = 0; i < index_size; i++) {
int64_t input_index_dim_size = input_dim[axis_index];
for (int64_t i = 0; i < index_size; i++) {
PADDLE_ENFORCE_LT(index_data[i], input_index_dim_size,
platform::errors::OutOfRange(
"The element of Index must be less than the size of "
Expand All @@ -168,9 +168,9 @@ void GatherV2Function(const Tensor* input, const Tensor* index, int axis,
index_data[i], i));
}

int inner_dim_size = 1;
int outer_dim_size = 1;
std::vector<int> out_dim_vec;
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
std::vector<int64_t> out_dim_vec;

for (int i = 0; i < axis_index; i++) {
inner_dim_size *= input_dim[i];
Expand All @@ -187,11 +187,11 @@ void GatherV2Function(const Tensor* input, const Tensor* index, int axis,
auto* out_data = out->mutable_data<T>(place);

int out_index = 0;
for (int i = 0; i < inner_dim_size; i++) {
for (int j = 0; j < index_size; j++) {
for (int k = 0; k < outer_dim_size; k++) {
int index = k + index_data[j] * outer_dim_size +
(i * input_size / inner_dim_size);
for (int64_t i = 0; i < inner_dim_size; i++) {
for (int64_t j = 0; j < index_size; j++) {
for (int64_t k = 0; k < outer_dim_size; k++) {
int64_t index = k + index_data[j] * outer_dim_size +
(i * input_size / inner_dim_size);
out_data[out_index] = input_data[index];
out_index++;
}
Expand All @@ -210,10 +210,10 @@ void GatherV2GradFunction(const Tensor* input, const Tensor* index,

if (input->numel() == 0) return;
int axis_index = axis;
int input_index_dim_size = input_dim[axis_index];
int64_t input_index_dim_size = input_dim[axis_index];

int inner_dim_size = 1;
int outer_dim_size = 1;
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;

for (int i = 0; i < axis_index; i++) {
inner_dim_size *= input_dim[i];
Expand All @@ -225,14 +225,14 @@ void GatherV2GradFunction(const Tensor* input, const Tensor* index,
auto* out_data = out->mutable_data<T>(place);
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto out_dim = out->dims();
int out_index_dim_size = out_dim[axis_index];
int64_t out_index_dim_size = out_dim[axis_index];
operators::math::set_constant(*dev_ctx, out, 0.0);

for (int i = 0; i < inner_dim_size; i++) {
for (int j = 0; j < input_index_dim_size; j++) {
for (int k = 0; k < outer_dim_size; k++) {
int index = k + index_data[j] * outer_dim_size +
i * outer_dim_size * out_index_dim_size;
for (int64_t i = 0; i < inner_dim_size; i++) {
for (int64_t j = 0; j < input_index_dim_size; j++) {
for (int64_t k = 0; k < outer_dim_size; k++) {
int64_t index = k + index_data[j] * outer_dim_size +
i * outer_dim_size * out_index_dim_size;
out_data[index] += input_data[j * outer_dim_size + k];
}
}
Expand Down

0 comments on commit 5ba9e21

Please sign in to comment.