diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc index a897fe505d4e2..1894e6102cbea 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cc @@ -107,6 +107,10 @@ class FusedSeqpoolCVMOpMaker : public framework::OpProtoAndCheckerMaker { "(float, default 0.0) The value to pad for empty sequence.") .SetDefault(0.0); AddAttr("use_cvm", "bool, use cvm or not").SetDefault(true); + AddAttr("need_filter", "(bool, default false)").SetDefault(false); + AddAttr("show_coeff", "(float, default 0.2)").SetDefault(0.2); + AddAttr("clk_coeff", "(float, default 1)").SetDefault(1); + AddAttr("threshold", "(float, default 0.96)").SetDefault(0.96); AddComment(R"DOC( Fuse multiple pairs of Sequence Pool and CVM Operator. diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu index 17bf4874f6b6b..eae1120200c43 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op.cu @@ -26,59 +26,59 @@ using platform::PADDLE_CUDA_NUM_THREADS; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ i += blockDim.x * gridDim.x) -__global__ void FillKey2Slot(const int total_len, const int64_t *slot_lens, - const int slot_num, int *key2slot) { - CUDA_KERNEL_LOOP(i, total_len) { - int low = 0; - int high = slot_num - 1; - while (low < high) { - int mid = (low + high) / 2; - if (i < slot_lens[mid]) { - high = mid; - } else { - low = mid + 1; - } - } - key2slot[i] = low; - } -} - -__global__ void FusedSeqpoolKernel(float **input_values, - float **seqpool_output_values, - size_t **lods_values, - const int64_t *data_lens, int *key2slot, - int64_t total_len, const int embedding_size, - const float pad_value) { - CUDA_KERNEL_LOOP(i, total_len * embedding_size) { - int key = i / embedding_size; - int offset = i % embedding_size; - int x = key2slot[key]; - int y = key - (x ? data_lens[x - 1] : 0); - - int start = *(lods_values[x] + y); - int end = *(lods_values[x] + y + 1); - +template +__global__ void FusedSeqpoolKernel( + T **input_values, T **seqpool_output_values, size_t **lods_values, + const int64_t *data_lens, const int batch_size, const int embedding_size, + const float pad_value, bool need_filter, float show_coeff, float clk_coeff, + float threshold) { + int bId = blockIdx.y * gridDim.x + blockIdx.x; + int x = bId / batch_size; + int y = bId - (x ? data_lens[x - 1] : 0); + int start = *(lods_values[x] + y); + int end = *(lods_values[x] + y + 1); + + for (int tid = threadIdx.x; tid < embedding_size; tid += blockDim.x) { if (start == end) { - *(seqpool_output_values[x] + y * embedding_size + offset) = pad_value; + *(seqpool_output_values[x] + y * embedding_size + tid) = pad_value; } else { - float val = 0; - for (int k = start; k < end; k++) { - val += *(input_values[x] + k * embedding_size + offset); + if (need_filter) { + T val = static_cast(0); + for (int k = start; k < end; k++) { + float show = *(input_values[x] + k * embedding_size); + float click = *(input_values[x] + k * embedding_size + 1); + if ((show - click) * show_coeff + click * clk_coeff < threshold) { + continue; + } + if (tid <= 1) { // show & click + val += *(input_values[x] + k * embedding_size + tid); + } else { + val += ((int)(*(input_values[x] + k * embedding_size + tid) * 128 + + 0.5)) / + 128.0; + } + } + *(seqpool_output_values[x] + y * embedding_size + tid) = val; + } else { + T val = static_cast(0); + for (int k = start; k < end; k++) { + val += *(input_values[x] + k * embedding_size + tid); + } + *(seqpool_output_values[x] + y * embedding_size + tid) = val; } - *(seqpool_output_values[x] + y * embedding_size + offset) = val; } } } -__global__ void FusedCVMKernel(float **output_values, - float **seqpool_output_values, - const int64_t *data_lens, int *key2slot, +template +__global__ void FusedCVMKernel(T **output_values, T **seqpool_output_values, + const int64_t *data_lens, const int batch_size, int64_t total_len, const int embedding_size, bool use_cvm) { CUDA_KERNEL_LOOP(i, total_len * embedding_size) { int key = i / embedding_size; int offset = i % embedding_size; - int x = key2slot[key]; + int x = key / batch_size; int y = key - (x ? data_lens[x - 1] : 0); int cvm_offset = 2; if (use_cvm) { @@ -103,15 +103,16 @@ __global__ void FusedCVMKernel(float **output_values, } } +template __global__ void FusedSeqpoolCVMGradKernel( - float **out_grads_values, float **out_seqpool_grads_values, - float **in_grads_values, float **cvm_values, size_t **lods_values, - const int64_t *data_lens, int *key2slot, int64_t total_len, - const int embedding_size, bool use_cvm) { + T **out_grads_values, T **out_seqpool_grads_values, T **in_grads_values, + T **cvm_values, size_t **lods_values, const int64_t *data_lens, + const int batch_size, int64_t total_len, const int embedding_size, + bool use_cvm) { CUDA_KERNEL_LOOP(i, total_len * embedding_size) { int key = i / embedding_size; int offset = i % embedding_size; - int x = key2slot[key]; + int x = key / batch_size; int y = key - (x ? data_lens[x - 1] : 0); int cvm_offset = 2; @@ -125,8 +126,8 @@ __global__ void FusedSeqpoolCVMGradKernel( *(out_grads_values[x] + y * embedding_size + offset); } else { *(out_seqpool_grads_values[x] + y * embedding_size + offset) = - *(out_grads_values[x] + y * (embedding_size - cvm_offset) + - offset - cvm_offset); + *(out_grads_values[x] + y * (embedding_size - cvm_offset) + offset - + cvm_offset); } } @@ -139,44 +140,44 @@ __global__ void FusedSeqpoolCVMGradKernel( } } +template void DoFusedSeqpoolCVM(const paddle::platform::Place &place, - float **gpu_input_values, float **gpu_output_values, - float **gpu_seqpool_output_values, size_t **lods_values, - const int64_t *data_lens, int slot_num, int *key2slot, + T **gpu_input_values, T **gpu_output_values, + T **gpu_seqpool_output_values, size_t **lods_values, + const int64_t *data_lens, int slot_num, int64_t total_len, const int embedding_size, - const float padding_value, bool use_cvm) { + const float padding_value, bool use_cvm, + bool need_filter, float show_coeff, float clk_coeff, + float threshold) { auto stream = dynamic_cast( - platform::DeviceContextPool::Instance().Get( - BOOST_GET_CONST(platform::CUDAPlace, place))) - ->stream(); - - FillKey2Slot<<<(total_len + PADDLE_CUDA_NUM_THREADS - 1) / - PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, 0, stream>>>(total_len, data_lens, - slot_num, key2slot); + platform::DeviceContextPool::Instance().Get( + BOOST_GET_CONST(platform::CUDAPlace, place))) + ->stream(); - FusedSeqpoolKernel<<<(total_len * embedding_size + PADDLE_CUDA_NUM_THREADS - - 1) / - PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + int batch_size = total_len / slot_num; + dim3 grid(batch_size, slot_num); + FusedSeqpoolKernel<<>>( gpu_input_values, gpu_seqpool_output_values, lods_values, data_lens, - key2slot, total_len, embedding_size, padding_value); + batch_size, embedding_size, padding_value, need_filter, show_coeff, + clk_coeff, threshold); FusedCVMKernel<<<(total_len * embedding_size + PADDLE_CUDA_NUM_THREADS - 1) / - PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, 0, stream>>>( - gpu_output_values, gpu_seqpool_output_values, data_lens, key2slot, - total_len, embedding_size, use_cvm); + PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + gpu_output_values, gpu_seqpool_output_values, data_lens, batch_size, + total_len, embedding_size, use_cvm); } +template void FusedSeqpoolCVM(const paddle::platform::Place &place, - const std::vector &input_data, - const std::vector &output_data, - const std::vector &seqpool_output_data, + const std::vector &input_data, + const std::vector &output_data, + const std::vector &seqpool_output_data, std::vector lods, const std::vector &data_lengths, const int embedding_size, const float padding_value, - const bool use_cvm) { + const bool use_cvm, float need_filter, float show_coeff, + float clk_coeff, float threshold) { auto data_lengths_lod = data_lengths; int slot_num = static_cast(data_lengths.size()); for (int i = 1; i < slot_num; i++) { @@ -184,16 +185,11 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place, } int64_t total_length = data_lengths_lod[slot_num - 1]; - int64_t total_bytes = total_length * embedding_size * sizeof(float); auto stream = dynamic_cast( - platform::DeviceContextPool::Instance().Get( - BOOST_GET_CONST(platform::CUDAPlace, place))) - ->stream(); - - LoDTensor key2slot_tensor; - int *key2slot = reinterpret_cast( - key2slot_tensor.mutable_data({total_length, 1}, place)); + platform::DeviceContextPool::Instance().Get( + BOOST_GET_CONST(platform::CUDAPlace, place))) + ->stream(); LoDTensor data_lens_tensor; int64_t *data_lens = reinterpret_cast( @@ -203,25 +199,25 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place, cudaMemcpyHostToDevice, stream); auto gpu_input_ptr = - memory::AllocShared(place, input_data.size() * sizeof(float *)); - float **gpu_input_values = reinterpret_cast(gpu_input_ptr->ptr()); + memory::AllocShared(place, input_data.size() * sizeof(T *)); + T **gpu_input_values = reinterpret_cast(gpu_input_ptr->ptr()); cudaMemcpyAsync(gpu_input_values, input_data.data(), - input_data.size() * sizeof(float *), cudaMemcpyHostToDevice, + input_data.size() * sizeof(T *), cudaMemcpyHostToDevice, stream); auto gpu_output_ptr = - memory::AllocShared(place, output_data.size() * sizeof(float *)); - float **gpu_output_values = reinterpret_cast(gpu_output_ptr->ptr()); + memory::AllocShared(place, output_data.size() * sizeof(T *)); + T **gpu_output_values = reinterpret_cast(gpu_output_ptr->ptr()); cudaMemcpyAsync(gpu_output_values, output_data.data(), - output_data.size() * sizeof(float *), cudaMemcpyHostToDevice, + output_data.size() * sizeof(T *), cudaMemcpyHostToDevice, stream); auto gpu_seqpool_output_ptr = - memory::AllocShared(place, seqpool_output_data.size() * sizeof(float *)); - float **gpu_seqpool_output_values = - reinterpret_cast(gpu_seqpool_output_ptr->ptr()); + memory::AllocShared(place, seqpool_output_data.size() * sizeof(T *)); + T **gpu_seqpool_output_values = + reinterpret_cast(gpu_seqpool_output_ptr->ptr()); cudaMemcpyAsync(gpu_seqpool_output_values, seqpool_output_data.data(), - seqpool_output_data.size() * sizeof(float *), + seqpool_output_data.size() * sizeof(T *), cudaMemcpyHostToDevice, stream); auto lods_ptr = memory::AllocShared(place, lods.size() * sizeof(size_t *)); @@ -231,8 +227,8 @@ void FusedSeqpoolCVM(const paddle::platform::Place &place, DoFusedSeqpoolCVM(place, gpu_input_values, gpu_output_values, gpu_seqpool_output_values, lods_values, data_lens, slot_num, - key2slot, total_length, embedding_size, padding_value, - use_cvm); + total_length, embedding_size, padding_value, use_cvm, + need_filter, show_coeff, clk_coeff, threshold); } template @@ -251,6 +247,10 @@ static void FusedSeqpoolCVMFunctor(const framework::ExecutionContext &ctx) { auto padding_value = ctx.Attr("pad_value"); auto use_cvm = ctx.Attr("use_cvm"); + bool need_filter = ctx.Attr("need_filter"); + float show_coeff = ctx.Attr("show_coeff"); + float clk_coeff = ctx.Attr("clk_coeff"); + float threshold = ctx.Attr("threshold"); int embedding_size = inputs[0]->numel() / inputs[0]->dims()[0]; @@ -262,7 +262,7 @@ static void FusedSeqpoolCVMFunctor(const framework::ExecutionContext &ctx) { auto lod_level = lod.size(); int batch_size = lod[lod_level - 1].size() - 1; // -1 to real batch size - input_data[i] = reinterpret_cast(input->data()); + input_data[i] = reinterpret_cast(input->data()); auto *output = outputs[i]; if (use_cvm) { output->Resize({batch_size, embedding_size}); @@ -274,47 +274,43 @@ static void FusedSeqpoolCVMFunctor(const framework::ExecutionContext &ctx) { data_lens[i] = lod[lod_level - 1].size() - 1; lods_data[i] = lod[lod_level - 1].CUDAData(ctx.GetPlace()); - LoDTensor seqpool_output_tensor; - seqpool_outputs.push_back(seqpool_output_tensor); seqpool_output_data[i] = reinterpret_cast(seqpool_outputs[i].mutable_data( {batch_size, embedding_size}, ctx.GetPlace())); } FusedSeqpoolCVM(ctx.GetPlace(), input_data, output_data, seqpool_output_data, - lods_data, data_lens, embedding_size, padding_value, use_cvm); + lods_data, data_lens, embedding_size, padding_value, use_cvm, + need_filter, show_coeff, clk_coeff, threshold); } +template void DoFusedSeqpoolCVMGrad(const paddle::platform::Place &place, - float **out_grads_values, - float **out_seqpool_grads_values, - float **in_grads_values, float **gpu_cvm_values, + T **out_grads_values, T **out_seqpool_grads_values, + T **in_grads_values, T **gpu_cvm_values, size_t **lods_values, const int64_t *slot_lens, - int slot_num, int *key2slot, int64_t total_len, + int slot_num, int64_t total_len, const int embedding_size, bool use_cvm) { auto stream = dynamic_cast( - platform::DeviceContextPool::Instance().Get( - BOOST_GET_CONST(platform::CUDAPlace, place))) - ->stream(); - FillKey2Slot<<<(total_len + PADDLE_CUDA_NUM_THREADS - 1) / - PADDLE_CUDA_NUM_THREADS, - PADDLE_CUDA_NUM_THREADS, 0, stream>>>(total_len, slot_lens, - slot_num, key2slot); - + platform::DeviceContextPool::Instance().Get( + BOOST_GET_CONST(platform::CUDAPlace, place))) + ->stream(); + const int batch_size = total_len / slot_num; FusedSeqpoolCVMGradKernel<<<(total_len * embedding_size + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( out_grads_values, out_seqpool_grads_values, in_grads_values, - gpu_cvm_values, lods_values, slot_lens, key2slot, total_len, + gpu_cvm_values, lods_values, slot_lens, batch_size, total_len, embedding_size, use_cvm); } +template void FusedSeqpoolCVMGrad(const paddle::platform::Place &place, - const std::vector &out_grads_data, - const std::vector &out_seqpool_grads_data, - const std::vector &in_grads_data, - const std::vector &cvm_data, + const std::vector &out_grads_data, + const std::vector &out_seqpool_grads_data, + const std::vector &in_grads_data, + const std::vector &cvm_data, std::vector &lods, const std::vector &data_lengths, const int embedding_size, const bool use_cvm) { @@ -325,16 +321,11 @@ void FusedSeqpoolCVMGrad(const paddle::platform::Place &place, } int64_t total_length = data_lengths_lod[slot_num - 1]; - int64_t total_bytes = total_length * embedding_size * sizeof(float); auto stream = dynamic_cast( - platform::DeviceContextPool::Instance().Get( - BOOST_GET_CONST(platform::CUDAPlace, place))) - ->stream(); - - LoDTensor keys2slot_tensor; - int *keys2slot = reinterpret_cast( - keys2slot_tensor.mutable_data({total_length, 1}, place)); + platform::DeviceContextPool::Instance().Get( + BOOST_GET_CONST(platform::CUDAPlace, place))) + ->stream(); LoDTensor data_lens_tensor; int64_t *data_lens = reinterpret_cast( @@ -344,34 +335,31 @@ void FusedSeqpoolCVMGrad(const paddle::platform::Place &place, cudaMemcpyHostToDevice, stream); auto gpu_out_grads_ptr = - memory::AllocShared(place, out_grads_data.size() * sizeof(float *)); - float **gpu_out_grads_values = - reinterpret_cast(gpu_out_grads_ptr->ptr()); + memory::AllocShared(place, out_grads_data.size() * sizeof(T *)); + T **gpu_out_grads_values = reinterpret_cast(gpu_out_grads_ptr->ptr()); cudaMemcpyAsync(gpu_out_grads_values, out_grads_data.data(), - out_grads_data.size() * sizeof(float *), - cudaMemcpyHostToDevice, stream); + out_grads_data.size() * sizeof(T *), cudaMemcpyHostToDevice, + stream); - auto gpu_out_seqpool_grads_ptr = memory::AllocShared( - place, out_seqpool_grads_data.size() * sizeof(float *)); - float **gpu_out_seqpool_grads_values = - reinterpret_cast(gpu_out_seqpool_grads_ptr->ptr()); + auto gpu_out_seqpool_grads_ptr = + memory::AllocShared(place, out_seqpool_grads_data.size() * sizeof(T *)); + T **gpu_out_seqpool_grads_values = + reinterpret_cast(gpu_out_seqpool_grads_ptr->ptr()); cudaMemcpyAsync(gpu_out_seqpool_grads_values, out_seqpool_grads_data.data(), - out_seqpool_grads_data.size() * sizeof(float *), + out_seqpool_grads_data.size() * sizeof(T *), cudaMemcpyHostToDevice, stream); auto gpu_in_grads_ptr = - memory::AllocShared(place, in_grads_data.size() * sizeof(float *)); - float **gpu_in_grads_values = - reinterpret_cast(gpu_in_grads_ptr->ptr()); + memory::AllocShared(place, in_grads_data.size() * sizeof(T *)); + T **gpu_in_grads_values = reinterpret_cast(gpu_in_grads_ptr->ptr()); cudaMemcpyAsync(gpu_in_grads_values, in_grads_data.data(), - in_grads_data.size() * sizeof(float *), - cudaMemcpyHostToDevice, stream); + in_grads_data.size() * sizeof(T *), cudaMemcpyHostToDevice, + stream); - auto gpu_cvm_ptr = - memory::AllocShared(place, cvm_data.size() * sizeof(float *)); - float **gpu_cvm_values = reinterpret_cast(gpu_cvm_ptr->ptr()); + auto gpu_cvm_ptr = memory::AllocShared(place, cvm_data.size() * sizeof(T *)); + T **gpu_cvm_values = reinterpret_cast(gpu_cvm_ptr->ptr()); cudaMemcpyAsync(gpu_cvm_values, cvm_data.data(), - cvm_data.size() * sizeof(float *), cudaMemcpyHostToDevice, + cvm_data.size() * sizeof(T *), cudaMemcpyHostToDevice, stream); auto lods_ptr = memory::AllocShared(place, lods.size() * sizeof(size_t *)); @@ -382,7 +370,7 @@ void FusedSeqpoolCVMGrad(const paddle::platform::Place &place, DoFusedSeqpoolCVMGrad(place, gpu_out_grads_values, gpu_out_seqpool_grads_values, gpu_in_grads_values, gpu_cvm_values, lods_values, data_lens, slot_num, - keys2slot, total_length, embedding_size, use_cvm); + total_length, embedding_size, use_cvm); } template @@ -395,14 +383,14 @@ static void FusedSeqpoolCVMGradFunctor(const framework::ExecutionContext &ctx) { auto use_cvm = ctx.Attr("use_cvm"); const auto slot_size = in_grads.size(); - std::vector out_grads_data(slot_size); - std::vector in_grads_data(slot_size); - std::vector cvm_data(slot_size); + std::vector out_grads_data(slot_size); + std::vector in_grads_data(slot_size); + std::vector cvm_data(slot_size); std::vector lods_data(slot_size); std::vector data_lengths(slot_size); std::vector out_seqpool_grads(slot_size); - std::vector out_seqpool_grads_data(slot_size); + std::vector out_seqpool_grads_data(slot_size); int embedding_size = in_grads[0]->numel() / in_grads[0]->dims()[0]; @@ -415,18 +403,16 @@ static void FusedSeqpoolCVMGradFunctor(const framework::ExecutionContext &ctx) { int batch_size = lod[lod_level - 1].size() - 1; // -1 to real batch size auto *out_grad = out_grads[i]; - out_grads_data[i] = reinterpret_cast(out_grad->data()); + out_grads_data[i] = reinterpret_cast(out_grad->data()); in_grads_data[i] = - reinterpret_cast(in_grad->mutable_data(ctx.GetPlace())); + reinterpret_cast(in_grad->mutable_data(ctx.GetPlace())); lods_data[i] = lod[lod_level - 1].CUDAData(ctx.GetPlace()); data_lengths[i] = lod[lod_level - 1].size() - 1; - cvm_data[i] = reinterpret_cast(cvm->data()); + cvm_data[i] = reinterpret_cast(cvm->data()); - LoDTensor out_seqpool_grad_tensor; - out_seqpool_grads.push_back(out_seqpool_grad_tensor); out_seqpool_grads_data[i] = - reinterpret_cast(out_seqpool_grads[i].mutable_data( + reinterpret_cast(out_seqpool_grads[i].mutable_data( {batch_size, embedding_size}, ctx.GetPlace())); } @@ -456,7 +442,7 @@ class FusedSeqpoolCVMGradCUDAKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(fused_seqpool_cvm, - ops::FusedSeqpoolCVMCUDAKernel); + ops::FusedSeqpoolCVMCUDAKernel); REGISTER_OP_CUDA_KERNEL(fused_seqpool_cvm_grad, - ops::FusedSeqpoolCVMGradCUDAKernel); + ops::FusedSeqpoolCVMGradCUDAKernel); diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index c7da95b45c478..caabf77c0c4da 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -1411,7 +1411,8 @@ def _pull_box_extended_sparse(input, size, extend_size=64, dtype='float32'): return outs, outs_extend -def fused_seqpool_cvm(input, pool_type, cvm, pad_value=0.0, use_cvm=True): +def fused_seqpool_cvm(input, pool_type, cvm, pad_value=0.0, use_cvm=True, + need_filter=False, show_coeff=0.2, clk_coeff=1.0, threshold=0.96): """ **Notes: The Op only receives List of LoDTensor as input, only support SUM pooling now. :attr:`input`. @@ -1453,7 +1454,12 @@ def fused_seqpool_cvm(input, pool_type, cvm, pad_value=0.0, use_cvm=True): attrs={ "pooltype": pool_type.upper(), "pad_value": pad_value, - "use_cvm": use_cvm + "use_cvm": use_cvm, + "need_filter": need_filter, + "show_coeff": show_coeff, + "clk_coeff": clk_coeff, + "threshold": threshold }) return outs +