Skip to content

Commit

Permalink
Add FFT related operators and APIs (PaddlePaddle#35665)
Browse files Browse the repository at this point in the history
* 1. add interface for fft;
2. add data type predicate;
3. fix paddle.roll.

* add fft c2c cufft kernel

* implement argument checking & op calling parts for fft_c2c and fftn_c2c

* add operator and opmaker definitions

* only register float and double for cpu.

* add common code for implementing FFT, add pocketfft as a dependency

* add fft c2c cufft kernel function

* fix bugs in python interface

* add support for c2r, r2c operators, op makers, kernels and kernel functors.

* test and fix bugs

* 1. fft_c2c function: add support for onesided=False;
2. add complex<float>, complex<double> support for concat and flip.

* 1. fft: fix python api bugs;
2. shape_op: add support for complex data types.

* fft c2c cufft kernel done with complie and link

* fix shape_op, add mkl placeholder

* remove mkl

* complete fft c2c in gpu

* 1. implement mkl-based fft, FFTC2CFunctor and common function exec_fft;
2. change the design, add input and output typename as template parameter for all FFTFunctors, update pocketfft-based implementation.

* complete fft c2c on gpu in ND

* complete fft c2c on gpu in ND

* complete fft c2c backward in ND

* fix MKL-based implementation

* Add frame op and CPU/GPU kernels.

* Add frame op forward unittest.

* Add frame op forward unittest.

* Remove axis parameter in FrameFunctor.

* Add frame op grad CPU/GPU kernels and unittest.

* Add frame op grad CPU/GPU kernels and unittest.

* Update doc string.

* Update after review and remove librosa requirement in unittest.

* Update grad kernel.

* add fft_c2r op

* Remove data allocation in TransCompute function.

* add fft r2c onesided with cpu(pocketfft/mkl) and gpu

* last fft c2r functor

* fix C2R and R2C for cufft, becase the direction is not an option in these cases.

* add fft r2c onesided with cpu(pocketfft/mkl) and gpu

* fix bugs in python APIs

* fix fft_c2r grad kernal

* fix bugs in python APIs

* add cuda fft c2r grad kernal functor

* clean code

* fix fft_c2r python API

* fill fft r2c result with conjugate symmetry (#19)

fill fft r2c result with conjugate symmetry

* add placeholder for unittests (#24)

* simple parameterize test function by auto generate test case from parm list (#25)

* miscellaneous fixes for python APIs (#26)

* add placeholder for unittests

* resize fft inputs before computation is n or s is provided.

* add complex kernels for pad and pad_grad

* simplify argument checking.

* add type promotion

* add int to float or complex promotion

* fix output data type for static mode

* fix fft's input dtype dispatch, import fft to paddle

* fix typos in axes checking (#27)

* fix typos in axes checking

* fix argument checking (#28)

* fix argument checking

* Add C2R Python layer normal and abnormal use cases (#29)

* documents and single case

* test c2r case

* New C2R Python layer normal and exception use cases

* complete rfft,rfft2,rfftn,ihfft,ihfft2,ihfftn unittest and doc string (PaddlePaddle#30)

* Documentation of the common interfaces of c2r and c2c (PaddlePaddle#31)

* Documentation of the common interfaces of c2r and c2c

* clean c++ code  (PaddlePaddle#32)

* clean code

* Add numpy-based implementation of spectral ops (PaddlePaddle#33)

* add numpy reference implementation of spectral ops

* Add fft_c2r numpy based implementation for unittest. (PaddlePaddle#34)

* add fft_c2r numpy implementation

* Add deframe op and stft/istft api. (#23)

* Add frame api

* Add deframe op and kernels.

* Add stft and istft apis.

* Add deframe api. Update stft and istft apis.

* Fix bug in frame_from_librosa function when input dims >= 3

* Rename deframe to overlap_add.

* Update istft.

* Update after code review.

* Add overlap_add op and stft/istft api unittest (PaddlePaddle#35)

* Add overlap_add op unittest.

* Register complex kernels of squeeze/unsquuze op.

* Add stft/istft api unittest.

* Add unittest for fft helper functions (PaddlePaddle#36)

* add unittests for fft helper functions. add complex kernel for roll op.

* complete static graph unittest for all public api (PaddlePaddle#37)

* Unittest of op with FFT C2C, C2R and r2c added (PaddlePaddle#38)

* documents and single case

* test c2r case

* New C2R Python layer normal and exception use cases

* Documentation of the common interfaces of c2r and c2c

* Unittest of op with FFT C2C, C2R and r2c added

Co-authored-by: lijiaqi <lijiaqi0612@163.com>

* add fft related options to CMakeLists.txt

* fix typos and clean code (PaddlePaddle#39)

* fix invisible character in mkl branch and fix error in error message

* clean code: remove docstring from unittest for signal.py.

* always convert numpy array to paddle.Tensor to avoid comparing numpy dtype with paddle dtype. (PaddlePaddle#40)

* always convert numpy array to paddle.Tensor to avoid comparing numpy dtype with paddle dtype.

* fix CI Errors: numpy dtype comparison, thrust when cuda is not available (PaddlePaddle#41)

1. always convert numpy array to paddle.Tensor to avoid comparing numpy dtype with paddle dtype.
2. promote floating point tensor to complex tensor ior fft_c2c and fft_c2r;
3. fix unittest to catch UnImplementedError and RuntimeError;
4. fix compile error by avoid using thrust when cuda is not available.
5.  fix sample code, use paddle.fft instead of paddle.tensor.fft

* remove inclusion of thrust, add __all__ list for fft (PaddlePaddle#42)

* Add api doc and update unittest. (PaddlePaddle#43)

* Add doc strings.
* Update overlap_add op unittest

* fix MKL-based FFT implementation (PaddlePaddle#44)

* fix MKL-based FFT implementation, MKL CDFT's FORWARD DOMAIN is always REAL for R2C and C2R

* remove code for debug (PaddlePaddle#45)

* use dynload for cufft (PaddlePaddle#46)

* use std::ptrdiff_t as datatype of stride (instead of int64_t) to avoid argument mismatch on some platforms.

* add complex support for fill_zeros_like

* use dynload for cufft

* Update doc and unittest. (PaddlePaddle#47)

* Add doc of frame op and overlap_add op.

* Update unittest.

* use dynload for cufft (PaddlePaddle#48)

1. use dynload for cufft
2. fix unittest;
3. temporarily disable Rocm.

* fix conflicts and merge upstream (PaddlePaddle#49)

fix conflicts and merge upstream

* fix compile error: only link dyload_cuda when cuda is available (PaddlePaddle#50)

* fix compile error: only link dyload_cuda when cuda is available

* fix dynload for cufft on windows (PaddlePaddle#51)

1. fix dynload for cufft on windows;
2. fix unittests.

* add NOMINMAX to compile on windows (PaddlePaddle#52)

 add NOMINMAX to compile on windows

* explicitly specify capture mode for lambdas (PaddlePaddle#55)

 explicitly specify capture mode for lambdas

* fix fft sample (PaddlePaddle#53)

* fix fft sample

* update scipy and numpy version for unittests of fft (PaddlePaddle#56)

update scipy and numpy version for unittests of fft

* Add static graph unittests of frame and overlap_add api. (PaddlePaddle#57)

* Remove cache of cuFFT & Disable ONEMKL (PaddlePaddle#59)

1. replace numpy.fft with scipy.fft as numpy<1.20 not support ortho norm
2. remove cache of cufft plans;
3. enhance error checking.
4. default WITH_ONEMKL to OFF

Co-authored-by: jeff41404 <jeff41404@gmail.com>
Co-authored-by: root <root@bjyz-sys-gpu-kongming9.bjyz.baidu.com>
Co-authored-by: KP <109694228@qq.com>
Co-authored-by: lijiaqi <lijiaqi0612@163.com>
Co-authored-by: Xiaoxu Chen <chenxx_id@163.com>
Co-authored-by: lijiaqi0612 <33169170+lijiaqi0612@users.noreply.github.com>
  • Loading branch information
7 people authored and niuliling123 committed Sep 29, 2021
1 parent fb9ebd0 commit 3a65093
Show file tree
Hide file tree
Showing 57 changed files with 9,413 additions and 44 deletions.
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ project(paddle CXX C)
# enable language CUDA
# TODO(Shibo Tao): remove find_package(CUDA) completely.
find_package(CUDA QUIET)
find_package(MKL CONFIG QUIET)
option(WITH_ONEMKL "Compile PaddlePaddle with oneMKL" OFF)
option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND})
option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF)
option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF)
Expand Down Expand Up @@ -225,6 +227,7 @@ option(WITH_STRIP "Strip so files of Whl packages" OFF)
option(NEW_RELEASE_CUBIN "PaddlePaddle next-level release strategy for pypi cubin package" OFF)
option(NEW_RELEASE_JIT "PaddlePaddle next-level release strategy for backup jit package" OFF)
option(WITH_ASCEND_INT64 "Compile with int64 kernel for ascend NPU" OFF)
option(WITH_POCKETFFT "Compile with pocketfft support" ON)

# PY_VERSION
if(NOT PY_VERSION)
Expand Down Expand Up @@ -373,6 +376,10 @@ if (WITH_MIPS)
add_definitions(-DPADDLE_WITH_MIPS)
endif()

if (WITH_ONEMKL)
add_definitions(-DPADDLE_WITH_ONEMKL)
endif()

if (WITH_HETERPS)
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new")
Expand Down
2 changes: 1 addition & 1 deletion cmake/FindGperftools.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
find_library(GPERFTOOLS_TCMALLOC
NAMES tcmalloc
HINTS ${Gperftools_ROOT_DIR}/lib)

find_library(GPERFTOOLS_PROFILER
NAMES profiler
HINTS ${Gperftools_ROOT_DIR}/lib)
Expand Down
44 changes: 44 additions & 0 deletions cmake/external/pocketfft.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

include(ExternalProject)


set(POCKETFFT_PATH "${THIRD_PARTY_PATH}/pocketfft" CACHE STRING "A path setting for external_pocketfft path.")
set(POCKETFFT_PREFIX_DIR ${POCKETFFT_PATH})

set(POCKETFFT_REPOSITORY https://gitlab.mpcdf.mpg.de/mtr/pocketfft.git)
set(POCKETFFT_TAG release_for_eigen)

SET(POCKETFFT_INCLUDE_DIR ${POCKETFFT_PREFIX_DIR}/src)
message("POCKETFFT_INCLUDE_DIR is ${POCKETFFT_INCLUDE_DIR}")
include_directories(${POCKETFFT_INCLUDE_DIR})

ExternalProject_Add(
extern_pocketfft
${EXTERNAL_PROJECT_LOG_ARGS}
${SHALLOW_CLONE}
GIT_REPOSITORY ${POCKETFFT_REPOSITORY}
GIT_TAG ${POCKETFFT_TAG}
PREFIX ${POCKETFFT_PREFIX_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)

add_library(pocketfft INTERFACE)

add_dependencies(pocketfft extern_pocketfft)
6 changes: 6 additions & 0 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -361,4 +361,10 @@ if (WITH_CRYPTO)
add_definitions(-DPADDLE_WITH_CRYPTO)
endif (WITH_CRYPTO)

if (WITH_POCKETFFT)
include(external/pocketfft)
list(APPEND third_party_deps extern_pocketfft)
add_definitions(-DPADDLE_WITH_POCKETFFT)
endif (WITH_POCKETFFT)

add_custom_target(third_party ALL DEPENDS ${third_party_deps})
18 changes: 17 additions & 1 deletion paddle/fluid/framework/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <iostream>
#include <string>
#include <typeindex>

Expand Down Expand Up @@ -170,11 +171,26 @@ extern inline proto::VarType::Type ToComplexType(proto::VarType::Type t) {
return proto::VarType::COMPLEX128;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unknown complex value data type (%s), now only support float32 and "
"Unknown real value data type (%s), now only support float32 and "
"float64.",
DataTypeToString(t)));
}
}

extern inline proto::VarType::Type ToRealType(proto::VarType::Type t) {
switch (t) {
case proto::VarType::COMPLEX64:
return proto::VarType::FP32;
case proto::VarType::COMPLEX128:
return proto::VarType::FP64;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unknown complex value data type (%s), now only support complex64 "
"and "
"complex128.",
DataTypeToString(t)));
}
}

} // namespace framework
} // namespace paddle
12 changes: 11 additions & 1 deletion paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ if (WITH_GPU)
endif()
endif()

if (WITH_POCKETFFT)
SET(OP_HEADER_DEPS ${OP_HEADER_DEPS} pocketfft)
endif()


SET(OP_MKL_DEPS "")
if (NOT WITH_MKL OR NOT WITH_AVX)
Expand All @@ -75,7 +79,7 @@ if(WITH_UNITY_BUILD)
endif()

register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op lstm_op run_program_op eye_op recurrent_op
sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})

op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS})

Expand All @@ -94,6 +98,12 @@ else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()

if (WITH_GPU AND (NOT WITH_ROCM))
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS})
else()
op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS})
endif()

op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute)
op_library(eye_op DEPS ${OP_HEADER_DEPS})
op_library(recurrent_op DEPS ${OP_HEADER_DEPS})
Expand Down
13 changes: 11 additions & 2 deletions paddle/fluid/operators/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/fluid/operators/concat_op.h"

#include <paddle/fluid/platform/complex.h>
#include <memory>
#include <string>
#include <vector>
Expand Down Expand Up @@ -237,7 +238,11 @@ REGISTER_OP_CPU_KERNEL(
ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, uint8_t>);
ops::ConcatKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, double>,
Expand All @@ -247,4 +252,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, uint8_t>);
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
13 changes: 11 additions & 2 deletions paddle/fluid/operators/concat_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/concat_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"

namespace ops = paddle::operators;
Expand All @@ -24,7 +25,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::ConcatKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext, uint8_t>);
ops::ConcatKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext,
plat::complex<float>>,
ops::ConcatKernel<paddle::platform::CUDADeviceContext,
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
concat_grad,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, double>,
Expand All @@ -33,4 +38,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, uint8_t>);
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext,
plat::complex<float>>,
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext,
plat::complex<double>>);
3 changes: 3 additions & 0 deletions paddle/fluid/operators/eigen/scale.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -42,6 +43,8 @@ template struct EigenScale<Eigen::DefaultDevice, int8_t>;
template struct EigenScale<Eigen::DefaultDevice, int16_t>;
template struct EigenScale<Eigen::DefaultDevice, int>;
template struct EigenScale<Eigen::DefaultDevice, int64_t>;
template struct EigenScale<Eigen::DefaultDevice, platform::complex<float>>;
template struct EigenScale<Eigen::DefaultDevice, platform::complex<double>>;

} // namespace operators
} // namespace paddle
3 changes: 3 additions & 0 deletions paddle/fluid/operators/eigen/scale.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
Expand Down Expand Up @@ -41,6 +42,8 @@ template struct EigenScale<Eigen::GpuDevice, int16_t>;
template struct EigenScale<Eigen::GpuDevice, int>;
template struct EigenScale<Eigen::GpuDevice, int64_t>;
template struct EigenScale<Eigen::GpuDevice, platform::float16>;
template struct EigenScale<Eigen::GpuDevice, platform::complex<float>>;
template struct EigenScale<Eigen::GpuDevice, platform::complex<double>>;

} // namespace operators
} // namespace paddle
13 changes: 11 additions & 2 deletions paddle/fluid/operators/fill_zeros_like_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fill_zeros_like_op.h"
#include "paddle/fluid/platform/complex.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -93,12 +94,20 @@ REGISTER_OP_CPU_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>);
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);

REGISTER_OP_CPU_KERNEL(
fill_zeros_like2,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>);
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::FillZerosLikeKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
13 changes: 11 additions & 2 deletions paddle/fluid/operators/fill_zeros_like_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/fluid/operators/fill_zeros_like_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"

namespace ops = paddle::operators;
Expand All @@ -25,7 +26,11 @@ REGISTER_OP_CUDA_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>);
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);

REGISTER_OP_CUDA_KERNEL(
fill_zeros_like2,
Expand All @@ -35,4 +40,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>);
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
6 changes: 5 additions & 1 deletion paddle/fluid/operators/flip_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/complex.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -145,6 +146,7 @@ class FlipOpGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(flip, ops::FlipOp, ops::FlipOpMaker, ops::FlipOpInferVarType,
ops::FlipOpGradMaker<paddle::framework::OpDesc>,
ops::FlipOpGradMaker<paddle::imperative::OpBase>);
Expand All @@ -153,7 +155,9 @@ REGISTER_OP_CPU_KERNEL(
ops::FlipKernel<paddle::platform::CPUDeviceContext, double>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, bool>);
ops::FlipKernel<paddle::platform::CPUDeviceContext, bool>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, plat::complex<float>>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, plat::complex<double>>);

/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(flip)
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/operators/flip_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#include <vector>
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/complex.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -163,4 +164,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::FlipKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, int>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, bool>);
ops::FlipKernel<paddle::platform::CUDADeviceContext, bool>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, plat::complex<float>>,
ops::FlipKernel<paddle::platform::CUDADeviceContext,
plat::complex<double>>);

0 comments on commit 3a65093

Please sign in to comment.