From 6dbd3d401a18ab15765ff3a1e62352d368eb1e24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Tue, 27 Sep 2022 08:28:52 +0100 Subject: [PATCH 01/21] added dft interface and implementation stubs for mklgpu backend --- CMakeLists.txt | 3 + examples/dft/CMakeLists.txt | 18 ++ include/oneapi/mkl.hpp | 1 + include/oneapi/mkl/detail/backends_table.hpp | 10 +- include/oneapi/mkl/dft.hpp | 27 +++ include/oneapi/mkl/dft/backward.hpp | 91 +++++++++ include/oneapi/mkl/dft/descriptor.hpp | 56 ++++++ .../dft/detail/mklgpu/onemkl_dft_mklgpu.hpp | 127 +++++++++++++ include/oneapi/mkl/dft/forward.hpp | 92 +++++++++ include/oneapi/mkl/types.hpp | 72 ++++++- scripts/func_parser.py | 9 +- src/dft/CMakeLists.txt | 46 +++++ src/dft/backends/CMakeLists.txt | 22 +++ src/dft/backends/mklgpu/CMakeLists.txt | 70 +++++++ src/dft/backends/mklgpu/backward.cpp | 178 ++++++++++++++++++ src/dft/backends/mklgpu/descriptor.cpp | 51 +++++ src/dft/backends/mklgpu/forward.cpp | 178 ++++++++++++++++++ .../backends/mklgpu/mkl_dft_gpu_wrappers.cpp | 52 +++++ src/dft/dft_loader.cpp | 174 +++++++++++++++++ src/dft/function_table.hpp | 83 ++++++++ tests/unit_tests/CMakeLists.txt | 8 +- tests/unit_tests/dft/CMakeLists.txt | 20 ++ tests/unit_tests/dft/source/CMakeLists.txt | 49 +++++ tests/unit_tests/dft/source/tmp.cpp | 0 24 files changed, 1431 insertions(+), 6 deletions(-) create mode 100644 examples/dft/CMakeLists.txt create mode 100644 include/oneapi/mkl/dft.hpp create mode 100644 include/oneapi/mkl/dft/backward.hpp create mode 100644 include/oneapi/mkl/dft/descriptor.hpp create mode 100644 include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp create mode 100644 include/oneapi/mkl/dft/forward.hpp create mode 100644 src/dft/CMakeLists.txt create mode 100644 src/dft/backends/CMakeLists.txt create mode 100644 src/dft/backends/mklgpu/CMakeLists.txt create mode 100644 src/dft/backends/mklgpu/backward.cpp create mode 100644 src/dft/backends/mklgpu/descriptor.cpp create mode 100644 src/dft/backends/mklgpu/forward.cpp create mode 100644 src/dft/backends/mklgpu/mkl_dft_gpu_wrappers.cpp create mode 100644 src/dft/dft_loader.cpp create mode 100644 src/dft/function_table.hpp create mode 100644 tests/unit_tests/dft/CMakeLists.txt create mode 100644 tests/unit_tests/dft/source/CMakeLists.txt create mode 100644 tests/unit_tests/dft/source/tmp.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 0d3a1cb7b..4320b79e8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,6 +83,9 @@ if(ENABLE_MKLCPU_BACKEND OR ENABLE_CURAND_BACKEND) list(APPEND DOMAINS_LIST "rng") endif() +if(ENABLE_MKLGPU_BACKEND) + list(APPEND DOMAINS_LIST "dft") +endif() # Define required CXX compilers before project if(CMAKE_CXX_COMPILER OR NOT ONEMKL_SYCL_IMPLEMENTATION STREQUAL "dpc++") diff --git a/examples/dft/CMakeLists.txt b/examples/dft/CMakeLists.txt new file mode 100644 index 000000000..692461b1d --- /dev/null +++ b/examples/dft/CMakeLists.txt @@ -0,0 +1,18 @@ +#=============================================================================== +# Copyright 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== diff --git a/include/oneapi/mkl.hpp b/include/oneapi/mkl.hpp index eac491793..a49c1ceda 100644 --- a/include/oneapi/mkl.hpp +++ b/include/oneapi/mkl.hpp @@ -23,6 +23,7 @@ #include "oneapi/mkl/types.hpp" #include "oneapi/mkl/blas.hpp" +#include "oneapi/mkl/dft.hpp" #include "oneapi/mkl/lapack.hpp" #include "oneapi/mkl/rng.hpp" diff --git a/include/oneapi/mkl/detail/backends_table.hpp b/include/oneapi/mkl/detail/backends_table.hpp index a7c37efd2..c51f2589b 100644 --- a/include/oneapi/mkl/detail/backends_table.hpp +++ b/include/oneapi/mkl/detail/backends_table.hpp @@ -41,7 +41,7 @@ namespace oneapi { namespace mkl { enum class device : uint16_t { x86cpu, intelgpu, nvidiagpu, amdgpu }; -enum class domain : uint16_t { blas, lapack, rng }; +enum class domain : uint16_t { blas, dft, lapack, rng }; static std::map>> libraries = { { domain::blas, @@ -73,6 +73,14 @@ static std::map>> libraries = #endif } } } }, + { domain::dft, + { { device::intelgpu, + { +#ifdef ENABLE_MKLGPU_BACKEND + LIB_NAME("dft_mklgpu") +#endif + } } } }, + { domain::lapack, { { device::x86cpu, { diff --git a/include/oneapi/mkl/dft.hpp b/include/oneapi/mkl/dft.hpp new file mode 100644 index 000000000..6856cb198 --- /dev/null +++ b/include/oneapi/mkl/dft.hpp @@ -0,0 +1,27 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_HPP_ +#define _ONEMKL_DFT_HPP_ + +#include "oneapi/mkl/dft/descriptor.hpp" +#include "oneapi/mkl/dft/forward.hpp" +#include "oneapi/mkl/dft/backward.hpp" + +#endif // _ONEMKL_DFT_HPP_ \ No newline at end of file diff --git a/include/oneapi/mkl/dft/backward.hpp b/include/oneapi/mkl/dft/backward.hpp new file mode 100644 index 000000000..2fb6911f6 --- /dev/null +++ b/include/oneapi/mkl/dft/backward.hpp @@ -0,0 +1,91 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_BACKWARD_HPP_ +#define _ONEMKL_DFT_BACKWARD_HPP_ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/dft/descriptor.hpp" + +namespace oneapi::mkl::dft { + //Buffer version + + //In-place transform + template + void compute_backward( descriptor_type &desc, + sycl::buffer &inout ); + + //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format + template + void compute_backward( descriptor_type &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im); + + //Out-of-place transform + template + void compute_backward( descriptor_type &desc, + sycl::buffer &in, + sycl::buffer &out); + + //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format + template + void compute_backward( descriptor_type &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im); + + //USM version + + //In-place transform + template + sycl::event compute_backward( descriptor_type &desc, + data_type *inout, + const std::vector &dependencies = {}); + + //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format + template + sycl::event compute_backward(descriptor_type &desc, + data_type *inout_re, + data_type *inout_im, + const std::vector &dependencies = {}); + + //Out-of-place transform + template + sycl::event compute_backward( descriptor_type &desc, + input_type *in, + output_type *out, + const std::vector &dependencies = {}); + + //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format + template + sycl::event compute_backward( descriptor_type &desc, + input_type *in_re, + input_type *in_im, + output_type *out_re, + output_type *out_im, + const std::vector &dependencies = {}); +} + +#endif // _ONEMKL_DFT_BACKWARD_HPP_ \ No newline at end of file diff --git a/include/oneapi/mkl/dft/descriptor.hpp b/include/oneapi/mkl/dft/descriptor.hpp new file mode 100644 index 000000000..0f3330f1e --- /dev/null +++ b/include/oneapi/mkl/dft/descriptor.hpp @@ -0,0 +1,56 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_DESCRIPTOR_HPP_ +#define _ONEMKL_DFT_DESCRIPTOR_HPP_ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/types.hpp" + +namespace oneapi::mkl::dft { + + template + class descriptor { + private: + sycl::queue queue_; + public: + // Syntax for 1-dimensional DFT + descriptor(std::int64_t length); + // Syntax for d-dimensional DFT + descriptor(std::vector dimensions); + + ~descriptor(); + + + void set_value(config_param param, ...); + + void get_value(config_param param, ...); + + void commit(sycl::queue &queue); + + sycl::queue& get_queue(){return queue_; }; + }; +} + +#endif // _ONEMKL_DFT_DESCRIPTOR_HPP_ \ No newline at end of file diff --git a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp new file mode 100644 index 000000000..f9fb9592f --- /dev/null +++ b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp @@ -0,0 +1,127 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#pragma once + +#if __has_include() +#include +#else +#include +#endif + +#include +#include + +#include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/dft/descriptor.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { +namespace mklgpu { + +#define ONEAPI_MKL_DFT_BACKEND_SIGNATURES(EXT, PRECISION, DOMAIN, T_REAL, T_FORWARD, T_BACKWARD) \ + \ +void commit_ ## EXT(descriptor &desc, sycl::queue& queue); \ + \ + /*Buffer version*/ \ + \ + /*In-place transform*/ \ +void compute_forward_buffer_inplace_ ## EXT(descriptor &desc, sycl::buffer &inout); \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ +void compute_forward_buffer_inplace_split_ ## EXT(descriptor &desc, sycl::buffer &inout_re, \ + sycl::buffer &inout_im); \ + \ + /*Out-of-place transform*/ \ +void compute_forward_buffer_outofplace_ ## EXT(descriptor &desc, sycl::buffer &in, \ + sycl::buffer &out); \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ +void compute_forward_buffer_outofplace_split_ ## EXT(descriptor &desc, sycl::buffer &in_re, \ + sycl::buffer &in_im, sycl::buffer &out_re, \ + sycl::buffer &out_im); \ + \ + /*USM version*/ \ + \ + /*In-place transform*/ \ +sycl::event compute_forward_usm_inplace_ ## EXT(descriptor &desc, T_BACKWARD *inout, \ + const std::vector &dependencies = {}); \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ +sycl::event compute_forward_usm_inplace_split_ ## EXT(descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ + const std::vector &dependencies = {}); \ + \ + /*Out-of-place transform*/ \ +sycl::event compute_forward_usm_outofplace_ ## EXT(descriptor &desc, T_FORWARD *in, T_BACKWARD *out, \ + const std::vector &dependencies = {}); \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ +sycl::event compute_forward_usm_outofplace_split_ ## EXT(descriptor &desc, T_REAL *in_re, T_REAL *in_im, \ + T_REAL *out_re, T_REAL *out_im, \ + const std::vector &dependencies = {}); \ + \ + /*Buffer version*/ \ + \ + /*In-place transform*/ \ +void compute_backward_buffer_inplace_ ## EXT(descriptor &desc, sycl::buffer &inout); \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ +void compute_backward_buffer_inplace_split_ ## EXT(descriptor &desc, sycl::buffer &inout_re, \ + sycl::buffer &inout_im); \ + \ + /*Out-of-place transform*/ \ +void compute_backward_buffer_outofplace_ ## EXT(descriptor &desc, sycl::buffer &in, \ + sycl::buffer &out); \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ +void compute_backward_buffer_outofplace_split_ ## EXT(descriptor &desc, sycl::buffer &in_re, \ + sycl::buffer &in_im, sycl::buffer &out_re, \ + sycl::buffer &out_im); \ + \ + /*USM version*/ \ + \ + /*In-place transform*/ \ +sycl::event compute_backward_usm_inplace_ ## EXT(descriptor &desc, T_BACKWARD *inout, \ + const std::vector &dependencies = {}); \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ +sycl::event compute_backward_usm_inplace_split_ ## EXT(descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ + const std::vector &dependencies = {}); \ + \ + /*Out-of-place transform*/ \ +sycl::event compute_backward_usm_outofplace_ ## EXT(descriptor &desc, T_BACKWARD *in, T_FORWARD *out, \ + const std::vector &dependencies = {}); \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ +sycl::event compute_backward_usm_outofplace_split_ ## EXT(descriptor &desc, T_REAL *in_re, T_REAL *in_im, \ + T_REAL *out_re, T_REAL *out_im, \ + const std::vector &dependencies = {}); \ + +ONEAPI_MKL_DFT_BACKEND_SIGNATURES(f, precision::SINGLE, domain::REAL, float, float, std::complex) +ONEAPI_MKL_DFT_BACKEND_SIGNATURES(c, precision::SINGLE, domain::COMPLEX, float, std::complex, std::complex) +ONEAPI_MKL_DFT_BACKEND_SIGNATURES(d, precision::DOUBLE, domain::REAL, double, double, std::complex) +ONEAPI_MKL_DFT_BACKEND_SIGNATURES(z, precision::DOUBLE, domain::COMPLEX, double, std::complex, std::complex) + +#undef ONEAPI_MKL_DFT_BACKEND_SIGNATURES + +} // namespace mklgpu +} // namespace dft +} // namespace mkl +} // namespace oneapi diff --git a/include/oneapi/mkl/dft/forward.hpp b/include/oneapi/mkl/dft/forward.hpp new file mode 100644 index 000000000..8a8a4f230 --- /dev/null +++ b/include/oneapi/mkl/dft/forward.hpp @@ -0,0 +1,92 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_FORWARD_HPP_ +#define _ONEMKL_DFT_FORWARD_HPP_ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/dft/descriptor.hpp" + +namespace oneapi::mkl::dft { + + //Buffer version + + //In-place transform + template + void compute_forward( descriptor_type &desc, + sycl::buffer &inout); + + //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format + template + void compute_forward( descriptor_type &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im); + + //Out-of-place transform + template + void compute_forward( descriptor_type &desc, + sycl::buffer &in, + sycl::buffer &out); + + //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format + template + void compute_forward( descriptor_type &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im); + + //USM version + + //In-place transform + template + sycl::event compute_forward( descriptor_type &desc, + data_type *inout, + const std::vector &dependencies = {}); + + //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format + template + sycl::event compute_forward(descriptor_type &desc, + data_type *inout_re, + data_type *inout_im, + const std::vector &dependencies = {}); + + //Out-of-place transform + template + sycl::event compute_forward( descriptor_type &desc, + input_type *in, + output_type *out, + const std::vector &dependencies = {}); + + //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format + template + sycl::event compute_forward( descriptor_type &desc, + input_type *in_re, + input_type *in_im, + output_type *out_re, + output_type *out_im, + const std::vector &dependencies = {}); +} + +#endif // _ONEMKL_DFT_FORWARD_HPP_ \ No newline at end of file diff --git a/include/oneapi/mkl/types.hpp b/include/oneapi/mkl/types.hpp index 67f924dde..d53ce866d 100644 --- a/include/oneapi/mkl/types.hpp +++ b/include/oneapi/mkl/types.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2021 Intel Corporation +* Copyright 2020-2022 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -109,6 +109,76 @@ enum class order : char { E = 1, }; +//DFT flag types +namespace dft{ +enum class precision { + SINGLE, + DOUBLE +}; +enum class domain { + REAL, + COMPLEX +}; +enum class config_param { + FORWARD_DOMAIN, + DIMENSION, + LENGTHS, + PRECISION, + + FORWARD_SCALE, + BACKWARD_SCALE, + + NUMBER_OF_TRANSFORMS, + + COMPLEX_STORAGE, + REAL_STORAGE, + CONJUGATE_EVEN_STORAGE, + + PLACEMENT, + + INPUT_STRIDES, + OUTPUT_STRIDES, + + FWD_DISTANCE, + BWD_DISTANCE, + + WORKSPACE, + ORDERING, + TRANSPOSE, + PACKED_FORMAT, + COMMIT_STATUS +}; +enum class config_value { + // for config_param::COMMIT_STATUS + COMMITTED, + UNCOMMITTED, + + // for config_param::COMPLEX_STORAGE, + // config_param::REAL_STORAGE and + // config_param::CONJUGATE_EVEN_STORAGE + COMPLEX_COMPLEX, + REAL_COMPLEX, + REAL_REAL, + + // for config_param::PLACEMENT + INPLACE, + NOT_INPLACE, + + // for config_param::ORDERING + ORDERED, + BACKWARD_SCRAMBLED, + + // Allow/avoid certain usages + ALLOW, + AVOID, + NONE, + + // for config_param::PACKED_FORMAT for storing conjugate-even finite sequence in real containers + CCE_FORMAT + +}; +} + } //namespace mkl } //namespace oneapi diff --git a/scripts/func_parser.py b/scripts/func_parser.py index cbaa26142..7f25dd2e3 100755 --- a/scripts/func_parser.py +++ b/scripts/func_parser.py @@ -149,9 +149,11 @@ def strip_line(l): """Delete all tabs""" return re.sub(' +',' ', l3) -def create_func_db(filename): - with open(filename, 'r') as f: - data = f.readlines() +def create_func_db(filenames): + data=[] + for filename in filenames.split(":"): + with open(filename, 'r') as f: + data.extend(f.readlines()) funcs_db = defaultdict(list) whole_line = "" idx = 0 @@ -170,6 +172,7 @@ def create_func_db(filename): else: stripped = whole_line.strip() whole_line = "" + print(stripped) parsed = parse_item(stripped) func_name, func_data = parsed[0], parsed[1:] funcs_db[func_name].append(to_dict(func_data)) diff --git a/src/dft/CMakeLists.txt b/src/dft/CMakeLists.txt new file mode 100644 index 000000000..d7f83cbc2 --- /dev/null +++ b/src/dft/CMakeLists.txt @@ -0,0 +1,46 @@ +#=============================================================================== +# Copyright 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +# Build backends +add_subdirectory(backends) + +# Recipe for DFT loader object +if(BUILD_SHARED_LIBS) +add_library(onemkl_dft OBJECT) +target_sources(onemkl_dft PRIVATE dft_loader.cpp) +target_include_directories(onemkl_dft + PRIVATE ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src + ${PROJECT_SOURCE_DIR}/src/include + ${CMAKE_BINARY_DIR}/bin + $ +) + +target_compile_options(onemkl_dft PRIVATE ${ONEMKL_BUILD_COPT}) + +set_target_properties(onemkl_dft PROPERTIES + POSITION_INDEPENDENT_CODE ON +) +if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) + add_sycl_to_target(TARGET onemkl_dft SOURCES dft_loader.cpp) +else() + target_link_libraries(onemkl_dft PUBLIC ONEMKL::SYCL::SYCL) +endif() + +endif() diff --git a/src/dft/backends/CMakeLists.txt b/src/dft/backends/CMakeLists.txt new file mode 100644 index 000000000..9cbd4f603 --- /dev/null +++ b/src/dft/backends/CMakeLists.txt @@ -0,0 +1,22 @@ +#=============================================================================== +# Copyright 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +if(ENABLE_MKLGPU_BACKEND) + add_subdirectory(mklgpu) +endif() \ No newline at end of file diff --git a/src/dft/backends/mklgpu/CMakeLists.txt b/src/dft/backends/mklgpu/CMakeLists.txt new file mode 100644 index 000000000..00a5614eb --- /dev/null +++ b/src/dft/backends/mklgpu/CMakeLists.txt @@ -0,0 +1,70 @@ +#=============================================================================== +# Copyright 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +set(LIB_NAME onemkl_dft_mklgpu) +set(LIB_OBJ ${LIB_NAME}_obj) + +find_package(MKL REQUIRED) + +add_library(${LIB_NAME}) +add_library(${LIB_OBJ} OBJECT + descriptor.cpp + forward.cpp + backward.cpp + $<$: mkl_dft_gpu_wrappers.cpp> +) + +target_include_directories(${LIB_OBJ} + PRIVATE ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src + ${CMAKE_BINARY_DIR}/bin + ${MKL_INCLUDE} +) + +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT}) + +target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ${MKL_LINK_SYCL}) + +set_target_properties(${LIB_OBJ} PROPERTIES + POSITION_INDEPENDENT_CODE ON +) +target_link_libraries(${LIB_NAME} PUBLIC ${LIB_OBJ}) + +#Set oneMKL libraries as not transitive for dynamic +if(BUILD_SHARED_LIBS) + set_target_properties(${LIB_NAME} PROPERTIES + INTERFACE_LINK_LIBRARIES ONEMKL::SYCL::SYCL + ) +endif() + +# Add major version to the library +set_target_properties(${LIB_NAME} PROPERTIES + SOVERSION ${PROJECT_VERSION_MAJOR} +) + +# Add dependencies rpath to the library +list(APPEND CMAKE_BUILD_RPATH $) + +# Add the library to install package +install(TARGETS ${LIB_OBJ} EXPORT oneMKLTargets) +install(TARGETS ${LIB_NAME} EXPORT oneMKLTargets + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib +) diff --git a/src/dft/backends/mklgpu/backward.cpp b/src/dft/backends/mklgpu/backward.cpp new file mode 100644 index 000000000..fe22c0498 --- /dev/null +++ b/src/dft/backends/mklgpu/backward.cpp @@ -0,0 +1,178 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/types.hpp" + +#include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { +namespace mklgpu { + +void compute_backward_buffer_inplace_f(descriptor &desc, sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_backward_buffer_inplace_c(descriptor &desc, sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_backward_buffer_inplace_d(descriptor &desc, sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_backward_buffer_inplace_z(descriptor &desc, sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklgpu"); +} + +void compute_backward_buffer_inplace_split_f(descriptor &desc, sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_backward_buffer_inplace_split_c(descriptor &desc, sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_backward_buffer_inplace_split_d(descriptor &desc, sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_backward_buffer_inplace_split_z(descriptor &desc, sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklgpu"); +} + +void compute_backward_buffer_outofplace_f(descriptor &desc, sycl::buffer, 1> &in, + sycl::buffer &out) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_backward_buffer_outofplace_c(descriptor &desc, sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_backward_buffer_outofplace_d(descriptor &desc, sycl::buffer, 1> &in, + sycl::buffer &out) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_backward_buffer_outofplace_z(descriptor &desc, sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { + throw std::runtime_error("Not implemented for mklgpu"); +} + +void compute_backward_buffer_outofplace_split_f(descriptor &desc, sycl::buffer &in_re, + sycl::buffer &in_im, sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_backward_buffer_outofplace_split_c(descriptor &desc, sycl::buffer &in_re, + sycl::buffer &in_im, sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_backward_buffer_outofplace_split_d(descriptor &desc, sycl::buffer &in_re, + sycl::buffer &in_im, sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_backward_buffer_outofplace_split_z(descriptor &desc, sycl::buffer &in_re, + sycl::buffer &in_im, sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklgpu"); +} + +sycl::event compute_backward_usm_inplace_f(descriptor &desc, std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_backward_usm_inplace_c(descriptor &desc, std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_backward_usm_inplace_d(descriptor &desc, std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_backward_usm_inplace_z(descriptor &desc, std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} + +sycl::event compute_backward_usm_inplace_split_f(descriptor &desc, float *inout_re, float *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_backward_usm_inplace_split_c(descriptor &desc, float *inout_re, float *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_backward_usm_inplace_split_d(descriptor &desc, double *inout_re, double *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_backward_usm_inplace_split_z(descriptor &desc, double *inout_re, double *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} + +sycl::event compute_backward_usm_outofplace_f(descriptor &desc, std::complex *in, float *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_backward_usm_outofplace_c(descriptor &desc, std::complex *in, std::complex *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_backward_usm_outofplace_d(descriptor &desc, std::complex *in, double *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_backward_usm_outofplace_z(descriptor &desc, std::complex *in, std::complex *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} + +sycl::event compute_backward_usm_outofplace_split_f(descriptor &desc, float *in_re, float *in_im, + float *out_re, float *out_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_backward_usm_outofplace_split_c(descriptor &desc, float *in_re, float *in_im, + float *out_re, float *out_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_backward_usm_outofplace_split_d(descriptor &desc, double *in_re, double *in_im, + double *out_re, double *out_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_backward_usm_outofplace_split_z(descriptor &desc, double *in_re, double *in_im, + double *out_re, double *out_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} + +} // namespace mklgpu +} // namespace mkl +} // namspace dft +} // namespace oneapi diff --git a/src/dft/backends/mklgpu/descriptor.cpp b/src/dft/backends/mklgpu/descriptor.cpp new file mode 100644 index 000000000..c148c7ae5 --- /dev/null +++ b/src/dft/backends/mklgpu/descriptor.cpp @@ -0,0 +1,51 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/types.hpp" + +#include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { +namespace mklgpu { + +void commit_f(descriptor &desc, sycl::queue &queue) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void commit_c(descriptor &desc, sycl::queue &queue) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void commit_d(descriptor &desc, sycl::queue &queue) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void commit_z(descriptor &desc, sycl::queue &queue) { + throw std::runtime_error("Not implemented for mklgpu"); +} + +} // namespace mklgpu +} // namespace mkl +} // namspace dft +} // namespace oneapi diff --git a/src/dft/backends/mklgpu/forward.cpp b/src/dft/backends/mklgpu/forward.cpp new file mode 100644 index 000000000..bb5efe2fe --- /dev/null +++ b/src/dft/backends/mklgpu/forward.cpp @@ -0,0 +1,178 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/types.hpp" + +#include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { +namespace mklgpu { + +void compute_forward_buffer_inplace_f(descriptor &desc, sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_forward_buffer_inplace_c(descriptor &desc, sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_forward_buffer_inplace_d(descriptor &desc, sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_forward_buffer_inplace_z(descriptor &desc, sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklgpu"); +} + +void compute_forward_buffer_inplace_split_f(descriptor &desc, sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_forward_buffer_inplace_split_c(descriptor &desc, sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_forward_buffer_inplace_split_d(descriptor &desc, sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_forward_buffer_inplace_split_z(descriptor &desc, sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklgpu"); +} + +void compute_forward_buffer_outofplace_f(descriptor &desc, sycl::buffer &in, + sycl::buffer, 1> &out) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_forward_buffer_outofplace_c(descriptor &desc, sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_forward_buffer_outofplace_d(descriptor &desc, sycl::buffer &in, + sycl::buffer, 1> &out) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_forward_buffer_outofplace_z(descriptor &desc, sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { + throw std::runtime_error("Not implemented for mklgpu"); +} + +void compute_forward_buffer_outofplace_split_f(descriptor &desc, sycl::buffer &in_re, + sycl::buffer &in_im, sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_forward_buffer_outofplace_split_c(descriptor &desc, sycl::buffer &in_re, + sycl::buffer &in_im, sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_forward_buffer_outofplace_split_d(descriptor &desc, sycl::buffer &in_re, + sycl::buffer &in_im, sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void compute_forward_buffer_outofplace_split_z(descriptor &desc, sycl::buffer &in_re, + sycl::buffer &in_im, sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklgpu"); +} + +sycl::event compute_forward_usm_inplace_f(descriptor &desc, std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_forward_usm_inplace_c(descriptor &desc, std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_forward_usm_inplace_d(descriptor &desc, std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_forward_usm_inplace_z(descriptor &desc, std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} + +sycl::event compute_forward_usm_inplace_split_f(descriptor &desc, float *inout_re, float *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_forward_usm_inplace_split_c(descriptor &desc, float *inout_re, float *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_forward_usm_inplace_split_d(descriptor &desc, double *inout_re, double *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_forward_usm_inplace_split_z(descriptor &desc, double *inout_re, double *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} + +sycl::event compute_forward_usm_outofplace_f(descriptor &desc, float *in, std::complex *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_forward_usm_outofplace_c(descriptor &desc, std::complex *in, std::complex *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_forward_usm_outofplace_d(descriptor &desc, double *in, std::complex *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_forward_usm_outofplace_z(descriptor &desc, std::complex *in, std::complex *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} + +sycl::event compute_forward_usm_outofplace_split_f(descriptor &desc, float *in_re, float *in_im, + float *out_re, float *out_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_forward_usm_outofplace_split_c(descriptor &desc, float *in_re, float *in_im, + float *out_re, float *out_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_forward_usm_outofplace_split_d(descriptor &desc, double *in_re, double *in_im, + double *out_re, double *out_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} +sycl::event compute_forward_usm_outofplace_split_z(descriptor &desc, double *in_re, double *in_im, + double *out_re, double *out_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklgpu"); +} + +} // namespace mklgpu +} // namespace mkl +} // namspace dft +} // namespace oneapi diff --git a/src/dft/backends/mklgpu/mkl_dft_gpu_wrappers.cpp b/src/dft/backends/mklgpu/mkl_dft_gpu_wrappers.cpp new file mode 100644 index 000000000..f0f837878 --- /dev/null +++ b/src/dft/backends/mklgpu/mkl_dft_gpu_wrappers.cpp @@ -0,0 +1,52 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp" +#include "dft/function_table.hpp" + +#define WRAPPER_VERSION 1 + +extern "C" dft_function_table_t mkl_dft_table = { + WRAPPER_VERSION, +#define ONEAPI_MKL_DFT_BACKEND_SIGNATURES(EXT) \ + oneapi::mkl::dft::mklgpu::commit_ ## EXT, \ + oneapi::mkl::dft::mklgpu::compute_forward_buffer_inplace_ ## EXT, \ + oneapi::mkl::dft::mklgpu::compute_forward_buffer_inplace_split_ ## EXT, \ + oneapi::mkl::dft::mklgpu::compute_forward_buffer_outofplace_ ## EXT, \ + oneapi::mkl::dft::mklgpu::compute_forward_buffer_outofplace_split_ ## EXT, \ + oneapi::mkl::dft::mklgpu::compute_forward_usm_inplace_ ## EXT, \ + oneapi::mkl::dft::mklgpu::compute_forward_usm_inplace_split_ ## EXT, \ + oneapi::mkl::dft::mklgpu::compute_forward_usm_outofplace_ ## EXT, \ + oneapi::mkl::dft::mklgpu::compute_forward_usm_outofplace_split_ ## EXT, \ + oneapi::mkl::dft::mklgpu::compute_backward_buffer_inplace_ ## EXT, \ + oneapi::mkl::dft::mklgpu::compute_backward_buffer_inplace_split_ ## EXT, \ + oneapi::mkl::dft::mklgpu::compute_backward_buffer_outofplace_ ## EXT, \ + oneapi::mkl::dft::mklgpu::compute_backward_buffer_outofplace_split_ ## EXT, \ + oneapi::mkl::dft::mklgpu::compute_backward_usm_inplace_ ## EXT, \ + oneapi::mkl::dft::mklgpu::compute_backward_usm_inplace_split_ ## EXT, \ + oneapi::mkl::dft::mklgpu::compute_backward_usm_outofplace_ ## EXT, \ + oneapi::mkl::dft::mklgpu::compute_backward_usm_outofplace_split_ ## EXT + + ONEAPI_MKL_DFT_BACKEND_SIGNATURES(f), + ONEAPI_MKL_DFT_BACKEND_SIGNATURES(c), + ONEAPI_MKL_DFT_BACKEND_SIGNATURES(d), + ONEAPI_MKL_DFT_BACKEND_SIGNATURES(z) + +#undef ONEAPI_MKL_DFT_BACKEND_SIGNATURES +}; diff --git a/src/dft/dft_loader.cpp b/src/dft/dft_loader.cpp new file mode 100644 index 000000000..727f60a34 --- /dev/null +++ b/src/dft/dft_loader.cpp @@ -0,0 +1,174 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/dft.hpp" + +#include "function_table_initializer.hpp" +#include "dft/function_table.hpp" + +#include "oneapi/mkl/detail/get_device_id.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { + +namespace detail { +static oneapi::mkl::detail::table_initializer function_tables; +} // namespace detail + +#define ONEAPI_MKL_DFT_SIGNATURES(EXT, PRECISION, DOMAIN, T_REAL, T_FORWARD, T_BACKWARD) \ + \ +template<> \ +void descriptor::commit(sycl::queue &queue){ \ + this->queue_ = queue; \ + detail::function_tables[get_device_id(queue)].commit_ ## EXT(*this, queue); \ +} \ + \ + /*Buffer version*/ \ + \ + /*In-place transform*/ \ +template<> \ +void compute_forward, T_BACKWARD>(descriptor &desc, sycl::buffer &inout){ \ + detail::function_tables[get_device_id(desc.get_queue())].compute_forward_buffer_inplace_ ## EXT(desc, inout); \ +} \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ +template<> \ +void compute_forward, T_REAL>(descriptor &desc, sycl::buffer &inout_re, \ + sycl::buffer &inout_im){ \ + detail::function_tables[get_device_id(desc.get_queue())].compute_forward_buffer_inplace_split_ ## EXT(desc, inout_re, inout_im); \ +} \ + \ + /*Out-of-place transform*/ \ +template<> \ +void compute_forward, T_FORWARD, T_BACKWARD>(descriptor &desc, sycl::buffer &in, \ + sycl::buffer &out){ \ + detail::function_tables[get_device_id(desc.get_queue())].compute_forward_buffer_outofplace_ ## EXT(desc, in, out); \ +} \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ +template<> \ +void compute_forward, T_REAL, T_REAL>(descriptor &desc, sycl::buffer &in_re, \ + sycl::buffer &in_im, sycl::buffer &out_re, \ + sycl::buffer &out_im){ \ + detail::function_tables[get_device_id(desc.get_queue())].compute_forward_buffer_outofplace_split_ ## EXT(desc, in_re, in_im, out_re, out_im); \ +} \ + \ + /*USM version*/ \ + \ + /*In-place transform*/ \ +template<> \ +sycl::event compute_forward, T_BACKWARD>(descriptor &desc, T_BACKWARD *inout, \ + const std::vector &dependencies){ \ + return detail::function_tables[get_device_id(desc.get_queue())].compute_forward_usm_inplace_ ## EXT(desc, inout, dependencies); \ +} \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ +template<> \ +sycl::event compute_forward, T_REAL>(descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ + const std::vector &dependencies){ \ + return detail::function_tables[get_device_id(desc.get_queue())].compute_forward_usm_inplace_split_ ## EXT(desc, inout_re, inout_im, dependencies); \ +} \ + \ + /*Out-of-place transform*/ \ +template<> \ +sycl::event compute_forward, T_FORWARD, T_BACKWARD>(descriptor &desc, T_FORWARD *in, T_BACKWARD *out, \ + const std::vector &dependencies){ \ + return detail::function_tables[get_device_id(desc.get_queue())].compute_forward_usm_outofplace_ ## EXT(desc, in, out, dependencies); \ +} \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ +template<> \ +sycl::event compute_forward, T_REAL, T_REAL>(descriptor &desc, T_REAL *in_re, T_REAL *in_im, \ + T_REAL *out_re, T_REAL *out_im, \ + const std::vector &dependencies){ \ + return detail::function_tables[get_device_id(desc.get_queue())].compute_forward_usm_outofplace_split_ ## EXT(desc, in_re, in_im, out_re, out_im, dependencies); \ +} \ + \ + /*Buffer version*/ \ + \ + /*In-place transform*/ \ +template<> \ +void compute_backward, T_BACKWARD>(descriptor &desc, sycl::buffer &inout){ \ + detail::function_tables[get_device_id(desc.get_queue())].compute_backward_buffer_inplace_ ## EXT(desc, inout); \ +} \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ +template<> \ +void compute_backward, T_REAL>(descriptor &desc, sycl::buffer &inout_re, \ + sycl::buffer &inout_im){ \ + detail::function_tables[get_device_id(desc.get_queue())].compute_backward_buffer_inplace_split_ ## EXT(desc, inout_re, inout_im); \ +} \ + \ + /*Out-of-place transform*/ \ +template<> \ +void compute_backward, T_BACKWARD, T_FORWARD>(descriptor &desc, sycl::buffer &in, \ + sycl::buffer &out){ \ + detail::function_tables[get_device_id(desc.get_queue())].compute_backward_buffer_outofplace_ ## EXT(desc, in, out); \ +} \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ +template<> \ +void compute_backward, T_REAL, T_REAL>(descriptor &desc, sycl::buffer &in_re, \ + sycl::buffer &in_im, sycl::buffer &out_re, \ + sycl::buffer &out_im){ \ + detail::function_tables[get_device_id(desc.get_queue())].compute_backward_buffer_outofplace_split_ ## EXT(desc, in_re, in_im, out_re, out_im); \ +} \ + \ + /*USM version*/ \ + \ + /*In-place transform*/ \ +template<> \ +sycl::event compute_backward, T_BACKWARD>(descriptor &desc, T_BACKWARD *inout, \ + const std::vector &dependencies){ \ + return detail::function_tables[get_device_id(desc.get_queue())].compute_backward_usm_inplace_ ## EXT(desc, inout, dependencies); \ +} \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ +template<> \ +sycl::event compute_backward, T_REAL>(descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ + const std::vector &dependencies){ \ + return detail::function_tables[get_device_id(desc.get_queue())].compute_backward_usm_inplace_split_ ## EXT(desc, inout_re, inout_im, dependencies); \ +} \ + \ + /*Out-of-place transform*/ \ +template<> \ +sycl::event compute_backward, T_BACKWARD, T_FORWARD>(descriptor &desc, T_BACKWARD *in, T_FORWARD *out, \ + const std::vector &dependencies){ \ + return detail::function_tables[get_device_id(desc.get_queue())].compute_backward_usm_outofplace_ ## EXT(desc, in, out, dependencies); \ +} \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ +template<> \ +sycl::event compute_backward, T_REAL, T_REAL>(descriptor &desc, T_REAL *in_re, T_REAL *in_im, \ + T_REAL *out_re, T_REAL *out_im, \ + const std::vector &dependencies){ \ + return detail::function_tables[get_device_id(desc.get_queue())].compute_backward_usm_outofplace_split_ ## EXT(desc, in_re, in_im, out_re, out_im, dependencies); \ +} \ + +ONEAPI_MKL_DFT_SIGNATURES(f, precision::SINGLE, domain::REAL, float, float, std::complex) +ONEAPI_MKL_DFT_SIGNATURES(c, precision::SINGLE, domain::COMPLEX, float, std::complex, std::complex) +ONEAPI_MKL_DFT_SIGNATURES(d, precision::DOUBLE, domain::REAL, double, double, std::complex) +ONEAPI_MKL_DFT_SIGNATURES(z, precision::DOUBLE, domain::COMPLEX, double, std::complex, std::complex) + +#undef ONEAPI_MKL_DFT_SIGNATURES + +} // namespace rng +} // namespace mkl +} // namespace oneapi diff --git a/src/dft/function_table.hpp b/src/dft/function_table.hpp new file mode 100644 index 000000000..978988f5c --- /dev/null +++ b/src/dft/function_table.hpp @@ -0,0 +1,83 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _DFT_FUNCTION_TABLE_HPP_ +#define _DFT_FUNCTION_TABLE_HPP_ + +#include +#include + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/dft/descriptor.hpp" + +typedef struct { + int version; + +#define ONEAPI_MKL_DFT_BACKEND_SIGNATURES(EXT, PRECISION, DOMAIN, T_REAL, T_FORWARD, T_BACKWARD) \ +void (*commit_ ## EXT)(oneapi::mkl::dft::descriptor &desc, sycl::queue &queue); \ +void (*compute_forward_buffer_inplace_ ## EXT)(oneapi::mkl::dft::descriptor &desc, sycl::buffer &inout); \ +void (*compute_forward_buffer_inplace_split_ ## EXT)(oneapi::mkl::dft::descriptor &desc, sycl::buffer &inout_re, \ + sycl::buffer &inout_im); \ +void (*compute_forward_buffer_outofplace_ ## EXT)(oneapi::mkl::dft::descriptor &desc, sycl::buffer &in, \ + sycl::buffer &out); \ +void (*compute_forward_buffer_outofplace_split_ ## EXT)(oneapi::mkl::dft::descriptor &desc, sycl::buffer &in_re, \ + sycl::buffer &in_im, sycl::buffer &out_re, \ + sycl::buffer &out_im); \ +sycl::event (*compute_forward_usm_inplace_ ## EXT)(oneapi::mkl::dft::descriptor &desc, T_BACKWARD *inout, \ + const std::vector &dependencies); \ +sycl::event (*compute_forward_usm_inplace_split_ ## EXT)(oneapi::mkl::dft::descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ + const std::vector &dependencies); \ +sycl::event (*compute_forward_usm_outofplace_ ## EXT)(oneapi::mkl::dft::descriptor &desc, T_FORWARD *in, T_BACKWARD *out, \ + const std::vector &dependencies); \ +sycl::event (*compute_forward_usm_outofplace_split_ ## EXT)(oneapi::mkl::dft::descriptor &desc, T_REAL *in_re, T_REAL *in_im, \ + T_REAL *out_re, T_REAL *out_im, \ + const std::vector &dependencies); \ +void (*compute_backward_buffer_inplace_ ## EXT)(oneapi::mkl::dft::descriptor &desc, sycl::buffer &inout); \ +void (*compute_backward_buffer_inplace_split_ ## EXT)(oneapi::mkl::dft::descriptor &desc, sycl::buffer &inout_re, \ + sycl::buffer &inout_im); \ +void (*compute_backward_buffer_outofplace_ ## EXT)(oneapi::mkl::dft::descriptor &desc, sycl::buffer &in, \ + sycl::buffer &out); \ +void (*compute_backward_buffer_outofplace_split_ ## EXT)(oneapi::mkl::dft::descriptor &desc, sycl::buffer &in_re, \ + sycl::buffer &in_im, sycl::buffer &out_re, \ + sycl::buffer &out_im); \ +sycl::event (*compute_backward_usm_inplace_ ## EXT)(oneapi::mkl::dft::descriptor &desc, T_BACKWARD *inout, \ + const std::vector &dependencies); \ +sycl::event (*compute_backward_usm_inplace_split_ ## EXT)(oneapi::mkl::dft::descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ + const std::vector &dependencies); \ +sycl::event (*compute_backward_usm_outofplace_ ## EXT)(oneapi::mkl::dft::descriptor &desc, T_BACKWARD *in, T_FORWARD *out, \ + const std::vector &dependencies); \ +sycl::event (*compute_backward_usm_outofplace_split_ ## EXT)(oneapi::mkl::dft::descriptor &desc, T_REAL *in_re, T_REAL *in_im, \ + T_REAL *out_re, T_REAL *out_im, \ + const std::vector &dependencies); \ + +ONEAPI_MKL_DFT_BACKEND_SIGNATURES(f, oneapi::mkl::dft::precision::SINGLE, oneapi::mkl::dft::domain::REAL, float, float, std::complex) +ONEAPI_MKL_DFT_BACKEND_SIGNATURES(c, oneapi::mkl::dft::precision::SINGLE, oneapi::mkl::dft::domain::COMPLEX, float, std::complex, std::complex) +ONEAPI_MKL_DFT_BACKEND_SIGNATURES(d, oneapi::mkl::dft::precision::DOUBLE, oneapi::mkl::dft::domain::REAL, double, double, std::complex) +ONEAPI_MKL_DFT_BACKEND_SIGNATURES(z, oneapi::mkl::dft::precision::DOUBLE, oneapi::mkl::dft::domain::COMPLEX, double, std::complex, std::complex) + +#undef ONEAPI_MKL_DFT_BACKEND_SIGNATURES +} dft_function_table_t; + +#endif //_DFT_FUNCTION_TABLE_HPP_ diff --git a/tests/unit_tests/CMakeLists.txt b/tests/unit_tests/CMakeLists.txt index c046def9b..ca4dad0ce 100644 --- a/tests/unit_tests/CMakeLists.txt +++ b/tests/unit_tests/CMakeLists.txt @@ -59,6 +59,12 @@ set(rng_TEST_LIST set(rng_TEST_LINK "") +# DFT config +set(dft_TEST_LIST + dft_source) + +set(dft_TEST_LINK "") + foreach(domain ${TARGET_DOMAINS}) # Generate RT and CT test lists set(${domain}_TEST_LIST_RT ${${domain}_TEST_LIST}) @@ -93,7 +99,7 @@ foreach(domain ${TARGET_DOMAINS}) endif() endif() - if(ENABLE_MKLCPU_BACKEND) + if((NOT domain STREQUAL "dft") AND ENABLE_MKLCPU_BACKEND) add_dependencies(test_main_${domain}_ct onemkl_${domain}_mklcpu) list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_mklcpu) endif() diff --git a/tests/unit_tests/dft/CMakeLists.txt b/tests/unit_tests/dft/CMakeLists.txt new file mode 100644 index 000000000..3a12a42ed --- /dev/null +++ b/tests/unit_tests/dft/CMakeLists.txt @@ -0,0 +1,20 @@ +#=============================================================================== +# Copyright 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +add_subdirectory(source) \ No newline at end of file diff --git a/tests/unit_tests/dft/source/CMakeLists.txt b/tests/unit_tests/dft/source/CMakeLists.txt new file mode 100644 index 000000000..be6ced78b --- /dev/null +++ b/tests/unit_tests/dft/source/CMakeLists.txt @@ -0,0 +1,49 @@ +#=============================================================================== +# Copyright 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +#Build object from all test sources +set(DFT_SOURCES + tmp.cpp +) + +if(BUILD_SHARED_LIBS) + add_library(dft_source_rt OBJECT ${DFT_SOURCES}) + target_compile_options(dft_source_rt PRIVATE -DCALL_RT_API -DNOMINMAX) + target_include_directories(dft_source_rt + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../include + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include + PUBLIC ${PROJECT_SOURCE_DIR}/include + PUBLIC ${PROJECT_SOURCE_DIR}/deps/googletest/include + PUBLIC ${CMAKE_BINARY_DIR}/bin + PUBLIC ${CBLAS_INCLUDE} + ) + target_link_libraries(dft_source_rt PUBLIC ONEMKL::SYCL::SYCL) +endif() + +add_library(dft_source_ct OBJECT ${DFT_SOURCES}) +target_compile_options(dft_source_ct PRIVATE -DNOMINMAX) +target_include_directories(dft_source_ct + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../include + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include + PUBLIC ${PROJECT_SOURCE_DIR}/include + PUBLIC ${PROJECT_SOURCE_DIR}/deps/googletest/include + PUBLIC ${CMAKE_BINARY_DIR}/bin + PUBLIC ${CBLAS_INCLUDE} +) +target_link_libraries(dft_source_ct PUBLIC ONEMKL::SYCL::SYCL) diff --git a/tests/unit_tests/dft/source/tmp.cpp b/tests/unit_tests/dft/source/tmp.cpp new file mode 100644 index 000000000..e69de29bb From 32ddffec0f9010c87143e44ee717c837df0e0150 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Tue, 27 Sep 2022 08:29:05 +0100 Subject: [PATCH 02/21] format --- include/oneapi/mkl/dft/backward.hpp | 85 +++-- include/oneapi/mkl/dft/descriptor.hpp | 32 +- .../dft/detail/mklgpu/onemkl_dft_mklgpu.hpp | 180 ++++++----- include/oneapi/mkl/dft/forward.hpp | 85 +++-- include/oneapi/mkl/types.hpp | 96 +++--- src/dft/backends/mklgpu/backward.cpp | 167 ++++++---- src/dft/backends/mklgpu/descriptor.cpp | 2 +- src/dft/backends/mklgpu/forward.cpp | 170 ++++++---- .../backends/mklgpu/mkl_dft_gpu_wrappers.cpp | 42 ++- src/dft/dft_loader.cpp | 301 ++++++++++-------- src/dft/function_table.hpp | 105 +++--- 11 files changed, 686 insertions(+), 579 deletions(-) diff --git a/include/oneapi/mkl/dft/backward.hpp b/include/oneapi/mkl/dft/backward.hpp index 2fb6911f6..734aecdc8 100644 --- a/include/oneapi/mkl/dft/backward.hpp +++ b/include/oneapi/mkl/dft/backward.hpp @@ -29,63 +29,50 @@ #include "oneapi/mkl/dft/descriptor.hpp" namespace oneapi::mkl::dft { - //Buffer version +//Buffer version - //In-place transform - template - void compute_backward( descriptor_type &desc, - sycl::buffer &inout ); +//In-place transform +template +void compute_backward(descriptor_type &desc, sycl::buffer &inout); - //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format - template - void compute_backward( descriptor_type &desc, - sycl::buffer &inout_re, - sycl::buffer &inout_im); +//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +void compute_backward(descriptor_type &desc, sycl::buffer &inout_re, + sycl::buffer &inout_im); - //Out-of-place transform - template - void compute_backward( descriptor_type &desc, - sycl::buffer &in, - sycl::buffer &out); +//Out-of-place transform +template +void compute_backward(descriptor_type &desc, sycl::buffer &in, + sycl::buffer &out); - //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format - template - void compute_backward( descriptor_type &desc, - sycl::buffer &in_re, - sycl::buffer &in_im, - sycl::buffer &out_re, - sycl::buffer &out_im); +//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +void compute_backward(descriptor_type &desc, sycl::buffer &in_re, + sycl::buffer &in_im, sycl::buffer &out_re, + sycl::buffer &out_im); - //USM version +//USM version - //In-place transform - template - sycl::event compute_backward( descriptor_type &desc, - data_type *inout, - const std::vector &dependencies = {}); +//In-place transform +template +sycl::event compute_backward(descriptor_type &desc, data_type *inout, + const std::vector &dependencies = {}); - //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format - template - sycl::event compute_backward(descriptor_type &desc, - data_type *inout_re, - data_type *inout_im, - const std::vector &dependencies = {}); +//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, data_type *inout_im, + const std::vector &dependencies = {}); - //Out-of-place transform - template - sycl::event compute_backward( descriptor_type &desc, - input_type *in, - output_type *out, - const std::vector &dependencies = {}); +//Out-of-place transform +template +sycl::event compute_backward(descriptor_type &desc, input_type *in, output_type *out, + const std::vector &dependencies = {}); - //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format - template - sycl::event compute_backward( descriptor_type &desc, - input_type *in_re, - input_type *in_im, - output_type *out_re, - output_type *out_im, - const std::vector &dependencies = {}); -} +//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +sycl::event compute_backward(descriptor_type &desc, input_type *in_re, input_type *in_im, + output_type *out_re, output_type *out_im, + const std::vector &dependencies = {}); +} // namespace oneapi::mkl::dft #endif // _ONEMKL_DFT_BACKWARD_HPP_ \ No newline at end of file diff --git a/include/oneapi/mkl/dft/descriptor.hpp b/include/oneapi/mkl/dft/descriptor.hpp index 0f3330f1e..6bcac154b 100644 --- a/include/oneapi/mkl/dft/descriptor.hpp +++ b/include/oneapi/mkl/dft/descriptor.hpp @@ -30,27 +30,29 @@ namespace oneapi::mkl::dft { - template - class descriptor { - private: - sycl::queue queue_; - public: - // Syntax for 1-dimensional DFT - descriptor(std::int64_t length); - // Syntax for d-dimensional DFT - descriptor(std::vector dimensions); +template +class descriptor { +private: + sycl::queue queue_; - ~descriptor(); +public: + // Syntax for 1-dimensional DFT + descriptor(std::int64_t length); + // Syntax for d-dimensional DFT + descriptor(std::vector dimensions); + ~descriptor(); - void set_value(config_param param, ...); + void set_value(config_param param, ...); - void get_value(config_param param, ...); + void get_value(config_param param, ...); - void commit(sycl::queue &queue); + void commit(sycl::queue& queue); - sycl::queue& get_queue(){return queue_; }; + sycl::queue& get_queue() { + return queue_; }; -} +}; +} // namespace oneapi::mkl::dft #endif // _ONEMKL_DFT_DESCRIPTOR_HPP_ \ No newline at end of file diff --git a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp index f9fb9592f..2004da0cb 100644 --- a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp +++ b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp @@ -36,88 +36,106 @@ namespace mkl { namespace dft { namespace mklgpu { -#define ONEAPI_MKL_DFT_BACKEND_SIGNATURES(EXT, PRECISION, DOMAIN, T_REAL, T_FORWARD, T_BACKWARD) \ - \ -void commit_ ## EXT(descriptor &desc, sycl::queue& queue); \ - \ - /*Buffer version*/ \ - \ - /*In-place transform*/ \ -void compute_forward_buffer_inplace_ ## EXT(descriptor &desc, sycl::buffer &inout); \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ -void compute_forward_buffer_inplace_split_ ## EXT(descriptor &desc, sycl::buffer &inout_re, \ - sycl::buffer &inout_im); \ - \ - /*Out-of-place transform*/ \ -void compute_forward_buffer_outofplace_ ## EXT(descriptor &desc, sycl::buffer &in, \ - sycl::buffer &out); \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ -void compute_forward_buffer_outofplace_split_ ## EXT(descriptor &desc, sycl::buffer &in_re, \ - sycl::buffer &in_im, sycl::buffer &out_re, \ - sycl::buffer &out_im); \ - \ - /*USM version*/ \ - \ - /*In-place transform*/ \ -sycl::event compute_forward_usm_inplace_ ## EXT(descriptor &desc, T_BACKWARD *inout, \ - const std::vector &dependencies = {}); \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ -sycl::event compute_forward_usm_inplace_split_ ## EXT(descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ - const std::vector &dependencies = {}); \ - \ - /*Out-of-place transform*/ \ -sycl::event compute_forward_usm_outofplace_ ## EXT(descriptor &desc, T_FORWARD *in, T_BACKWARD *out, \ - const std::vector &dependencies = {}); \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ -sycl::event compute_forward_usm_outofplace_split_ ## EXT(descriptor &desc, T_REAL *in_re, T_REAL *in_im, \ - T_REAL *out_re, T_REAL *out_im, \ - const std::vector &dependencies = {}); \ - \ - /*Buffer version*/ \ - \ - /*In-place transform*/ \ -void compute_backward_buffer_inplace_ ## EXT(descriptor &desc, sycl::buffer &inout); \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ -void compute_backward_buffer_inplace_split_ ## EXT(descriptor &desc, sycl::buffer &inout_re, \ - sycl::buffer &inout_im); \ - \ - /*Out-of-place transform*/ \ -void compute_backward_buffer_outofplace_ ## EXT(descriptor &desc, sycl::buffer &in, \ - sycl::buffer &out); \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ -void compute_backward_buffer_outofplace_split_ ## EXT(descriptor &desc, sycl::buffer &in_re, \ - sycl::buffer &in_im, sycl::buffer &out_re, \ - sycl::buffer &out_im); \ - \ - /*USM version*/ \ - \ - /*In-place transform*/ \ -sycl::event compute_backward_usm_inplace_ ## EXT(descriptor &desc, T_BACKWARD *inout, \ - const std::vector &dependencies = {}); \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ -sycl::event compute_backward_usm_inplace_split_ ## EXT(descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ - const std::vector &dependencies = {}); \ - \ - /*Out-of-place transform*/ \ -sycl::event compute_backward_usm_outofplace_ ## EXT(descriptor &desc, T_BACKWARD *in, T_FORWARD *out, \ - const std::vector &dependencies = {}); \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ -sycl::event compute_backward_usm_outofplace_split_ ## EXT(descriptor &desc, T_REAL *in_re, T_REAL *in_im, \ - T_REAL *out_re, T_REAL *out_im, \ - const std::vector &dependencies = {}); \ +#define ONEAPI_MKL_DFT_BACKEND_SIGNATURES(EXT, PRECISION, DOMAIN, T_REAL, T_FORWARD, T_BACKWARD) \ + \ + void commit_##EXT(descriptor &desc, sycl::queue &queue); \ + \ + /*Buffer version*/ \ + \ + /*In-place transform*/ \ + void compute_forward_buffer_inplace_##EXT(descriptor &desc, \ + sycl::buffer &inout); \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + void compute_forward_buffer_inplace_split_##EXT(descriptor &desc, \ + sycl::buffer &inout_re, \ + sycl::buffer &inout_im); \ + \ + /*Out-of-place transform*/ \ + void compute_forward_buffer_outofplace_##EXT(descriptor &desc, \ + sycl::buffer &in, \ + sycl::buffer &out); \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + void compute_forward_buffer_outofplace_split_##EXT( \ + descriptor &desc, sycl::buffer &in_re, \ + sycl::buffer &in_im, sycl::buffer &out_re, \ + sycl::buffer &out_im); \ + \ + /*USM version*/ \ + \ + /*In-place transform*/ \ + sycl::event compute_forward_usm_inplace_##EXT( \ + descriptor &desc, T_BACKWARD *inout, \ + const std::vector &dependencies = {}); \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + sycl::event compute_forward_usm_inplace_split_##EXT( \ + descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ + const std::vector &dependencies = {}); \ + \ + /*Out-of-place transform*/ \ + sycl::event compute_forward_usm_outofplace_##EXT( \ + descriptor &desc, T_FORWARD *in, T_BACKWARD *out, \ + const std::vector &dependencies = {}); \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + sycl::event compute_forward_usm_outofplace_split_##EXT( \ + descriptor &desc, T_REAL *in_re, T_REAL *in_im, T_REAL *out_re, \ + T_REAL *out_im, const std::vector &dependencies = {}); \ + \ + /*Buffer version*/ \ + \ + /*In-place transform*/ \ + void compute_backward_buffer_inplace_##EXT(descriptor &desc, \ + sycl::buffer &inout); \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + void compute_backward_buffer_inplace_split_##EXT(descriptor &desc, \ + sycl::buffer &inout_re, \ + sycl::buffer &inout_im); \ + \ + /*Out-of-place transform*/ \ + void compute_backward_buffer_outofplace_##EXT(descriptor &desc, \ + sycl::buffer &in, \ + sycl::buffer &out); \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + void compute_backward_buffer_outofplace_split_##EXT( \ + descriptor &desc, sycl::buffer &in_re, \ + sycl::buffer &in_im, sycl::buffer &out_re, \ + sycl::buffer &out_im); \ + \ + /*USM version*/ \ + \ + /*In-place transform*/ \ + sycl::event compute_backward_usm_inplace_##EXT( \ + descriptor &desc, T_BACKWARD *inout, \ + const std::vector &dependencies = {}); \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + sycl::event compute_backward_usm_inplace_split_##EXT( \ + descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ + const std::vector &dependencies = {}); \ + \ + /*Out-of-place transform*/ \ + sycl::event compute_backward_usm_outofplace_##EXT( \ + descriptor &desc, T_BACKWARD *in, T_FORWARD *out, \ + const std::vector &dependencies = {}); \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + sycl::event compute_backward_usm_outofplace_split_##EXT( \ + descriptor &desc, T_REAL *in_re, T_REAL *in_im, T_REAL *out_re, \ + T_REAL *out_im, const std::vector &dependencies = {}); -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(f, precision::SINGLE, domain::REAL, float, float, std::complex) -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(c, precision::SINGLE, domain::COMPLEX, float, std::complex, std::complex) -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(d, precision::DOUBLE, domain::REAL, double, double, std::complex) -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(z, precision::DOUBLE, domain::COMPLEX, double, std::complex, std::complex) +ONEAPI_MKL_DFT_BACKEND_SIGNATURES(f, precision::SINGLE, domain::REAL, float, float, + std::complex) +ONEAPI_MKL_DFT_BACKEND_SIGNATURES(c, precision::SINGLE, domain::COMPLEX, float, std::complex, + std::complex) +ONEAPI_MKL_DFT_BACKEND_SIGNATURES(d, precision::DOUBLE, domain::REAL, double, double, + std::complex) +ONEAPI_MKL_DFT_BACKEND_SIGNATURES(z, precision::DOUBLE, domain::COMPLEX, double, + std::complex, std::complex) #undef ONEAPI_MKL_DFT_BACKEND_SIGNATURES diff --git a/include/oneapi/mkl/dft/forward.hpp b/include/oneapi/mkl/dft/forward.hpp index 8a8a4f230..8b37185f3 100644 --- a/include/oneapi/mkl/dft/forward.hpp +++ b/include/oneapi/mkl/dft/forward.hpp @@ -30,63 +30,50 @@ namespace oneapi::mkl::dft { - //Buffer version +//Buffer version - //In-place transform - template - void compute_forward( descriptor_type &desc, - sycl::buffer &inout); +//In-place transform +template +void compute_forward(descriptor_type &desc, sycl::buffer &inout); - //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format - template - void compute_forward( descriptor_type &desc, - sycl::buffer &inout_re, - sycl::buffer &inout_im); +//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +void compute_forward(descriptor_type &desc, sycl::buffer &inout_re, + sycl::buffer &inout_im); - //Out-of-place transform - template - void compute_forward( descriptor_type &desc, - sycl::buffer &in, - sycl::buffer &out); +//Out-of-place transform +template +void compute_forward(descriptor_type &desc, sycl::buffer &in, + sycl::buffer &out); - //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format - template - void compute_forward( descriptor_type &desc, - sycl::buffer &in_re, - sycl::buffer &in_im, - sycl::buffer &out_re, - sycl::buffer &out_im); +//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +void compute_forward(descriptor_type &desc, sycl::buffer &in_re, + sycl::buffer &in_im, sycl::buffer &out_re, + sycl::buffer &out_im); - //USM version +//USM version - //In-place transform - template - sycl::event compute_forward( descriptor_type &desc, - data_type *inout, - const std::vector &dependencies = {}); +//In-place transform +template +sycl::event compute_forward(descriptor_type &desc, data_type *inout, + const std::vector &dependencies = {}); - //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format - template - sycl::event compute_forward(descriptor_type &desc, - data_type *inout_re, - data_type *inout_im, - const std::vector &dependencies = {}); +//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, data_type *inout_im, + const std::vector &dependencies = {}); - //Out-of-place transform - template - sycl::event compute_forward( descriptor_type &desc, - input_type *in, - output_type *out, - const std::vector &dependencies = {}); +//Out-of-place transform +template +sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *out, + const std::vector &dependencies = {}); - //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format - template - sycl::event compute_forward( descriptor_type &desc, - input_type *in_re, - input_type *in_im, - output_type *out_re, - output_type *out_im, - const std::vector &dependencies = {}); -} +//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +sycl::event compute_forward(descriptor_type &desc, input_type *in_re, input_type *in_im, + output_type *out_re, output_type *out_im, + const std::vector &dependencies = {}); +} // namespace oneapi::mkl::dft #endif // _ONEMKL_DFT_FORWARD_HPP_ \ No newline at end of file diff --git a/include/oneapi/mkl/types.hpp b/include/oneapi/mkl/types.hpp index d53ce866d..faf39235c 100644 --- a/include/oneapi/mkl/types.hpp +++ b/include/oneapi/mkl/types.hpp @@ -110,74 +110,68 @@ enum class order : char { }; //DFT flag types -namespace dft{ -enum class precision { - SINGLE, - DOUBLE -}; -enum class domain { - REAL, - COMPLEX -}; +namespace dft { +enum class precision { SINGLE, DOUBLE }; +enum class domain { REAL, COMPLEX }; enum class config_param { - FORWARD_DOMAIN, - DIMENSION, - LENGTHS, - PRECISION, + FORWARD_DOMAIN, + DIMENSION, + LENGTHS, + PRECISION, - FORWARD_SCALE, - BACKWARD_SCALE, + FORWARD_SCALE, + BACKWARD_SCALE, - NUMBER_OF_TRANSFORMS, + NUMBER_OF_TRANSFORMS, - COMPLEX_STORAGE, - REAL_STORAGE, - CONJUGATE_EVEN_STORAGE, + COMPLEX_STORAGE, + REAL_STORAGE, + CONJUGATE_EVEN_STORAGE, - PLACEMENT, + PLACEMENT, - INPUT_STRIDES, - OUTPUT_STRIDES, + INPUT_STRIDES, + OUTPUT_STRIDES, - FWD_DISTANCE, - BWD_DISTANCE, + FWD_DISTANCE, + BWD_DISTANCE, - WORKSPACE, - ORDERING, - TRANSPOSE, - PACKED_FORMAT, - COMMIT_STATUS + WORKSPACE, + ORDERING, + TRANSPOSE, + PACKED_FORMAT, + COMMIT_STATUS }; enum class config_value { - // for config_param::COMMIT_STATUS - COMMITTED, - UNCOMMITTED, + // for config_param::COMMIT_STATUS + COMMITTED, + UNCOMMITTED, - // for config_param::COMPLEX_STORAGE, - // config_param::REAL_STORAGE and - // config_param::CONJUGATE_EVEN_STORAGE - COMPLEX_COMPLEX, - REAL_COMPLEX, - REAL_REAL, + // for config_param::COMPLEX_STORAGE, + // config_param::REAL_STORAGE and + // config_param::CONJUGATE_EVEN_STORAGE + COMPLEX_COMPLEX, + REAL_COMPLEX, + REAL_REAL, - // for config_param::PLACEMENT - INPLACE, - NOT_INPLACE, + // for config_param::PLACEMENT + INPLACE, + NOT_INPLACE, - // for config_param::ORDERING - ORDERED, - BACKWARD_SCRAMBLED, + // for config_param::ORDERING + ORDERED, + BACKWARD_SCRAMBLED, - // Allow/avoid certain usages - ALLOW, - AVOID, - NONE, + // Allow/avoid certain usages + ALLOW, + AVOID, + NONE, - // for config_param::PACKED_FORMAT for storing conjugate-even finite sequence in real containers - CCE_FORMAT + // for config_param::PACKED_FORMAT for storing conjugate-even finite sequence in real containers + CCE_FORMAT }; -} +} // namespace dft } //namespace mkl } //namespace oneapi diff --git a/src/dft/backends/mklgpu/backward.cpp b/src/dft/backends/mklgpu/backward.cpp index fe22c0498..d50efe915 100644 --- a/src/dft/backends/mklgpu/backward.cpp +++ b/src/dft/backends/mklgpu/backward.cpp @@ -32,147 +32,176 @@ namespace mkl { namespace dft { namespace mklgpu { -void compute_backward_buffer_inplace_f(descriptor &desc, sycl::buffer, 1> &inout) { +void compute_backward_buffer_inplace_f(descriptor &desc, + sycl::buffer, 1> &inout) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_backward_buffer_inplace_c(descriptor &desc, sycl::buffer, 1> &inout) { +void compute_backward_buffer_inplace_c(descriptor &desc, + sycl::buffer, 1> &inout) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_backward_buffer_inplace_d(descriptor &desc, sycl::buffer, 1> &inout) { +void compute_backward_buffer_inplace_d(descriptor &desc, + sycl::buffer, 1> &inout) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_backward_buffer_inplace_z(descriptor &desc, sycl::buffer, 1> &inout) { +void compute_backward_buffer_inplace_z(descriptor &desc, + sycl::buffer, 1> &inout) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_backward_buffer_inplace_split_f(descriptor &desc, sycl::buffer &inout_re, - sycl::buffer &inout_im) { +void compute_backward_buffer_inplace_split_f(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_backward_buffer_inplace_split_c(descriptor &desc, sycl::buffer &inout_re, - sycl::buffer &inout_im) { +void compute_backward_buffer_inplace_split_c(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_backward_buffer_inplace_split_d(descriptor &desc, sycl::buffer &inout_re, - sycl::buffer &inout_im) { +void compute_backward_buffer_inplace_split_d(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_backward_buffer_inplace_split_z(descriptor &desc, sycl::buffer &inout_re, - sycl::buffer &inout_im) { +void compute_backward_buffer_inplace_split_z(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_backward_buffer_outofplace_f(descriptor &desc, sycl::buffer, 1> &in, - sycl::buffer &out) { +void compute_backward_buffer_outofplace_f(descriptor &desc, + sycl::buffer, 1> &in, + sycl::buffer &out) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_backward_buffer_outofplace_c(descriptor &desc, sycl::buffer, 1> &in, - sycl::buffer, 1> &out) { +void compute_backward_buffer_outofplace_c(descriptor &desc, + sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_backward_buffer_outofplace_d(descriptor &desc, sycl::buffer, 1> &in, - sycl::buffer &out) { +void compute_backward_buffer_outofplace_d(descriptor &desc, + sycl::buffer, 1> &in, + sycl::buffer &out) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_backward_buffer_outofplace_z(descriptor &desc, sycl::buffer, 1> &in, - sycl::buffer, 1> &out) { +void compute_backward_buffer_outofplace_z(descriptor &desc, + sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_backward_buffer_outofplace_split_f(descriptor &desc, sycl::buffer &in_re, - sycl::buffer &in_im, sycl::buffer &out_re, - sycl::buffer &out_im) { +void compute_backward_buffer_outofplace_split_f(descriptor &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_backward_buffer_outofplace_split_c(descriptor &desc, sycl::buffer &in_re, - sycl::buffer &in_im, sycl::buffer &out_re, - sycl::buffer &out_im) { +void compute_backward_buffer_outofplace_split_c( + descriptor &desc, sycl::buffer &in_re, + sycl::buffer &in_im, sycl::buffer &out_re, sycl::buffer &out_im) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_backward_buffer_outofplace_split_d(descriptor &desc, sycl::buffer &in_re, - sycl::buffer &in_im, sycl::buffer &out_re, - sycl::buffer &out_im) { +void compute_backward_buffer_outofplace_split_d(descriptor &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_backward_buffer_outofplace_split_z(descriptor &desc, sycl::buffer &in_re, - sycl::buffer &in_im, sycl::buffer &out_re, - sycl::buffer &out_im) { +void compute_backward_buffer_outofplace_split_z( + descriptor &desc, sycl::buffer &in_re, + sycl::buffer &in_im, sycl::buffer &out_re, + sycl::buffer &out_im) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_backward_usm_inplace_f(descriptor &desc, std::complex *inout, - const std::vector &dependencies) { +sycl::event compute_backward_usm_inplace_f(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_backward_usm_inplace_c(descriptor &desc, std::complex *inout, - const std::vector &dependencies) { +sycl::event compute_backward_usm_inplace_c(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_backward_usm_inplace_d(descriptor &desc, std::complex *inout, - const std::vector &dependencies) { +sycl::event compute_backward_usm_inplace_d(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_backward_usm_inplace_z(descriptor &desc, std::complex *inout, - const std::vector &dependencies) { +sycl::event compute_backward_usm_inplace_z(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_backward_usm_inplace_split_f(descriptor &desc, float *inout_re, float *inout_im, - const std::vector &dependencies) { +sycl::event compute_backward_usm_inplace_split_f(descriptor &desc, + float *inout_re, float *inout_im, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_backward_usm_inplace_split_c(descriptor &desc, float *inout_re, float *inout_im, - const std::vector &dependencies) { +sycl::event compute_backward_usm_inplace_split_c( + descriptor &desc, float *inout_re, float *inout_im, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_backward_usm_inplace_split_d(descriptor &desc, double *inout_re, double *inout_im, - const std::vector &dependencies) { +sycl::event compute_backward_usm_inplace_split_d(descriptor &desc, + double *inout_re, double *inout_im, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_backward_usm_inplace_split_z(descriptor &desc, double *inout_re, double *inout_im, - const std::vector &dependencies) { +sycl::event compute_backward_usm_inplace_split_z( + descriptor &desc, double *inout_re, double *inout_im, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_backward_usm_outofplace_f(descriptor &desc, std::complex *in, float *out, - const std::vector &dependencies) { +sycl::event compute_backward_usm_outofplace_f(descriptor &desc, + std::complex *in, float *out, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_backward_usm_outofplace_c(descriptor &desc, std::complex *in, std::complex *out, - const std::vector &dependencies) { +sycl::event compute_backward_usm_outofplace_c(descriptor &desc, + std::complex *in, std::complex *out, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_backward_usm_outofplace_d(descriptor &desc, std::complex *in, double *out, - const std::vector &dependencies) { +sycl::event compute_backward_usm_outofplace_d(descriptor &desc, + std::complex *in, double *out, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_backward_usm_outofplace_z(descriptor &desc, std::complex *in, std::complex *out, - const std::vector &dependencies) { +sycl::event compute_backward_usm_outofplace_z(descriptor &desc, + std::complex *in, std::complex *out, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_backward_usm_outofplace_split_f(descriptor &desc, float *in_re, float *in_im, - float *out_re, float *out_im, - const std::vector &dependencies) { +sycl::event compute_backward_usm_outofplace_split_f( + descriptor &desc, float *in_re, float *in_im, float *out_re, + float *out_im, const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_backward_usm_outofplace_split_c(descriptor &desc, float *in_re, float *in_im, - float *out_re, float *out_im, - const std::vector &dependencies) { +sycl::event compute_backward_usm_outofplace_split_c( + descriptor &desc, float *in_re, float *in_im, float *out_re, + float *out_im, const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_backward_usm_outofplace_split_d(descriptor &desc, double *in_re, double *in_im, - double *out_re, double *out_im, - const std::vector &dependencies) { +sycl::event compute_backward_usm_outofplace_split_d( + descriptor &desc, double *in_re, double *in_im, double *out_re, + double *out_im, const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_backward_usm_outofplace_split_z(descriptor &desc, double *in_re, double *in_im, - double *out_re, double *out_im, - const std::vector &dependencies) { +sycl::event compute_backward_usm_outofplace_split_z( + descriptor &desc, double *in_re, double *in_im, + double *out_re, double *out_im, const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } } // namespace mklgpu +} // namespace dft } // namespace mkl -} // namspace dft } // namespace oneapi diff --git a/src/dft/backends/mklgpu/descriptor.cpp b/src/dft/backends/mklgpu/descriptor.cpp index c148c7ae5..656cc59eb 100644 --- a/src/dft/backends/mklgpu/descriptor.cpp +++ b/src/dft/backends/mklgpu/descriptor.cpp @@ -46,6 +46,6 @@ void commit_z(descriptor &desc, sycl::queue } } // namespace mklgpu +} // namespace dft } // namespace mkl -} // namspace dft } // namespace oneapi diff --git a/src/dft/backends/mklgpu/forward.cpp b/src/dft/backends/mklgpu/forward.cpp index bb5efe2fe..a2deb8879 100644 --- a/src/dft/backends/mklgpu/forward.cpp +++ b/src/dft/backends/mklgpu/forward.cpp @@ -32,147 +32,179 @@ namespace mkl { namespace dft { namespace mklgpu { -void compute_forward_buffer_inplace_f(descriptor &desc, sycl::buffer, 1> &inout) { +void compute_forward_buffer_inplace_f(descriptor &desc, + sycl::buffer, 1> &inout) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_forward_buffer_inplace_c(descriptor &desc, sycl::buffer, 1> &inout) { +void compute_forward_buffer_inplace_c(descriptor &desc, + sycl::buffer, 1> &inout) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_forward_buffer_inplace_d(descriptor &desc, sycl::buffer, 1> &inout) { +void compute_forward_buffer_inplace_d(descriptor &desc, + sycl::buffer, 1> &inout) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_forward_buffer_inplace_z(descriptor &desc, sycl::buffer, 1> &inout) { +void compute_forward_buffer_inplace_z(descriptor &desc, + sycl::buffer, 1> &inout) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_forward_buffer_inplace_split_f(descriptor &desc, sycl::buffer &inout_re, - sycl::buffer &inout_im) { +void compute_forward_buffer_inplace_split_f(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_forward_buffer_inplace_split_c(descriptor &desc, sycl::buffer &inout_re, - sycl::buffer &inout_im) { +void compute_forward_buffer_inplace_split_c(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_forward_buffer_inplace_split_d(descriptor &desc, sycl::buffer &inout_re, - sycl::buffer &inout_im) { +void compute_forward_buffer_inplace_split_d(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_forward_buffer_inplace_split_z(descriptor &desc, sycl::buffer &inout_re, - sycl::buffer &inout_im) { +void compute_forward_buffer_inplace_split_z(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_forward_buffer_outofplace_f(descriptor &desc, sycl::buffer &in, - sycl::buffer, 1> &out) { +void compute_forward_buffer_outofplace_f(descriptor &desc, + sycl::buffer &in, + sycl::buffer, 1> &out) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_forward_buffer_outofplace_c(descriptor &desc, sycl::buffer, 1> &in, - sycl::buffer, 1> &out) { +void compute_forward_buffer_outofplace_c(descriptor &desc, + sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_forward_buffer_outofplace_d(descriptor &desc, sycl::buffer &in, - sycl::buffer, 1> &out) { +void compute_forward_buffer_outofplace_d(descriptor &desc, + sycl::buffer &in, + sycl::buffer, 1> &out) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_forward_buffer_outofplace_z(descriptor &desc, sycl::buffer, 1> &in, - sycl::buffer, 1> &out) { +void compute_forward_buffer_outofplace_z(descriptor &desc, + sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_forward_buffer_outofplace_split_f(descriptor &desc, sycl::buffer &in_re, - sycl::buffer &in_im, sycl::buffer &out_re, - sycl::buffer &out_im) { +void compute_forward_buffer_outofplace_split_f(descriptor &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_forward_buffer_outofplace_split_c(descriptor &desc, sycl::buffer &in_re, - sycl::buffer &in_im, sycl::buffer &out_re, - sycl::buffer &out_im) { +void compute_forward_buffer_outofplace_split_c(descriptor &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_forward_buffer_outofplace_split_d(descriptor &desc, sycl::buffer &in_re, - sycl::buffer &in_im, sycl::buffer &out_re, - sycl::buffer &out_im) { +void compute_forward_buffer_outofplace_split_d(descriptor &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { throw std::runtime_error("Not implemented for mklgpu"); } -void compute_forward_buffer_outofplace_split_z(descriptor &desc, sycl::buffer &in_re, - sycl::buffer &in_im, sycl::buffer &out_re, - sycl::buffer &out_im) { +void compute_forward_buffer_outofplace_split_z(descriptor &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_forward_usm_inplace_f(descriptor &desc, std::complex *inout, - const std::vector &dependencies) { +sycl::event compute_forward_usm_inplace_f(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_forward_usm_inplace_c(descriptor &desc, std::complex *inout, - const std::vector &dependencies) { +sycl::event compute_forward_usm_inplace_c(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_forward_usm_inplace_d(descriptor &desc, std::complex *inout, - const std::vector &dependencies) { +sycl::event compute_forward_usm_inplace_d(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_forward_usm_inplace_z(descriptor &desc, std::complex *inout, - const std::vector &dependencies) { +sycl::event compute_forward_usm_inplace_z(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_forward_usm_inplace_split_f(descriptor &desc, float *inout_re, float *inout_im, - const std::vector &dependencies) { +sycl::event compute_forward_usm_inplace_split_f(descriptor &desc, + float *inout_re, float *inout_im, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_forward_usm_inplace_split_c(descriptor &desc, float *inout_re, float *inout_im, - const std::vector &dependencies) { +sycl::event compute_forward_usm_inplace_split_c( + descriptor &desc, float *inout_re, float *inout_im, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_forward_usm_inplace_split_d(descriptor &desc, double *inout_re, double *inout_im, - const std::vector &dependencies) { +sycl::event compute_forward_usm_inplace_split_d(descriptor &desc, + double *inout_re, double *inout_im, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_forward_usm_inplace_split_z(descriptor &desc, double *inout_re, double *inout_im, - const std::vector &dependencies) { +sycl::event compute_forward_usm_inplace_split_z( + descriptor &desc, double *inout_re, double *inout_im, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_forward_usm_outofplace_f(descriptor &desc, float *in, std::complex *out, - const std::vector &dependencies) { +sycl::event compute_forward_usm_outofplace_f(descriptor &desc, + float *in, std::complex *out, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_forward_usm_outofplace_c(descriptor &desc, std::complex *in, std::complex *out, - const std::vector &dependencies) { +sycl::event compute_forward_usm_outofplace_c(descriptor &desc, + std::complex *in, std::complex *out, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_forward_usm_outofplace_d(descriptor &desc, double *in, std::complex *out, - const std::vector &dependencies) { +sycl::event compute_forward_usm_outofplace_d(descriptor &desc, + double *in, std::complex *out, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_forward_usm_outofplace_z(descriptor &desc, std::complex *in, std::complex *out, - const std::vector &dependencies) { +sycl::event compute_forward_usm_outofplace_z(descriptor &desc, + std::complex *in, std::complex *out, + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_forward_usm_outofplace_split_f(descriptor &desc, float *in_re, float *in_im, - float *out_re, float *out_im, - const std::vector &dependencies) { +sycl::event compute_forward_usm_outofplace_split_f( + descriptor &desc, float *in_re, float *in_im, float *out_re, + float *out_im, const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_forward_usm_outofplace_split_c(descriptor &desc, float *in_re, float *in_im, - float *out_re, float *out_im, - const std::vector &dependencies) { +sycl::event compute_forward_usm_outofplace_split_c( + descriptor &desc, float *in_re, float *in_im, float *out_re, + float *out_im, const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_forward_usm_outofplace_split_d(descriptor &desc, double *in_re, double *in_im, - double *out_re, double *out_im, - const std::vector &dependencies) { +sycl::event compute_forward_usm_outofplace_split_d( + descriptor &desc, double *in_re, double *in_im, double *out_re, + double *out_im, const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } -sycl::event compute_forward_usm_outofplace_split_z(descriptor &desc, double *in_re, double *in_im, - double *out_re, double *out_im, - const std::vector &dependencies) { +sycl::event compute_forward_usm_outofplace_split_z( + descriptor &desc, double *in_re, double *in_im, + double *out_re, double *out_im, const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } } // namespace mklgpu +} // namespace dft } // namespace mkl -} // namspace dft } // namespace oneapi diff --git a/src/dft/backends/mklgpu/mkl_dft_gpu_wrappers.cpp b/src/dft/backends/mklgpu/mkl_dft_gpu_wrappers.cpp index f0f837878..a26c8d4c5 100644 --- a/src/dft/backends/mklgpu/mkl_dft_gpu_wrappers.cpp +++ b/src/dft/backends/mklgpu/mkl_dft_gpu_wrappers.cpp @@ -24,29 +24,27 @@ extern "C" dft_function_table_t mkl_dft_table = { WRAPPER_VERSION, -#define ONEAPI_MKL_DFT_BACKEND_SIGNATURES(EXT) \ - oneapi::mkl::dft::mklgpu::commit_ ## EXT, \ - oneapi::mkl::dft::mklgpu::compute_forward_buffer_inplace_ ## EXT, \ - oneapi::mkl::dft::mklgpu::compute_forward_buffer_inplace_split_ ## EXT, \ - oneapi::mkl::dft::mklgpu::compute_forward_buffer_outofplace_ ## EXT, \ - oneapi::mkl::dft::mklgpu::compute_forward_buffer_outofplace_split_ ## EXT, \ - oneapi::mkl::dft::mklgpu::compute_forward_usm_inplace_ ## EXT, \ - oneapi::mkl::dft::mklgpu::compute_forward_usm_inplace_split_ ## EXT, \ - oneapi::mkl::dft::mklgpu::compute_forward_usm_outofplace_ ## EXT, \ - oneapi::mkl::dft::mklgpu::compute_forward_usm_outofplace_split_ ## EXT, \ - oneapi::mkl::dft::mklgpu::compute_backward_buffer_inplace_ ## EXT, \ - oneapi::mkl::dft::mklgpu::compute_backward_buffer_inplace_split_ ## EXT, \ - oneapi::mkl::dft::mklgpu::compute_backward_buffer_outofplace_ ## EXT, \ - oneapi::mkl::dft::mklgpu::compute_backward_buffer_outofplace_split_ ## EXT, \ - oneapi::mkl::dft::mklgpu::compute_backward_usm_inplace_ ## EXT, \ - oneapi::mkl::dft::mklgpu::compute_backward_usm_inplace_split_ ## EXT, \ - oneapi::mkl::dft::mklgpu::compute_backward_usm_outofplace_ ## EXT, \ - oneapi::mkl::dft::mklgpu::compute_backward_usm_outofplace_split_ ## EXT +#define ONEAPI_MKL_DFT_BACKEND_SIGNATURES(EXT) \ + oneapi::mkl::dft::mklgpu::commit_##EXT, \ + oneapi::mkl::dft::mklgpu::compute_forward_buffer_inplace_##EXT, \ + oneapi::mkl::dft::mklgpu::compute_forward_buffer_inplace_split_##EXT, \ + oneapi::mkl::dft::mklgpu::compute_forward_buffer_outofplace_##EXT, \ + oneapi::mkl::dft::mklgpu::compute_forward_buffer_outofplace_split_##EXT, \ + oneapi::mkl::dft::mklgpu::compute_forward_usm_inplace_##EXT, \ + oneapi::mkl::dft::mklgpu::compute_forward_usm_inplace_split_##EXT, \ + oneapi::mkl::dft::mklgpu::compute_forward_usm_outofplace_##EXT, \ + oneapi::mkl::dft::mklgpu::compute_forward_usm_outofplace_split_##EXT, \ + oneapi::mkl::dft::mklgpu::compute_backward_buffer_inplace_##EXT, \ + oneapi::mkl::dft::mklgpu::compute_backward_buffer_inplace_split_##EXT, \ + oneapi::mkl::dft::mklgpu::compute_backward_buffer_outofplace_##EXT, \ + oneapi::mkl::dft::mklgpu::compute_backward_buffer_outofplace_split_##EXT, \ + oneapi::mkl::dft::mklgpu::compute_backward_usm_inplace_##EXT, \ + oneapi::mkl::dft::mklgpu::compute_backward_usm_inplace_split_##EXT, \ + oneapi::mkl::dft::mklgpu::compute_backward_usm_outofplace_##EXT, \ + oneapi::mkl::dft::mklgpu::compute_backward_usm_outofplace_split_##EXT - ONEAPI_MKL_DFT_BACKEND_SIGNATURES(f), - ONEAPI_MKL_DFT_BACKEND_SIGNATURES(c), - ONEAPI_MKL_DFT_BACKEND_SIGNATURES(d), - ONEAPI_MKL_DFT_BACKEND_SIGNATURES(z) + ONEAPI_MKL_DFT_BACKEND_SIGNATURES(f), ONEAPI_MKL_DFT_BACKEND_SIGNATURES(c), + ONEAPI_MKL_DFT_BACKEND_SIGNATURES(d), ONEAPI_MKL_DFT_BACKEND_SIGNATURES(z) #undef ONEAPI_MKL_DFT_BACKEND_SIGNATURES }; diff --git a/src/dft/dft_loader.cpp b/src/dft/dft_loader.cpp index 727f60a34..cb9b8aed7 100644 --- a/src/dft/dft_loader.cpp +++ b/src/dft/dft_loader.cpp @@ -29,146 +29,181 @@ namespace mkl { namespace dft { namespace detail { -static oneapi::mkl::detail::table_initializer function_tables; +static oneapi::mkl::detail::table_initializer + function_tables; } // namespace detail -#define ONEAPI_MKL_DFT_SIGNATURES(EXT, PRECISION, DOMAIN, T_REAL, T_FORWARD, T_BACKWARD) \ - \ -template<> \ -void descriptor::commit(sycl::queue &queue){ \ - this->queue_ = queue; \ - detail::function_tables[get_device_id(queue)].commit_ ## EXT(*this, queue); \ -} \ - \ - /*Buffer version*/ \ - \ - /*In-place transform*/ \ -template<> \ -void compute_forward, T_BACKWARD>(descriptor &desc, sycl::buffer &inout){ \ - detail::function_tables[get_device_id(desc.get_queue())].compute_forward_buffer_inplace_ ## EXT(desc, inout); \ -} \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ -template<> \ -void compute_forward, T_REAL>(descriptor &desc, sycl::buffer &inout_re, \ - sycl::buffer &inout_im){ \ - detail::function_tables[get_device_id(desc.get_queue())].compute_forward_buffer_inplace_split_ ## EXT(desc, inout_re, inout_im); \ -} \ - \ - /*Out-of-place transform*/ \ -template<> \ -void compute_forward, T_FORWARD, T_BACKWARD>(descriptor &desc, sycl::buffer &in, \ - sycl::buffer &out){ \ - detail::function_tables[get_device_id(desc.get_queue())].compute_forward_buffer_outofplace_ ## EXT(desc, in, out); \ -} \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ -template<> \ -void compute_forward, T_REAL, T_REAL>(descriptor &desc, sycl::buffer &in_re, \ - sycl::buffer &in_im, sycl::buffer &out_re, \ - sycl::buffer &out_im){ \ - detail::function_tables[get_device_id(desc.get_queue())].compute_forward_buffer_outofplace_split_ ## EXT(desc, in_re, in_im, out_re, out_im); \ -} \ - \ - /*USM version*/ \ - \ - /*In-place transform*/ \ -template<> \ -sycl::event compute_forward, T_BACKWARD>(descriptor &desc, T_BACKWARD *inout, \ - const std::vector &dependencies){ \ - return detail::function_tables[get_device_id(desc.get_queue())].compute_forward_usm_inplace_ ## EXT(desc, inout, dependencies); \ -} \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ -template<> \ -sycl::event compute_forward, T_REAL>(descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ - const std::vector &dependencies){ \ - return detail::function_tables[get_device_id(desc.get_queue())].compute_forward_usm_inplace_split_ ## EXT(desc, inout_re, inout_im, dependencies); \ -} \ - \ - /*Out-of-place transform*/ \ -template<> \ -sycl::event compute_forward, T_FORWARD, T_BACKWARD>(descriptor &desc, T_FORWARD *in, T_BACKWARD *out, \ - const std::vector &dependencies){ \ - return detail::function_tables[get_device_id(desc.get_queue())].compute_forward_usm_outofplace_ ## EXT(desc, in, out, dependencies); \ -} \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ -template<> \ -sycl::event compute_forward, T_REAL, T_REAL>(descriptor &desc, T_REAL *in_re, T_REAL *in_im, \ - T_REAL *out_re, T_REAL *out_im, \ - const std::vector &dependencies){ \ - return detail::function_tables[get_device_id(desc.get_queue())].compute_forward_usm_outofplace_split_ ## EXT(desc, in_re, in_im, out_re, out_im, dependencies); \ -} \ - \ - /*Buffer version*/ \ - \ - /*In-place transform*/ \ -template<> \ -void compute_backward, T_BACKWARD>(descriptor &desc, sycl::buffer &inout){ \ - detail::function_tables[get_device_id(desc.get_queue())].compute_backward_buffer_inplace_ ## EXT(desc, inout); \ -} \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ -template<> \ -void compute_backward, T_REAL>(descriptor &desc, sycl::buffer &inout_re, \ - sycl::buffer &inout_im){ \ - detail::function_tables[get_device_id(desc.get_queue())].compute_backward_buffer_inplace_split_ ## EXT(desc, inout_re, inout_im); \ -} \ - \ - /*Out-of-place transform*/ \ -template<> \ -void compute_backward, T_BACKWARD, T_FORWARD>(descriptor &desc, sycl::buffer &in, \ - sycl::buffer &out){ \ - detail::function_tables[get_device_id(desc.get_queue())].compute_backward_buffer_outofplace_ ## EXT(desc, in, out); \ -} \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ -template<> \ -void compute_backward, T_REAL, T_REAL>(descriptor &desc, sycl::buffer &in_re, \ - sycl::buffer &in_im, sycl::buffer &out_re, \ - sycl::buffer &out_im){ \ - detail::function_tables[get_device_id(desc.get_queue())].compute_backward_buffer_outofplace_split_ ## EXT(desc, in_re, in_im, out_re, out_im); \ -} \ - \ - /*USM version*/ \ - \ - /*In-place transform*/ \ -template<> \ -sycl::event compute_backward, T_BACKWARD>(descriptor &desc, T_BACKWARD *inout, \ - const std::vector &dependencies){ \ - return detail::function_tables[get_device_id(desc.get_queue())].compute_backward_usm_inplace_ ## EXT(desc, inout, dependencies); \ -} \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ -template<> \ -sycl::event compute_backward, T_REAL>(descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ - const std::vector &dependencies){ \ - return detail::function_tables[get_device_id(desc.get_queue())].compute_backward_usm_inplace_split_ ## EXT(desc, inout_re, inout_im, dependencies); \ -} \ - \ - /*Out-of-place transform*/ \ -template<> \ -sycl::event compute_backward, T_BACKWARD, T_FORWARD>(descriptor &desc, T_BACKWARD *in, T_FORWARD *out, \ - const std::vector &dependencies){ \ - return detail::function_tables[get_device_id(desc.get_queue())].compute_backward_usm_outofplace_ ## EXT(desc, in, out, dependencies); \ -} \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ -template<> \ -sycl::event compute_backward, T_REAL, T_REAL>(descriptor &desc, T_REAL *in_re, T_REAL *in_im, \ - T_REAL *out_re, T_REAL *out_im, \ - const std::vector &dependencies){ \ - return detail::function_tables[get_device_id(desc.get_queue())].compute_backward_usm_outofplace_split_ ## EXT(desc, in_re, in_im, out_re, out_im, dependencies); \ -} \ +#define ONEAPI_MKL_DFT_SIGNATURES(EXT, PRECISION, DOMAIN, T_REAL, T_FORWARD, T_BACKWARD) \ + \ + template <> \ + void descriptor::commit(sycl::queue &queue) { \ + this->queue_ = queue; \ + detail::function_tables[get_device_id(queue)].commit_##EXT(*this, queue); \ + } \ + \ + /*Buffer version*/ \ + \ + /*In-place transform*/ \ + template <> \ + void compute_forward, T_BACKWARD>( \ + descriptor & desc, sycl::buffer & inout) { \ + detail::function_tables[get_device_id(desc.get_queue())] \ + .compute_forward_buffer_inplace_##EXT(desc, inout); \ + } \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + template <> \ + void compute_forward, T_REAL>( \ + descriptor & desc, sycl::buffer & inout_re, \ + sycl::buffer & inout_im) { \ + detail::function_tables[get_device_id(desc.get_queue())] \ + .compute_forward_buffer_inplace_split_##EXT(desc, inout_re, inout_im); \ + } \ + \ + /*Out-of-place transform*/ \ + template <> \ + void compute_forward, T_FORWARD, T_BACKWARD>( \ + descriptor & desc, sycl::buffer & in, \ + sycl::buffer & out) { \ + detail::function_tables[get_device_id(desc.get_queue())] \ + .compute_forward_buffer_outofplace_##EXT(desc, in, out); \ + } \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + template <> \ + void compute_forward, T_REAL, T_REAL>( \ + descriptor & desc, sycl::buffer & in_re, \ + sycl::buffer & in_im, sycl::buffer & out_re, \ + sycl::buffer & out_im) { \ + detail::function_tables[get_device_id(desc.get_queue())] \ + .compute_forward_buffer_outofplace_split_##EXT(desc, in_re, in_im, out_re, out_im); \ + } \ + \ + /*USM version*/ \ + \ + /*In-place transform*/ \ + template <> \ + sycl::event compute_forward, T_BACKWARD>( \ + descriptor & desc, T_BACKWARD * inout, \ + const std::vector &dependencies) { \ + return detail::function_tables[get_device_id(desc.get_queue())] \ + .compute_forward_usm_inplace_##EXT(desc, inout, dependencies); \ + } \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + template <> \ + sycl::event compute_forward, T_REAL>( \ + descriptor & desc, T_REAL * inout_re, T_REAL * inout_im, \ + const std::vector &dependencies) { \ + return detail::function_tables[get_device_id(desc.get_queue())] \ + .compute_forward_usm_inplace_split_##EXT(desc, inout_re, inout_im, dependencies); \ + } \ + \ + /*Out-of-place transform*/ \ + template <> \ + sycl::event compute_forward, T_FORWARD, T_BACKWARD>( \ + descriptor & desc, T_FORWARD * in, T_BACKWARD * out, \ + const std::vector &dependencies) { \ + return detail::function_tables[get_device_id(desc.get_queue())] \ + .compute_forward_usm_outofplace_##EXT(desc, in, out, dependencies); \ + } \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + template <> \ + sycl::event compute_forward, T_REAL, T_REAL>( \ + descriptor & desc, T_REAL * in_re, T_REAL * in_im, T_REAL * out_re, \ + T_REAL * out_im, const std::vector &dependencies) { \ + return detail::function_tables[get_device_id(desc.get_queue())] \ + .compute_forward_usm_outofplace_split_##EXT(desc, in_re, in_im, out_re, out_im, \ + dependencies); \ + } \ + \ + /*Buffer version*/ \ + \ + /*In-place transform*/ \ + template <> \ + void compute_backward, T_BACKWARD>( \ + descriptor & desc, sycl::buffer & inout) { \ + detail::function_tables[get_device_id(desc.get_queue())] \ + .compute_backward_buffer_inplace_##EXT(desc, inout); \ + } \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + template <> \ + void compute_backward, T_REAL>( \ + descriptor & desc, sycl::buffer & inout_re, \ + sycl::buffer & inout_im) { \ + detail::function_tables[get_device_id(desc.get_queue())] \ + .compute_backward_buffer_inplace_split_##EXT(desc, inout_re, inout_im); \ + } \ + \ + /*Out-of-place transform*/ \ + template <> \ + void compute_backward, T_BACKWARD, T_FORWARD>( \ + descriptor & desc, sycl::buffer & in, \ + sycl::buffer & out) { \ + detail::function_tables[get_device_id(desc.get_queue())] \ + .compute_backward_buffer_outofplace_##EXT(desc, in, out); \ + } \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + template <> \ + void compute_backward, T_REAL, T_REAL>( \ + descriptor & desc, sycl::buffer & in_re, \ + sycl::buffer & in_im, sycl::buffer & out_re, \ + sycl::buffer & out_im) { \ + detail::function_tables[get_device_id(desc.get_queue())] \ + .compute_backward_buffer_outofplace_split_##EXT(desc, in_re, in_im, out_re, out_im); \ + } \ + \ + /*USM version*/ \ + \ + /*In-place transform*/ \ + template <> \ + sycl::event compute_backward, T_BACKWARD>( \ + descriptor & desc, T_BACKWARD * inout, \ + const std::vector &dependencies) { \ + return detail::function_tables[get_device_id(desc.get_queue())] \ + .compute_backward_usm_inplace_##EXT(desc, inout, dependencies); \ + } \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + template <> \ + sycl::event compute_backward, T_REAL>( \ + descriptor & desc, T_REAL * inout_re, T_REAL * inout_im, \ + const std::vector &dependencies) { \ + return detail::function_tables[get_device_id(desc.get_queue())] \ + .compute_backward_usm_inplace_split_##EXT(desc, inout_re, inout_im, dependencies); \ + } \ + \ + /*Out-of-place transform*/ \ + template <> \ + sycl::event compute_backward, T_BACKWARD, T_FORWARD>( \ + descriptor & desc, T_BACKWARD * in, T_FORWARD * out, \ + const std::vector &dependencies) { \ + return detail::function_tables[get_device_id(desc.get_queue())] \ + .compute_backward_usm_outofplace_##EXT(desc, in, out, dependencies); \ + } \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + template <> \ + sycl::event compute_backward, T_REAL, T_REAL>( \ + descriptor & desc, T_REAL * in_re, T_REAL * in_im, T_REAL * out_re, \ + T_REAL * out_im, const std::vector &dependencies) { \ + return detail::function_tables[get_device_id(desc.get_queue())] \ + .compute_backward_usm_outofplace_split_##EXT(desc, in_re, in_im, out_re, out_im, \ + dependencies); \ + } ONEAPI_MKL_DFT_SIGNATURES(f, precision::SINGLE, domain::REAL, float, float, std::complex) -ONEAPI_MKL_DFT_SIGNATURES(c, precision::SINGLE, domain::COMPLEX, float, std::complex, std::complex) +ONEAPI_MKL_DFT_SIGNATURES(c, precision::SINGLE, domain::COMPLEX, float, std::complex, + std::complex) ONEAPI_MKL_DFT_SIGNATURES(d, precision::DOUBLE, domain::REAL, double, double, std::complex) -ONEAPI_MKL_DFT_SIGNATURES(z, precision::DOUBLE, domain::COMPLEX, double, std::complex, std::complex) +ONEAPI_MKL_DFT_SIGNATURES(z, precision::DOUBLE, domain::COMPLEX, double, std::complex, + std::complex) #undef ONEAPI_MKL_DFT_SIGNATURES -} // namespace rng +} // namespace dft } // namespace mkl } // namespace oneapi diff --git a/src/dft/function_table.hpp b/src/dft/function_table.hpp index 978988f5c..91af8fe04 100644 --- a/src/dft/function_table.hpp +++ b/src/dft/function_table.hpp @@ -35,47 +35,72 @@ typedef struct { int version; -#define ONEAPI_MKL_DFT_BACKEND_SIGNATURES(EXT, PRECISION, DOMAIN, T_REAL, T_FORWARD, T_BACKWARD) \ -void (*commit_ ## EXT)(oneapi::mkl::dft::descriptor &desc, sycl::queue &queue); \ -void (*compute_forward_buffer_inplace_ ## EXT)(oneapi::mkl::dft::descriptor &desc, sycl::buffer &inout); \ -void (*compute_forward_buffer_inplace_split_ ## EXT)(oneapi::mkl::dft::descriptor &desc, sycl::buffer &inout_re, \ - sycl::buffer &inout_im); \ -void (*compute_forward_buffer_outofplace_ ## EXT)(oneapi::mkl::dft::descriptor &desc, sycl::buffer &in, \ - sycl::buffer &out); \ -void (*compute_forward_buffer_outofplace_split_ ## EXT)(oneapi::mkl::dft::descriptor &desc, sycl::buffer &in_re, \ - sycl::buffer &in_im, sycl::buffer &out_re, \ - sycl::buffer &out_im); \ -sycl::event (*compute_forward_usm_inplace_ ## EXT)(oneapi::mkl::dft::descriptor &desc, T_BACKWARD *inout, \ - const std::vector &dependencies); \ -sycl::event (*compute_forward_usm_inplace_split_ ## EXT)(oneapi::mkl::dft::descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ - const std::vector &dependencies); \ -sycl::event (*compute_forward_usm_outofplace_ ## EXT)(oneapi::mkl::dft::descriptor &desc, T_FORWARD *in, T_BACKWARD *out, \ - const std::vector &dependencies); \ -sycl::event (*compute_forward_usm_outofplace_split_ ## EXT)(oneapi::mkl::dft::descriptor &desc, T_REAL *in_re, T_REAL *in_im, \ - T_REAL *out_re, T_REAL *out_im, \ - const std::vector &dependencies); \ -void (*compute_backward_buffer_inplace_ ## EXT)(oneapi::mkl::dft::descriptor &desc, sycl::buffer &inout); \ -void (*compute_backward_buffer_inplace_split_ ## EXT)(oneapi::mkl::dft::descriptor &desc, sycl::buffer &inout_re, \ - sycl::buffer &inout_im); \ -void (*compute_backward_buffer_outofplace_ ## EXT)(oneapi::mkl::dft::descriptor &desc, sycl::buffer &in, \ - sycl::buffer &out); \ -void (*compute_backward_buffer_outofplace_split_ ## EXT)(oneapi::mkl::dft::descriptor &desc, sycl::buffer &in_re, \ - sycl::buffer &in_im, sycl::buffer &out_re, \ - sycl::buffer &out_im); \ -sycl::event (*compute_backward_usm_inplace_ ## EXT)(oneapi::mkl::dft::descriptor &desc, T_BACKWARD *inout, \ - const std::vector &dependencies); \ -sycl::event (*compute_backward_usm_inplace_split_ ## EXT)(oneapi::mkl::dft::descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ - const std::vector &dependencies); \ -sycl::event (*compute_backward_usm_outofplace_ ## EXT)(oneapi::mkl::dft::descriptor &desc, T_BACKWARD *in, T_FORWARD *out, \ - const std::vector &dependencies); \ -sycl::event (*compute_backward_usm_outofplace_split_ ## EXT)(oneapi::mkl::dft::descriptor &desc, T_REAL *in_re, T_REAL *in_im, \ - T_REAL *out_re, T_REAL *out_im, \ - const std::vector &dependencies); \ +#define ONEAPI_MKL_DFT_BACKEND_SIGNATURES(EXT, PRECISION, DOMAIN, T_REAL, T_FORWARD, T_BACKWARD) \ + void (*commit_##EXT)(oneapi::mkl::dft::descriptor & desc, \ + sycl::queue & queue); \ + void (*compute_forward_buffer_inplace_##EXT)( \ + oneapi::mkl::dft::descriptor & desc, \ + sycl::buffer & inout); \ + void (*compute_forward_buffer_inplace_split_##EXT)( \ + oneapi::mkl::dft::descriptor & desc, \ + sycl::buffer & inout_re, sycl::buffer & inout_im); \ + void (*compute_forward_buffer_outofplace_##EXT)( \ + oneapi::mkl::dft::descriptor & desc, sycl::buffer & in, \ + sycl::buffer & out); \ + void (*compute_forward_buffer_outofplace_split_##EXT)( \ + oneapi::mkl::dft::descriptor & desc, sycl::buffer & in_re, \ + sycl::buffer & in_im, sycl::buffer & out_re, \ + sycl::buffer & out_im); \ + sycl::event (*compute_forward_usm_inplace_##EXT)( \ + oneapi::mkl::dft::descriptor & desc, T_BACKWARD * inout, \ + const std::vector &dependencies); \ + sycl::event (*compute_forward_usm_inplace_split_##EXT)( \ + oneapi::mkl::dft::descriptor & desc, T_REAL * inout_re, \ + T_REAL * inout_im, const std::vector &dependencies); \ + sycl::event (*compute_forward_usm_outofplace_##EXT)( \ + oneapi::mkl::dft::descriptor & desc, T_FORWARD * in, T_BACKWARD * out, \ + const std::vector &dependencies); \ + sycl::event (*compute_forward_usm_outofplace_split_##EXT)( \ + oneapi::mkl::dft::descriptor & desc, T_REAL * in_re, T_REAL * in_im, \ + T_REAL * out_re, T_REAL * out_im, const std::vector &dependencies); \ + void (*compute_backward_buffer_inplace_##EXT)( \ + oneapi::mkl::dft::descriptor & desc, \ + sycl::buffer & inout); \ + void (*compute_backward_buffer_inplace_split_##EXT)( \ + oneapi::mkl::dft::descriptor & desc, \ + sycl::buffer & inout_re, sycl::buffer & inout_im); \ + void (*compute_backward_buffer_outofplace_##EXT)( \ + oneapi::mkl::dft::descriptor & desc, sycl::buffer & in, \ + sycl::buffer & out); \ + void (*compute_backward_buffer_outofplace_split_##EXT)( \ + oneapi::mkl::dft::descriptor & desc, sycl::buffer & in_re, \ + sycl::buffer & in_im, sycl::buffer & out_re, \ + sycl::buffer & out_im); \ + sycl::event (*compute_backward_usm_inplace_##EXT)( \ + oneapi::mkl::dft::descriptor & desc, T_BACKWARD * inout, \ + const std::vector &dependencies); \ + sycl::event (*compute_backward_usm_inplace_split_##EXT)( \ + oneapi::mkl::dft::descriptor & desc, T_REAL * inout_re, \ + T_REAL * inout_im, const std::vector &dependencies); \ + sycl::event (*compute_backward_usm_outofplace_##EXT)( \ + oneapi::mkl::dft::descriptor & desc, T_BACKWARD * in, T_FORWARD * out, \ + const std::vector &dependencies); \ + sycl::event (*compute_backward_usm_outofplace_split_##EXT)( \ + oneapi::mkl::dft::descriptor & desc, T_REAL * in_re, T_REAL * in_im, \ + T_REAL * out_re, T_REAL * out_im, const std::vector &dependencies); -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(f, oneapi::mkl::dft::precision::SINGLE, oneapi::mkl::dft::domain::REAL, float, float, std::complex) -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(c, oneapi::mkl::dft::precision::SINGLE, oneapi::mkl::dft::domain::COMPLEX, float, std::complex, std::complex) -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(d, oneapi::mkl::dft::precision::DOUBLE, oneapi::mkl::dft::domain::REAL, double, double, std::complex) -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(z, oneapi::mkl::dft::precision::DOUBLE, oneapi::mkl::dft::domain::COMPLEX, double, std::complex, std::complex) + ONEAPI_MKL_DFT_BACKEND_SIGNATURES(f, oneapi::mkl::dft::precision::SINGLE, + oneapi::mkl::dft::domain::REAL, float, float, + std::complex) + ONEAPI_MKL_DFT_BACKEND_SIGNATURES(c, oneapi::mkl::dft::precision::SINGLE, + oneapi::mkl::dft::domain::COMPLEX, float, std::complex, + std::complex) + ONEAPI_MKL_DFT_BACKEND_SIGNATURES(d, oneapi::mkl::dft::precision::DOUBLE, + oneapi::mkl::dft::domain::REAL, double, double, + std::complex) + ONEAPI_MKL_DFT_BACKEND_SIGNATURES(z, oneapi::mkl::dft::precision::DOUBLE, + oneapi::mkl::dft::domain::COMPLEX, double, + std::complex, std::complex) #undef ONEAPI_MKL_DFT_BACKEND_SIGNATURES } dft_function_table_t; From e4810411bc45ba580ddebc0cabba1c5c2759db85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Wed, 28 Sep 2022 08:20:26 +0100 Subject: [PATCH 03/21] addressed comments from internal review --- include/oneapi/mkl/dft.hpp | 2 +- include/oneapi/mkl/dft/backward.hpp | 10 +++--- include/oneapi/mkl/dft/descriptor.hpp | 2 +- .../dft/detail/mklgpu/onemkl_dft_mklgpu.hpp | 8 ++--- include/oneapi/mkl/dft/forward.hpp | 2 +- scripts/func_parser.py | 9 ++---- src/dft/backends/CMakeLists.txt | 2 +- src/dft/backends/mklgpu/backward.cpp | 32 +++++++++---------- src/dft/dft_loader.cpp | 8 ++--- src/dft/function_table.hpp | 8 ++--- tests/unit_tests/dft/CMakeLists.txt | 2 +- 11 files changed, 41 insertions(+), 44 deletions(-) diff --git a/include/oneapi/mkl/dft.hpp b/include/oneapi/mkl/dft.hpp index 6856cb198..9fd7b7ef6 100644 --- a/include/oneapi/mkl/dft.hpp +++ b/include/oneapi/mkl/dft.hpp @@ -24,4 +24,4 @@ #include "oneapi/mkl/dft/forward.hpp" #include "oneapi/mkl/dft/backward.hpp" -#endif // _ONEMKL_DFT_HPP_ \ No newline at end of file +#endif // _ONEMKL_DFT_HPP_ diff --git a/include/oneapi/mkl/dft/backward.hpp b/include/oneapi/mkl/dft/backward.hpp index 734aecdc8..4afe60505 100644 --- a/include/oneapi/mkl/dft/backward.hpp +++ b/include/oneapi/mkl/dft/backward.hpp @@ -56,23 +56,23 @@ void compute_backward(descriptor_type &desc, sycl::buffer &in_re, //In-place transform template sycl::event compute_backward(descriptor_type &desc, data_type *inout, - const std::vector &dependencies = {}); + const std::vector &dependencies = {}); //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format template sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, data_type *inout_im, - const std::vector &dependencies = {}); + const std::vector &dependencies = {}); //Out-of-place transform template sycl::event compute_backward(descriptor_type &desc, input_type *in, output_type *out, - const std::vector &dependencies = {}); + const std::vector &dependencies = {}); //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format template sycl::event compute_backward(descriptor_type &desc, input_type *in_re, input_type *in_im, output_type *out_re, output_type *out_im, - const std::vector &dependencies = {}); + const std::vector &dependencies = {}); } // namespace oneapi::mkl::dft -#endif // _ONEMKL_DFT_BACKWARD_HPP_ \ No newline at end of file +#endif // _ONEMKL_DFT_BACKWARD_HPP_ diff --git a/include/oneapi/mkl/dft/descriptor.hpp b/include/oneapi/mkl/dft/descriptor.hpp index 6bcac154b..5e88bbfde 100644 --- a/include/oneapi/mkl/dft/descriptor.hpp +++ b/include/oneapi/mkl/dft/descriptor.hpp @@ -55,4 +55,4 @@ class descriptor { }; } // namespace oneapi::mkl::dft -#endif // _ONEMKL_DFT_DESCRIPTOR_HPP_ \ No newline at end of file +#endif // _ONEMKL_DFT_DESCRIPTOR_HPP_ diff --git a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp index 2004da0cb..64ef639b3 100644 --- a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp +++ b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp @@ -111,22 +111,22 @@ namespace mklgpu { /*In-place transform*/ \ sycl::event compute_backward_usm_inplace_##EXT( \ descriptor &desc, T_BACKWARD *inout, \ - const std::vector &dependencies = {}); \ + const std::vector &dependencies = {}); \ \ /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ sycl::event compute_backward_usm_inplace_split_##EXT( \ descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ - const std::vector &dependencies = {}); \ + const std::vector &dependencies = {}); \ \ /*Out-of-place transform*/ \ sycl::event compute_backward_usm_outofplace_##EXT( \ descriptor &desc, T_BACKWARD *in, T_FORWARD *out, \ - const std::vector &dependencies = {}); \ + const std::vector &dependencies = {}); \ \ /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ sycl::event compute_backward_usm_outofplace_split_##EXT( \ descriptor &desc, T_REAL *in_re, T_REAL *in_im, T_REAL *out_re, \ - T_REAL *out_im, const std::vector &dependencies = {}); + T_REAL *out_im, const std::vector &dependencies = {}); ONEAPI_MKL_DFT_BACKEND_SIGNATURES(f, precision::SINGLE, domain::REAL, float, float, std::complex) diff --git a/include/oneapi/mkl/dft/forward.hpp b/include/oneapi/mkl/dft/forward.hpp index 8b37185f3..9093cda76 100644 --- a/include/oneapi/mkl/dft/forward.hpp +++ b/include/oneapi/mkl/dft/forward.hpp @@ -76,4 +76,4 @@ sycl::event compute_forward(descriptor_type &desc, input_type *in_re, input_type const std::vector &dependencies = {}); } // namespace oneapi::mkl::dft -#endif // _ONEMKL_DFT_FORWARD_HPP_ \ No newline at end of file +#endif // _ONEMKL_DFT_FORWARD_HPP_ diff --git a/scripts/func_parser.py b/scripts/func_parser.py index 7f25dd2e3..cbaa26142 100755 --- a/scripts/func_parser.py +++ b/scripts/func_parser.py @@ -149,11 +149,9 @@ def strip_line(l): """Delete all tabs""" return re.sub(' +',' ', l3) -def create_func_db(filenames): - data=[] - for filename in filenames.split(":"): - with open(filename, 'r') as f: - data.extend(f.readlines()) +def create_func_db(filename): + with open(filename, 'r') as f: + data = f.readlines() funcs_db = defaultdict(list) whole_line = "" idx = 0 @@ -172,7 +170,6 @@ def create_func_db(filenames): else: stripped = whole_line.strip() whole_line = "" - print(stripped) parsed = parse_item(stripped) func_name, func_data = parsed[0], parsed[1:] funcs_db[func_name].append(to_dict(func_data)) diff --git a/src/dft/backends/CMakeLists.txt b/src/dft/backends/CMakeLists.txt index 9cbd4f603..70dd060c6 100644 --- a/src/dft/backends/CMakeLists.txt +++ b/src/dft/backends/CMakeLists.txt @@ -19,4 +19,4 @@ if(ENABLE_MKLGPU_BACKEND) add_subdirectory(mklgpu) -endif() \ No newline at end of file +endif() diff --git a/src/dft/backends/mklgpu/backward.cpp b/src/dft/backends/mklgpu/backward.cpp index d50efe915..8422df374 100644 --- a/src/dft/backends/mklgpu/backward.cpp +++ b/src/dft/backends/mklgpu/backward.cpp @@ -119,85 +119,85 @@ void compute_backward_buffer_outofplace_split_z( sycl::event compute_backward_usm_inplace_f(descriptor &desc, std::complex *inout, - const std::vector &dependencies) { + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } sycl::event compute_backward_usm_inplace_c(descriptor &desc, std::complex *inout, - const std::vector &dependencies) { + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } sycl::event compute_backward_usm_inplace_d(descriptor &desc, std::complex *inout, - const std::vector &dependencies) { + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } sycl::event compute_backward_usm_inplace_z(descriptor &desc, std::complex *inout, - const std::vector &dependencies) { + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } sycl::event compute_backward_usm_inplace_split_f(descriptor &desc, float *inout_re, float *inout_im, - const std::vector &dependencies) { + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } sycl::event compute_backward_usm_inplace_split_c( descriptor &desc, float *inout_re, float *inout_im, - const std::vector &dependencies) { + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } sycl::event compute_backward_usm_inplace_split_d(descriptor &desc, double *inout_re, double *inout_im, - const std::vector &dependencies) { + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } sycl::event compute_backward_usm_inplace_split_z( descriptor &desc, double *inout_re, double *inout_im, - const std::vector &dependencies) { + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } sycl::event compute_backward_usm_outofplace_f(descriptor &desc, std::complex *in, float *out, - const std::vector &dependencies) { + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } sycl::event compute_backward_usm_outofplace_c(descriptor &desc, std::complex *in, std::complex *out, - const std::vector &dependencies) { + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } sycl::event compute_backward_usm_outofplace_d(descriptor &desc, std::complex *in, double *out, - const std::vector &dependencies) { + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } sycl::event compute_backward_usm_outofplace_z(descriptor &desc, std::complex *in, std::complex *out, - const std::vector &dependencies) { + const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } sycl::event compute_backward_usm_outofplace_split_f( descriptor &desc, float *in_re, float *in_im, float *out_re, - float *out_im, const std::vector &dependencies) { + float *out_im, const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } sycl::event compute_backward_usm_outofplace_split_c( descriptor &desc, float *in_re, float *in_im, float *out_re, - float *out_im, const std::vector &dependencies) { + float *out_im, const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } sycl::event compute_backward_usm_outofplace_split_d( descriptor &desc, double *in_re, double *in_im, double *out_re, - double *out_im, const std::vector &dependencies) { + double *out_im, const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } sycl::event compute_backward_usm_outofplace_split_z( descriptor &desc, double *in_re, double *in_im, - double *out_re, double *out_im, const std::vector &dependencies) { + double *out_re, double *out_im, const std::vector &dependencies) { throw std::runtime_error("Not implemented for mklgpu"); } diff --git a/src/dft/dft_loader.cpp b/src/dft/dft_loader.cpp index cb9b8aed7..3066249a3 100644 --- a/src/dft/dft_loader.cpp +++ b/src/dft/dft_loader.cpp @@ -162,7 +162,7 @@ static oneapi::mkl::detail::table_initializer \ sycl::event compute_backward, T_BACKWARD>( \ descriptor & desc, T_BACKWARD * inout, \ - const std::vector &dependencies) { \ + const std::vector &dependencies) { \ return detail::function_tables[get_device_id(desc.get_queue())] \ .compute_backward_usm_inplace_##EXT(desc, inout, dependencies); \ } \ @@ -171,7 +171,7 @@ static oneapi::mkl::detail::table_initializer \ sycl::event compute_backward, T_REAL>( \ descriptor & desc, T_REAL * inout_re, T_REAL * inout_im, \ - const std::vector &dependencies) { \ + const std::vector &dependencies) { \ return detail::function_tables[get_device_id(desc.get_queue())] \ .compute_backward_usm_inplace_split_##EXT(desc, inout_re, inout_im, dependencies); \ } \ @@ -180,7 +180,7 @@ static oneapi::mkl::detail::table_initializer \ sycl::event compute_backward, T_BACKWARD, T_FORWARD>( \ descriptor & desc, T_BACKWARD * in, T_FORWARD * out, \ - const std::vector &dependencies) { \ + const std::vector &dependencies) { \ return detail::function_tables[get_device_id(desc.get_queue())] \ .compute_backward_usm_outofplace_##EXT(desc, in, out, dependencies); \ } \ @@ -189,7 +189,7 @@ static oneapi::mkl::detail::table_initializer \ sycl::event compute_backward, T_REAL, T_REAL>( \ descriptor & desc, T_REAL * in_re, T_REAL * in_im, T_REAL * out_re, \ - T_REAL * out_im, const std::vector &dependencies) { \ + T_REAL * out_im, const std::vector &dependencies) { \ return detail::function_tables[get_device_id(desc.get_queue())] \ .compute_backward_usm_outofplace_split_##EXT(desc, in_re, in_im, out_re, out_im, \ dependencies); \ diff --git a/src/dft/function_table.hpp b/src/dft/function_table.hpp index 91af8fe04..a25bebe8b 100644 --- a/src/dft/function_table.hpp +++ b/src/dft/function_table.hpp @@ -78,16 +78,16 @@ typedef struct { sycl::buffer & out_im); \ sycl::event (*compute_backward_usm_inplace_##EXT)( \ oneapi::mkl::dft::descriptor & desc, T_BACKWARD * inout, \ - const std::vector &dependencies); \ + const std::vector &dependencies); \ sycl::event (*compute_backward_usm_inplace_split_##EXT)( \ oneapi::mkl::dft::descriptor & desc, T_REAL * inout_re, \ - T_REAL * inout_im, const std::vector &dependencies); \ + T_REAL * inout_im, const std::vector &dependencies); \ sycl::event (*compute_backward_usm_outofplace_##EXT)( \ oneapi::mkl::dft::descriptor & desc, T_BACKWARD * in, T_FORWARD * out, \ - const std::vector &dependencies); \ + const std::vector &dependencies); \ sycl::event (*compute_backward_usm_outofplace_split_##EXT)( \ oneapi::mkl::dft::descriptor & desc, T_REAL * in_re, T_REAL * in_im, \ - T_REAL * out_re, T_REAL * out_im, const std::vector &dependencies); + T_REAL * out_re, T_REAL * out_im, const std::vector &dependencies); ONEAPI_MKL_DFT_BACKEND_SIGNATURES(f, oneapi::mkl::dft::precision::SINGLE, oneapi::mkl::dft::domain::REAL, float, float, diff --git a/tests/unit_tests/dft/CMakeLists.txt b/tests/unit_tests/dft/CMakeLists.txt index 3a12a42ed..4eddd205f 100644 --- a/tests/unit_tests/dft/CMakeLists.txt +++ b/tests/unit_tests/dft/CMakeLists.txt @@ -17,4 +17,4 @@ # SPDX-License-Identifier: Apache-2.0 #=============================================================================== -add_subdirectory(source) \ No newline at end of file +add_subdirectory(source) From a48b583d273204a94db3dec717be53d74a0a212c Mon Sep 17 00:00:00 2001 From: "Anant, Srivastava" Date: Mon, 10 Oct 2022 00:28:35 -0700 Subject: [PATCH 04/21] pimpl descriptor class skeleton --- CMakeLists.txt | 3 +- examples/dft/CMakeLists.txt | 13 ++ .../compile_time_dispatching/CMakeLists.txt | 49 ++++ .../complex_fwd_usm_mklcpu.cpp | 142 ++++++++++++ .../dft/run_time_dispatching/CMakeLists.txt | 67 ++++++ .../run_time_dispatching/complex_fwd_usm.cpp | 118 ++++++++++ include/oneapi/mkl/dft/descriptor.hpp | 30 ++- .../oneapi/mkl/dft/detail/descriptor_impl.hpp | 45 ++++ include/oneapi/mkl/dft/detail/dft_loader.hpp | 36 +++ .../dft/detail/mklcpu/onemkl_dft_mklcpu.hpp | 145 ++++++++++++ .../dft/detail/mklgpu/onemkl_dft_mklgpu.hpp | 2 +- src/dft/backends/CMakeLists.txt | 4 + src/dft/backends/mklcpu/CMakeLists.txt | 71 ++++++ src/dft/backends/mklcpu/backward.cpp | 207 +++++++++++++++++ src/dft/backends/mklcpu/commit.cpp | 51 +++++ src/dft/backends/mklcpu/descriptor.cpp | 57 +++++ src/dft/backends/mklcpu/forward.cpp | 210 ++++++++++++++++++ .../backends/mklcpu/mkl_dft_cpu_wrappers.cpp | 50 +++++ src/dft/backends/mklgpu/CMakeLists.txt | 1 + src/dft/backends/mklgpu/commit.cpp | 51 +++++ src/dft/backends/mklgpu/descriptor.cpp | 63 +++--- 21 files changed, 1377 insertions(+), 38 deletions(-) create mode 100644 examples/dft/compile_time_dispatching/CMakeLists.txt create mode 100644 examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp create mode 100644 examples/dft/run_time_dispatching/CMakeLists.txt create mode 100644 examples/dft/run_time_dispatching/complex_fwd_usm.cpp create mode 100644 include/oneapi/mkl/dft/detail/descriptor_impl.hpp create mode 100644 include/oneapi/mkl/dft/detail/dft_loader.hpp create mode 100644 include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp create mode 100644 src/dft/backends/mklcpu/CMakeLists.txt create mode 100644 src/dft/backends/mklcpu/backward.cpp create mode 100644 src/dft/backends/mklcpu/commit.cpp create mode 100644 src/dft/backends/mklcpu/descriptor.cpp create mode 100644 src/dft/backends/mklcpu/forward.cpp create mode 100644 src/dft/backends/mklcpu/mkl_dft_cpu_wrappers.cpp create mode 100644 src/dft/backends/mklgpu/commit.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index ff8355f26..424992483 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,7 +85,8 @@ if(ENABLE_MKLCPU_BACKEND OR ENABLE_CURAND_BACKEND) list(APPEND DOMAINS_LIST "rng") endif() -if(ENABLE_MKLGPU_BACKEND) +if(ENABLE_MKLGPU_BACKEND + OR ENABLE_MKLCPU_BACKEND) list(APPEND DOMAINS_LIST "dft") endif() diff --git a/examples/dft/CMakeLists.txt b/examples/dft/CMakeLists.txt index 692461b1d..e43bea36d 100644 --- a/examples/dft/CMakeLists.txt +++ b/examples/dft/CMakeLists.txt @@ -16,3 +16,16 @@ # # SPDX-License-Identifier: Apache-2.0 #=============================================================================== + +# Note: compile-time example uses both MKLCPU and CURAND backends, therefore +# cmake in the sub-directory will only build it if CURAND backend is enabled +add_subdirectory(compile_time_dispatching) + +# Note: compile-time example uses both MKLCPU and CUSOLVER backends, therefore +# cmake in the sub-directory will only build it if CUSOLVER backend is enabled +# add_subdirectory(compile_time_dispatching) + +# runtime compilation is only possible with dynamic libraries +# if (BUILD_SHARED_LIBS) +# add_subdirectory(run_time_dispatching) +# endif() diff --git a/examples/dft/compile_time_dispatching/CMakeLists.txt b/examples/dft/compile_time_dispatching/CMakeLists.txt new file mode 100644 index 000000000..72b7e0701 --- /dev/null +++ b/examples/dft/compile_time_dispatching/CMakeLists.txt @@ -0,0 +1,49 @@ +#=============================================================================== +# Copyright 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +#Build object from all sources +set(DFTI_CT_SOURCES "") +if(ENABLE_MKLCPU_BACKEND) + list(APPEND DFTI_CT_SOURCES "complex_fwd_usm_mklcpu") +endif() + +if(domain STREQUAL "dft" AND ENABLE_MKLCPU_BACKEND) + find_library(OPENCL_LIBRARY NAMES OpenCL) + message(STATUS "Found OpenCL: ${OPENCL_LIBRARY}") +endif() + +foreach(dfti_ct_sources ${DFTI_CT_SOURCES}) + add_executable(example_${domain}_${dfti_ct_sources} ${dfti_ct_sources}.cpp) + target_include_directories(example_${domain}_${dfti_ct_sources} + PUBLIC ${PROJECT_SOURCE_DIR}/examples/include + PUBLIC ${PROJECT_SOURCE_DIR}/include + PUBLIC ${CMAKE_BINARY_DIR}/bin + ) + if(domain STREQUAL "dft" AND ENABLE_MKLCPU_BACKEND) + add_dependencies(example_${domain}_${dfti_ct_sources} onemkl_${domain}_mklcpu) + list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_mklcpu) + target_link_libraries(example_${domain}_${dfti_ct_sources} PUBLIC ${OPENCL_LIBRARY}) + endif() + target_link_libraries(example_${domain}_${dfti_ct_sources} PUBLIC + ${ONEMKL_LIBRARIES_${domain}} + ONEMKL::SYCL::SYCL + ) + # Register example as ctest + add_test(NAME ${domain}/EXAMPLE/CT/${dfti_ct_sources} COMMAND example_${domain}_${dfti_ct_sources}) +endforeach(dfti_ct_sources) diff --git a/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp b/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp new file mode 100644 index 000000000..02009c3f7 --- /dev/null +++ b/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp @@ -0,0 +1,142 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +/* +* +* Content: +* This example demonstrates use of oneapi::mkl::dft::getrf and +* oneapi::mkl::dft::getrs to perform LU factorization and compute +* the solution on both an Intel cpu device and NVIDIA cpu device. +* +* This example demonstrates only single precision (float) data type +* for matrix data +* +*******************************************************************************/ + +// STL includes +#include +#include +#include +#include + +// oneMKL/SYCL includes +#if __has_include() +#include +#else +#include +#endif +#include "oneapi/mkl.hpp" + +// local includes +#include "example_helper.hpp" + +void run_getrs_example(const sycl::device& cpu_device) { + // Matrix sizes and leading dimensions + constexpr std::size_t n = 10; + + // Catch asynchronous exceptions for cpu and cpu + auto cpu_error_handler = [&](sycl::exception_list exceptions) { + for (auto const& e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (sycl::exception const& e) { + // Handle not dft related exceptions that happened during asynchronous call + std::cerr + << "Caught asynchronous SYCL exception on cpu device during GETRF or GETRS:" + << std::endl; + std::cerr << "\t" << e.what() << std::endl; + } + } + std::exit(2); + }; + + // + // Preparation on cpu + // + sycl::queue cpu_queue(cpu_device, cpu_error_handler); + sycl::context cpu_context = cpu_queue.get_context(); + sycl::event cpu_getrf_done; + + double *x_usm = (double*) malloc_shared(n*2*sizeof(double), cpu_queue.get_device(), cpu_queue.get_context()); + std::cout << "DFTI example" << std::endl; + oneapi::mkl::dft::descriptor desc(10); + // compute_forward(desc, x_usm); +} + +// +// Description of example setup, apis used and supported floating point type precisions +// + +void print_example_banner() { + std::cout << "" << std::endl; + std::cout << "########################################################################" + << std::endl; + std::cout + << "# DFTI complex in-place forward transform for USM/Buffer API's example: " + << std::endl; + std::cout << "# " << std::endl; + std::cout << "# Using APIs:" << std::endl; + std::cout << "# USM/BUffer forward complex in-place" << std::endl; + std::cout << "# " << std::endl; + std::cout << "# Using single precision (float) data type" << std::endl; + std::cout << "# " << std::endl; + std::cout << "# Device will be selected during runtime." << std::endl; + std::cout << "# The environment variable SYCL_DEVICE_FILTER can be used to specify" + << std::endl; + std::cout << "# Using single precision (float) data type" << std::endl; + std::cout << "# " << std::endl; + std::cout << "# Running on both Intel cpu and NVIDIA cpu devices" << std::endl; + std::cout << "# " << std::endl; + std::cout << "########################################################################" + << std::endl; + std::cout << std::endl; +} + +// +// Main entry point for example. +// +int main(int argc, char** argv) { + print_example_banner(); + + try { + sycl::device cpu_dev((sycl::cpu_selector())); + std::cout << "Running DFT Complex forward inplace USM example" << std::endl; + std::cout << "Running with single precision real data type on:" << std::endl; + std::cout << "\tcpu device :" << cpu_dev.get_info() << std::endl; + + run_getrs_example(cpu_dev); + std::cout << "DFT Complex USM example ran OK on MKLcpu" << std::endl; + } + catch (sycl::exception const& e) { + // Handle not dft related exceptions that happened during synchronous call + std::cerr << "Caught synchronous SYCL exception:" << std::endl; + std::cerr << "\t" << e.what() << std::endl; + std::cerr << "\tSYCL error code: " << e.code().value() << std::endl; + return 1; + } + catch (std::exception const& e) { + // Handle not SYCL related exceptions that happened during synchronous call + std::cerr << "Caught synchronous std::exception:" << std::endl; + std::cerr << "\t" << e.what() << std::endl; + return 1; + } + + return 0; +} diff --git a/examples/dft/run_time_dispatching/CMakeLists.txt b/examples/dft/run_time_dispatching/CMakeLists.txt new file mode 100644 index 000000000..dc947a9dc --- /dev/null +++ b/examples/dft/run_time_dispatching/CMakeLists.txt @@ -0,0 +1,67 @@ +#=============================================================================== +# Copyright 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +# NOTE: user needs to set env var SYCL_DEVICE_FILTER to use runtime example (no need to specify backend when building with CMake) + +# Build object from all example sources +set(DFT_RT_SOURCES "complex_fwd_usm") + +# Set up for the right backend for run-time dispatching examples +# If users build more than one backend (i.e. mklcpu and mklgpu, or mklcpu and CUDA), they may need to +# overwrite SYCL_DEVICE_FILTER in their environment to run on the desired backend +set(DEVICE_FILTERS "") +if(ENABLE_MKLCPU_BACKEND) + list(APPEND DEVICE_FILTERS "cpu") +endif() +# RNG only supports mklcpu backend on Windows +if(UNIX AND ENABLE_MKLGPU_BACKEND) + list(APPEND DEVICE_FILTERS "gpu") +endif() + +message(STATUS "SYCL_DEVICE_FILTER will be set to the following value(s): [${DEVICE_FILTERS}] for run-time dispatching examples") + +foreach(dft_rt_sources ${DFT_RT_SOURCES}) + add_executable(example_${domain}_${dft_rt_sources} ${dft_rt_sources}.cpp) + target_include_directories(example_${domain}_${dft_rt_sources} + PUBLIC ${PROJECT_SOURCE_DIR}/examples/include + PUBLIC ${PROJECT_SOURCE_DIR}/include + PUBLIC ${CMAKE_BINARY_DIR}/bin + ) + + add_dependencies(example_${domain}_${dft_rt_sources} onemkl) + + if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) + add_sycl_to_target(TARGET example_${domain}_${dft_rt_sources} SOURCES ${DFT_RT_SOURCES}) + endif() + + target_link_libraries(example_${domain}_${dft_rt_sources} PUBLIC + onemkl + ONEMKL::SYCL::SYCL + ${CMAKE_DL_LIBS} + ) + + # Register example as ctest + foreach(device_filter ${DEVICE_FILTERS}) + add_test(NAME ${domain}/EXAMPLE/RT/${dft_rt_sources}/${device_filter} COMMAND example_${domain}_${dft_rt_sources}) + set_property(TEST ${domain}/EXAMPLE/RT/${dft_rt_sources}/${device_filter} PROPERTY + ENVIRONMENT LD_LIBRARY_PATH=${CMAKE_BINARY_DIR}/lib:$ENV{LD_LIBRARY_PATH} + ENVIRONMENT SYCL_DEVICE_FILTER=${device_filter}) + endforeach(device_filter) + +endforeach() diff --git a/examples/dft/run_time_dispatching/complex_fwd_usm.cpp b/examples/dft/run_time_dispatching/complex_fwd_usm.cpp new file mode 100644 index 000000000..1b906afb6 --- /dev/null +++ b/examples/dft/run_time_dispatching/complex_fwd_usm.cpp @@ -0,0 +1,118 @@ + +// stl includes +#include +#include +#include +#include + +// oneMKL/SYCL includes +#if __has_include() +#include +#else +#include +#endif +#include "oneapi/mkl.hpp" + +// local includes +#include "example_helper.hpp" + +constexpr int SUCCESS = 0; +constexpr int FAILURE = 1; +constexpr double TWOPI = 6.2831853071795864769; + +void run_uniform_example(const sycl::device& dev) { + + int N = 16; + int harmonic = 5; + int buffer_result = FAILURE; + int usm_result = FAILURE; + int result = FAILURE; + + // Catch asynchronous exceptions + auto exception_handler = [](sycl::exception_list exceptions) { + for (std::exception_ptr const& e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (sycl::exception const& e) { + std::cerr << "Caught asynchronous SYCL exception during generation:" << std::endl; + std::cerr << "\t" << e.what() << std::endl; + } + } + std::exit(2); + }; + + sycl::queue queue(dev, exception_handler); + + double *x_usm = (double*) malloc_shared(N*2*sizeof(double), queue.get_device(), queue.get_context()); + + oneapi::mkl::dft::descriptor< + oneapi::mkl::dft::precision::DOUBLE, + oneapi::mkl::dft::domain::COMPLEX + > desc(N); +} + +// +// Description of example setup, APIs used and supported floating point type precisions +// +void print_example_banner() { + std::cout << "" << std::endl; + std::cout << "########################################################################" + << std::endl; + std::cout + << "# DFTI complex in-place forward transform for USM/Buffer API's example: " + << std::endl; + std::cout << "# " << std::endl; + std::cout << "# Using APIs:" << std::endl; + std::cout << "# USM/BUffer forward complex in-place" << std::endl; + std::cout << "# " << std::endl; + std::cout << "# Using single precision (float) data type" << std::endl; + std::cout << "# " << std::endl; + std::cout << "# Device will be selected during runtime." << std::endl; + std::cout << "# The environment variable SYCL_DEVICE_FILTER can be used to specify" + << std::endl; + std::cout << "# SYCL device" << std::endl; + std::cout << "# " << std::endl; + std::cout << "########################################################################" + << std::endl; + std::cout << std::endl; +} + +// +// Main entry point for example. +// + +int main(int argc, char** argv) { + print_example_banner(); + + try { + sycl::device my_dev((sycl::default_selector())); + + if (my_dev.is_gpu()) { + std::cout << "Running DFT complex forward example on GPU device" << std::endl; + std::cout << "Device name is: " << my_dev.get_info() + << std::endl; + } + else { + std::cout << "Running DFT complex forward example on CPU device" << std::endl; + std::cout << "Device name is: " << my_dev.get_info() + << std::endl; + } + std::cout << "Running with single precision real data type:" << std::endl; + + run_uniform_example(my_dev); + std::cout << "DFIT example ran OK" << std::endl; + } + catch (sycl::exception const& e) { + std::cerr << "Caught synchronous SYCL exception:" << std::endl; + std::cerr << "\t" << e.what() << std::endl; + std::cerr << "\tSYCL error code: " << e.code().value() << std::endl; + return 1; + } + catch (std::exception const& e) { + std::cerr << "Caught std::exception during generation:" << std::endl; + std::cerr << "\t" << e.what() << std::endl; + return 1; + } + return 0; +} diff --git a/include/oneapi/mkl/dft/descriptor.hpp b/include/oneapi/mkl/dft/descriptor.hpp index 5e88bbfde..3a9c59f53 100644 --- a/include/oneapi/mkl/dft/descriptor.hpp +++ b/include/oneapi/mkl/dft/descriptor.hpp @@ -27,21 +27,35 @@ #endif #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/detail/backend_selector.hpp" -namespace oneapi::mkl::dft { +#include "oneapi/mkl/dft/detail/descriptor_impl.hpp" +#include "oneapi/mkl/dft/detail/dft_loader.hpp" -template +namespace oneapi { +namespace mkl { +namespace dft { + +template class descriptor { private: sycl::queue queue_; - + std::unique_ptr pimpl_; + int x; public: // Syntax for 1-dimensional DFT - descriptor(std::int64_t length); + descriptor(std::int64_t length) +#ifdef ENABLE_MKLCPU_BACKEND + : pimpl_(mklcpu::create_descriptor(length)) {} +#endif +#ifdef ENABLE_MKLGPU_BACKEND + : pimpl_(mklgpu::create_descriptor(length)) {} +#endif + // Syntax for d-dimensional DFT descriptor(std::vector dimensions); - ~descriptor(); + // ~descriptor(); void set_value(config_param param, ...); @@ -53,6 +67,10 @@ class descriptor { return queue_; }; }; -} // namespace oneapi::mkl::dft + +} //namespace dft +} //namespace mkl +} //namespace oneapi + #endif // _ONEMKL_DFT_DESCRIPTOR_HPP_ diff --git a/include/oneapi/mkl/dft/detail/descriptor_impl.hpp b/include/oneapi/mkl/dft/detail/descriptor_impl.hpp new file mode 100644 index 000000000..2361a4750 --- /dev/null +++ b/include/oneapi/mkl/dft/detail/descriptor_impl.hpp @@ -0,0 +1,45 @@ +#ifndef _ONEMKL_DFT_DESCRIPTOR_IMPL_HPP_ +#define _ONEMKL_DFT_DESCRIPTOR_IMPL_HPP_ + +#include +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/detail/export.hpp" +#include "oneapi/mkl/detail/get_device_id.hpp" +#include "oneapi/mkl/types.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { +namespace detail { + +class descriptor_impl { +public: + descriptor_impl(std::size_t length) : length_(length) {} + + descriptor_impl(const descriptor_impl& other) : length_(other.length_) {} + + virtual descriptor_impl* copy_state() = 0; + + virtual ~descriptor_impl() {} + + sycl::queue& get_queue() { + return queue_; + } + +protected: + sycl::queue queue_; + std::size_t length_; +}; + +} // namespace detail +} // namespace dft +} // namespace mkl +} // namespace oneapi + +#endif //_ONEMKL_DFT_DESCRIPTOR_IMPL_HPP_ + diff --git a/include/oneapi/mkl/dft/detail/dft_loader.hpp b/include/oneapi/mkl/dft/detail/dft_loader.hpp new file mode 100644 index 000000000..5314058d3 --- /dev/null +++ b/include/oneapi/mkl/dft/detail/dft_loader.hpp @@ -0,0 +1,36 @@ +#ifndef _ONEMKL_DFT_LOADER_HPP_ +#define _ONEMKL_DFT_LOADER_HPP_ + +#include +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/detail/export.hpp" +#include "oneapi/mkl/detail/get_device_id.hpp" + +#include "oneapi/mkl/dft/detail/descriptor_impl.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { + +namespace mklcpu { + +ONEMKL_EXPORT oneapi::mkl::dft::detail::descriptor_impl* create_descriptor(std::size_t length); + +} // namespace mklcpu + +namespace mklgpu { + +ONEMKL_EXPORT oneapi::mkl::dft::detail::descriptor_impl* create_descriptor(std::size_t length); + +} // namespace mklgpu + +} // namespace dft +} // namespace mkl +} // namespace oneapi + +#endif //_ONEMKL_DFT_LOADER_HPP_ diff --git a/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp b/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp new file mode 100644 index 000000000..edf8d706a --- /dev/null +++ b/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp @@ -0,0 +1,145 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#pragma once + +#if __has_include() +#include +#else +#include +#endif + +#include +#include + +#include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/dft/descriptor.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { +namespace mklcpu { + +#define ONEAPI_MKL_DFT_BACKEND_SIGNATURES(EXT, PRECISION, DOMAIN, T_REAL, T_FORWARD, T_BACKWARD) \ + \ + void commit_##EXT(descriptor &desc, sycl::queue &queue); \ + \ + /*Buffer version*/ \ + \ + /*In-place transform*/ \ + void compute_forward_buffer_inplace_##EXT(descriptor &desc, \ + sycl::buffer &inout); \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + void compute_forward_buffer_inplace_split_##EXT(descriptor &desc, \ + sycl::buffer &inout_re, \ + sycl::buffer &inout_im); \ + \ + /*Out-of-place transform*/ \ + void compute_forward_buffer_outofplace_##EXT(descriptor &desc, \ + sycl::buffer &in, \ + sycl::buffer &out); \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + void compute_forward_buffer_outofplace_split_##EXT( \ + descriptor &desc, sycl::buffer &in_re, \ + sycl::buffer &in_im, sycl::buffer &out_re, \ + sycl::buffer &out_im); \ + \ + /*USM version*/ \ + \ + /*In-place transform*/ \ + sycl::event compute_forward_usm_inplace_##EXT( \ + descriptor &desc, T_BACKWARD *inout, \ + const std::vector &dependencies = {}); \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + sycl::event compute_forward_usm_inplace_split_##EXT( \ + descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ + const std::vector &dependencies = {}); \ + \ + /*Out-of-place transform*/ \ + sycl::event compute_forward_usm_outofplace_##EXT( \ + descriptor &desc, T_FORWARD *in, T_BACKWARD *out, \ + const std::vector &dependencies = {}); \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + sycl::event compute_forward_usm_outofplace_split_##EXT( \ + descriptor &desc, T_REAL *in_re, T_REAL *in_im, T_REAL *out_re, \ + T_REAL *out_im, const std::vector &dependencies = {}); \ + \ + /*Buffer version*/ \ + \ + /*In-place transform*/ \ + void compute_backward_buffer_inplace_##EXT(descriptor &desc, \ + sycl::buffer &inout); \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + void compute_backward_buffer_inplace_split_##EXT(descriptor &desc, \ + sycl::buffer &inout_re, \ + sycl::buffer &inout_im); \ + \ + /*Out-of-place transform*/ \ + void compute_backward_buffer_outofplace_##EXT(descriptor &desc, \ + sycl::buffer &in, \ + sycl::buffer &out); \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + void compute_backward_buffer_outofplace_split_##EXT( \ + descriptor &desc, sycl::buffer &in_re, \ + sycl::buffer &in_im, sycl::buffer &out_re, \ + sycl::buffer &out_im); \ + \ + /*USM version*/ \ + \ + /*In-place transform*/ \ + sycl::event compute_backward_usm_inplace_##EXT( \ + descriptor &desc, T_BACKWARD *inout, \ + const std::vector &dependencies = {}); \ + \ + /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + sycl::event compute_backward_usm_inplace_split_##EXT( \ + descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ + const std::vector &dependencies = {}); \ + \ + /*Out-of-place transform*/ \ + sycl::event compute_backward_usm_outofplace_##EXT( \ + descriptor &desc, T_BACKWARD *in, T_FORWARD *out, \ + const std::vector &dependencies = {}); \ + \ + /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ + sycl::event compute_backward_usm_outofplace_split_##EXT( \ + descriptor &desc, T_REAL *in_re, T_REAL *in_im, T_REAL *out_re, \ + T_REAL *out_im, const std::vector &dependencies = {}); + +ONEAPI_MKL_DFT_BACKEND_SIGNATURES(f, precision::SINGLE, domain::REAL, float, float, + std::complex) +ONEAPI_MKL_DFT_BACKEND_SIGNATURES(c, precision::SINGLE, domain::COMPLEX, float, std::complex, + std::complex) +ONEAPI_MKL_DFT_BACKEND_SIGNATURES(d, precision::DOUBLE, domain::REAL, double, double, + std::complex) +ONEAPI_MKL_DFT_BACKEND_SIGNATURES(z, precision::DOUBLE, domain::COMPLEX, double, + std::complex, std::complex) + +#undef ONEAPI_MKL_DFT_BACKEND_SIGNATURES + +} // namespace mklcpu +} // namespace dft +} // namespace mkl +} // namespace oneapi diff --git a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp index 64ef639b3..e82de9656 100644 --- a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp +++ b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp @@ -28,8 +28,8 @@ #include #include -#include "oneapi/mkl/types.hpp" #include "oneapi/mkl/dft/descriptor.hpp" +#include "oneapi/mkl/types.hpp" namespace oneapi { namespace mkl { diff --git a/src/dft/backends/CMakeLists.txt b/src/dft/backends/CMakeLists.txt index 70dd060c6..c75086840 100644 --- a/src/dft/backends/CMakeLists.txt +++ b/src/dft/backends/CMakeLists.txt @@ -20,3 +20,7 @@ if(ENABLE_MKLGPU_BACKEND) add_subdirectory(mklgpu) endif() + +if(ENABLE_MKLCPU_BACKEND) + add_subdirectory(mklcpu) +endif() diff --git a/src/dft/backends/mklcpu/CMakeLists.txt b/src/dft/backends/mklcpu/CMakeLists.txt new file mode 100644 index 000000000..4ff97cd54 --- /dev/null +++ b/src/dft/backends/mklcpu/CMakeLists.txt @@ -0,0 +1,71 @@ +#=============================================================================== +# Copyright 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +set(LIB_NAME onemkl_dft_mklcpu) +set(LIB_OBJ ${LIB_NAME}_obj) + +find_package(MKL REQUIRED) + +add_library(${LIB_NAME}) +add_library(${LIB_OBJ} OBJECT + commit.cpp + descriptor.cpp + forward.cpp + backward.cpp + $<$: mkl_dft_cpu_wrappers.cpp> +) + +target_include_directories(${LIB_OBJ} + PRIVATE ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src + ${CMAKE_BINARY_DIR}/bin + ${MKL_INCLUDE} +) + +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT}) + +target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ${MKL_LINK_SYCL}) + +set_target_properties(${LIB_OBJ} PROPERTIES + POSITION_INDEPENDENT_CODE ON +) +target_link_libraries(${LIB_NAME} PUBLIC ${LIB_OBJ}) + +#Set oneMKL libraries as not transitive for dynamic +if(BUILD_SHARED_LIBS) + set_target_properties(${LIB_NAME} PROPERTIES + INTERFACE_LINK_LIBRARIES ONEMKL::SYCL::SYCL + ) +endif() + +# Add major version to the library +set_target_properties(${LIB_NAME} PROPERTIES + SOVERSION ${PROJECT_VERSION_MAJOR} +) + +# Add dependencies rpath to the library +list(APPEND CMAKE_BUILD_RPATH $) + +# Add the library to install package +install(TARGETS ${LIB_OBJ} EXPORT oneMKLTargets) +install(TARGETS ${LIB_NAME} EXPORT oneMKLTargets + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib +) diff --git a/src/dft/backends/mklcpu/backward.cpp b/src/dft/backends/mklcpu/backward.cpp new file mode 100644 index 000000000..369e64a78 --- /dev/null +++ b/src/dft/backends/mklcpu/backward.cpp @@ -0,0 +1,207 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/types.hpp" + +#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { +namespace mklcpu { + +void compute_backward_buffer_inplace_f(descriptor &desc, + sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_inplace_c(descriptor &desc, + sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_inplace_d(descriptor &desc, + sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_inplace_z(descriptor &desc, + sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklcpu"); +} + +void compute_backward_buffer_inplace_split_f(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_inplace_split_c(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_inplace_split_d(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_inplace_split_z(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} + +void compute_backward_buffer_outofplace_f(descriptor &desc, + sycl::buffer, 1> &in, + sycl::buffer &out) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_outofplace_c(descriptor &desc, + sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_outofplace_d(descriptor &desc, + sycl::buffer, 1> &in, + sycl::buffer &out) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_outofplace_z(descriptor &desc, + sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { + throw std::runtime_error("Not implemented for mklcpu"); +} + +void compute_backward_buffer_outofplace_split_f(descriptor &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_outofplace_split_c( + descriptor &desc, sycl::buffer &in_re, + sycl::buffer &in_im, sycl::buffer &out_re, sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_outofplace_split_d(descriptor &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_outofplace_split_z( + descriptor &desc, sycl::buffer &in_re, + sycl::buffer &in_im, sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} + +sycl::event compute_backward_usm_inplace_f(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_inplace_c(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_inplace_d(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_inplace_z(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} + +sycl::event compute_backward_usm_inplace_split_f(descriptor &desc, + float *inout_re, float *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_inplace_split_c( + descriptor &desc, float *inout_re, float *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_inplace_split_d(descriptor &desc, + double *inout_re, double *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_inplace_split_z( + descriptor &desc, double *inout_re, double *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} + +sycl::event compute_backward_usm_outofplace_f(descriptor &desc, + std::complex *in, float *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_outofplace_c(descriptor &desc, + std::complex *in, std::complex *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_outofplace_d(descriptor &desc, + std::complex *in, double *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_outofplace_z(descriptor &desc, + std::complex *in, std::complex *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} + +sycl::event compute_backward_usm_outofplace_split_f( + descriptor &desc, float *in_re, float *in_im, float *out_re, + float *out_im, const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_outofplace_split_c( + descriptor &desc, float *in_re, float *in_im, float *out_re, + float *out_im, const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_outofplace_split_d( + descriptor &desc, double *in_re, double *in_im, double *out_re, + double *out_im, const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_outofplace_split_z( + descriptor &desc, double *in_re, double *in_im, + double *out_re, double *out_im, const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} + +} // namespace mklcpu +} // namespace dft +} // namespace mkl +} // namespace oneapi diff --git a/src/dft/backends/mklcpu/commit.cpp b/src/dft/backends/mklcpu/commit.cpp new file mode 100644 index 000000000..d28a193d7 --- /dev/null +++ b/src/dft/backends/mklcpu/commit.cpp @@ -0,0 +1,51 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/types.hpp" + +#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { +namespace mklcpu { + +void commit_f(descriptor &desc, sycl::queue &queue) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void commit_c(descriptor &desc, sycl::queue &queue) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void commit_d(descriptor &desc, sycl::queue &queue) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void commit_z(descriptor &desc, sycl::queue &queue) { + throw std::runtime_error("Not implemented for mklcpu"); +} + +} // namespace mklcpu +} // namespace dft +} // namespace mkl +} // namespace oneapi diff --git a/src/dft/backends/mklcpu/descriptor.cpp b/src/dft/backends/mklcpu/descriptor.cpp new file mode 100644 index 000000000..d495131ed --- /dev/null +++ b/src/dft/backends/mklcpu/descriptor.cpp @@ -0,0 +1,57 @@ +#include +#if __has_include() +#include +#else +#include +#endif + +#include "mkl_version.h" + +#include "oneapi/mkl/types.hpp" + +#include "oneapi/mkl/dft/detail/descriptor_impl.hpp" +#include "oneapi/mkl/dft/descriptor.hpp" +#include "oneapi/mkl/exceptions.hpp" + +#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" + +#include "mkl_dfti.h" + +namespace oneapi { +namespace mkl { +namespace dft { +namespace mklcpu { + +class descriptor_derived_impl : public oneapi::mkl::dft::detail::descriptor_impl { +public: + descriptor_derived_impl(std::size_t length) + : oneapi::mkl::dft::detail::descriptor_impl(length) { + std::cout << "special entry points" << std::endl; + DFTI_DESCRIPTOR_HANDLE hand = NULL; + } + + descriptor_derived_impl(const descriptor_derived_impl* other) + : oneapi::mkl::dft::detail::descriptor_impl(*other) { + std::cout << "special entry points copy const" << std::endl; + } + + virtual oneapi::mkl::dft::detail::descriptor_impl* copy_state() override { + return new descriptor_derived_impl(this); + } + + virtual ~descriptor_derived_impl() override { + std::cout << "descriptor_derived_impl descriptor" << std::endl; + } +private: +}; + +oneapi::mkl::dft::detail::descriptor_impl* create_descriptor(std::size_t length) { + return new descriptor_derived_impl(length); +} + + + +} // namespace mklcpu +} // namespace dft +} // namespace mkl +} // namespace oneapi diff --git a/src/dft/backends/mklcpu/forward.cpp b/src/dft/backends/mklcpu/forward.cpp new file mode 100644 index 000000000..765adbebf --- /dev/null +++ b/src/dft/backends/mklcpu/forward.cpp @@ -0,0 +1,210 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/types.hpp" + +#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { +namespace mklcpu { + +void compute_forward_buffer_inplace_f(descriptor &desc, + sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_inplace_c(descriptor &desc, + sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_inplace_d(descriptor &desc, + sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_inplace_z(descriptor &desc, + sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklcpu"); +} + +void compute_forward_buffer_inplace_split_f(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_inplace_split_c(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_inplace_split_d(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_inplace_split_z(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} + +void compute_forward_buffer_outofplace_f(descriptor &desc, + sycl::buffer &in, + sycl::buffer, 1> &out) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_outofplace_c(descriptor &desc, + sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_outofplace_d(descriptor &desc, + sycl::buffer &in, + sycl::buffer, 1> &out) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_outofplace_z(descriptor &desc, + sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { + throw std::runtime_error("Not implemented for mklcpu"); +} + +void compute_forward_buffer_outofplace_split_f(descriptor &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_outofplace_split_c(descriptor &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_outofplace_split_d(descriptor &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_outofplace_split_z(descriptor &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} + +sycl::event compute_forward_usm_inplace_f(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_inplace_c(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_inplace_d(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_inplace_z(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} + +sycl::event compute_forward_usm_inplace_split_f(descriptor &desc, + float *inout_re, float *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_inplace_split_c( + descriptor &desc, float *inout_re, float *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_inplace_split_d(descriptor &desc, + double *inout_re, double *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_inplace_split_z( + descriptor &desc, double *inout_re, double *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} + +sycl::event compute_forward_usm_outofplace_f(descriptor &desc, + float *in, std::complex *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_outofplace_c(descriptor &desc, + std::complex *in, std::complex *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_outofplace_d(descriptor &desc, + double *in, std::complex *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_outofplace_z(descriptor &desc, + std::complex *in, std::complex *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} + +sycl::event compute_forward_usm_outofplace_split_f( + descriptor &desc, float *in_re, float *in_im, float *out_re, + float *out_im, const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_outofplace_split_c( + descriptor &desc, float *in_re, float *in_im, float *out_re, + float *out_im, const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_outofplace_split_d( + descriptor &desc, double *in_re, double *in_im, double *out_re, + double *out_im, const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_outofplace_split_z( + descriptor &desc, double *in_re, double *in_im, + double *out_re, double *out_im, const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} + +} // namespace mklcpu +} // namespace dft +} // namespace mkl +} // namespace oneapi diff --git a/src/dft/backends/mklcpu/mkl_dft_cpu_wrappers.cpp b/src/dft/backends/mklcpu/mkl_dft_cpu_wrappers.cpp new file mode 100644 index 000000000..8a5374cd8 --- /dev/null +++ b/src/dft/backends/mklcpu/mkl_dft_cpu_wrappers.cpp @@ -0,0 +1,50 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" +#include "dft/function_table.hpp" + +#define WRAPPER_VERSION 1 + +extern "C" dft_function_table_t mkl_dft_table = { + WRAPPER_VERSION, +#define ONEAPI_MKL_DFT_BACKEND_SIGNATURES(EXT) \ + oneapi::mkl::dft::mklcpu::commit_##EXT, \ + oneapi::mkl::dft::mklcpu::compute_forward_buffer_inplace_##EXT, \ + oneapi::mkl::dft::mklcpu::compute_forward_buffer_inplace_split_##EXT, \ + oneapi::mkl::dft::mklcpu::compute_forward_buffer_outofplace_##EXT, \ + oneapi::mkl::dft::mklcpu::compute_forward_buffer_outofplace_split_##EXT, \ + oneapi::mkl::dft::mklcpu::compute_forward_usm_inplace_##EXT, \ + oneapi::mkl::dft::mklcpu::compute_forward_usm_inplace_split_##EXT, \ + oneapi::mkl::dft::mklcpu::compute_forward_usm_outofplace_##EXT, \ + oneapi::mkl::dft::mklcpu::compute_forward_usm_outofplace_split_##EXT, \ + oneapi::mkl::dft::mklcpu::compute_backward_buffer_inplace_##EXT, \ + oneapi::mkl::dft::mklcpu::compute_backward_buffer_inplace_split_##EXT, \ + oneapi::mkl::dft::mklcpu::compute_backward_buffer_outofplace_##EXT, \ + oneapi::mkl::dft::mklcpu::compute_backward_buffer_outofplace_split_##EXT, \ + oneapi::mkl::dft::mklcpu::compute_backward_usm_inplace_##EXT, \ + oneapi::mkl::dft::mklcpu::compute_backward_usm_inplace_split_##EXT, \ + oneapi::mkl::dft::mklcpu::compute_backward_usm_outofplace_##EXT, \ + oneapi::mkl::dft::mklcpu::compute_backward_usm_outofplace_split_##EXT + + ONEAPI_MKL_DFT_BACKEND_SIGNATURES(f), ONEAPI_MKL_DFT_BACKEND_SIGNATURES(c), + ONEAPI_MKL_DFT_BACKEND_SIGNATURES(d), ONEAPI_MKL_DFT_BACKEND_SIGNATURES(z) + +#undef ONEAPI_MKL_DFT_BACKEND_SIGNATURES +}; diff --git a/src/dft/backends/mklgpu/CMakeLists.txt b/src/dft/backends/mklgpu/CMakeLists.txt index 00a5614eb..6a91a8c34 100644 --- a/src/dft/backends/mklgpu/CMakeLists.txt +++ b/src/dft/backends/mklgpu/CMakeLists.txt @@ -24,6 +24,7 @@ find_package(MKL REQUIRED) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT + commit.cpp descriptor.cpp forward.cpp backward.cpp diff --git a/src/dft/backends/mklgpu/commit.cpp b/src/dft/backends/mklgpu/commit.cpp new file mode 100644 index 000000000..656cc59eb --- /dev/null +++ b/src/dft/backends/mklgpu/commit.cpp @@ -0,0 +1,51 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/types.hpp" + +#include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { +namespace mklgpu { + +void commit_f(descriptor &desc, sycl::queue &queue) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void commit_c(descriptor &desc, sycl::queue &queue) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void commit_d(descriptor &desc, sycl::queue &queue) { + throw std::runtime_error("Not implemented for mklgpu"); +} +void commit_z(descriptor &desc, sycl::queue &queue) { + throw std::runtime_error("Not implemented for mklgpu"); +} + +} // namespace mklgpu +} // namespace dft +} // namespace mkl +} // namespace oneapi diff --git a/src/dft/backends/mklgpu/descriptor.cpp b/src/dft/backends/mklgpu/descriptor.cpp index 656cc59eb..c2805345a 100644 --- a/src/dft/backends/mklgpu/descriptor.cpp +++ b/src/dft/backends/mklgpu/descriptor.cpp @@ -1,30 +1,18 @@ -/******************************************************************************* -* Copyright 2022 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions -* and limitations under the License. -* -* -* SPDX-License-Identifier: Apache-2.0 -*******************************************************************************/ - +#include #if __has_include() #include #else #include #endif +#include "mkl_version.h" + #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/dft/detail/descriptor_impl.hpp" +#include "oneapi/mkl/dft/descriptor.hpp" +#include "oneapi/mkl/exceptions.hpp" + #include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp" namespace oneapi { @@ -32,19 +20,34 @@ namespace mkl { namespace dft { namespace mklgpu { -void commit_f(descriptor &desc, sycl::queue &queue) { - throw std::runtime_error("Not implemented for mklgpu"); -} -void commit_c(descriptor &desc, sycl::queue &queue) { - throw std::runtime_error("Not implemented for mklgpu"); -} -void commit_d(descriptor &desc, sycl::queue &queue) { - throw std::runtime_error("Not implemented for mklgpu"); -} -void commit_z(descriptor &desc, sycl::queue &queue) { - throw std::runtime_error("Not implemented for mklgpu"); +class descriptor_derived_impl : public oneapi::mkl::dft::detail::descriptor_impl { +public: + descriptor_derived_impl(std::size_t length) + : oneapi::mkl::dft::detail::descriptor_impl(length) { + std::cout << "special entry points" << std::endl; + } + + descriptor_derived_impl(const descriptor_derived_impl* other) + : oneapi::mkl::dft::detail::descriptor_impl(*other) { + std::cout << "special entry points copy const" << std::endl; + } + + virtual oneapi::mkl::dft::detail::descriptor_impl* copy_state() override { + return new descriptor_derived_impl(this); + } + + virtual ~descriptor_derived_impl() override { + std::cout << "descriptor_derived_impl descriptor" << std::endl; + } +private: +}; + +oneapi::mkl::dft::detail::descriptor_impl* create_descriptor(std::size_t length) { + return new descriptor_derived_impl(length); } + + } // namespace mklgpu } // namespace dft } // namespace mkl From 47af09ed5b9fcfdefff52117791665e074c964a7 Mon Sep 17 00:00:00 2001 From: "Anant, Srivastava" Date: Tue, 11 Oct 2022 12:12:40 -0700 Subject: [PATCH 05/21] push descriptor outside of device specific impl + desc class w/ prec and dom --- .../complex_fwd_usm_mklcpu.cpp | 9 +- include/oneapi/mkl/dft/descriptor.hpp | 8 +- .../oneapi/mkl/dft/detail/descriptor_impl.hpp | 2 + include/oneapi/mkl/dft/detail/dft_loader.hpp | 15 +--- src/dft/backends/descriptor.cxx | 90 +++++++++++++++++++ src/dft/backends/mklcpu/CMakeLists.txt | 1 - src/dft/backends/mklcpu/descriptor.cpp | 57 ------------ .../backends/mklcpu/mkl_dft_cpu_wrappers.cpp | 1 + src/dft/backends/mklgpu/CMakeLists.txt | 2 +- src/dft/backends/mklgpu/descriptor.cpp | 54 ----------- 10 files changed, 106 insertions(+), 133 deletions(-) create mode 100644 src/dft/backends/descriptor.cxx delete mode 100644 src/dft/backends/mklcpu/descriptor.cpp delete mode 100644 src/dft/backends/mklgpu/descriptor.cpp diff --git a/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp b/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp index 02009c3f7..4eea5e051 100644 --- a/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp +++ b/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp @@ -67,6 +67,7 @@ void run_getrs_example(const sycl::device& cpu_device) { std::exit(2); }; + std::cout << "DFTI example" << std::endl; // // Preparation on cpu // @@ -75,9 +76,13 @@ void run_getrs_example(const sycl::device& cpu_device) { sycl::event cpu_getrf_done; double *x_usm = (double*) malloc_shared(n*2*sizeof(double), cpu_queue.get_device(), cpu_queue.get_context()); - std::cout << "DFTI example" << std::endl; + + // enabling oneapi::mkl::dft::descriptor desc(10); - // compute_forward(desc, x_usm); + // desc.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (double)(1.0/N)); + // [compile time] desc.commit(oneapi::mkl::backend_selector{ cpu_queue }); + // [run time] desc.commit(cpu_queue); + // oneapi::mkl::dft::compute_forward(desc, x_usm); } // diff --git a/include/oneapi/mkl/dft/descriptor.hpp b/include/oneapi/mkl/dft/descriptor.hpp index 3a9c59f53..5d9e31e02 100644 --- a/include/oneapi/mkl/dft/descriptor.hpp +++ b/include/oneapi/mkl/dft/descriptor.hpp @@ -41,16 +41,10 @@ class descriptor { private: sycl::queue queue_; std::unique_ptr pimpl_; - int x; public: // Syntax for 1-dimensional DFT descriptor(std::int64_t length) -#ifdef ENABLE_MKLCPU_BACKEND - : pimpl_(mklcpu::create_descriptor(length)) {} -#endif -#ifdef ENABLE_MKLGPU_BACKEND - : pimpl_(mklgpu::create_descriptor(length)) {} -#endif + : pimpl_(detail::create_descriptor(length)) {} // Syntax for d-dimensional DFT descriptor(std::vector dimensions); diff --git a/include/oneapi/mkl/dft/detail/descriptor_impl.hpp b/include/oneapi/mkl/dft/detail/descriptor_impl.hpp index 2361a4750..322f15e08 100644 --- a/include/oneapi/mkl/dft/detail/descriptor_impl.hpp +++ b/include/oneapi/mkl/dft/detail/descriptor_impl.hpp @@ -34,6 +34,8 @@ class descriptor_impl { protected: sycl::queue queue_; std::size_t length_; + oneapi::mkl::dft::precision prec_; + oneapi::mkl::dft::domain dom_; }; } // namespace detail diff --git a/include/oneapi/mkl/dft/detail/dft_loader.hpp b/include/oneapi/mkl/dft/detail/dft_loader.hpp index 5314058d3..32d73040d 100644 --- a/include/oneapi/mkl/dft/detail/dft_loader.hpp +++ b/include/oneapi/mkl/dft/detail/dft_loader.hpp @@ -16,19 +16,12 @@ namespace oneapi { namespace mkl { namespace dft { +namespace detail { -namespace mklcpu { - -ONEMKL_EXPORT oneapi::mkl::dft::detail::descriptor_impl* create_descriptor(std::size_t length); - -} // namespace mklcpu - -namespace mklgpu { - -ONEMKL_EXPORT oneapi::mkl::dft::detail::descriptor_impl* create_descriptor(std::size_t length); - -} // namespace mklgpu +template +oneapi::mkl::dft::detail::descriptor_impl* create_descriptor(std::size_t length); +} // namespace detail } // namespace dft } // namespace mkl } // namespace oneapi diff --git a/src/dft/backends/descriptor.cxx b/src/dft/backends/descriptor.cxx new file mode 100644 index 000000000..002c229e6 --- /dev/null +++ b/src/dft/backends/descriptor.cxx @@ -0,0 +1,90 @@ +#include +#if __has_include() +#include +#else +#include +#endif + +#include "mkl_version.h" + +#include "oneapi/mkl/types.hpp" + +#include "oneapi/mkl/dft/detail/descriptor_impl.hpp" +#include "oneapi/mkl/dft/descriptor.hpp" +#include "oneapi/mkl/exceptions.hpp" + +#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" +#include "mkl_dfti.h" + +namespace oneapi { +namespace mkl { +namespace dft { +namespace detail { + +class descriptor_derived_impl : public oneapi::mkl::dft::detail::descriptor_impl { +public: + descriptor_derived_impl(std::size_t length) + : oneapi::mkl::dft::detail::descriptor_impl(length) { + std::cout << "special entry points" << std::endl; + } + + descriptor_derived_impl(const descriptor_derived_impl* other) + : oneapi::mkl::dft::detail::descriptor_impl(*other) { + std::cout << "special entry points copy const" << std::endl; + } + + virtual oneapi::mkl::dft::detail::descriptor_impl* copy_state() override { + return new descriptor_derived_impl(this); + } + + virtual ~descriptor_derived_impl() override { + std::cout << "descriptor_derived_impl descriptor" << std::endl; + } + + void set_precision(oneapi::mkl::dft::precision prec) {prec_ = prec;} + void set_domain(oneapi::mkl::dft::domain dom) {dom_ = dom;} + +private: + DFTI_DESCRIPTOR_HANDLE hand; +}; + +template <> +oneapi::mkl::dft::detail::descriptor_impl* +create_descriptor(std::size_t length) { + auto desc_pimpl = new descriptor_derived_impl(length); + desc_pimpl->set_precision(oneapi::mkl::dft::precision::DOUBLE); + desc_pimpl->set_domain(oneapi::mkl::dft::domain::COMPLEX); + return desc_pimpl; +} + +template <> +oneapi::mkl::dft::detail::descriptor_impl* +create_descriptor(std::size_t length) { + auto desc_pimpl = new descriptor_derived_impl(length); + desc_pimpl->set_precision(oneapi::mkl::dft::precision::DOUBLE); + desc_pimpl->set_domain(oneapi::mkl::dft::domain::REAL); + return desc_pimpl; +} + +template <> +oneapi::mkl::dft::detail::descriptor_impl* +create_descriptor(std::size_t length) { + auto desc_pimpl = new descriptor_derived_impl(length); + desc_pimpl->set_precision(oneapi::mkl::dft::precision::SINGLE); + desc_pimpl->set_domain(oneapi::mkl::dft::domain::COMPLEX); + return desc_pimpl; +} + +template <> +oneapi::mkl::dft::detail::descriptor_impl* +create_descriptor(std::size_t length) { + auto desc_pimpl = new descriptor_derived_impl(length); + desc_pimpl->set_precision(oneapi::mkl::dft::precision::SINGLE); + desc_pimpl->set_domain(oneapi::mkl::dft::domain::REAL); + return desc_pimpl; +} + +} // namespace detail +} // namespace dft +} // namespace mkl +} // namespace oneapi diff --git a/src/dft/backends/mklcpu/CMakeLists.txt b/src/dft/backends/mklcpu/CMakeLists.txt index 4ff97cd54..57ff6dd98 100644 --- a/src/dft/backends/mklcpu/CMakeLists.txt +++ b/src/dft/backends/mklcpu/CMakeLists.txt @@ -25,7 +25,6 @@ find_package(MKL REQUIRED) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT commit.cpp - descriptor.cpp forward.cpp backward.cpp $<$: mkl_dft_cpu_wrappers.cpp> diff --git a/src/dft/backends/mklcpu/descriptor.cpp b/src/dft/backends/mklcpu/descriptor.cpp deleted file mode 100644 index d495131ed..000000000 --- a/src/dft/backends/mklcpu/descriptor.cpp +++ /dev/null @@ -1,57 +0,0 @@ -#include -#if __has_include() -#include -#else -#include -#endif - -#include "mkl_version.h" - -#include "oneapi/mkl/types.hpp" - -#include "oneapi/mkl/dft/detail/descriptor_impl.hpp" -#include "oneapi/mkl/dft/descriptor.hpp" -#include "oneapi/mkl/exceptions.hpp" - -#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" - -#include "mkl_dfti.h" - -namespace oneapi { -namespace mkl { -namespace dft { -namespace mklcpu { - -class descriptor_derived_impl : public oneapi::mkl::dft::detail::descriptor_impl { -public: - descriptor_derived_impl(std::size_t length) - : oneapi::mkl::dft::detail::descriptor_impl(length) { - std::cout << "special entry points" << std::endl; - DFTI_DESCRIPTOR_HANDLE hand = NULL; - } - - descriptor_derived_impl(const descriptor_derived_impl* other) - : oneapi::mkl::dft::detail::descriptor_impl(*other) { - std::cout << "special entry points copy const" << std::endl; - } - - virtual oneapi::mkl::dft::detail::descriptor_impl* copy_state() override { - return new descriptor_derived_impl(this); - } - - virtual ~descriptor_derived_impl() override { - std::cout << "descriptor_derived_impl descriptor" << std::endl; - } -private: -}; - -oneapi::mkl::dft::detail::descriptor_impl* create_descriptor(std::size_t length) { - return new descriptor_derived_impl(length); -} - - - -} // namespace mklcpu -} // namespace dft -} // namespace mkl -} // namespace oneapi diff --git a/src/dft/backends/mklcpu/mkl_dft_cpu_wrappers.cpp b/src/dft/backends/mklcpu/mkl_dft_cpu_wrappers.cpp index 8a5374cd8..7ce1bbf63 100644 --- a/src/dft/backends/mklcpu/mkl_dft_cpu_wrappers.cpp +++ b/src/dft/backends/mklcpu/mkl_dft_cpu_wrappers.cpp @@ -19,6 +19,7 @@ #include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" #include "dft/function_table.hpp" +#include "../descriptor.cxx" #define WRAPPER_VERSION 1 diff --git a/src/dft/backends/mklgpu/CMakeLists.txt b/src/dft/backends/mklgpu/CMakeLists.txt index 6a91a8c34..d373d2957 100644 --- a/src/dft/backends/mklgpu/CMakeLists.txt +++ b/src/dft/backends/mklgpu/CMakeLists.txt @@ -24,8 +24,8 @@ find_package(MKL REQUIRED) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT + ../descriptor.cpp commit.cpp - descriptor.cpp forward.cpp backward.cpp $<$: mkl_dft_gpu_wrappers.cpp> diff --git a/src/dft/backends/mklgpu/descriptor.cpp b/src/dft/backends/mklgpu/descriptor.cpp deleted file mode 100644 index c2805345a..000000000 --- a/src/dft/backends/mklgpu/descriptor.cpp +++ /dev/null @@ -1,54 +0,0 @@ -#include -#if __has_include() -#include -#else -#include -#endif - -#include "mkl_version.h" - -#include "oneapi/mkl/types.hpp" - -#include "oneapi/mkl/dft/detail/descriptor_impl.hpp" -#include "oneapi/mkl/dft/descriptor.hpp" -#include "oneapi/mkl/exceptions.hpp" - -#include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp" - -namespace oneapi { -namespace mkl { -namespace dft { -namespace mklgpu { - -class descriptor_derived_impl : public oneapi::mkl::dft::detail::descriptor_impl { -public: - descriptor_derived_impl(std::size_t length) - : oneapi::mkl::dft::detail::descriptor_impl(length) { - std::cout << "special entry points" << std::endl; - } - - descriptor_derived_impl(const descriptor_derived_impl* other) - : oneapi::mkl::dft::detail::descriptor_impl(*other) { - std::cout << "special entry points copy const" << std::endl; - } - - virtual oneapi::mkl::dft::detail::descriptor_impl* copy_state() override { - return new descriptor_derived_impl(this); - } - - virtual ~descriptor_derived_impl() override { - std::cout << "descriptor_derived_impl descriptor" << std::endl; - } -private: -}; - -oneapi::mkl::dft::detail::descriptor_impl* create_descriptor(std::size_t length) { - return new descriptor_derived_impl(length); -} - - - -} // namespace mklgpu -} // namespace dft -} // namespace mkl -} // namespace oneapi From 0c629273ee61f7c1aca50b44ed68f15bf05b98ac Mon Sep 17 00:00:00 2001 From: "Anant, Srivastava" Date: Wed, 12 Oct 2022 01:17:27 -0700 Subject: [PATCH 06/21] initial implementation of set_value and corresponding desc_impl class changes --- .../complex_fwd_usm_mklcpu.cpp | 16 ++- include/oneapi/mkl/dft/descriptor.hpp | 9 +- .../oneapi/mkl/dft/detail/descriptor_impl.hpp | 52 +++++++- include/oneapi/mkl/dft/detail/dft_loader.hpp | 3 + include/oneapi/mkl/dft/types.hpp | 112 ++++++++++++++++++ include/oneapi/mkl/types.hpp | 64 ---------- src/dft/backends/descriptor.cxx | 78 ++++++++---- src/dft/backends/mklcpu/backward.cpp | 1 + src/dft/backends/mklcpu/commit.cpp | 1 + src/dft/backends/mklcpu/forward.cpp | 1 + src/dft/function_table.hpp | 1 + 11 files changed, 243 insertions(+), 95 deletions(-) create mode 100644 include/oneapi/mkl/dft/types.hpp diff --git a/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp b/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp index 4eea5e051..170437253 100644 --- a/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp +++ b/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp @@ -48,7 +48,9 @@ void run_getrs_example(const sycl::device& cpu_device) { // Matrix sizes and leading dimensions - constexpr std::size_t n = 10; + constexpr std::size_t N = 10; + std::int64_t rs[3] {0, N, 1}; + // Catch asynchronous exceptions for cpu and cpu auto cpu_error_handler = [&](sycl::exception_list exceptions) { @@ -75,11 +77,17 @@ void run_getrs_example(const sycl::device& cpu_device) { sycl::context cpu_context = cpu_queue.get_context(); sycl::event cpu_getrf_done; - double *x_usm = (double*) malloc_shared(n*2*sizeof(double), cpu_queue.get_device(), cpu_queue.get_context()); + double *x_usm = (double*) malloc_shared(N*2*sizeof(double), cpu_queue.get_device(), cpu_queue.get_context()); // enabling - oneapi::mkl::dft::descriptor desc(10); - // desc.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (double)(1.0/N)); + oneapi::mkl::dft::descriptor desc(N); + oneapi::mkl::dft::descriptor desc_vector({N,N}); + desc.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (double)(1.0/N)); + desc.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, 4); + desc.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, rs); + desc.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, N); + desc.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, N); + desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); // [compile time] desc.commit(oneapi::mkl::backend_selector{ cpu_queue }); // [run time] desc.commit(cpu_queue); // oneapi::mkl::dft::compute_forward(desc, x_usm); diff --git a/include/oneapi/mkl/dft/descriptor.hpp b/include/oneapi/mkl/dft/descriptor.hpp index 5d9e31e02..802345d44 100644 --- a/include/oneapi/mkl/dft/descriptor.hpp +++ b/include/oneapi/mkl/dft/descriptor.hpp @@ -27,6 +27,7 @@ #endif #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/dft/types.hpp" #include "oneapi/mkl/detail/backend_selector.hpp" #include "oneapi/mkl/dft/detail/descriptor_impl.hpp" @@ -47,11 +48,15 @@ class descriptor { : pimpl_(detail::create_descriptor(length)) {} // Syntax for d-dimensional DFT - descriptor(std::vector dimensions); + descriptor(std::vector dimensions) + : pimpl_(detail::create_descriptor(dimensions)) {} // ~descriptor(); - void set_value(config_param param, ...); + template + void set_value(config_param param, Types... args) { + pimpl_->set_value(param, args...); + } void get_value(config_param param, ...); diff --git a/include/oneapi/mkl/dft/detail/descriptor_impl.hpp b/include/oneapi/mkl/dft/detail/descriptor_impl.hpp index 322f15e08..b7717db24 100644 --- a/include/oneapi/mkl/dft/detail/descriptor_impl.hpp +++ b/include/oneapi/mkl/dft/detail/descriptor_impl.hpp @@ -8,9 +8,11 @@ #include #endif +#include "oneapi/mkl/types.hpp" + #include "oneapi/mkl/detail/export.hpp" #include "oneapi/mkl/detail/get_device_id.hpp" -#include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/dft/types.hpp" namespace oneapi { namespace mkl { @@ -21,8 +23,52 @@ class descriptor_impl { public: descriptor_impl(std::size_t length) : length_(length) {} + descriptor_impl(std::vector dimension) : dimension_(dimension) {} + descriptor_impl(const descriptor_impl& other) : length_(other.length_) {} + void set_value(config_param param, ...) { + int err = 0; + va_list vl; + va_start(vl, param); + switch (param) + { + case config_param::INPUT_STRIDES: + // values.input_strides = va_arg(vl, std::vector); + break; + case config_param::OUTPUT_STRIDES: + // values.output_strides = va_arg(vl, std::vector); + break; + case config_param::FORWARD_SCALE: + values.fwd_scale = va_arg(vl, double); + break; + case config_param::BACKWARD_SCALE: + values.bwd_scale = va_arg(vl, double); + break; + case config_param::NUMBER_OF_TRANSFORMS: + values.number_of_transform = va_arg(vl, int64_t); + break; + case config_param::FWD_DISTANCE: + values.fwd_dist = va_arg(vl, int64_t); + break; + case config_param::BWD_DISTANCE: + values.bwd_dist = va_arg(vl, int64_t); + break; + case config_param::PLACEMENT: + values.placement = va_arg(vl, config_value); + break; + case config_param::COMPLEX_STORAGE: + values.complex_storage = va_arg(vl, config_value); + break; + case config_param::CONJUGATE_EVEN_STORAGE: + values.conj_even_storage = va_arg(vl, config_value); + break; + + default: err = 1; + } + va_end(vl); + } + virtual descriptor_impl* copy_state() = 0; virtual ~descriptor_impl() {} @@ -34,8 +80,12 @@ class descriptor_impl { protected: sycl::queue queue_; std::size_t length_; + std::vector dimension_; + + // descriptor configuration values and structs oneapi::mkl::dft::precision prec_; oneapi::mkl::dft::domain dom_; + oneapi::mkl::dft::dft_values values; }; } // namespace detail diff --git a/include/oneapi/mkl/dft/detail/dft_loader.hpp b/include/oneapi/mkl/dft/detail/dft_loader.hpp index 32d73040d..94e2562bd 100644 --- a/include/oneapi/mkl/dft/detail/dft_loader.hpp +++ b/include/oneapi/mkl/dft/detail/dft_loader.hpp @@ -21,6 +21,9 @@ namespace detail { template oneapi::mkl::dft::detail::descriptor_impl* create_descriptor(std::size_t length); +template +oneapi::mkl::dft::detail::descriptor_impl* create_descriptor(std::vector dimensions); + } // namespace detail } // namespace dft } // namespace mkl diff --git a/include/oneapi/mkl/dft/types.hpp b/include/oneapi/mkl/dft/types.hpp new file mode 100644 index 000000000..1c62297ae --- /dev/null +++ b/include/oneapi/mkl/dft/types.hpp @@ -0,0 +1,112 @@ +/******************************************************************************* +* Copyright 2020-2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_TYPES_HPP_ +#define _ONEMKL_DFT_TYPES_HPP_ + +#include "oneapi/mkl/bfloat16.hpp" +#if __has_include() +#include +#else +#include +#endif + +namespace oneapi { +namespace mkl { +namespace dft { + +enum class precision { SINGLE, DOUBLE }; +enum class domain { REAL, COMPLEX }; +enum class config_param { + FORWARD_DOMAIN, + DIMENSION, + LENGTHS, + PRECISION, + + FORWARD_SCALE, + BACKWARD_SCALE, + + NUMBER_OF_TRANSFORMS, + + COMPLEX_STORAGE, + // WHAT IS THE FUTURE OF THIS ?? + REAL_STORAGE, + CONJUGATE_EVEN_STORAGE, + + PLACEMENT, + + INPUT_STRIDES, + OUTPUT_STRIDES, + + FWD_DISTANCE, + BWD_DISTANCE, + + WORKSPACE, + ORDERING, + TRANSPOSE, + PACKED_FORMAT, + COMMIT_STATUS +}; +enum class config_value { + // for config_param::COMMIT_STATUS + COMMITTED, + UNCOMMITTED, + + // for config_param::COMPLEX_STORAGE, + // config_param::REAL_STORAGE and + // config_param::CONJUGATE_EVEN_STORAGE + COMPLEX_COMPLEX, + REAL_COMPLEX, + REAL_REAL, + + // for config_param::PLACEMENT + INPLACE, + NOT_INPLACE, + + // for config_param::ORDERING + ORDERED, + BACKWARD_SCRAMBLED, + + // Allow/avoid certain usages + ALLOW, + AVOID, + NONE, + + // for config_param::PACKED_FORMAT for storing conjugate-even finite sequence in real containers + CCE_FORMAT + +}; + +struct dft_values { + std::vector input_strides; + std::vector output_strides; + double bwd_scale; + double fwd_scale; + std::int64_t number_of_transform; + std::int64_t fwd_dist; + std::int64_t bwd_dist; + config_value placement; + config_value complex_storage; + config_value conj_even_storage; +}; +} // namespace dft +} // namespace mkl +} // namespace oneapi + +#endif //_ONEMKL_TYPES_HPP_ \ No newline at end of file diff --git a/include/oneapi/mkl/types.hpp b/include/oneapi/mkl/types.hpp index faf39235c..87503658f 100644 --- a/include/oneapi/mkl/types.hpp +++ b/include/oneapi/mkl/types.hpp @@ -109,70 +109,6 @@ enum class order : char { E = 1, }; -//DFT flag types -namespace dft { -enum class precision { SINGLE, DOUBLE }; -enum class domain { REAL, COMPLEX }; -enum class config_param { - FORWARD_DOMAIN, - DIMENSION, - LENGTHS, - PRECISION, - - FORWARD_SCALE, - BACKWARD_SCALE, - - NUMBER_OF_TRANSFORMS, - - COMPLEX_STORAGE, - REAL_STORAGE, - CONJUGATE_EVEN_STORAGE, - - PLACEMENT, - - INPUT_STRIDES, - OUTPUT_STRIDES, - - FWD_DISTANCE, - BWD_DISTANCE, - - WORKSPACE, - ORDERING, - TRANSPOSE, - PACKED_FORMAT, - COMMIT_STATUS -}; -enum class config_value { - // for config_param::COMMIT_STATUS - COMMITTED, - UNCOMMITTED, - - // for config_param::COMPLEX_STORAGE, - // config_param::REAL_STORAGE and - // config_param::CONJUGATE_EVEN_STORAGE - COMPLEX_COMPLEX, - REAL_COMPLEX, - REAL_REAL, - - // for config_param::PLACEMENT - INPLACE, - NOT_INPLACE, - - // for config_param::ORDERING - ORDERED, - BACKWARD_SCRAMBLED, - - // Allow/avoid certain usages - ALLOW, - AVOID, - NONE, - - // for config_param::PACKED_FORMAT for storing conjugate-even finite sequence in real containers - CCE_FORMAT - -}; -} // namespace dft - } //namespace mkl } //namespace oneapi diff --git a/src/dft/backends/descriptor.cxx b/src/dft/backends/descriptor.cxx index 002c229e6..54e5a2320 100644 --- a/src/dft/backends/descriptor.cxx +++ b/src/dft/backends/descriptor.cxx @@ -8,6 +8,7 @@ #include "mkl_version.h" #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/dft/types.hpp" #include "oneapi/mkl/dft/detail/descriptor_impl.hpp" #include "oneapi/mkl/dft/descriptor.hpp" @@ -21,67 +22,96 @@ namespace mkl { namespace dft { namespace detail { +template class descriptor_derived_impl : public oneapi::mkl::dft::detail::descriptor_impl { public: - descriptor_derived_impl(std::size_t length) - : oneapi::mkl::dft::detail::descriptor_impl(length) { - std::cout << "special entry points" << std::endl; + descriptor_derived_impl(std::size_t length) : oneapi::mkl::dft::detail::descriptor_impl(length) { + prec_ = prec; + dom_ = dom; } - descriptor_derived_impl(const descriptor_derived_impl* other) - : oneapi::mkl::dft::detail::descriptor_impl(*other) { + descriptor_derived_impl(std::vector dimensions) + : oneapi::mkl::dft::detail::descriptor_impl(dimensions) { + prec_ = prec; + dom_ = dom; + } + + descriptor_derived_impl(const descriptor_derived_impl* other) : oneapi::mkl::dft::detail::descriptor_impl(*other) { std::cout << "special entry points copy const" << std::endl; } + template + void set_value(config_param param, Types... args) { + printf("test... derived\n"); + } + virtual oneapi::mkl::dft::detail::descriptor_impl* copy_state() override { return new descriptor_derived_impl(this); } virtual ~descriptor_derived_impl() override { std::cout << "descriptor_derived_impl descriptor" << std::endl; + std::cout << values.bwd_scale << std::endl; } - void set_precision(oneapi::mkl::dft::precision prec) {prec_ = prec;} - void set_domain(oneapi::mkl::dft::domain dom) {dom_ = dom;} - private: DFTI_DESCRIPTOR_HANDLE hand; }; +// base constructor specialized template <> oneapi::mkl::dft::detail::descriptor_impl* create_descriptor(std::size_t length) { - auto desc_pimpl = new descriptor_derived_impl(length); - desc_pimpl->set_precision(oneapi::mkl::dft::precision::DOUBLE); - desc_pimpl->set_domain(oneapi::mkl::dft::domain::COMPLEX); - return desc_pimpl; + return new descriptor_derived_impl(length); } template <> oneapi::mkl::dft::detail::descriptor_impl* create_descriptor(std::size_t length) { - auto desc_pimpl = new descriptor_derived_impl(length); - desc_pimpl->set_precision(oneapi::mkl::dft::precision::DOUBLE); - desc_pimpl->set_domain(oneapi::mkl::dft::domain::REAL); - return desc_pimpl; + return new descriptor_derived_impl(length); } template <> oneapi::mkl::dft::detail::descriptor_impl* create_descriptor(std::size_t length) { - auto desc_pimpl = new descriptor_derived_impl(length); - desc_pimpl->set_precision(oneapi::mkl::dft::precision::SINGLE); - desc_pimpl->set_domain(oneapi::mkl::dft::domain::COMPLEX); - return desc_pimpl; + return new descriptor_derived_impl(length); } template <> oneapi::mkl::dft::detail::descriptor_impl* create_descriptor(std::size_t length) { - auto desc_pimpl = new descriptor_derived_impl(length); - desc_pimpl->set_precision(oneapi::mkl::dft::precision::SINGLE); - desc_pimpl->set_domain(oneapi::mkl::dft::domain::REAL); - return desc_pimpl; + return new descriptor_derived_impl(length); +} + +// vectorized constructor specialized +template <> +oneapi::mkl::dft::detail::descriptor_impl* +create_descriptor( + std::vector dimensions) { + return new descriptor_derived_impl( + dimensions); +} + +template <> +oneapi::mkl::dft::detail::descriptor_impl* +create_descriptor( + std::vector dimensions) { + return new descriptor_derived_impl(dimensions); +} + +template <> +oneapi::mkl::dft::detail::descriptor_impl* +create_descriptor( + std::vector dimensions) { + return new descriptor_derived_impl( + dimensions); +} + +template <> +oneapi::mkl::dft::detail::descriptor_impl* +create_descriptor( + std::vector dimensions) { + return new descriptor_derived_impl(dimensions); } } // namespace detail diff --git a/src/dft/backends/mklcpu/backward.cpp b/src/dft/backends/mklcpu/backward.cpp index 369e64a78..a7f86bc70 100644 --- a/src/dft/backends/mklcpu/backward.cpp +++ b/src/dft/backends/mklcpu/backward.cpp @@ -24,6 +24,7 @@ #endif #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/dft/types.hpp" #include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" diff --git a/src/dft/backends/mklcpu/commit.cpp b/src/dft/backends/mklcpu/commit.cpp index d28a193d7..3bc0b330b 100644 --- a/src/dft/backends/mklcpu/commit.cpp +++ b/src/dft/backends/mklcpu/commit.cpp @@ -24,6 +24,7 @@ #endif #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/dft/types.hpp" #include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" diff --git a/src/dft/backends/mklcpu/forward.cpp b/src/dft/backends/mklcpu/forward.cpp index 765adbebf..c34683672 100644 --- a/src/dft/backends/mklcpu/forward.cpp +++ b/src/dft/backends/mklcpu/forward.cpp @@ -24,6 +24,7 @@ #endif #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/dft/types.hpp" #include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" diff --git a/src/dft/function_table.hpp b/src/dft/function_table.hpp index a25bebe8b..de03ad365 100644 --- a/src/dft/function_table.hpp +++ b/src/dft/function_table.hpp @@ -30,6 +30,7 @@ #endif #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/dft/types.hpp" #include "oneapi/mkl/dft/descriptor.hpp" typedef struct { From d5f824b13f0e79fa48931a54edd74304b90872d0 Mon Sep 17 00:00:00 2001 From: "Anant, Srivastava" Date: Wed, 12 Oct 2022 11:43:06 -0700 Subject: [PATCH 07/21] before refactoring the desc_impl class --- include/oneapi/mkl/dft/descriptor.hpp | 8 ++++++++ include/oneapi/mkl/dft/types.hpp | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/include/oneapi/mkl/dft/descriptor.hpp b/include/oneapi/mkl/dft/descriptor.hpp index 802345d44..7441cd567 100644 --- a/include/oneapi/mkl/dft/descriptor.hpp +++ b/include/oneapi/mkl/dft/descriptor.hpp @@ -62,6 +62,14 @@ class descriptor { void commit(sycl::queue& queue); +#ifdef ENABLE_MKLCPU_BACKEND + void commit(backend_selector selector); +#endif + +#ifdef ENABLE_MKLGPU_BACKEND + void commit(backend_selector selector); +#endif + sycl::queue& get_queue() { return queue_; }; diff --git a/include/oneapi/mkl/dft/types.hpp b/include/oneapi/mkl/dft/types.hpp index 1c62297ae..796da59ad 100644 --- a/include/oneapi/mkl/dft/types.hpp +++ b/include/oneapi/mkl/dft/types.hpp @@ -104,6 +104,10 @@ struct dft_values { config_value placement; config_value complex_storage; config_value conj_even_storage; + + std::int64_t dimension; + config_value domain; + config_value precision; }; } // namespace dft } // namespace mkl From 710a4e157c8d3fbbfe2144708e4886066084c4c0 Mon Sep 17 00:00:00 2001 From: "Anant, Srivastava" Date: Thu, 13 Oct 2022 02:40:01 -0700 Subject: [PATCH 08/21] refactor and move desc class outside of pimpl & constructor generalization --- include/oneapi/mkl/dft/descriptor.hpp | 29 ++-- .../oneapi/mkl/dft/detail/descriptor_impl.hpp | 70 +------- include/oneapi/mkl/dft/detail/dft_loader.hpp | 32 ---- src/dft/backends/descriptor.cxx | 154 ++++++++---------- 4 files changed, 89 insertions(+), 196 deletions(-) delete mode 100644 include/oneapi/mkl/dft/detail/dft_loader.hpp diff --git a/include/oneapi/mkl/dft/descriptor.hpp b/include/oneapi/mkl/dft/descriptor.hpp index 7441cd567..78a6a25c7 100644 --- a/include/oneapi/mkl/dft/descriptor.hpp +++ b/include/oneapi/mkl/dft/descriptor.hpp @@ -31,7 +31,6 @@ #include "oneapi/mkl/detail/backend_selector.hpp" #include "oneapi/mkl/dft/detail/descriptor_impl.hpp" -#include "oneapi/mkl/dft/detail/dft_loader.hpp" namespace oneapi { namespace mkl { @@ -39,24 +38,16 @@ namespace dft { template class descriptor { -private: - sycl::queue queue_; - std::unique_ptr pimpl_; public: // Syntax for 1-dimensional DFT - descriptor(std::int64_t length) - : pimpl_(detail::create_descriptor(length)) {} + descriptor(std::int64_t length); // Syntax for d-dimensional DFT - descriptor(std::vector dimensions) - : pimpl_(detail::create_descriptor(dimensions)) {} + descriptor(std::vector dimensions); - // ~descriptor(); + ~descriptor() {} - template - void set_value(config_param param, Types... args) { - pimpl_->set_value(param, args...); - } + void set_value(config_param param, ...); void get_value(config_param param, ...); @@ -72,7 +63,17 @@ class descriptor { sycl::queue& get_queue() { return queue_; - }; + } +private: + sycl::queue queue_; + std::unique_ptr pimpl_; + + std::int64_t rank_; + std::vector dimension_; + + // descriptor configuration values and structs + void* handle_; + oneapi::mkl::dft::dft_values values; }; } //namespace dft diff --git a/include/oneapi/mkl/dft/detail/descriptor_impl.hpp b/include/oneapi/mkl/dft/detail/descriptor_impl.hpp index b7717db24..c6f2b5824 100644 --- a/include/oneapi/mkl/dft/detail/descriptor_impl.hpp +++ b/include/oneapi/mkl/dft/detail/descriptor_impl.hpp @@ -21,73 +21,19 @@ namespace detail { class descriptor_impl { public: - descriptor_impl(std::size_t length) : length_(length) {} - - descriptor_impl(std::vector dimension) : dimension_(dimension) {} - - descriptor_impl(const descriptor_impl& other) : length_(other.length_) {} - - void set_value(config_param param, ...) { - int err = 0; - va_list vl; - va_start(vl, param); - switch (param) - { - case config_param::INPUT_STRIDES: - // values.input_strides = va_arg(vl, std::vector); - break; - case config_param::OUTPUT_STRIDES: - // values.output_strides = va_arg(vl, std::vector); - break; - case config_param::FORWARD_SCALE: - values.fwd_scale = va_arg(vl, double); - break; - case config_param::BACKWARD_SCALE: - values.bwd_scale = va_arg(vl, double); - break; - case config_param::NUMBER_OF_TRANSFORMS: - values.number_of_transform = va_arg(vl, int64_t); - break; - case config_param::FWD_DISTANCE: - values.fwd_dist = va_arg(vl, int64_t); - break; - case config_param::BWD_DISTANCE: - values.bwd_dist = va_arg(vl, int64_t); - break; - case config_param::PLACEMENT: - values.placement = va_arg(vl, config_value); - break; - case config_param::COMPLEX_STORAGE: - values.complex_storage = va_arg(vl, config_value); - break; - case config_param::CONJUGATE_EVEN_STORAGE: - values.conj_even_storage = va_arg(vl, config_value); - break; - - default: err = 1; - } - va_end(vl); - } - - virtual descriptor_impl* copy_state() = 0; - - virtual ~descriptor_impl() {} - - sycl::queue& get_queue() { - return queue_; - } + descriptor_impl(); + ~descriptor_impl() {} protected: sycl::queue queue_; - std::size_t length_; - std::vector dimension_; - - // descriptor configuration values and structs - oneapi::mkl::dft::precision prec_; - oneapi::mkl::dft::domain dom_; - oneapi::mkl::dft::dft_values values; + void* handle_; }; +template +oneapi::mkl::dft::detail::descriptor_impl* create_commit(oneapi::mkl::device libkey, sycl::queue queue) { + return new descriptor_impl(); +} + } // namespace detail } // namespace dft } // namespace mkl diff --git a/include/oneapi/mkl/dft/detail/dft_loader.hpp b/include/oneapi/mkl/dft/detail/dft_loader.hpp deleted file mode 100644 index 94e2562bd..000000000 --- a/include/oneapi/mkl/dft/detail/dft_loader.hpp +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef _ONEMKL_DFT_LOADER_HPP_ -#define _ONEMKL_DFT_LOADER_HPP_ - -#include -#if __has_include() -#include -#else -#include -#endif - -#include "oneapi/mkl/detail/export.hpp" -#include "oneapi/mkl/detail/get_device_id.hpp" - -#include "oneapi/mkl/dft/detail/descriptor_impl.hpp" - -namespace oneapi { -namespace mkl { -namespace dft { -namespace detail { - -template -oneapi::mkl::dft::detail::descriptor_impl* create_descriptor(std::size_t length); - -template -oneapi::mkl::dft::detail::descriptor_impl* create_descriptor(std::vector dimensions); - -} // namespace detail -} // namespace dft -} // namespace mkl -} // namespace oneapi - -#endif //_ONEMKL_DFT_LOADER_HPP_ diff --git a/src/dft/backends/descriptor.cxx b/src/dft/backends/descriptor.cxx index 54e5a2320..5cac8f7d4 100644 --- a/src/dft/backends/descriptor.cxx +++ b/src/dft/backends/descriptor.cxx @@ -5,116 +5,94 @@ #include #endif -#include "mkl_version.h" - #include "oneapi/mkl/types.hpp" #include "oneapi/mkl/dft/types.hpp" -#include "oneapi/mkl/dft/detail/descriptor_impl.hpp" #include "oneapi/mkl/dft/descriptor.hpp" #include "oneapi/mkl/exceptions.hpp" #include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" + #include "mkl_dfti.h" namespace oneapi { namespace mkl { namespace dft { -namespace detail { template -class descriptor_derived_impl : public oneapi::mkl::dft::detail::descriptor_impl { -public: - descriptor_derived_impl(std::size_t length) : oneapi::mkl::dft::detail::descriptor_impl(length) { - prec_ = prec; - dom_ = dom; - } - - descriptor_derived_impl(std::vector dimensions) - : oneapi::mkl::dft::detail::descriptor_impl(dimensions) { - prec_ = prec; - dom_ = dom; - } - - descriptor_derived_impl(const descriptor_derived_impl* other) : oneapi::mkl::dft::detail::descriptor_impl(*other) { - std::cout << "special entry points copy const" << std::endl; - } - - template - void set_value(config_param param, Types... args) { - printf("test... derived\n"); - } - - virtual oneapi::mkl::dft::detail::descriptor_impl* copy_state() override { - return new descriptor_derived_impl(this); +descriptor::descriptor(std::vector dimension) : + dimension_(dimension), + handle_(nullptr), + rank_(dimension.size()) + { + // TODO: initialize the device_handle, handle_buffer + auto handle = reinterpret_cast(handle_); } - virtual ~descriptor_derived_impl() override { - std::cout << "descriptor_derived_impl descriptor" << std::endl; - std::cout << values.bwd_scale << std::endl; - } - -private: - DFTI_DESCRIPTOR_HANDLE hand; -}; - -// base constructor specialized -template <> -oneapi::mkl::dft::detail::descriptor_impl* -create_descriptor(std::size_t length) { - return new descriptor_derived_impl(length); -} - -template <> -oneapi::mkl::dft::detail::descriptor_impl* -create_descriptor(std::size_t length) { - return new descriptor_derived_impl(length); -} - -template <> -oneapi::mkl::dft::detail::descriptor_impl* -create_descriptor(std::size_t length) { - return new descriptor_derived_impl(length); -} - -template <> -oneapi::mkl::dft::detail::descriptor_impl* -create_descriptor(std::size_t length) { - return new descriptor_derived_impl(length); -} - -// vectorized constructor specialized -template <> -oneapi::mkl::dft::detail::descriptor_impl* -create_descriptor( - std::vector dimensions) { - return new descriptor_derived_impl( - dimensions); -} +template +descriptor::descriptor(std::int64_t length) : + descriptor(std::vector{length}) {} -template <> -oneapi::mkl::dft::detail::descriptor_impl* -create_descriptor( - std::vector dimensions) { - return new descriptor_derived_impl(dimensions); +template +void descriptor::set_value(config_param param, ...) { + int err = 0; + va_list vl; + va_start(vl, param); + switch (param) + { + case config_param::INPUT_STRIDES: + // values.input_strides = va_arg(vl, std::vector); + break; + case config_param::OUTPUT_STRIDES: + // values.output_strides = va_arg(vl, std::vector); + break; + case config_param::FORWARD_SCALE: + values.fwd_scale = va_arg(vl, double); + break; + case config_param::BACKWARD_SCALE: + values.bwd_scale = va_arg(vl, double); + break; + case config_param::NUMBER_OF_TRANSFORMS: + values.number_of_transform = va_arg(vl, int64_t); + break; + case config_param::FWD_DISTANCE: + values.fwd_dist = va_arg(vl, int64_t); + break; + case config_param::BWD_DISTANCE: + values.bwd_dist = va_arg(vl, int64_t); + break; + case config_param::PLACEMENT: + values.placement = va_arg(vl, config_value); + break; + case config_param::COMPLEX_STORAGE: + values.complex_storage = va_arg(vl, config_value); + break; + case config_param::CONJUGATE_EVEN_STORAGE: + values.conj_even_storage = va_arg(vl, config_value); + break; + + default: err = 1; + } + va_end(vl); } -template <> -oneapi::mkl::dft::detail::descriptor_impl* -create_descriptor( - std::vector dimensions) { - return new descriptor_derived_impl( - dimensions); +template +void descriptor::get_value(config_param param, ...) { + int err = 0; + va_list vl; + va_start(vl, param); + switch (param) + { + default: break; + } + va_end(vl); } -template <> -oneapi::mkl::dft::detail::descriptor_impl* -create_descriptor( - std::vector dimensions) { - return new descriptor_derived_impl(dimensions); -} +template class descriptor; +template class descriptor; +template class descriptor; +template class descriptor; -} // namespace detail } // namespace dft } // namespace mkl } // namespace oneapi From 3f9d0b8c3378d5cb3929349cedfce6c801d15f0a Mon Sep 17 00:00:00 2001 From: "Anant, Srivastava" Date: Thu, 13 Oct 2022 03:20:19 -0700 Subject: [PATCH 09/21] enable setting strides --- .../complex_fwd_usm_mklcpu.cpp | 2 +- include/oneapi/mkl/dft/descriptor.hpp | 2 +- src/dft/backends/descriptor.cxx | 75 +++++++++++-------- 3 files changed, 44 insertions(+), 35 deletions(-) diff --git a/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp b/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp index 170437253..6101cb8f6 100644 --- a/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp +++ b/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp @@ -84,7 +84,7 @@ void run_getrs_example(const sycl::device& cpu_device) { oneapi::mkl::dft::descriptor desc_vector({N,N}); desc.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (double)(1.0/N)); desc.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, 4); - desc.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, rs); + desc_vector.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, rs); desc.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, N); desc.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, N); desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); diff --git a/include/oneapi/mkl/dft/descriptor.hpp b/include/oneapi/mkl/dft/descriptor.hpp index 78a6a25c7..057b14fbb 100644 --- a/include/oneapi/mkl/dft/descriptor.hpp +++ b/include/oneapi/mkl/dft/descriptor.hpp @@ -45,7 +45,7 @@ class descriptor { // Syntax for d-dimensional DFT descriptor(std::vector dimensions); - ~descriptor() {} + ~descriptor(); void set_value(config_param param, ...); diff --git a/src/dft/backends/descriptor.cxx b/src/dft/backends/descriptor.cxx index 5cac8f7d4..94fc29fae 100644 --- a/src/dft/backends/descriptor.cxx +++ b/src/dft/backends/descriptor.cxx @@ -33,45 +33,54 @@ template descriptor::descriptor(std::int64_t length) : descriptor(std::vector{length}) {} +template +descriptor::~descriptor() { + // call DftiFreeDescriptor +} + +// impliment error class template void descriptor::set_value(config_param param, ...) { int err = 0; va_list vl; va_start(vl, param); - switch (param) - { - case config_param::INPUT_STRIDES: - // values.input_strides = va_arg(vl, std::vector); - break; - case config_param::OUTPUT_STRIDES: - // values.output_strides = va_arg(vl, std::vector); - break; - case config_param::FORWARD_SCALE: - values.fwd_scale = va_arg(vl, double); - break; - case config_param::BACKWARD_SCALE: - values.bwd_scale = va_arg(vl, double); - break; - case config_param::NUMBER_OF_TRANSFORMS: - values.number_of_transform = va_arg(vl, int64_t); - break; - case config_param::FWD_DISTANCE: - values.fwd_dist = va_arg(vl, int64_t); - break; - case config_param::BWD_DISTANCE: - values.bwd_dist = va_arg(vl, int64_t); - break; - case config_param::PLACEMENT: - values.placement = va_arg(vl, config_value); - break; - case config_param::COMPLEX_STORAGE: - values.complex_storage = va_arg(vl, config_value); - break; - case config_param::CONJUGATE_EVEN_STORAGE: - values.conj_even_storage = va_arg(vl, config_value); - break; + switch (param) { + case config_param::INPUT_STRIDES: + case config_param::OUTPUT_STRIDES: { + int64_t *strides = va_arg(vl, int64_t *); + if (strides == nullptr) break; + + if (param == config_param::INPUT_STRIDES) + std::copy(strides, strides+rank_+1, std::back_inserter(values.input_strides)); + if (param == config_param::OUTPUT_STRIDES) + std::copy(strides, strides+rank_+1, std::back_inserter(values.output_strides)); + } break; + case config_param::FORWARD_SCALE: + values.fwd_scale = va_arg(vl, double); + break; + case config_param::BACKWARD_SCALE: + values.bwd_scale = va_arg(vl, double); + break; + case config_param::NUMBER_OF_TRANSFORMS: + values.number_of_transform = va_arg(vl, int64_t); + break; + case config_param::FWD_DISTANCE: + values.fwd_dist = va_arg(vl, int64_t); + break; + case config_param::BWD_DISTANCE: + values.bwd_dist = va_arg(vl, int64_t); + break; + case config_param::PLACEMENT: + values.placement = va_arg(vl, config_value); + break; + case config_param::COMPLEX_STORAGE: + values.complex_storage = va_arg(vl, config_value); + break; + case config_param::CONJUGATE_EVEN_STORAGE: + values.conj_even_storage = va_arg(vl, config_value); + break; - default: err = 1; + default: err = 1; } va_end(vl); } From ba84e0edc13cf3335ad864f51feb138b14013e0b Mon Sep 17 00:00:00 2001 From: "Anant, Srivastava" Date: Mon, 31 Oct 2022 08:23:56 -0700 Subject: [PATCH 10/21] cpu commit+set_value --- .../complex_fwd_usm_mklcpu.cpp | 31 +- .../run_time_dispatching/complex_fwd_usm.cpp | 31 +- .../uniform_usm_mklcpu_curand.cpp | 1 + include/oneapi/mkl/dft.hpp | 12 + include/oneapi/mkl/dft/descriptor.hpp | 115 +++++- .../{descriptor_impl.hpp => commit_impl.hpp} | 32 +- include/oneapi/mkl/dft/detail/dft_loader.hpp | 48 +++ .../dft/detail/mklcpu/onemkl_dft_mklcpu.hpp | 106 +----- .../dft/detail/mklgpu/onemkl_dft_mklgpu.hpp | 103 +----- include/oneapi/mkl/dft/types.hpp | 68 +++- src/dft/backends/CMakeLists.txt | 6 +- src/dft/backends/descriptor.cxx | 107 ------ src/dft/backends/mklcpu/CMakeLists.txt | 5 +- src/dft/backends/mklcpu/backward.cpp | 322 ++++++++--------- src/dft/backends/mklcpu/commit.cpp | 97 +++++- src/dft/backends/mklcpu/forward.cpp | 328 +++++++++--------- .../backends/mklcpu/mkl_dft_cpu_wrappers.cpp | 25 +- src/dft/backends/mklgpu/CMakeLists.txt | 1 - src/dft/dft_loader.cpp | 184 +--------- src/dft/function_table.hpp | 70 +--- 20 files changed, 724 insertions(+), 968 deletions(-) rename include/oneapi/mkl/dft/detail/{descriptor_impl.hpp => commit_impl.hpp} (50%) create mode 100644 include/oneapi/mkl/dft/detail/dft_loader.hpp delete mode 100644 src/dft/backends/descriptor.cxx diff --git a/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp b/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp index 6101cb8f6..014617214 100644 --- a/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp +++ b/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp @@ -77,20 +77,23 @@ void run_getrs_example(const sycl::device& cpu_device) { sycl::context cpu_context = cpu_queue.get_context(); sycl::event cpu_getrf_done; - double *x_usm = (double*) malloc_shared(N*2*sizeof(double), cpu_queue.get_device(), cpu_queue.get_context()); - - // enabling - oneapi::mkl::dft::descriptor desc(N); - oneapi::mkl::dft::descriptor desc_vector({N,N}); - desc.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (double)(1.0/N)); - desc.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, 4); - desc_vector.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, rs); - desc.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, N); - desc.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, N); - desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); - // [compile time] desc.commit(oneapi::mkl::backend_selector{ cpu_queue }); - // [run time] desc.commit(cpu_queue); - // oneapi::mkl::dft::compute_forward(desc, x_usm); +double *x_usm = (double*) malloc_shared(N*2*sizeof(double), cpu_queue.get_device(), cpu_queue.get_context()); + +// enabling +// 1. create descriptors +oneapi::mkl::dft::descriptor desc(N); + +// 2. variadic set_value +desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); + +// 3. commit_descriptor (compile_time CPU) +desc.commit(oneapi::mkl::backend_selector{ cpu_queue }); + +// 4. commit_descriptor (run_time xPU) unusable from libonemkl_dft_mklcpu.so +desc.commit(cpu_queue); + +// 5. compute_forward / compute_backward (CPU) +// oneapi::mkl::dft::compute_forward(desc, x_usm); } // diff --git a/examples/dft/run_time_dispatching/complex_fwd_usm.cpp b/examples/dft/run_time_dispatching/complex_fwd_usm.cpp index 1b906afb6..5b44d3442 100644 --- a/examples/dft/run_time_dispatching/complex_fwd_usm.cpp +++ b/examples/dft/run_time_dispatching/complex_fwd_usm.cpp @@ -44,12 +44,31 @@ void run_uniform_example(const sycl::device& dev) { sycl::queue queue(dev, exception_handler); - double *x_usm = (double*) malloc_shared(N*2*sizeof(double), queue.get_device(), queue.get_context()); - - oneapi::mkl::dft::descriptor< - oneapi::mkl::dft::precision::DOUBLE, - oneapi::mkl::dft::domain::COMPLEX - > desc(N); + std::cout << "DFTI example run_time dispatch" << std::endl; + // + // Preparation on cpu + // + sycl::queue cpu_queue(dev, exception_handler); + sycl::context cpu_context = cpu_queue.get_context(); + sycl::event cpu_getrf_done; + + double *x_usm = (double*) malloc_shared(N*2*sizeof(double), cpu_queue.get_device(), cpu_queue.get_context()); + + // enabling + // 1. create descriptors + oneapi::mkl::dft::descriptor desc_vector({N,N}); + + // 2. variadic set_value + desc_vector.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (double)(1.0/N)); + desc_vector.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, 4); + desc_vector.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, N); + desc_vector.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); + + // 4. commit_descriptor (run_time xPU) + desc_vector.commit(cpu_queue); + + // 5. compute_forward / compute_backward (CPU) + // oneapi::mkl::dft::compute_forward(desc, x_usm); } // diff --git a/examples/rng/compile_time_dispatching/uniform_usm_mklcpu_curand.cpp b/examples/rng/compile_time_dispatching/uniform_usm_mklcpu_curand.cpp index cdfd6c765..7eb3ce83a 100644 --- a/examples/rng/compile_time_dispatching/uniform_usm_mklcpu_curand.cpp +++ b/examples/rng/compile_time_dispatching/uniform_usm_mklcpu_curand.cpp @@ -98,6 +98,7 @@ void run_uniform_example(const sycl::device& cpu_dev, const sycl::device& gpu_de // preparation on CPU device and GPU device sycl::queue cpu_queue(cpu_dev, cpu_exception_handler); sycl::queue gpu_queue(gpu_dev, gpu_exception_handler); + oneapi::mkl::rng::default_engine test_engine(cpu_queue, seed); oneapi::mkl::rng::default_engine cpu_engine( oneapi::mkl::backend_selector{ cpu_queue }, seed); oneapi::mkl::rng::default_engine gpu_engine( diff --git a/include/oneapi/mkl/dft.hpp b/include/oneapi/mkl/dft.hpp index 9fd7b7ef6..179e3e056 100644 --- a/include/oneapi/mkl/dft.hpp +++ b/include/oneapi/mkl/dft.hpp @@ -20,6 +20,18 @@ #ifndef _ONEMKL_DFT_HPP_ #define _ONEMKL_DFT_HPP_ +#if __has_include() +#include +#else +#include +#endif +#include +#include + +#include "oneapi/mkl/detail/config.hpp" +#include "oneapi/mkl/detail/get_device_id.hpp" +#include "oneapi/mkl/dft/detail/dft_loader.hpp" + #include "oneapi/mkl/dft/descriptor.hpp" #include "oneapi/mkl/dft/forward.hpp" #include "oneapi/mkl/dft/backward.hpp" diff --git a/include/oneapi/mkl/dft/descriptor.hpp b/include/oneapi/mkl/dft/descriptor.hpp index 057b14fbb..59c3226cb 100644 --- a/include/oneapi/mkl/dft/descriptor.hpp +++ b/include/oneapi/mkl/dft/descriptor.hpp @@ -30,8 +30,15 @@ #include "oneapi/mkl/dft/types.hpp" #include "oneapi/mkl/detail/backend_selector.hpp" -#include "oneapi/mkl/dft/detail/descriptor_impl.hpp" +#include "oneapi/mkl/dft/detail/commit_impl.hpp" +#include "oneapi/mkl/dft/detail/dft_loader.hpp" +#ifdef ENABLE_MKLCPU_BACKEND +#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" +#endif +#ifdef ENABLE_MKLGPU_BACKEND +#include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp" +#endif namespace oneapi { namespace mkl { namespace dft { @@ -51,14 +58,20 @@ class descriptor { void get_value(config_param param, ...); - void commit(sycl::queue& queue); + void commit(sycl::queue& queue) { + pimpl_.reset(detail::create_commit(get_device_id(queue), queue, values)); + } #ifdef ENABLE_MKLCPU_BACKEND - void commit(backend_selector selector); + void commit(backend_selector selector) { + pimpl_.reset(mklcpu::create_commit(selector.get_queue(), values)); + } #endif #ifdef ENABLE_MKLGPU_BACKEND - void commit(backend_selector selector); + void commit(backend_selector selector) { + // pimpl_.reset(mklgpu::create_commit(selector.get_queue())); + } #endif sycl::queue& get_queue() { @@ -66,7 +79,7 @@ class descriptor { } private: sycl::queue queue_; - std::unique_ptr pimpl_; + std::unique_ptr pimpl_; // commit only std::int64_t rank_; std::vector dimension_; @@ -76,6 +89,98 @@ class descriptor { oneapi::mkl::dft::dft_values values; }; +template +descriptor::descriptor(std::vector dimension) : + dimension_(dimension), + handle_(nullptr), + rank_(dimension.size()) + { + // TODO: initialize the device_handle, handle_buffer + values.domain = dom; + values.precision = prec; + values.dimension = dimension_; + values.rank = rank_; + } + +template +descriptor::descriptor(std::int64_t length) : + descriptor(std::vector{length}) {} + +template +descriptor::~descriptor() { + // call DftiFreeDescriptor +} + +// impliment error class +template +void descriptor::set_value(config_param param, ...) { + int err = 0; + va_list vl; + va_start(vl, param); + printf("oneapi interface set_value\n"); + switch (param) { + case config_param::INPUT_STRIDES: + values.set_input_strides = true; + case config_param::OUTPUT_STRIDES: { + int64_t *strides = va_arg(vl, int64_t *); + if (strides == nullptr) break; + + if (param == config_param::INPUT_STRIDES) + std::copy(strides, strides+rank_+1, std::back_inserter(values.input_strides)); + if (param == config_param::OUTPUT_STRIDES) + std::copy(strides, strides+rank_+1, std::back_inserter(values.output_strides)); + values.set_output_strides = true; + } break; + case config_param::FORWARD_SCALE: + values.fwd_scale = va_arg(vl, double); + values.set_fwd_scale = true; + break; + case config_param::BACKWARD_SCALE: + values.bwd_scale = va_arg(vl, double); + values.set_bwd_scale = true; + break; + case config_param::NUMBER_OF_TRANSFORMS: + values.number_of_transforms = va_arg(vl, int64_t); + values.set_number_of_transforms = true; + break; + case config_param::FWD_DISTANCE: + values.fwd_dist = va_arg(vl, int64_t); + values.set_fwd_dist = true; + break; + case config_param::BWD_DISTANCE: + values.bwd_dist = va_arg(vl, int64_t); + values.set_bwd_dist = true; + break; + case config_param::PLACEMENT: + values.placement = va_arg(vl, config_value); + values.set_placement = true; + break; + case config_param::COMPLEX_STORAGE: + values.complex_storage = va_arg(vl, config_value); + values.set_complex_storage = true; + break; + case config_param::CONJUGATE_EVEN_STORAGE: + values.conj_even_storage = va_arg(vl, config_value); + values.set_conj_even_storage = true; + break; + + default: err = 1; + } + va_end(vl); +} + +template +void descriptor::get_value(config_param param, ...) { + int err = 0; + va_list vl; + va_start(vl, param); + switch (param) + { + default: break; + } + va_end(vl); +} + } //namespace dft } //namespace mkl } //namespace oneapi diff --git a/include/oneapi/mkl/dft/detail/descriptor_impl.hpp b/include/oneapi/mkl/dft/detail/commit_impl.hpp similarity index 50% rename from include/oneapi/mkl/dft/detail/descriptor_impl.hpp rename to include/oneapi/mkl/dft/detail/commit_impl.hpp index c6f2b5824..2cb556467 100644 --- a/include/oneapi/mkl/dft/detail/descriptor_impl.hpp +++ b/include/oneapi/mkl/dft/detail/commit_impl.hpp @@ -1,5 +1,5 @@ -#ifndef _ONEMKL_DFT_DESCRIPTOR_IMPL_HPP_ -#define _ONEMKL_DFT_DESCRIPTOR_IMPL_HPP_ +#ifndef _ONEMKL_DFT_COMMIT_IMPL_HPP_ +#define _ONEMKL_DFT_COMMIT_IMPL_HPP_ #include #if __has_include() @@ -8,36 +8,42 @@ #include #endif -#include "oneapi/mkl/types.hpp" - #include "oneapi/mkl/detail/export.hpp" #include "oneapi/mkl/detail/get_device_id.hpp" #include "oneapi/mkl/dft/types.hpp" +#include "oneapi/mkl/types.hpp" + namespace oneapi { namespace mkl { namespace dft { namespace detail { -class descriptor_impl { +class commit_impl { public: - descriptor_impl(); - ~descriptor_impl() {} + commit_impl(sycl::queue queue) : queue_(queue), handle(nullptr) {} + + commit_impl(const commit_impl& other) : queue_(other.queue_), handle(other.handle) {} + + virtual commit_impl* copy_state() = 0; + + virtual ~commit_impl() {} + + sycl::queue& get_queue() { + return queue_; + } protected: + bool status; sycl::queue queue_; - void* handle_; + void* handle; }; -template -oneapi::mkl::dft::detail::descriptor_impl* create_commit(oneapi::mkl::device libkey, sycl::queue queue) { - return new descriptor_impl(); -} } // namespace detail } // namespace dft } // namespace mkl } // namespace oneapi -#endif //_ONEMKL_DFT_DESCRIPTOR_IMPL_HPP_ +#endif //_ONEMKL_DFT_COMMIT_IMPL_HPP_ diff --git a/include/oneapi/mkl/dft/detail/dft_loader.hpp b/include/oneapi/mkl/dft/detail/dft_loader.hpp new file mode 100644 index 000000000..7c939e36e --- /dev/null +++ b/include/oneapi/mkl/dft/detail/dft_loader.hpp @@ -0,0 +1,48 @@ +/******************************************************************************* +* Copyright 2020-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_LOADER_HPP_ +#define _ONEMKL_DFT_LOADER_HPP_ + +#include +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/detail/export.hpp" +#include "oneapi/mkl/detail/get_device_id.hpp" + +#include "oneapi/mkl/dft/types.hpp" +#include "oneapi/mkl/dft/detail/commit_impl.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { +namespace detail { + +ONEMKL_EXPORT commit_impl* create_commit(oneapi::mkl::device libkey, sycl::queue queue, dft_values values); + +} // namespace detail +} // namespace dft +} // namespace mkl +} // namespace oneapi + +#endif //_ONEMKL_DFT_LOADER_HPP_ diff --git a/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp b/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp index edf8d706a..5063a2ba8 100644 --- a/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp +++ b/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp @@ -28,116 +28,16 @@ #include #include +#include "oneapi/mkl/dft/detail/commit_impl.hpp" +#include "oneapi/mkl/dft/types.hpp" #include "oneapi/mkl/types.hpp" -#include "oneapi/mkl/dft/descriptor.hpp" namespace oneapi { namespace mkl { namespace dft { namespace mklcpu { -#define ONEAPI_MKL_DFT_BACKEND_SIGNATURES(EXT, PRECISION, DOMAIN, T_REAL, T_FORWARD, T_BACKWARD) \ - \ - void commit_##EXT(descriptor &desc, sycl::queue &queue); \ - \ - /*Buffer version*/ \ - \ - /*In-place transform*/ \ - void compute_forward_buffer_inplace_##EXT(descriptor &desc, \ - sycl::buffer &inout); \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - void compute_forward_buffer_inplace_split_##EXT(descriptor &desc, \ - sycl::buffer &inout_re, \ - sycl::buffer &inout_im); \ - \ - /*Out-of-place transform*/ \ - void compute_forward_buffer_outofplace_##EXT(descriptor &desc, \ - sycl::buffer &in, \ - sycl::buffer &out); \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - void compute_forward_buffer_outofplace_split_##EXT( \ - descriptor &desc, sycl::buffer &in_re, \ - sycl::buffer &in_im, sycl::buffer &out_re, \ - sycl::buffer &out_im); \ - \ - /*USM version*/ \ - \ - /*In-place transform*/ \ - sycl::event compute_forward_usm_inplace_##EXT( \ - descriptor &desc, T_BACKWARD *inout, \ - const std::vector &dependencies = {}); \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - sycl::event compute_forward_usm_inplace_split_##EXT( \ - descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ - const std::vector &dependencies = {}); \ - \ - /*Out-of-place transform*/ \ - sycl::event compute_forward_usm_outofplace_##EXT( \ - descriptor &desc, T_FORWARD *in, T_BACKWARD *out, \ - const std::vector &dependencies = {}); \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - sycl::event compute_forward_usm_outofplace_split_##EXT( \ - descriptor &desc, T_REAL *in_re, T_REAL *in_im, T_REAL *out_re, \ - T_REAL *out_im, const std::vector &dependencies = {}); \ - \ - /*Buffer version*/ \ - \ - /*In-place transform*/ \ - void compute_backward_buffer_inplace_##EXT(descriptor &desc, \ - sycl::buffer &inout); \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - void compute_backward_buffer_inplace_split_##EXT(descriptor &desc, \ - sycl::buffer &inout_re, \ - sycl::buffer &inout_im); \ - \ - /*Out-of-place transform*/ \ - void compute_backward_buffer_outofplace_##EXT(descriptor &desc, \ - sycl::buffer &in, \ - sycl::buffer &out); \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - void compute_backward_buffer_outofplace_split_##EXT( \ - descriptor &desc, sycl::buffer &in_re, \ - sycl::buffer &in_im, sycl::buffer &out_re, \ - sycl::buffer &out_im); \ - \ - /*USM version*/ \ - \ - /*In-place transform*/ \ - sycl::event compute_backward_usm_inplace_##EXT( \ - descriptor &desc, T_BACKWARD *inout, \ - const std::vector &dependencies = {}); \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - sycl::event compute_backward_usm_inplace_split_##EXT( \ - descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ - const std::vector &dependencies = {}); \ - \ - /*Out-of-place transform*/ \ - sycl::event compute_backward_usm_outofplace_##EXT( \ - descriptor &desc, T_BACKWARD *in, T_FORWARD *out, \ - const std::vector &dependencies = {}); \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - sycl::event compute_backward_usm_outofplace_split_##EXT( \ - descriptor &desc, T_REAL *in_re, T_REAL *in_im, T_REAL *out_re, \ - T_REAL *out_im, const std::vector &dependencies = {}); - -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(f, precision::SINGLE, domain::REAL, float, float, - std::complex) -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(c, precision::SINGLE, domain::COMPLEX, float, std::complex, - std::complex) -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(d, precision::DOUBLE, domain::REAL, double, double, - std::complex) -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(z, precision::DOUBLE, domain::COMPLEX, double, - std::complex, std::complex) - -#undef ONEAPI_MKL_DFT_BACKEND_SIGNATURES +ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit(sycl::queue queue, dft_values values); } // namespace mklcpu } // namespace dft diff --git a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp index e82de9656..81a45257d 100644 --- a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp +++ b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp @@ -36,108 +36,7 @@ namespace mkl { namespace dft { namespace mklgpu { -#define ONEAPI_MKL_DFT_BACKEND_SIGNATURES(EXT, PRECISION, DOMAIN, T_REAL, T_FORWARD, T_BACKWARD) \ - \ - void commit_##EXT(descriptor &desc, sycl::queue &queue); \ - \ - /*Buffer version*/ \ - \ - /*In-place transform*/ \ - void compute_forward_buffer_inplace_##EXT(descriptor &desc, \ - sycl::buffer &inout); \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - void compute_forward_buffer_inplace_split_##EXT(descriptor &desc, \ - sycl::buffer &inout_re, \ - sycl::buffer &inout_im); \ - \ - /*Out-of-place transform*/ \ - void compute_forward_buffer_outofplace_##EXT(descriptor &desc, \ - sycl::buffer &in, \ - sycl::buffer &out); \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - void compute_forward_buffer_outofplace_split_##EXT( \ - descriptor &desc, sycl::buffer &in_re, \ - sycl::buffer &in_im, sycl::buffer &out_re, \ - sycl::buffer &out_im); \ - \ - /*USM version*/ \ - \ - /*In-place transform*/ \ - sycl::event compute_forward_usm_inplace_##EXT( \ - descriptor &desc, T_BACKWARD *inout, \ - const std::vector &dependencies = {}); \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - sycl::event compute_forward_usm_inplace_split_##EXT( \ - descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ - const std::vector &dependencies = {}); \ - \ - /*Out-of-place transform*/ \ - sycl::event compute_forward_usm_outofplace_##EXT( \ - descriptor &desc, T_FORWARD *in, T_BACKWARD *out, \ - const std::vector &dependencies = {}); \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - sycl::event compute_forward_usm_outofplace_split_##EXT( \ - descriptor &desc, T_REAL *in_re, T_REAL *in_im, T_REAL *out_re, \ - T_REAL *out_im, const std::vector &dependencies = {}); \ - \ - /*Buffer version*/ \ - \ - /*In-place transform*/ \ - void compute_backward_buffer_inplace_##EXT(descriptor &desc, \ - sycl::buffer &inout); \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - void compute_backward_buffer_inplace_split_##EXT(descriptor &desc, \ - sycl::buffer &inout_re, \ - sycl::buffer &inout_im); \ - \ - /*Out-of-place transform*/ \ - void compute_backward_buffer_outofplace_##EXT(descriptor &desc, \ - sycl::buffer &in, \ - sycl::buffer &out); \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - void compute_backward_buffer_outofplace_split_##EXT( \ - descriptor &desc, sycl::buffer &in_re, \ - sycl::buffer &in_im, sycl::buffer &out_re, \ - sycl::buffer &out_im); \ - \ - /*USM version*/ \ - \ - /*In-place transform*/ \ - sycl::event compute_backward_usm_inplace_##EXT( \ - descriptor &desc, T_BACKWARD *inout, \ - const std::vector &dependencies = {}); \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - sycl::event compute_backward_usm_inplace_split_##EXT( \ - descriptor &desc, T_REAL *inout_re, T_REAL *inout_im, \ - const std::vector &dependencies = {}); \ - \ - /*Out-of-place transform*/ \ - sycl::event compute_backward_usm_outofplace_##EXT( \ - descriptor &desc, T_BACKWARD *in, T_FORWARD *out, \ - const std::vector &dependencies = {}); \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - sycl::event compute_backward_usm_outofplace_split_##EXT( \ - descriptor &desc, T_REAL *in_re, T_REAL *in_im, T_REAL *out_re, \ - T_REAL *out_im, const std::vector &dependencies = {}); - -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(f, precision::SINGLE, domain::REAL, float, float, - std::complex) -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(c, precision::SINGLE, domain::COMPLEX, float, std::complex, - std::complex) -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(d, precision::DOUBLE, domain::REAL, double, double, - std::complex) -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(z, precision::DOUBLE, domain::COMPLEX, double, - std::complex, std::complex) - -#undef ONEAPI_MKL_DFT_BACKEND_SIGNATURES +ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit(sycl::queue queue, dft_values values); } // namespace mklgpu } // namespace dft diff --git a/include/oneapi/mkl/dft/types.hpp b/include/oneapi/mkl/dft/types.hpp index 796da59ad..857fff8e1 100644 --- a/include/oneapi/mkl/dft/types.hpp +++ b/include/oneapi/mkl/dft/types.hpp @@ -27,6 +27,25 @@ #include #endif +#ifdef NDEBUG +#define logf(...) +#else +#define logf(...) \ + printf("%s - (%s) : ", __FILE__, __FUNCTION__); \ + printf(__VA_ARGS__); \ + printf("\n"); +#endif + +template +std::ostream& operator<<(std::ostream& os, const std::vector& vector) { + if (vector.empty()) return os; + os.put('['); + for (auto element : vector) { + os << element << ", "; + } + return os << "\b\b]"; +} + namespace oneapi { namespace mkl { namespace dft { @@ -45,7 +64,6 @@ enum class config_param { NUMBER_OF_TRANSFORMS, COMPLEX_STORAGE, - // WHAT IS THE FUTURE OF THIS ?? REAL_STORAGE, CONJUGATE_EVEN_STORAGE, @@ -93,24 +111,60 @@ enum class config_value { }; +static std::unordered_map prec_map{ { precision::SINGLE, "SINGLE" }, + { precision::DOUBLE, "DOUBLE" } }; + +static std::unordered_map dom_map{ { domain::REAL, "REAL" }, + { domain::COMPLEX, "COMPLEX" } }; + struct dft_values { std::vector input_strides; std::vector output_strides; double bwd_scale; double fwd_scale; - std::int64_t number_of_transform; + std::int64_t number_of_transforms; std::int64_t fwd_dist; std::int64_t bwd_dist; config_value placement; config_value complex_storage; config_value conj_even_storage; - std::int64_t dimension; - config_value domain; - config_value precision; + bool set_input_strides = false; + bool set_output_strides = false; + bool set_bwd_scale = false; + bool set_fwd_scale = false; + bool set_number_of_transforms = false; + bool set_fwd_dist = false; + bool set_bwd_dist = false; + bool set_placement = false; + bool set_complex_storage = false; + bool set_conj_even_storage = false; + + std::vector dimension; + std::int64_t rank; + domain domain; + precision precision; + friend auto operator<<(std::ostream& os, dft_values const& val) -> std::ostream& { + os << "------------- oneAPI Descriptor ------------\n"; + os << "input_strides : " << val.input_strides << "\n"; + os << "output_strides : " << val.output_strides << "\n"; + os << "bwd_scale : " << val.bwd_scale << "\n"; + os << "fwd_scale : " << val.fwd_scale << "\n"; + os << "number_of_transforms : " << val.number_of_transforms << "\n"; + os << "fwd_dist : " << val.fwd_dist << "\n"; + os << "bwd_dist : " << val.bwd_dist << "\n"; + os << "placement : " << (int) val.placement << "\n"; + os << "complex_storage : " << (int) val.complex_storage << "\n"; + os << "conj_even_storage : " << (int) val.conj_even_storage << "\n"; + os << "dimension : " << val.dimension << "\n"; + os << "rank : " << val.rank << "\n"; + os << "domain : " << dom_map[val.domain] << "\n"; + os << "precision : " << prec_map[val.precision]; + return os; + } }; } // namespace dft -} // namespace mkl -} // namespace oneapi +} // namespace mkl +} // namespace oneapi #endif //_ONEMKL_TYPES_HPP_ \ No newline at end of file diff --git a/src/dft/backends/CMakeLists.txt b/src/dft/backends/CMakeLists.txt index c75086840..5c65534b1 100644 --- a/src/dft/backends/CMakeLists.txt +++ b/src/dft/backends/CMakeLists.txt @@ -17,9 +17,9 @@ # SPDX-License-Identifier: Apache-2.0 #=============================================================================== -if(ENABLE_MKLGPU_BACKEND) - add_subdirectory(mklgpu) -endif() +# if(ENABLE_MKLGPU_BACKEND) +# add_subdirectory(mklgpu) +# endif() if(ENABLE_MKLCPU_BACKEND) add_subdirectory(mklcpu) diff --git a/src/dft/backends/descriptor.cxx b/src/dft/backends/descriptor.cxx deleted file mode 100644 index 94fc29fae..000000000 --- a/src/dft/backends/descriptor.cxx +++ /dev/null @@ -1,107 +0,0 @@ -#include -#if __has_include() -#include -#else -#include -#endif - -#include "oneapi/mkl/types.hpp" -#include "oneapi/mkl/dft/types.hpp" - -#include "oneapi/mkl/dft/descriptor.hpp" -#include "oneapi/mkl/exceptions.hpp" - -#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" - -#include "mkl_dfti.h" - -namespace oneapi { -namespace mkl { -namespace dft { - -template -descriptor::descriptor(std::vector dimension) : - dimension_(dimension), - handle_(nullptr), - rank_(dimension.size()) - { - // TODO: initialize the device_handle, handle_buffer - auto handle = reinterpret_cast(handle_); - } - -template -descriptor::descriptor(std::int64_t length) : - descriptor(std::vector{length}) {} - -template -descriptor::~descriptor() { - // call DftiFreeDescriptor -} - -// impliment error class -template -void descriptor::set_value(config_param param, ...) { - int err = 0; - va_list vl; - va_start(vl, param); - switch (param) { - case config_param::INPUT_STRIDES: - case config_param::OUTPUT_STRIDES: { - int64_t *strides = va_arg(vl, int64_t *); - if (strides == nullptr) break; - - if (param == config_param::INPUT_STRIDES) - std::copy(strides, strides+rank_+1, std::back_inserter(values.input_strides)); - if (param == config_param::OUTPUT_STRIDES) - std::copy(strides, strides+rank_+1, std::back_inserter(values.output_strides)); - } break; - case config_param::FORWARD_SCALE: - values.fwd_scale = va_arg(vl, double); - break; - case config_param::BACKWARD_SCALE: - values.bwd_scale = va_arg(vl, double); - break; - case config_param::NUMBER_OF_TRANSFORMS: - values.number_of_transform = va_arg(vl, int64_t); - break; - case config_param::FWD_DISTANCE: - values.fwd_dist = va_arg(vl, int64_t); - break; - case config_param::BWD_DISTANCE: - values.bwd_dist = va_arg(vl, int64_t); - break; - case config_param::PLACEMENT: - values.placement = va_arg(vl, config_value); - break; - case config_param::COMPLEX_STORAGE: - values.complex_storage = va_arg(vl, config_value); - break; - case config_param::CONJUGATE_EVEN_STORAGE: - values.conj_even_storage = va_arg(vl, config_value); - break; - - default: err = 1; - } - va_end(vl); -} - -template -void descriptor::get_value(config_param param, ...) { - int err = 0; - va_list vl; - va_start(vl, param); - switch (param) - { - default: break; - } - va_end(vl); -} - -template class descriptor; -template class descriptor; -template class descriptor; -template class descriptor; - -} // namespace dft -} // namespace mkl -} // namespace oneapi diff --git a/src/dft/backends/mklcpu/CMakeLists.txt b/src/dft/backends/mklcpu/CMakeLists.txt index 57ff6dd98..69978073d 100644 --- a/src/dft/backends/mklcpu/CMakeLists.txt +++ b/src/dft/backends/mklcpu/CMakeLists.txt @@ -20,6 +20,7 @@ set(LIB_NAME onemkl_dft_mklcpu) set(LIB_OBJ ${LIB_NAME}_obj) +set(USE_DPCPP_API ON) find_package(MKL REQUIRED) add_library(${LIB_NAME}) @@ -38,7 +39,9 @@ target_include_directories(${LIB_OBJ} ) target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT}) - +if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) + add_sycl_to_target(TARGET ${LIB_OBJ} SOURCES ${SOURCES}) +endif() target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ${MKL_LINK_SYCL}) set_target_properties(${LIB_OBJ} PROPERTIES diff --git a/src/dft/backends/mklcpu/backward.cpp b/src/dft/backends/mklcpu/backward.cpp index a7f86bc70..4cafbe549 100644 --- a/src/dft/backends/mklcpu/backward.cpp +++ b/src/dft/backends/mklcpu/backward.cpp @@ -33,174 +33,174 @@ namespace mkl { namespace dft { namespace mklcpu { -void compute_backward_buffer_inplace_f(descriptor &desc, - sycl::buffer, 1> &inout) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_backward_buffer_inplace_c(descriptor &desc, - sycl::buffer, 1> &inout) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_backward_buffer_inplace_d(descriptor &desc, - sycl::buffer, 1> &inout) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_backward_buffer_inplace_z(descriptor &desc, - sycl::buffer, 1> &inout) { - throw std::runtime_error("Not implemented for mklcpu"); -} +// void compute_backward_buffer_inplace_f(descriptor &desc, +// sycl::buffer, 1> &inout) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_backward_buffer_inplace_c(descriptor &desc, +// sycl::buffer, 1> &inout) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_backward_buffer_inplace_d(descriptor &desc, +// sycl::buffer, 1> &inout) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_backward_buffer_inplace_z(descriptor &desc, +// sycl::buffer, 1> &inout) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } -void compute_backward_buffer_inplace_split_f(descriptor &desc, - sycl::buffer &inout_re, - sycl::buffer &inout_im) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_backward_buffer_inplace_split_c(descriptor &desc, - sycl::buffer &inout_re, - sycl::buffer &inout_im) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_backward_buffer_inplace_split_d(descriptor &desc, - sycl::buffer &inout_re, - sycl::buffer &inout_im) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_backward_buffer_inplace_split_z(descriptor &desc, - sycl::buffer &inout_re, - sycl::buffer &inout_im) { - throw std::runtime_error("Not implemented for mklcpu"); -} +// void compute_backward_buffer_inplace_split_f(descriptor &desc, +// sycl::buffer &inout_re, +// sycl::buffer &inout_im) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_backward_buffer_inplace_split_c(descriptor &desc, +// sycl::buffer &inout_re, +// sycl::buffer &inout_im) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_backward_buffer_inplace_split_d(descriptor &desc, +// sycl::buffer &inout_re, +// sycl::buffer &inout_im) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_backward_buffer_inplace_split_z(descriptor &desc, +// sycl::buffer &inout_re, +// sycl::buffer &inout_im) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } -void compute_backward_buffer_outofplace_f(descriptor &desc, - sycl::buffer, 1> &in, - sycl::buffer &out) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_backward_buffer_outofplace_c(descriptor &desc, - sycl::buffer, 1> &in, - sycl::buffer, 1> &out) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_backward_buffer_outofplace_d(descriptor &desc, - sycl::buffer, 1> &in, - sycl::buffer &out) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_backward_buffer_outofplace_z(descriptor &desc, - sycl::buffer, 1> &in, - sycl::buffer, 1> &out) { - throw std::runtime_error("Not implemented for mklcpu"); -} +// void compute_backward_buffer_outofplace_f(descriptor &desc, +// sycl::buffer, 1> &in, +// sycl::buffer &out) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_backward_buffer_outofplace_c(descriptor &desc, +// sycl::buffer, 1> &in, +// sycl::buffer, 1> &out) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_backward_buffer_outofplace_d(descriptor &desc, +// sycl::buffer, 1> &in, +// sycl::buffer &out) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_backward_buffer_outofplace_z(descriptor &desc, +// sycl::buffer, 1> &in, +// sycl::buffer, 1> &out) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } -void compute_backward_buffer_outofplace_split_f(descriptor &desc, - sycl::buffer &in_re, - sycl::buffer &in_im, - sycl::buffer &out_re, - sycl::buffer &out_im) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_backward_buffer_outofplace_split_c( - descriptor &desc, sycl::buffer &in_re, - sycl::buffer &in_im, sycl::buffer &out_re, sycl::buffer &out_im) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_backward_buffer_outofplace_split_d(descriptor &desc, - sycl::buffer &in_re, - sycl::buffer &in_im, - sycl::buffer &out_re, - sycl::buffer &out_im) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_backward_buffer_outofplace_split_z( - descriptor &desc, sycl::buffer &in_re, - sycl::buffer &in_im, sycl::buffer &out_re, - sycl::buffer &out_im) { - throw std::runtime_error("Not implemented for mklcpu"); -} +// void compute_backward_buffer_outofplace_split_f(descriptor &desc, +// sycl::buffer &in_re, +// sycl::buffer &in_im, +// sycl::buffer &out_re, +// sycl::buffer &out_im) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_backward_buffer_outofplace_split_c( +// descriptor &desc, sycl::buffer &in_re, +// sycl::buffer &in_im, sycl::buffer &out_re, sycl::buffer &out_im) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_backward_buffer_outofplace_split_d(descriptor &desc, +// sycl::buffer &in_re, +// sycl::buffer &in_im, +// sycl::buffer &out_re, +// sycl::buffer &out_im) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_backward_buffer_outofplace_split_z( +// descriptor &desc, sycl::buffer &in_re, +// sycl::buffer &in_im, sycl::buffer &out_re, +// sycl::buffer &out_im) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } -sycl::event compute_backward_usm_inplace_f(descriptor &desc, - std::complex *inout, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_backward_usm_inplace_c(descriptor &desc, - std::complex *inout, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_backward_usm_inplace_d(descriptor &desc, - std::complex *inout, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_backward_usm_inplace_z(descriptor &desc, - std::complex *inout, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} +// sycl::event compute_backward_usm_inplace_f(descriptor &desc, +// std::complex *inout, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_backward_usm_inplace_c(descriptor &desc, +// std::complex *inout, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_backward_usm_inplace_d(descriptor &desc, +// std::complex *inout, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_backward_usm_inplace_z(descriptor &desc, +// std::complex *inout, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } -sycl::event compute_backward_usm_inplace_split_f(descriptor &desc, - float *inout_re, float *inout_im, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_backward_usm_inplace_split_c( - descriptor &desc, float *inout_re, float *inout_im, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_backward_usm_inplace_split_d(descriptor &desc, - double *inout_re, double *inout_im, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_backward_usm_inplace_split_z( - descriptor &desc, double *inout_re, double *inout_im, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} +// sycl::event compute_backward_usm_inplace_split_f(descriptor &desc, +// float *inout_re, float *inout_im, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_backward_usm_inplace_split_c( +// descriptor &desc, float *inout_re, float *inout_im, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_backward_usm_inplace_split_d(descriptor &desc, +// double *inout_re, double *inout_im, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_backward_usm_inplace_split_z( +// descriptor &desc, double *inout_re, double *inout_im, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } -sycl::event compute_backward_usm_outofplace_f(descriptor &desc, - std::complex *in, float *out, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_backward_usm_outofplace_c(descriptor &desc, - std::complex *in, std::complex *out, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_backward_usm_outofplace_d(descriptor &desc, - std::complex *in, double *out, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_backward_usm_outofplace_z(descriptor &desc, - std::complex *in, std::complex *out, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} +// sycl::event compute_backward_usm_outofplace_f(descriptor &desc, +// std::complex *in, float *out, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_backward_usm_outofplace_c(descriptor &desc, +// std::complex *in, std::complex *out, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_backward_usm_outofplace_d(descriptor &desc, +// std::complex *in, double *out, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_backward_usm_outofplace_z(descriptor &desc, +// std::complex *in, std::complex *out, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } -sycl::event compute_backward_usm_outofplace_split_f( - descriptor &desc, float *in_re, float *in_im, float *out_re, - float *out_im, const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_backward_usm_outofplace_split_c( - descriptor &desc, float *in_re, float *in_im, float *out_re, - float *out_im, const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_backward_usm_outofplace_split_d( - descriptor &desc, double *in_re, double *in_im, double *out_re, - double *out_im, const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_backward_usm_outofplace_split_z( - descriptor &desc, double *in_re, double *in_im, - double *out_re, double *out_im, const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} +// sycl::event compute_backward_usm_outofplace_split_f( +// descriptor &desc, float *in_re, float *in_im, float *out_re, +// float *out_im, const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_backward_usm_outofplace_split_c( +// descriptor &desc, float *in_re, float *in_im, float *out_re, +// float *out_im, const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_backward_usm_outofplace_split_d( +// descriptor &desc, double *in_re, double *in_im, double *out_re, +// double *out_im, const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_backward_usm_outofplace_split_z( +// descriptor &desc, double *in_re, double *in_im, +// double *out_re, double *out_im, const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } } // namespace mklcpu } // namespace dft diff --git a/src/dft/backends/mklcpu/commit.cpp b/src/dft/backends/mklcpu/commit.cpp index 3bc0b330b..b006705f6 100644 --- a/src/dft/backends/mklcpu/commit.cpp +++ b/src/dft/backends/mklcpu/commit.cpp @@ -28,22 +28,97 @@ #include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" +#include "oneapi/mkl/dft/detail/commit_impl.hpp" +#include "mkl_service.h" +#include "mkl_dfti.h" + namespace oneapi { namespace mkl { namespace dft { namespace mklcpu { -void commit_f(descriptor &desc, sycl::queue &queue) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void commit_c(descriptor &desc, sycl::queue &queue) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void commit_d(descriptor &desc, sycl::queue &queue) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void commit_z(descriptor &desc, sycl::queue &queue) { - throw std::runtime_error("Not implemented for mklcpu"); +class commit_derived_impl : public oneapi::mkl::dft::detail::commit_impl { +public: + commit_derived_impl(sycl::queue queue, dft_values config_values) + : oneapi::mkl::dft::detail::commit_impl(queue), + status(0) { + logf("CPU impl, handle->%p", &handle); + + DFTI_DESCRIPTOR_HANDLE local_handle = nullptr; + + std::cout << config_values << std::endl; + if (config_values.rank == 1) { + status = DftiCreateDescriptor(&local_handle, precision_map[config_values.precision], + domain_map[config_values.domain], config_values.rank, + config_values.dimension[0]); + } + else { + status = DftiCreateDescriptor(&local_handle, precision_map[config_values.precision], + domain_map[config_values.domain], config_values.rank, + &config_values.dimension[0]); + } + if(status != DFTI_NO_ERROR) throw oneapi::mkl::exception("dft", "commit", "DftiCreateDescriptor failed"); + + set_value(local_handle, config_values); + + status = DftiCommitDescriptor(local_handle); + if(status != DFTI_NO_ERROR) throw oneapi::mkl::exception("dft", "commit", "DftiCommitDescriptor failed"); + + // commit_impl (pimpl_->handle) should return this handle + handle = local_handle; + } + + commit_derived_impl(const commit_derived_impl* other) + : oneapi::mkl::dft::detail::commit_impl(*other) { } + + virtual oneapi::mkl::dft::detail::commit_impl* copy_state() override { + return new commit_derived_impl(this); + } + + virtual ~commit_derived_impl() override { } + +private: + bool status; + std::unordered_map precision_map{ + { oneapi::mkl::dft::precision::SINGLE, DFTI_SINGLE }, + { oneapi::mkl::dft::precision::DOUBLE, DFTI_DOUBLE } + }; + std::unordered_map domain_map{ + { oneapi::mkl::dft::domain::REAL, DFTI_REAL }, + { oneapi::mkl::dft::domain::COMPLEX, DFTI_COMPLEX } + }; + + void set_value(DFTI_DESCRIPTOR_HANDLE& descHandle, dft_values config) { + logf("address of cpu handle->%p", &descHandle); + logf("handle is_null? %s", (descHandle == nullptr) ? "yes" : "no"); + + // TODO : add complex storage and workspace + if (config.set_input_strides) + status |= DftiSetValue(descHandle, DFTI_INPUT_STRIDES, &config.input_strides[0]); + if (config.set_output_strides) + status |= DftiSetValue(descHandle, DFTI_OUTPUT_STRIDES, &config.output_strides[0]); + if (config.set_bwd_scale) + status |= DftiSetValue(descHandle, DFTI_BACKWARD_SCALE, config.bwd_scale); + if (config.set_fwd_scale) + status |= DftiSetValue(descHandle, DFTI_BACKWARD_SCALE, config.fwd_scale); + if (config.set_number_of_transforms) + status |= DftiSetValue(descHandle, DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms); + if (config.set_fwd_dist) + status |= DftiSetValue(descHandle, DFTI_FWD_DISTANCE, config.fwd_dist); + if (config.set_bwd_dist) + status |= DftiSetValue(descHandle, DFTI_BWD_DISTANCE, config.bwd_dist); + if (config.set_placement) + status |= DftiSetValue(descHandle, DFTI_PLACEMENT, + (config.placement == oneapi::mkl::dft::config_value::INPLACE) + ? DFTI_INPLACE + : DFTI_NOT_INPLACE); + + if(status != DFTI_NO_ERROR) throw oneapi::mkl::exception("dft", "commit", "DftiSetValue failed"); + } +}; + +oneapi::mkl::dft::detail::commit_impl* create_commit(sycl::queue queue, dft_values values) { + return new commit_derived_impl(queue, values); } } // namespace mklcpu diff --git a/src/dft/backends/mklcpu/forward.cpp b/src/dft/backends/mklcpu/forward.cpp index c34683672..634177398 100644 --- a/src/dft/backends/mklcpu/forward.cpp +++ b/src/dft/backends/mklcpu/forward.cpp @@ -33,177 +33,177 @@ namespace mkl { namespace dft { namespace mklcpu { -void compute_forward_buffer_inplace_f(descriptor &desc, - sycl::buffer, 1> &inout) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_forward_buffer_inplace_c(descriptor &desc, - sycl::buffer, 1> &inout) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_forward_buffer_inplace_d(descriptor &desc, - sycl::buffer, 1> &inout) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_forward_buffer_inplace_z(descriptor &desc, - sycl::buffer, 1> &inout) { - throw std::runtime_error("Not implemented for mklcpu"); -} +// void compute_forward_buffer_inplace_f(descriptor &desc, +// sycl::buffer, 1> &inout) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_forward_buffer_inplace_c(descriptor &desc, +// sycl::buffer, 1> &inout) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_forward_buffer_inplace_d(descriptor &desc, +// sycl::buffer, 1> &inout) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_forward_buffer_inplace_z(descriptor &desc, +// sycl::buffer, 1> &inout) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } -void compute_forward_buffer_inplace_split_f(descriptor &desc, - sycl::buffer &inout_re, - sycl::buffer &inout_im) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_forward_buffer_inplace_split_c(descriptor &desc, - sycl::buffer &inout_re, - sycl::buffer &inout_im) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_forward_buffer_inplace_split_d(descriptor &desc, - sycl::buffer &inout_re, - sycl::buffer &inout_im) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_forward_buffer_inplace_split_z(descriptor &desc, - sycl::buffer &inout_re, - sycl::buffer &inout_im) { - throw std::runtime_error("Not implemented for mklcpu"); -} +// void compute_forward_buffer_inplace_split_f(descriptor &desc, +// sycl::buffer &inout_re, +// sycl::buffer &inout_im) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_forward_buffer_inplace_split_c(descriptor &desc, +// sycl::buffer &inout_re, +// sycl::buffer &inout_im) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_forward_buffer_inplace_split_d(descriptor &desc, +// sycl::buffer &inout_re, +// sycl::buffer &inout_im) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_forward_buffer_inplace_split_z(descriptor &desc, +// sycl::buffer &inout_re, +// sycl::buffer &inout_im) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } -void compute_forward_buffer_outofplace_f(descriptor &desc, - sycl::buffer &in, - sycl::buffer, 1> &out) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_forward_buffer_outofplace_c(descriptor &desc, - sycl::buffer, 1> &in, - sycl::buffer, 1> &out) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_forward_buffer_outofplace_d(descriptor &desc, - sycl::buffer &in, - sycl::buffer, 1> &out) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_forward_buffer_outofplace_z(descriptor &desc, - sycl::buffer, 1> &in, - sycl::buffer, 1> &out) { - throw std::runtime_error("Not implemented for mklcpu"); -} +// void compute_forward_buffer_outofplace_f(descriptor &desc, +// sycl::buffer &in, +// sycl::buffer, 1> &out) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_forward_buffer_outofplace_c(descriptor &desc, +// sycl::buffer, 1> &in, +// sycl::buffer, 1> &out) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_forward_buffer_outofplace_d(descriptor &desc, +// sycl::buffer &in, +// sycl::buffer, 1> &out) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_forward_buffer_outofplace_z(descriptor &desc, +// sycl::buffer, 1> &in, +// sycl::buffer, 1> &out) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } -void compute_forward_buffer_outofplace_split_f(descriptor &desc, - sycl::buffer &in_re, - sycl::buffer &in_im, - sycl::buffer &out_re, - sycl::buffer &out_im) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_forward_buffer_outofplace_split_c(descriptor &desc, - sycl::buffer &in_re, - sycl::buffer &in_im, - sycl::buffer &out_re, - sycl::buffer &out_im) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_forward_buffer_outofplace_split_d(descriptor &desc, - sycl::buffer &in_re, - sycl::buffer &in_im, - sycl::buffer &out_re, - sycl::buffer &out_im) { - throw std::runtime_error("Not implemented for mklcpu"); -} -void compute_forward_buffer_outofplace_split_z(descriptor &desc, - sycl::buffer &in_re, - sycl::buffer &in_im, - sycl::buffer &out_re, - sycl::buffer &out_im) { - throw std::runtime_error("Not implemented for mklcpu"); -} +// void compute_forward_buffer_outofplace_split_f(descriptor &desc, +// sycl::buffer &in_re, +// sycl::buffer &in_im, +// sycl::buffer &out_re, +// sycl::buffer &out_im) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_forward_buffer_outofplace_split_c(descriptor &desc, +// sycl::buffer &in_re, +// sycl::buffer &in_im, +// sycl::buffer &out_re, +// sycl::buffer &out_im) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_forward_buffer_outofplace_split_d(descriptor &desc, +// sycl::buffer &in_re, +// sycl::buffer &in_im, +// sycl::buffer &out_re, +// sycl::buffer &out_im) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// void compute_forward_buffer_outofplace_split_z(descriptor &desc, +// sycl::buffer &in_re, +// sycl::buffer &in_im, +// sycl::buffer &out_re, +// sycl::buffer &out_im) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } -sycl::event compute_forward_usm_inplace_f(descriptor &desc, - std::complex *inout, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_forward_usm_inplace_c(descriptor &desc, - std::complex *inout, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_forward_usm_inplace_d(descriptor &desc, - std::complex *inout, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_forward_usm_inplace_z(descriptor &desc, - std::complex *inout, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} +// sycl::event compute_forward_usm_inplace_f(descriptor &desc, +// std::complex *inout, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_forward_usm_inplace_c(descriptor &desc, +// std::complex *inout, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_forward_usm_inplace_d(descriptor &desc, +// std::complex *inout, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_forward_usm_inplace_z(descriptor &desc, +// std::complex *inout, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } -sycl::event compute_forward_usm_inplace_split_f(descriptor &desc, - float *inout_re, float *inout_im, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_forward_usm_inplace_split_c( - descriptor &desc, float *inout_re, float *inout_im, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_forward_usm_inplace_split_d(descriptor &desc, - double *inout_re, double *inout_im, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_forward_usm_inplace_split_z( - descriptor &desc, double *inout_re, double *inout_im, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} +// sycl::event compute_forward_usm_inplace_split_f(descriptor &desc, +// float *inout_re, float *inout_im, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_forward_usm_inplace_split_c( +// descriptor &desc, float *inout_re, float *inout_im, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_forward_usm_inplace_split_d(descriptor &desc, +// double *inout_re, double *inout_im, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_forward_usm_inplace_split_z( +// descriptor &desc, double *inout_re, double *inout_im, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } -sycl::event compute_forward_usm_outofplace_f(descriptor &desc, - float *in, std::complex *out, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_forward_usm_outofplace_c(descriptor &desc, - std::complex *in, std::complex *out, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_forward_usm_outofplace_d(descriptor &desc, - double *in, std::complex *out, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_forward_usm_outofplace_z(descriptor &desc, - std::complex *in, std::complex *out, - const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} +// sycl::event compute_forward_usm_outofplace_f(descriptor &desc, +// float *in, std::complex *out, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_forward_usm_outofplace_c(descriptor &desc, +// std::complex *in, std::complex *out, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_forward_usm_outofplace_d(descriptor &desc, +// double *in, std::complex *out, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_forward_usm_outofplace_z(descriptor &desc, +// std::complex *in, std::complex *out, +// const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } -sycl::event compute_forward_usm_outofplace_split_f( - descriptor &desc, float *in_re, float *in_im, float *out_re, - float *out_im, const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_forward_usm_outofplace_split_c( - descriptor &desc, float *in_re, float *in_im, float *out_re, - float *out_im, const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_forward_usm_outofplace_split_d( - descriptor &desc, double *in_re, double *in_im, double *out_re, - double *out_im, const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} -sycl::event compute_forward_usm_outofplace_split_z( - descriptor &desc, double *in_re, double *in_im, - double *out_re, double *out_im, const std::vector &dependencies) { - throw std::runtime_error("Not implemented for mklcpu"); -} +// sycl::event compute_forward_usm_outofplace_split_f( +// descriptor &desc, float *in_re, float *in_im, float *out_re, +// float *out_im, const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_forward_usm_outofplace_split_c( +// descriptor &desc, float *in_re, float *in_im, float *out_re, +// float *out_im, const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_forward_usm_outofplace_split_d( +// descriptor &desc, double *in_re, double *in_im, double *out_re, +// double *out_im, const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } +// sycl::event compute_forward_usm_outofplace_split_z( +// descriptor &desc, double *in_re, double *in_im, +// double *out_re, double *out_im, const std::vector &dependencies) { +// throw std::runtime_error("Not implemented for mklcpu"); +// } } // namespace mklcpu } // namespace dft diff --git a/src/dft/backends/mklcpu/mkl_dft_cpu_wrappers.cpp b/src/dft/backends/mklcpu/mkl_dft_cpu_wrappers.cpp index 7ce1bbf63..9fc897f3d 100644 --- a/src/dft/backends/mklcpu/mkl_dft_cpu_wrappers.cpp +++ b/src/dft/backends/mklcpu/mkl_dft_cpu_wrappers.cpp @@ -19,33 +19,10 @@ #include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" #include "dft/function_table.hpp" -#include "../descriptor.cxx" #define WRAPPER_VERSION 1 extern "C" dft_function_table_t mkl_dft_table = { WRAPPER_VERSION, -#define ONEAPI_MKL_DFT_BACKEND_SIGNATURES(EXT) \ - oneapi::mkl::dft::mklcpu::commit_##EXT, \ - oneapi::mkl::dft::mklcpu::compute_forward_buffer_inplace_##EXT, \ - oneapi::mkl::dft::mklcpu::compute_forward_buffer_inplace_split_##EXT, \ - oneapi::mkl::dft::mklcpu::compute_forward_buffer_outofplace_##EXT, \ - oneapi::mkl::dft::mklcpu::compute_forward_buffer_outofplace_split_##EXT, \ - oneapi::mkl::dft::mklcpu::compute_forward_usm_inplace_##EXT, \ - oneapi::mkl::dft::mklcpu::compute_forward_usm_inplace_split_##EXT, \ - oneapi::mkl::dft::mklcpu::compute_forward_usm_outofplace_##EXT, \ - oneapi::mkl::dft::mklcpu::compute_forward_usm_outofplace_split_##EXT, \ - oneapi::mkl::dft::mklcpu::compute_backward_buffer_inplace_##EXT, \ - oneapi::mkl::dft::mklcpu::compute_backward_buffer_inplace_split_##EXT, \ - oneapi::mkl::dft::mklcpu::compute_backward_buffer_outofplace_##EXT, \ - oneapi::mkl::dft::mklcpu::compute_backward_buffer_outofplace_split_##EXT, \ - oneapi::mkl::dft::mklcpu::compute_backward_usm_inplace_##EXT, \ - oneapi::mkl::dft::mklcpu::compute_backward_usm_inplace_split_##EXT, \ - oneapi::mkl::dft::mklcpu::compute_backward_usm_outofplace_##EXT, \ - oneapi::mkl::dft::mklcpu::compute_backward_usm_outofplace_split_##EXT - - ONEAPI_MKL_DFT_BACKEND_SIGNATURES(f), ONEAPI_MKL_DFT_BACKEND_SIGNATURES(c), - ONEAPI_MKL_DFT_BACKEND_SIGNATURES(d), ONEAPI_MKL_DFT_BACKEND_SIGNATURES(z) - -#undef ONEAPI_MKL_DFT_BACKEND_SIGNATURES + oneapi::mkl::dft::mklcpu::create_commit }; diff --git a/src/dft/backends/mklgpu/CMakeLists.txt b/src/dft/backends/mklgpu/CMakeLists.txt index d373d2957..c30bf1ecb 100644 --- a/src/dft/backends/mklgpu/CMakeLists.txt +++ b/src/dft/backends/mklgpu/CMakeLists.txt @@ -24,7 +24,6 @@ find_package(MKL REQUIRED) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT - ../descriptor.cpp commit.cpp forward.cpp backward.cpp diff --git a/src/dft/dft_loader.cpp b/src/dft/dft_loader.cpp index 3066249a3..65e251efd 100644 --- a/src/dft/dft_loader.cpp +++ b/src/dft/dft_loader.cpp @@ -17,193 +17,23 @@ * SPDX-License-Identifier: Apache-2.0 *******************************************************************************/ -#include "oneapi/mkl/dft.hpp" +#include "oneapi/mkl/dft/detail/dft_loader.hpp" #include "function_table_initializer.hpp" #include "dft/function_table.hpp" - -#include "oneapi/mkl/detail/get_device_id.hpp" - + namespace oneapi { namespace mkl { namespace dft { - namespace detail { -static oneapi::mkl::detail::table_initializer - function_tables; -} // namespace detail - -#define ONEAPI_MKL_DFT_SIGNATURES(EXT, PRECISION, DOMAIN, T_REAL, T_FORWARD, T_BACKWARD) \ - \ - template <> \ - void descriptor::commit(sycl::queue &queue) { \ - this->queue_ = queue; \ - detail::function_tables[get_device_id(queue)].commit_##EXT(*this, queue); \ - } \ - \ - /*Buffer version*/ \ - \ - /*In-place transform*/ \ - template <> \ - void compute_forward, T_BACKWARD>( \ - descriptor & desc, sycl::buffer & inout) { \ - detail::function_tables[get_device_id(desc.get_queue())] \ - .compute_forward_buffer_inplace_##EXT(desc, inout); \ - } \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - void compute_forward, T_REAL>( \ - descriptor & desc, sycl::buffer & inout_re, \ - sycl::buffer & inout_im) { \ - detail::function_tables[get_device_id(desc.get_queue())] \ - .compute_forward_buffer_inplace_split_##EXT(desc, inout_re, inout_im); \ - } \ - \ - /*Out-of-place transform*/ \ - template <> \ - void compute_forward, T_FORWARD, T_BACKWARD>( \ - descriptor & desc, sycl::buffer & in, \ - sycl::buffer & out) { \ - detail::function_tables[get_device_id(desc.get_queue())] \ - .compute_forward_buffer_outofplace_##EXT(desc, in, out); \ - } \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - void compute_forward, T_REAL, T_REAL>( \ - descriptor & desc, sycl::buffer & in_re, \ - sycl::buffer & in_im, sycl::buffer & out_re, \ - sycl::buffer & out_im) { \ - detail::function_tables[get_device_id(desc.get_queue())] \ - .compute_forward_buffer_outofplace_split_##EXT(desc, in_re, in_im, out_re, out_im); \ - } \ - \ - /*USM version*/ \ - \ - /*In-place transform*/ \ - template <> \ - sycl::event compute_forward, T_BACKWARD>( \ - descriptor & desc, T_BACKWARD * inout, \ - const std::vector &dependencies) { \ - return detail::function_tables[get_device_id(desc.get_queue())] \ - .compute_forward_usm_inplace_##EXT(desc, inout, dependencies); \ - } \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - sycl::event compute_forward, T_REAL>( \ - descriptor & desc, T_REAL * inout_re, T_REAL * inout_im, \ - const std::vector &dependencies) { \ - return detail::function_tables[get_device_id(desc.get_queue())] \ - .compute_forward_usm_inplace_split_##EXT(desc, inout_re, inout_im, dependencies); \ - } \ - \ - /*Out-of-place transform*/ \ - template <> \ - sycl::event compute_forward, T_FORWARD, T_BACKWARD>( \ - descriptor & desc, T_FORWARD * in, T_BACKWARD * out, \ - const std::vector &dependencies) { \ - return detail::function_tables[get_device_id(desc.get_queue())] \ - .compute_forward_usm_outofplace_##EXT(desc, in, out, dependencies); \ - } \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - sycl::event compute_forward, T_REAL, T_REAL>( \ - descriptor & desc, T_REAL * in_re, T_REAL * in_im, T_REAL * out_re, \ - T_REAL * out_im, const std::vector &dependencies) { \ - return detail::function_tables[get_device_id(desc.get_queue())] \ - .compute_forward_usm_outofplace_split_##EXT(desc, in_re, in_im, out_re, out_im, \ - dependencies); \ - } \ - \ - /*Buffer version*/ \ - \ - /*In-place transform*/ \ - template <> \ - void compute_backward, T_BACKWARD>( \ - descriptor & desc, sycl::buffer & inout) { \ - detail::function_tables[get_device_id(desc.get_queue())] \ - .compute_backward_buffer_inplace_##EXT(desc, inout); \ - } \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - void compute_backward, T_REAL>( \ - descriptor & desc, sycl::buffer & inout_re, \ - sycl::buffer & inout_im) { \ - detail::function_tables[get_device_id(desc.get_queue())] \ - .compute_backward_buffer_inplace_split_##EXT(desc, inout_re, inout_im); \ - } \ - \ - /*Out-of-place transform*/ \ - template <> \ - void compute_backward, T_BACKWARD, T_FORWARD>( \ - descriptor & desc, sycl::buffer & in, \ - sycl::buffer & out) { \ - detail::function_tables[get_device_id(desc.get_queue())] \ - .compute_backward_buffer_outofplace_##EXT(desc, in, out); \ - } \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - void compute_backward, T_REAL, T_REAL>( \ - descriptor & desc, sycl::buffer & in_re, \ - sycl::buffer & in_im, sycl::buffer & out_re, \ - sycl::buffer & out_im) { \ - detail::function_tables[get_device_id(desc.get_queue())] \ - .compute_backward_buffer_outofplace_split_##EXT(desc, in_re, in_im, out_re, out_im); \ - } \ - \ - /*USM version*/ \ - \ - /*In-place transform*/ \ - template <> \ - sycl::event compute_backward, T_BACKWARD>( \ - descriptor & desc, T_BACKWARD * inout, \ - const std::vector &dependencies) { \ - return detail::function_tables[get_device_id(desc.get_queue())] \ - .compute_backward_usm_inplace_##EXT(desc, inout, dependencies); \ - } \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - sycl::event compute_backward, T_REAL>( \ - descriptor & desc, T_REAL * inout_re, T_REAL * inout_im, \ - const std::vector &dependencies) { \ - return detail::function_tables[get_device_id(desc.get_queue())] \ - .compute_backward_usm_inplace_split_##EXT(desc, inout_re, inout_im, dependencies); \ - } \ - \ - /*Out-of-place transform*/ \ - template <> \ - sycl::event compute_backward, T_BACKWARD, T_FORWARD>( \ - descriptor & desc, T_BACKWARD * in, T_FORWARD * out, \ - const std::vector &dependencies) { \ - return detail::function_tables[get_device_id(desc.get_queue())] \ - .compute_backward_usm_outofplace_##EXT(desc, in, out, dependencies); \ - } \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - sycl::event compute_backward, T_REAL, T_REAL>( \ - descriptor & desc, T_REAL * in_re, T_REAL * in_im, T_REAL * out_re, \ - T_REAL * out_im, const std::vector &dependencies) { \ - return detail::function_tables[get_device_id(desc.get_queue())] \ - .compute_backward_usm_outofplace_split_##EXT(desc, in_re, in_im, out_re, out_im, \ - dependencies); \ - } -ONEAPI_MKL_DFT_SIGNATURES(f, precision::SINGLE, domain::REAL, float, float, std::complex) -ONEAPI_MKL_DFT_SIGNATURES(c, precision::SINGLE, domain::COMPLEX, float, std::complex, - std::complex) -ONEAPI_MKL_DFT_SIGNATURES(d, precision::DOUBLE, domain::REAL, double, double, std::complex) -ONEAPI_MKL_DFT_SIGNATURES(z, precision::DOUBLE, domain::COMPLEX, double, std::complex, - std::complex) +static oneapi::mkl::detail::table_initializer function_tables; -#undef ONEAPI_MKL_DFT_SIGNATURES +commit_impl* create_commit(oneapi::mkl::device libkey, sycl::queue queue, dft_values values) { + return function_tables[libkey].create_commit_sycl(queue, values); +} +} // namespace detail } // namespace dft } // namespace mkl } // namespace oneapi diff --git a/src/dft/function_table.hpp b/src/dft/function_table.hpp index de03ad365..aa7039496 100644 --- a/src/dft/function_table.hpp +++ b/src/dft/function_table.hpp @@ -35,75 +35,7 @@ typedef struct { int version; - -#define ONEAPI_MKL_DFT_BACKEND_SIGNATURES(EXT, PRECISION, DOMAIN, T_REAL, T_FORWARD, T_BACKWARD) \ - void (*commit_##EXT)(oneapi::mkl::dft::descriptor & desc, \ - sycl::queue & queue); \ - void (*compute_forward_buffer_inplace_##EXT)( \ - oneapi::mkl::dft::descriptor & desc, \ - sycl::buffer & inout); \ - void (*compute_forward_buffer_inplace_split_##EXT)( \ - oneapi::mkl::dft::descriptor & desc, \ - sycl::buffer & inout_re, sycl::buffer & inout_im); \ - void (*compute_forward_buffer_outofplace_##EXT)( \ - oneapi::mkl::dft::descriptor & desc, sycl::buffer & in, \ - sycl::buffer & out); \ - void (*compute_forward_buffer_outofplace_split_##EXT)( \ - oneapi::mkl::dft::descriptor & desc, sycl::buffer & in_re, \ - sycl::buffer & in_im, sycl::buffer & out_re, \ - sycl::buffer & out_im); \ - sycl::event (*compute_forward_usm_inplace_##EXT)( \ - oneapi::mkl::dft::descriptor & desc, T_BACKWARD * inout, \ - const std::vector &dependencies); \ - sycl::event (*compute_forward_usm_inplace_split_##EXT)( \ - oneapi::mkl::dft::descriptor & desc, T_REAL * inout_re, \ - T_REAL * inout_im, const std::vector &dependencies); \ - sycl::event (*compute_forward_usm_outofplace_##EXT)( \ - oneapi::mkl::dft::descriptor & desc, T_FORWARD * in, T_BACKWARD * out, \ - const std::vector &dependencies); \ - sycl::event (*compute_forward_usm_outofplace_split_##EXT)( \ - oneapi::mkl::dft::descriptor & desc, T_REAL * in_re, T_REAL * in_im, \ - T_REAL * out_re, T_REAL * out_im, const std::vector &dependencies); \ - void (*compute_backward_buffer_inplace_##EXT)( \ - oneapi::mkl::dft::descriptor & desc, \ - sycl::buffer & inout); \ - void (*compute_backward_buffer_inplace_split_##EXT)( \ - oneapi::mkl::dft::descriptor & desc, \ - sycl::buffer & inout_re, sycl::buffer & inout_im); \ - void (*compute_backward_buffer_outofplace_##EXT)( \ - oneapi::mkl::dft::descriptor & desc, sycl::buffer & in, \ - sycl::buffer & out); \ - void (*compute_backward_buffer_outofplace_split_##EXT)( \ - oneapi::mkl::dft::descriptor & desc, sycl::buffer & in_re, \ - sycl::buffer & in_im, sycl::buffer & out_re, \ - sycl::buffer & out_im); \ - sycl::event (*compute_backward_usm_inplace_##EXT)( \ - oneapi::mkl::dft::descriptor & desc, T_BACKWARD * inout, \ - const std::vector &dependencies); \ - sycl::event (*compute_backward_usm_inplace_split_##EXT)( \ - oneapi::mkl::dft::descriptor & desc, T_REAL * inout_re, \ - T_REAL * inout_im, const std::vector &dependencies); \ - sycl::event (*compute_backward_usm_outofplace_##EXT)( \ - oneapi::mkl::dft::descriptor & desc, T_BACKWARD * in, T_FORWARD * out, \ - const std::vector &dependencies); \ - sycl::event (*compute_backward_usm_outofplace_split_##EXT)( \ - oneapi::mkl::dft::descriptor & desc, T_REAL * in_re, T_REAL * in_im, \ - T_REAL * out_re, T_REAL * out_im, const std::vector &dependencies); - - ONEAPI_MKL_DFT_BACKEND_SIGNATURES(f, oneapi::mkl::dft::precision::SINGLE, - oneapi::mkl::dft::domain::REAL, float, float, - std::complex) - ONEAPI_MKL_DFT_BACKEND_SIGNATURES(c, oneapi::mkl::dft::precision::SINGLE, - oneapi::mkl::dft::domain::COMPLEX, float, std::complex, - std::complex) - ONEAPI_MKL_DFT_BACKEND_SIGNATURES(d, oneapi::mkl::dft::precision::DOUBLE, - oneapi::mkl::dft::domain::REAL, double, double, - std::complex) - ONEAPI_MKL_DFT_BACKEND_SIGNATURES(z, oneapi::mkl::dft::precision::DOUBLE, - oneapi::mkl::dft::domain::COMPLEX, double, - std::complex, std::complex) - -#undef ONEAPI_MKL_DFT_BACKEND_SIGNATURES + oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl)(sycl::queue queue, oneapi::mkl::dft::dft_values values); } dft_function_table_t; #endif //_DFT_FUNCTION_TABLE_HPP_ From 90dba9e67dfe8770853063262cdeb61e3bf61536 Mon Sep 17 00:00:00 2001 From: "Anant, Srivastava" Date: Mon, 14 Nov 2022 22:42:49 -0800 Subject: [PATCH 11/21] address reviews --- .../complex_fwd_usm_mklcpu.cpp | 2 +- include/oneapi/mkl/dft/descriptor.hpp | 77 +++++++++--------- include/oneapi/mkl/dft/detail/commit_impl.hpp | 25 +++++- include/oneapi/mkl/dft/detail/dft_loader.hpp | 2 +- include/oneapi/mkl/dft/types.hpp | 50 +----------- src/dft/backends/mklcpu/commit.cpp | 79 +++++++++---------- 6 files changed, 100 insertions(+), 135 deletions(-) diff --git a/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp b/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp index 014617214..a473ab333 100644 --- a/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp +++ b/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp @@ -90,7 +90,7 @@ desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::conf desc.commit(oneapi::mkl::backend_selector{ cpu_queue }); // 4. commit_descriptor (run_time xPU) unusable from libonemkl_dft_mklcpu.so -desc.commit(cpu_queue); +// desc.commit(cpu_queue); // 5. compute_forward / compute_backward (CPU) // oneapi::mkl::dft::compute_forward(desc, x_usm); diff --git a/include/oneapi/mkl/dft/descriptor.hpp b/include/oneapi/mkl/dft/descriptor.hpp index 59c3226cb..4f30619fd 100644 --- a/include/oneapi/mkl/dft/descriptor.hpp +++ b/include/oneapi/mkl/dft/descriptor.hpp @@ -59,12 +59,12 @@ class descriptor { void get_value(config_param param, ...); void commit(sycl::queue& queue) { - pimpl_.reset(detail::create_commit(get_device_id(queue), queue, values)); + pimpl_.reset(detail::create_commit(get_device_id(queue), queue, values_)); } #ifdef ENABLE_MKLCPU_BACKEND void commit(backend_selector selector) { - pimpl_.reset(mklcpu::create_commit(selector.get_queue(), values)); + pimpl_.reset(mklcpu::create_commit(selector.get_queue(), values_)); } #endif @@ -73,33 +73,43 @@ class descriptor { // pimpl_.reset(mklgpu::create_commit(selector.get_queue())); } #endif - - sycl::queue& get_queue() { - return queue_; - } private: - sycl::queue queue_; std::unique_ptr pimpl_; // commit only std::int64_t rank_; - std::vector dimension_; + std::vector dimensions_; - // descriptor configuration values and structs + // descriptor configuration values_ and structs void* handle_; - oneapi::mkl::dft::dft_values values; + oneapi::mkl::dft::dft_values values_; }; template -descriptor::descriptor(std::vector dimension) : - dimension_(dimension), +descriptor::descriptor(std::vector dimensions) : + dimensions_(dimensions), handle_(nullptr), - rank_(dimension.size()) + rank_(dimensions.size()) { - // TODO: initialize the device_handle, handle_buffer - values.domain = dom; - values.precision = prec; - values.dimension = dimension_; - values.rank = rank_; + // Compute default strides. + std::vector defaultStrides(rank_, 1); + for(int i = rank_ - 1; i < 0; --i){ + defaultStrides[i] = defaultStrides[i - 1] * dimensions_[i]; + } + defaultStrides[0] = 0; + values_.input_strides = defaultStrides; + values_.output_strides = defaultStrides; + values_.bwd_scale = 1.0; + values_.fwd_scale = 1.0; + values_.number_of_transforms = 1; + values_.fwd_dist = 1; + values_.bwd_dist = 1; + values_.placement = config_value::INPLACE; + values_.complex_storage = config_value::COMPLEX_COMPLEX; + values_.conj_even_storage = config_value::COMPLEX_COMPLEX; + values_.dimensions = dimensions_; + values_.rank = rank_; + values_.domain = dom; + values_.precision = prec; } template @@ -120,48 +130,37 @@ void descriptor::set_value(config_param param, ...) { printf("oneapi interface set_value\n"); switch (param) { case config_param::INPUT_STRIDES: - values.set_input_strides = true; case config_param::OUTPUT_STRIDES: { int64_t *strides = va_arg(vl, int64_t *); if (strides == nullptr) break; if (param == config_param::INPUT_STRIDES) - std::copy(strides, strides+rank_+1, std::back_inserter(values.input_strides)); + std::copy(strides, strides+rank_+1, std::back_inserter(values_.input_strides)); if (param == config_param::OUTPUT_STRIDES) - std::copy(strides, strides+rank_+1, std::back_inserter(values.output_strides)); - values.set_output_strides = true; + std::copy(strides, strides+rank_+1, std::back_inserter(values_.output_strides)); } break; case config_param::FORWARD_SCALE: - values.fwd_scale = va_arg(vl, double); - values.set_fwd_scale = true; + values_.fwd_scale = va_arg(vl, double); break; case config_param::BACKWARD_SCALE: - values.bwd_scale = va_arg(vl, double); - values.set_bwd_scale = true; + values_.bwd_scale = va_arg(vl, double); break; case config_param::NUMBER_OF_TRANSFORMS: - values.number_of_transforms = va_arg(vl, int64_t); - values.set_number_of_transforms = true; + values_.number_of_transforms = va_arg(vl, int64_t); break; case config_param::FWD_DISTANCE: - values.fwd_dist = va_arg(vl, int64_t); - values.set_fwd_dist = true; + values_.fwd_dist = va_arg(vl, int64_t); break; case config_param::BWD_DISTANCE: - values.bwd_dist = va_arg(vl, int64_t); - values.set_bwd_dist = true; + values_.bwd_dist = va_arg(vl, int64_t); break; case config_param::PLACEMENT: - values.placement = va_arg(vl, config_value); - values.set_placement = true; + values_.placement = va_arg(vl, config_value); break; case config_param::COMPLEX_STORAGE: - values.complex_storage = va_arg(vl, config_value); - values.set_complex_storage = true; - break; + values_.complex_storage = va_arg(vl, config_value); case config_param::CONJUGATE_EVEN_STORAGE: - values.conj_even_storage = va_arg(vl, config_value); - values.set_conj_even_storage = true; + values_.conj_even_storage = va_arg(vl, config_value); break; default: err = 1; diff --git a/include/oneapi/mkl/dft/detail/commit_impl.hpp b/include/oneapi/mkl/dft/detail/commit_impl.hpp index 2cb556467..b0d10de8a 100644 --- a/include/oneapi/mkl/dft/detail/commit_impl.hpp +++ b/include/oneapi/mkl/dft/detail/commit_impl.hpp @@ -1,3 +1,22 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + #ifndef _ONEMKL_DFT_COMMIT_IMPL_HPP_ #define _ONEMKL_DFT_COMMIT_IMPL_HPP_ @@ -21,11 +40,9 @@ namespace detail { class commit_impl { public: - commit_impl(sycl::queue queue) : queue_(queue), handle(nullptr) {} - - commit_impl(const commit_impl& other) : queue_(other.queue_), handle(other.handle) {} + commit_impl(sycl::queue queue) : queue_(queue), status(false), handle(nullptr) {} - virtual commit_impl* copy_state() = 0; + commit_impl(const commit_impl& other) : queue_(other.queue_), status(other.status), handle(other.handle) {} virtual ~commit_impl() {} diff --git a/include/oneapi/mkl/dft/detail/dft_loader.hpp b/include/oneapi/mkl/dft/detail/dft_loader.hpp index 7c939e36e..a219d59bc 100644 --- a/include/oneapi/mkl/dft/detail/dft_loader.hpp +++ b/include/oneapi/mkl/dft/detail/dft_loader.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2021 Intel Corporation +* Copyright 2022 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/include/oneapi/mkl/dft/types.hpp b/include/oneapi/mkl/dft/types.hpp index 857fff8e1..b42a72543 100644 --- a/include/oneapi/mkl/dft/types.hpp +++ b/include/oneapi/mkl/dft/types.hpp @@ -27,25 +27,6 @@ #include #endif -#ifdef NDEBUG -#define logf(...) -#else -#define logf(...) \ - printf("%s - (%s) : ", __FILE__, __FUNCTION__); \ - printf(__VA_ARGS__); \ - printf("\n"); -#endif - -template -std::ostream& operator<<(std::ostream& os, const std::vector& vector) { - if (vector.empty()) return os; - os.put('['); - for (auto element : vector) { - os << element << ", "; - } - return os << "\b\b]"; -} - namespace oneapi { namespace mkl { namespace dft { @@ -129,39 +110,10 @@ struct dft_values { config_value complex_storage; config_value conj_even_storage; - bool set_input_strides = false; - bool set_output_strides = false; - bool set_bwd_scale = false; - bool set_fwd_scale = false; - bool set_number_of_transforms = false; - bool set_fwd_dist = false; - bool set_bwd_dist = false; - bool set_placement = false; - bool set_complex_storage = false; - bool set_conj_even_storage = false; - - std::vector dimension; + std::vector dimensions; std::int64_t rank; domain domain; precision precision; - friend auto operator<<(std::ostream& os, dft_values const& val) -> std::ostream& { - os << "------------- oneAPI Descriptor ------------\n"; - os << "input_strides : " << val.input_strides << "\n"; - os << "output_strides : " << val.output_strides << "\n"; - os << "bwd_scale : " << val.bwd_scale << "\n"; - os << "fwd_scale : " << val.fwd_scale << "\n"; - os << "number_of_transforms : " << val.number_of_transforms << "\n"; - os << "fwd_dist : " << val.fwd_dist << "\n"; - os << "bwd_dist : " << val.bwd_dist << "\n"; - os << "placement : " << (int) val.placement << "\n"; - os << "complex_storage : " << (int) val.complex_storage << "\n"; - os << "conj_even_storage : " << (int) val.conj_even_storage << "\n"; - os << "dimension : " << val.dimension << "\n"; - os << "rank : " << val.rank << "\n"; - os << "domain : " << dom_map[val.domain] << "\n"; - os << "precision : " << prec_map[val.precision]; - return os; - } }; } // namespace dft } // namespace mkl diff --git a/src/dft/backends/mklcpu/commit.cpp b/src/dft/backends/mklcpu/commit.cpp index b006705f6..7ce00e321 100644 --- a/src/dft/backends/mklcpu/commit.cpp +++ b/src/dft/backends/mklcpu/commit.cpp @@ -41,28 +41,29 @@ class commit_derived_impl : public oneapi::mkl::dft::detail::commit_impl { public: commit_derived_impl(sycl::queue queue, dft_values config_values) : oneapi::mkl::dft::detail::commit_impl(queue), - status(0) { - logf("CPU impl, handle->%p", &handle); + status(false) { DFTI_DESCRIPTOR_HANDLE local_handle = nullptr; - std::cout << config_values << std::endl; if (config_values.rank == 1) { - status = DftiCreateDescriptor(&local_handle, precision_map[config_values.precision], - domain_map[config_values.domain], config_values.rank, - config_values.dimension[0]); + status = DftiCreateDescriptor(&local_handle, get_precision(config_values.precision), + get_domain(config_values.domain), config_values.rank, + config_values.dimensions[0]); + } else { + status = DftiCreateDescriptor(&local_handle, get_precision(config_values.precision), + get_domain(config_values.domain), config_values.rank, + config_values.dimensions.data()); } - else { - status = DftiCreateDescriptor(&local_handle, precision_map[config_values.precision], - domain_map[config_values.domain], config_values.rank, - &config_values.dimension[0]); + if(status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception("dft", "commit", "DftiCreateDescriptor failed"); } - if(status != DFTI_NO_ERROR) throw oneapi::mkl::exception("dft", "commit", "DftiCreateDescriptor failed"); set_value(local_handle, config_values); status = DftiCommitDescriptor(local_handle); - if(status != DFTI_NO_ERROR) throw oneapi::mkl::exception("dft", "commit", "DftiCommitDescriptor failed"); + if(status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception("dft", "commit", "DftiCommitDescriptor failed"); + } // commit_impl (pimpl_->handle) should return this handle handle = local_handle; @@ -71,49 +72,45 @@ class commit_derived_impl : public oneapi::mkl::dft::detail::commit_impl { commit_derived_impl(const commit_derived_impl* other) : oneapi::mkl::dft::detail::commit_impl(*other) { } - virtual oneapi::mkl::dft::detail::commit_impl* copy_state() override { - return new commit_derived_impl(this); + virtual ~commit_derived_impl() override { + DftiFreeDescriptor((DFTI_DESCRIPTOR_HANDLE *) &handle); } - virtual ~commit_derived_impl() override { } - private: bool status; - std::unordered_map precision_map{ - { oneapi::mkl::dft::precision::SINGLE, DFTI_SINGLE }, - { oneapi::mkl::dft::precision::DOUBLE, DFTI_DOUBLE } - }; - std::unordered_map domain_map{ - { oneapi::mkl::dft::domain::REAL, DFTI_REAL }, - { oneapi::mkl::dft::domain::COMPLEX, DFTI_COMPLEX } - }; + constexpr DFTI_CONFIG_VALUE get_domain(oneapi::mkl::dft::domain dom) { + if (dom == oneapi::mkl::dft::domain::COMPLEX) { + return DFTI_COMPLEX; + } else { + return DFTI_REAL; + } + } - void set_value(DFTI_DESCRIPTOR_HANDLE& descHandle, dft_values config) { - logf("address of cpu handle->%p", &descHandle); - logf("handle is_null? %s", (descHandle == nullptr) ? "yes" : "no"); + constexpr DFTI_CONFIG_VALUE get_precision(oneapi::mkl::dft::precision prec) { + if (prec == oneapi::mkl::dft::precision::SINGLE) { + return DFTI_SINGLE; + } else { + return DFTI_DOUBLE; + } + } + void set_value(DFTI_DESCRIPTOR_HANDLE& descHandle, dft_values config) { // TODO : add complex storage and workspace - if (config.set_input_strides) - status |= DftiSetValue(descHandle, DFTI_INPUT_STRIDES, &config.input_strides[0]); - if (config.set_output_strides) - status |= DftiSetValue(descHandle, DFTI_OUTPUT_STRIDES, &config.output_strides[0]); - if (config.set_bwd_scale) + status |= DftiSetValue(descHandle, DFTI_INPUT_STRIDES, config.input_strides.data()); + status |= DftiSetValue(descHandle, DFTI_OUTPUT_STRIDES, config.output_strides.data()); status |= DftiSetValue(descHandle, DFTI_BACKWARD_SCALE, config.bwd_scale); - if (config.set_fwd_scale) - status |= DftiSetValue(descHandle, DFTI_BACKWARD_SCALE, config.fwd_scale); - if (config.set_number_of_transforms) + status |= DftiSetValue(descHandle, DFTI_FORWARD_SCALE, config.fwd_scale); status |= DftiSetValue(descHandle, DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms); - if (config.set_fwd_dist) - status |= DftiSetValue(descHandle, DFTI_FWD_DISTANCE, config.fwd_dist); - if (config.set_bwd_dist) - status |= DftiSetValue(descHandle, DFTI_BWD_DISTANCE, config.bwd_dist); - if (config.set_placement) + status |= DftiSetValue(descHandle, DFTI_INPUT_DISTANCE, config.fwd_dist); + status |= DftiSetValue(descHandle, DFTI_OUTPUT_DISTANCE, config.bwd_dist); status |= DftiSetValue(descHandle, DFTI_PLACEMENT, (config.placement == oneapi::mkl::dft::config_value::INPLACE) ? DFTI_INPLACE : DFTI_NOT_INPLACE); - if(status != DFTI_NO_ERROR) throw oneapi::mkl::exception("dft", "commit", "DftiSetValue failed"); + if(status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception("dft", "commit", "DftiSetValue failed"); + } } }; From 1852b137e1489ec485a237fb7dc0b021beb9ca58 Mon Sep 17 00:00:00 2001 From: "Anant, Srivastava" Date: Mon, 21 Nov 2022 04:08:01 -0800 Subject: [PATCH 12/21] adress commit handle and error handling --- include/oneapi/mkl/dft/descriptor.hpp | 1 + include/oneapi/mkl/dft/detail/commit_impl.hpp | 5 ++--- include/oneapi/mkl/dft/types.hpp | 19 ++++++++++++++++++ src/dft/backends/mklcpu/commit.cpp | 20 ++++++++----------- 4 files changed, 30 insertions(+), 15 deletions(-) diff --git a/include/oneapi/mkl/dft/descriptor.hpp b/include/oneapi/mkl/dft/descriptor.hpp index 4f30619fd..07e63e41b 100644 --- a/include/oneapi/mkl/dft/descriptor.hpp +++ b/include/oneapi/mkl/dft/descriptor.hpp @@ -130,6 +130,7 @@ void descriptor::set_value(config_param param, ...) { printf("oneapi interface set_value\n"); switch (param) { case config_param::INPUT_STRIDES: + [[fallthrough]]; case config_param::OUTPUT_STRIDES: { int64_t *strides = va_arg(vl, int64_t *); if (strides == nullptr) break; diff --git a/include/oneapi/mkl/dft/detail/commit_impl.hpp b/include/oneapi/mkl/dft/detail/commit_impl.hpp index b0d10de8a..06c88084a 100644 --- a/include/oneapi/mkl/dft/detail/commit_impl.hpp +++ b/include/oneapi/mkl/dft/detail/commit_impl.hpp @@ -40,9 +40,9 @@ namespace detail { class commit_impl { public: - commit_impl(sycl::queue queue) : queue_(queue), status(false), handle(nullptr) {} + commit_impl(sycl::queue queue) : queue_(queue), status(false) {} - commit_impl(const commit_impl& other) : queue_(other.queue_), status(other.status), handle(other.handle) {} + commit_impl(const commit_impl& other) : queue_(other.queue_), status(other.status) {} virtual ~commit_impl() {} @@ -53,7 +53,6 @@ class commit_impl { protected: bool status; sycl::queue queue_; - void* handle; }; diff --git a/include/oneapi/mkl/dft/types.hpp b/include/oneapi/mkl/dft/types.hpp index b42a72543..9f471f1fe 100644 --- a/include/oneapi/mkl/dft/types.hpp +++ b/include/oneapi/mkl/dft/types.hpp @@ -31,6 +31,25 @@ namespace oneapi { namespace mkl { namespace dft { +typedef int DFT_ERROR; +// this could be gereralized to device specific impl +// (this works for gpu and cpu both) +enum class error_status { + DFT_NO_ERROR = 0, + DFT_MEMORY_ERROR = 1, + DFT_INVALID_CONFIGURATION = 2, + DFT_INCONSISTENT_CONFIGURATION = 3, + DFT_MULTITHREADED_ERROR = 4, + DFT_BAD_DESCRIPTOR = 5, + DFT_UNIMPLEMENTED = 6, + DFT_MKL_INTERNAL_ERROR = 7, + DFT_NUMBER_OF_THREADS_ERROR = 8, + DFT_1D_LENGTH_EXCEEDS_INT32 = 9, + DFT_1D_MEMORY_EXCEEDS_INT32 = 9, + DFT_NO_WORKSPACE = 11, +}; + + enum class precision { SINGLE, DOUBLE }; enum class domain { REAL, COMPLEX }; enum class config_param { diff --git a/src/dft/backends/mklcpu/commit.cpp b/src/dft/backends/mklcpu/commit.cpp index 7ce00e321..44d0b37c4 100644 --- a/src/dft/backends/mklcpu/commit.cpp +++ b/src/dft/backends/mklcpu/commit.cpp @@ -41,16 +41,14 @@ class commit_derived_impl : public oneapi::mkl::dft::detail::commit_impl { public: commit_derived_impl(sycl::queue queue, dft_values config_values) : oneapi::mkl::dft::detail::commit_impl(queue), - status(false) { - - DFTI_DESCRIPTOR_HANDLE local_handle = nullptr; + status(-1) { if (config_values.rank == 1) { - status = DftiCreateDescriptor(&local_handle, get_precision(config_values.precision), + status = DftiCreateDescriptor(&handle, get_precision(config_values.precision), get_domain(config_values.domain), config_values.rank, config_values.dimensions[0]); } else { - status = DftiCreateDescriptor(&local_handle, get_precision(config_values.precision), + status = DftiCreateDescriptor(&handle, get_precision(config_values.precision), get_domain(config_values.domain), config_values.rank, config_values.dimensions.data()); } @@ -58,15 +56,12 @@ class commit_derived_impl : public oneapi::mkl::dft::detail::commit_impl { throw oneapi::mkl::exception("dft", "commit", "DftiCreateDescriptor failed"); } - set_value(local_handle, config_values); + set_value(handle, config_values); - status = DftiCommitDescriptor(local_handle); + status = DftiCommitDescriptor(handle); if(status != DFTI_NO_ERROR) { throw oneapi::mkl::exception("dft", "commit", "DftiCommitDescriptor failed"); } - - // commit_impl (pimpl_->handle) should return this handle - handle = local_handle; } commit_derived_impl(const commit_derived_impl* other) @@ -77,7 +72,8 @@ class commit_derived_impl : public oneapi::mkl::dft::detail::commit_impl { } private: - bool status; + DFT_ERROR status; + DFTI_DESCRIPTOR_HANDLE handle = nullptr; constexpr DFTI_CONFIG_VALUE get_domain(oneapi::mkl::dft::domain dom) { if (dom == oneapi::mkl::dft::domain::COMPLEX) { return DFTI_COMPLEX; @@ -95,7 +91,7 @@ class commit_derived_impl : public oneapi::mkl::dft::detail::commit_impl { } void set_value(DFTI_DESCRIPTOR_HANDLE& descHandle, dft_values config) { - // TODO : add complex storage and workspace + // TODO : add complex storage and workspace, fix error handling status |= DftiSetValue(descHandle, DFTI_INPUT_STRIDES, config.input_strides.data()); status |= DftiSetValue(descHandle, DFTI_OUTPUT_STRIDES, config.output_strides.data()); status |= DftiSetValue(descHandle, DFTI_BACKWARD_SCALE, config.bwd_scale); From c4ad8d8513690ef8c6b9ef8de96d96b9290aa14c Mon Sep 17 00:00:00 2001 From: "Anant, Srivastava" Date: Mon, 21 Nov 2022 04:13:10 -0800 Subject: [PATCH 13/21] remove the use of error enums --- include/oneapi/mkl/dft/types.hpp | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/include/oneapi/mkl/dft/types.hpp b/include/oneapi/mkl/dft/types.hpp index 9f471f1fe..c3debd94d 100644 --- a/include/oneapi/mkl/dft/types.hpp +++ b/include/oneapi/mkl/dft/types.hpp @@ -32,23 +32,6 @@ namespace mkl { namespace dft { typedef int DFT_ERROR; -// this could be gereralized to device specific impl -// (this works for gpu and cpu both) -enum class error_status { - DFT_NO_ERROR = 0, - DFT_MEMORY_ERROR = 1, - DFT_INVALID_CONFIGURATION = 2, - DFT_INCONSISTENT_CONFIGURATION = 3, - DFT_MULTITHREADED_ERROR = 4, - DFT_BAD_DESCRIPTOR = 5, - DFT_UNIMPLEMENTED = 6, - DFT_MKL_INTERNAL_ERROR = 7, - DFT_NUMBER_OF_THREADS_ERROR = 8, - DFT_1D_LENGTH_EXCEEDS_INT32 = 9, - DFT_1D_MEMORY_EXCEEDS_INT32 = 9, - DFT_NO_WORKSPACE = 11, -}; - enum class precision { SINGLE, DOUBLE }; enum class domain { REAL, COMPLEX }; From 6cf88ba4eaeba1d2be10fb89c47777059955dd2f Mon Sep 17 00:00:00 2001 From: "Anant, Srivastava" Date: Tue, 22 Nov 2022 21:30:00 -0800 Subject: [PATCH 14/21] generalize descriptor passing+extra forward decl --- .../complex_fwd_usm_mklcpu.cpp | 2 +- include/oneapi/mkl/dft/descriptor.hpp | 120 +------------- include/oneapi/mkl/dft/detail/dft_loader.hpp | 14 +- .../dft/detail/mklcpu/onemkl_dft_mklcpu.hpp | 31 +++- .../dft/detail/mklgpu/onemkl_dft_mklgpu.hpp | 30 +++- src/dft/CMakeLists.txt | 4 +- src/dft/backends/mklcpu/CMakeLists.txt | 1 + src/dft/backends/mklcpu/commit.cpp | 22 ++- .../backends/mklcpu/mkl_dft_cpu_wrappers.cpp | 3 + src/dft/backends/mklgpu/CMakeLists.txt | 1 + src/dft/descriptor.cpp | 146 ++++++++++++++++++ src/dft/dft_loader.cpp | 32 +++- src/dft/function_table.hpp | 13 +- 13 files changed, 283 insertions(+), 136 deletions(-) create mode 100644 src/dft/descriptor.cpp diff --git a/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp b/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp index a473ab333..486f31d19 100644 --- a/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp +++ b/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp @@ -132,7 +132,7 @@ int main(int argc, char** argv) { print_example_banner(); try { - sycl::device cpu_dev((sycl::cpu_selector())); + sycl::device cpu_dev((sycl::cpu_selector_v)); std::cout << "Running DFT Complex forward inplace USM example" << std::endl; std::cout << "Running with single precision real data type on:" << std::endl; std::cout << "\tcpu device :" << cpu_dev.get_info() << std::endl; diff --git a/include/oneapi/mkl/dft/descriptor.hpp b/include/oneapi/mkl/dft/descriptor.hpp index 07e63e41b..f9c734813 100644 --- a/include/oneapi/mkl/dft/descriptor.hpp +++ b/include/oneapi/mkl/dft/descriptor.hpp @@ -31,14 +31,7 @@ #include "oneapi/mkl/detail/backend_selector.hpp" #include "oneapi/mkl/dft/detail/commit_impl.hpp" -#include "oneapi/mkl/dft/detail/dft_loader.hpp" -#ifdef ENABLE_MKLCPU_BACKEND -#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" -#endif -#ifdef ENABLE_MKLGPU_BACKEND -#include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp" -#endif namespace oneapi { namespace mkl { namespace dft { @@ -58,23 +51,21 @@ class descriptor { void get_value(config_param param, ...); - void commit(sycl::queue& queue) { - pimpl_.reset(detail::create_commit(get_device_id(queue), queue, values_)); - } + void commit(sycl::queue& queue); #ifdef ENABLE_MKLCPU_BACKEND - void commit(backend_selector selector) { - pimpl_.reset(mklcpu::create_commit(selector.get_queue(), values_)); - } + void commit(backend_selector selector); #endif #ifdef ENABLE_MKLGPU_BACKEND - void commit(backend_selector selector) { - // pimpl_.reset(mklgpu::create_commit(selector.get_queue())); - } + void commit(backend_selector selector); #endif + + sycl::queue& get_queue() { return queue_; }; + dft_values get_values() { return values_; }; private: std::unique_ptr pimpl_; // commit only + sycl::queue queue_; std::int64_t rank_; std::vector dimensions_; @@ -84,103 +75,6 @@ class descriptor { oneapi::mkl::dft::dft_values values_; }; -template -descriptor::descriptor(std::vector dimensions) : - dimensions_(dimensions), - handle_(nullptr), - rank_(dimensions.size()) - { - // Compute default strides. - std::vector defaultStrides(rank_, 1); - for(int i = rank_ - 1; i < 0; --i){ - defaultStrides[i] = defaultStrides[i - 1] * dimensions_[i]; - } - defaultStrides[0] = 0; - values_.input_strides = defaultStrides; - values_.output_strides = defaultStrides; - values_.bwd_scale = 1.0; - values_.fwd_scale = 1.0; - values_.number_of_transforms = 1; - values_.fwd_dist = 1; - values_.bwd_dist = 1; - values_.placement = config_value::INPLACE; - values_.complex_storage = config_value::COMPLEX_COMPLEX; - values_.conj_even_storage = config_value::COMPLEX_COMPLEX; - values_.dimensions = dimensions_; - values_.rank = rank_; - values_.domain = dom; - values_.precision = prec; - } - -template -descriptor::descriptor(std::int64_t length) : - descriptor(std::vector{length}) {} - -template -descriptor::~descriptor() { - // call DftiFreeDescriptor -} - -// impliment error class -template -void descriptor::set_value(config_param param, ...) { - int err = 0; - va_list vl; - va_start(vl, param); - printf("oneapi interface set_value\n"); - switch (param) { - case config_param::INPUT_STRIDES: - [[fallthrough]]; - case config_param::OUTPUT_STRIDES: { - int64_t *strides = va_arg(vl, int64_t *); - if (strides == nullptr) break; - - if (param == config_param::INPUT_STRIDES) - std::copy(strides, strides+rank_+1, std::back_inserter(values_.input_strides)); - if (param == config_param::OUTPUT_STRIDES) - std::copy(strides, strides+rank_+1, std::back_inserter(values_.output_strides)); - } break; - case config_param::FORWARD_SCALE: - values_.fwd_scale = va_arg(vl, double); - break; - case config_param::BACKWARD_SCALE: - values_.bwd_scale = va_arg(vl, double); - break; - case config_param::NUMBER_OF_TRANSFORMS: - values_.number_of_transforms = va_arg(vl, int64_t); - break; - case config_param::FWD_DISTANCE: - values_.fwd_dist = va_arg(vl, int64_t); - break; - case config_param::BWD_DISTANCE: - values_.bwd_dist = va_arg(vl, int64_t); - break; - case config_param::PLACEMENT: - values_.placement = va_arg(vl, config_value); - break; - case config_param::COMPLEX_STORAGE: - values_.complex_storage = va_arg(vl, config_value); - case config_param::CONJUGATE_EVEN_STORAGE: - values_.conj_even_storage = va_arg(vl, config_value); - break; - - default: err = 1; - } - va_end(vl); -} - -template -void descriptor::get_value(config_param param, ...) { - int err = 0; - va_list vl; - va_start(vl, param); - switch (param) - { - default: break; - } - va_end(vl); -} - } //namespace dft } //namespace mkl } //namespace oneapi diff --git a/include/oneapi/mkl/dft/detail/dft_loader.hpp b/include/oneapi/mkl/dft/detail/dft_loader.hpp index a219d59bc..170292eff 100644 --- a/include/oneapi/mkl/dft/detail/dft_loader.hpp +++ b/include/oneapi/mkl/dft/detail/dft_loader.hpp @@ -30,15 +30,27 @@ #include "oneapi/mkl/detail/export.hpp" #include "oneapi/mkl/detail/get_device_id.hpp" +#include "oneapi/mkl/types.hpp" #include "oneapi/mkl/dft/types.hpp" #include "oneapi/mkl/dft/detail/commit_impl.hpp" +#include "oneapi/mkl/dft/descriptor.hpp" namespace oneapi { namespace mkl { namespace dft { namespace detail { -ONEMKL_EXPORT commit_impl* create_commit(oneapi::mkl::device libkey, sycl::queue queue, dft_values values); +ONEMKL_EXPORT commit_impl* create_commit(oneapi::mkl::dft::descriptor& desc); + +ONEMKL_EXPORT commit_impl* create_commit(oneapi::mkl::dft::descriptor& desc); + +ONEMKL_EXPORT commit_impl* create_commit(oneapi::mkl::dft::descriptor& desc); + +ONEMKL_EXPORT commit_impl* create_commit(oneapi::mkl::dft::descriptor& desc); } // namespace detail } // namespace dft diff --git a/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp b/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp index 5063a2ba8..53fb764c5 100644 --- a/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp +++ b/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp @@ -17,29 +17,48 @@ * SPDX-License-Identifier: Apache-2.0 *******************************************************************************/ -#pragma once +#ifndef _ONEMKL_DFT_MKLCPU_HPP_ +#define _ONEMKL_DFT_MKLCPU_HPP_ +#include #if __has_include() #include #else #include #endif -#include -#include +#include "oneapi/mkl/detail/export.hpp" +#include "oneapi/mkl/detail/get_device_id.hpp" -#include "oneapi/mkl/dft/detail/commit_impl.hpp" -#include "oneapi/mkl/dft/types.hpp" #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/dft/types.hpp" +#include "oneapi/mkl/dft/detail/commit_impl.hpp" +#include "oneapi/mkl/dft/descriptor.hpp" namespace oneapi { namespace mkl { namespace dft { namespace mklcpu { -ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit(sycl::queue queue, dft_values values); +ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit( + oneapi::mkl::dft::descriptor& desc); + +ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit( + oneapi::mkl::dft::descriptor& desc); + +ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit( + oneapi::mkl::dft::descriptor& desc); + +ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit( + oneapi::mkl::dft::descriptor& desc); } // namespace mklcpu } // namespace dft } // namespace mkl } // namespace oneapi + +#endif // _ONEMKL_DFT_MKLCPU_HPP_ \ No newline at end of file diff --git a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp index 81a45257d..569efefb3 100644 --- a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp +++ b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp @@ -17,28 +17,48 @@ * SPDX-License-Identifier: Apache-2.0 *******************************************************************************/ -#pragma once +#ifndef _ONEMKL_DFT_MKLGPU_HPP_ +#define _ONEMKL_DFT_MKLGPU_HPP_ +#include #if __has_include() #include #else #include #endif -#include -#include +#include "oneapi/mkl/detail/export.hpp" +#include "oneapi/mkl/detail/get_device_id.hpp" -#include "oneapi/mkl/dft/descriptor.hpp" #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/dft/types.hpp" +#include "oneapi/mkl/dft/detail/commit_impl.hpp" +#include "oneapi/mkl/dft/descriptor.hpp" namespace oneapi { namespace mkl { namespace dft { namespace mklgpu { -ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit(sycl::queue queue, dft_values values); +ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit( + oneapi::mkl::dft::descriptor& desc); + +ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit( + oneapi::mkl::dft::descriptor& desc); + +ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit( + oneapi::mkl::dft::descriptor& desc); + +ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit( + oneapi::mkl::dft::descriptor& desc); } // namespace mklgpu } // namespace dft } // namespace mkl } // namespace oneapi + +#endif // _ONEMKL_DFT_MKLGPU_HPP_ diff --git a/src/dft/CMakeLists.txt b/src/dft/CMakeLists.txt index d7f83cbc2..dc72a444d 100644 --- a/src/dft/CMakeLists.txt +++ b/src/dft/CMakeLists.txt @@ -23,7 +23,7 @@ add_subdirectory(backends) # Recipe for DFT loader object if(BUILD_SHARED_LIBS) add_library(onemkl_dft OBJECT) -target_sources(onemkl_dft PRIVATE dft_loader.cpp) +target_sources(onemkl_dft PRIVATE descriptor.cpp dft_loader.cpp) target_include_directories(onemkl_dft PRIVATE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src @@ -38,7 +38,7 @@ set_target_properties(onemkl_dft PROPERTIES POSITION_INDEPENDENT_CODE ON ) if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) - add_sycl_to_target(TARGET onemkl_dft SOURCES dft_loader.cpp) + add_sycl_to_target(TARGET onemkl_dft SOURCES descriptor.cpp dft_loader.cpp) else() target_link_libraries(onemkl_dft PUBLIC ONEMKL::SYCL::SYCL) endif() diff --git a/src/dft/backends/mklcpu/CMakeLists.txt b/src/dft/backends/mklcpu/CMakeLists.txt index 69978073d..266e8c0f3 100644 --- a/src/dft/backends/mklcpu/CMakeLists.txt +++ b/src/dft/backends/mklcpu/CMakeLists.txt @@ -26,6 +26,7 @@ find_package(MKL REQUIRED) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT commit.cpp + ../../descriptor.cpp forward.cpp backward.cpp $<$: mkl_dft_cpu_wrappers.cpp> diff --git a/src/dft/backends/mklcpu/commit.cpp b/src/dft/backends/mklcpu/commit.cpp index 44d0b37c4..65a1b34ae 100644 --- a/src/dft/backends/mklcpu/commit.cpp +++ b/src/dft/backends/mklcpu/commit.cpp @@ -25,6 +25,7 @@ #include "oneapi/mkl/types.hpp" #include "oneapi/mkl/dft/types.hpp" +#include "oneapi/mkl/dft/descriptor.hpp" #include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" @@ -110,8 +111,25 @@ class commit_derived_impl : public oneapi::mkl::dft::detail::commit_impl { } }; -oneapi::mkl::dft::detail::commit_impl* create_commit(sycl::queue queue, dft_values values) { - return new commit_derived_impl(queue, values); +oneapi::mkl::dft::detail::commit_impl* create_commit( + oneapi::mkl::dft::descriptor &desc +) { + return new commit_derived_impl(desc.get_queue(), desc.get_values()); +} +oneapi::mkl::dft::detail::commit_impl* create_commit( + oneapi::mkl::dft::descriptor &desc +) { + return new commit_derived_impl(desc.get_queue(), desc.get_values()); +} +oneapi::mkl::dft::detail::commit_impl* create_commit( + oneapi::mkl::dft::descriptor &desc +) { + return new commit_derived_impl(desc.get_queue(), desc.get_values()); +} +oneapi::mkl::dft::detail::commit_impl* create_commit( + oneapi::mkl::dft::descriptor &desc +) { + return new commit_derived_impl(desc.get_queue(), desc.get_values()); } } // namespace mklcpu diff --git a/src/dft/backends/mklcpu/mkl_dft_cpu_wrappers.cpp b/src/dft/backends/mklcpu/mkl_dft_cpu_wrappers.cpp index 9fc897f3d..90d39ec1f 100644 --- a/src/dft/backends/mklcpu/mkl_dft_cpu_wrappers.cpp +++ b/src/dft/backends/mklcpu/mkl_dft_cpu_wrappers.cpp @@ -24,5 +24,8 @@ extern "C" dft_function_table_t mkl_dft_table = { WRAPPER_VERSION, + oneapi::mkl::dft::mklcpu::create_commit, + oneapi::mkl::dft::mklcpu::create_commit, + oneapi::mkl::dft::mklcpu::create_commit, oneapi::mkl::dft::mklcpu::create_commit }; diff --git a/src/dft/backends/mklgpu/CMakeLists.txt b/src/dft/backends/mklgpu/CMakeLists.txt index c30bf1ecb..ca9cb9d09 100644 --- a/src/dft/backends/mklgpu/CMakeLists.txt +++ b/src/dft/backends/mklgpu/CMakeLists.txt @@ -24,6 +24,7 @@ find_package(MKL REQUIRED) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT + ../../descriptor.cpp commit.cpp forward.cpp backward.cpp diff --git a/src/dft/descriptor.cpp b/src/dft/descriptor.cpp new file mode 100644 index 000000000..75b0763fc --- /dev/null +++ b/src/dft/descriptor.cpp @@ -0,0 +1,146 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/dft/descriptor.hpp" + +#ifdef ENABLE_MKLCPU_BACKEND +#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" +#endif +#ifdef ENABLE_MKLGPU_BACKEND +#include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp" +#endif +#include "oneapi/mkl/dft/detail/dft_loader.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { + +template +void descriptor::commit(sycl::queue &queue) { + queue_ = queue; + pimpl_.reset(detail::create_commit(*this)); +} + +#ifdef ENABLE_MKLCPU_BACKEND +template +void descriptor::commit(backend_selector selector) { + queue_ = selector.get_queue(); + pimpl_.reset(mklcpu::create_commit(*this)); +} +#endif + +#ifdef ENABLE_MKLGPU_BACKEND +template +void descriptor::commit(backend_selector selector) { + queue_ = selector.get_queue(); + // pimpl_.reset(mklgpu::create_commit(*this)); +} +#endif + +// impliment error class +template +void descriptor::set_value(config_param param, ...) { + int err = 0; + va_list vl; + va_start(vl, param); + printf("oneapi interface set_value\n"); + switch (param) { + case config_param::INPUT_STRIDES: [[fallthrough]]; + case config_param::OUTPUT_STRIDES: { + int64_t *strides = va_arg(vl, int64_t *); + if (strides == nullptr) + break; + if (param == config_param::INPUT_STRIDES) + std::copy(strides, strides + rank_ + 1, std::back_inserter(values_.input_strides)); + if (param == config_param::OUTPUT_STRIDES) + std::copy(strides, strides + rank_ + 1, std::back_inserter(values_.output_strides)); + } break; + case config_param::FORWARD_SCALE: values_.fwd_scale = va_arg(vl, double); break; + case config_param::BACKWARD_SCALE: values_.bwd_scale = va_arg(vl, double); break; + case config_param::NUMBER_OF_TRANSFORMS: + values_.number_of_transforms = va_arg(vl, int64_t); + break; + case config_param::FWD_DISTANCE: values_.fwd_dist = va_arg(vl, int64_t); break; + case config_param::BWD_DISTANCE: values_.bwd_dist = va_arg(vl, int64_t); break; + case config_param::PLACEMENT: values_.placement = va_arg(vl, config_value); break; + case config_param::COMPLEX_STORAGE: + values_.complex_storage = va_arg(vl, config_value); + break; + case config_param::CONJUGATE_EVEN_STORAGE: + values_.conj_even_storage = va_arg(vl, config_value); + break; + default: err = 1; + } + va_end(vl); +} +template +descriptor::descriptor(std::vector dimensions) + : dimensions_(dimensions), + handle_(nullptr), + rank_(dimensions.size()) { + // Compute default strides. + std::vector defaultStrides(rank_, 1); + for (int i = rank_ - 1; i < 0; --i) { + defaultStrides[i] = defaultStrides[i - 1] * dimensions_[i]; + } + defaultStrides[0] = 0; + values_.input_strides = defaultStrides; + values_.output_strides = defaultStrides; + values_.bwd_scale = 1.0; + values_.fwd_scale = 1.0; + values_.number_of_transforms = 1; + values_.fwd_dist = 1; + values_.bwd_dist = 1; + values_.placement = config_value::INPLACE; + values_.complex_storage = config_value::COMPLEX_COMPLEX; + values_.conj_even_storage = config_value::COMPLEX_COMPLEX; + values_.dimensions = dimensions_; + values_.rank = rank_; + values_.domain = dom; + values_.precision = prec; +} + +template +descriptor::descriptor(std::int64_t length) + : descriptor(std::vector{ length }) {} + +template +descriptor::~descriptor() { + // call DftiFreeDescriptor +} + +template +void descriptor::get_value(config_param param, ...) { + int err = 0; + va_list vl; + va_start(vl, param); + switch (param) { + default: break; + } + va_end(vl); +} + +template class descriptor; +template class descriptor; +template class descriptor; +template class descriptor; + +} //namespace dft +} //namespace mkl +} //namespace oneapi \ No newline at end of file diff --git a/src/dft/dft_loader.cpp b/src/dft/dft_loader.cpp index 65e251efd..84523e533 100644 --- a/src/dft/dft_loader.cpp +++ b/src/dft/dft_loader.cpp @@ -21,19 +21,41 @@ #include "function_table_initializer.hpp" #include "dft/function_table.hpp" - +#include "oneapi/mkl/detail/get_device_id.hpp" + namespace oneapi { namespace mkl { namespace dft { namespace detail { -static oneapi::mkl::detail::table_initializer function_tables; +static oneapi::mkl::detail::table_initializer + function_tables; + +commit_impl* create_commit(oneapi::mkl::dft::descriptor& desc) { + auto libkey = get_device_id(desc.get_queue()); + return function_tables[libkey].create_commit_sycl_fz(desc); +} + +commit_impl* create_commit(oneapi::mkl::dft::descriptor& desc) { + auto libkey = get_device_id(desc.get_queue()); + return function_tables[libkey].create_commit_sycl_dz(desc); +} + +commit_impl* create_commit(oneapi::mkl::dft::descriptor& desc) { + auto libkey = get_device_id(desc.get_queue()); + return function_tables[libkey].create_commit_sycl_fr(desc); +} -commit_impl* create_commit(oneapi::mkl::device libkey, sycl::queue queue, dft_values values) { - return function_tables[libkey].create_commit_sycl(queue, values); +commit_impl* create_commit(oneapi::mkl::dft::descriptor& desc) { + auto libkey = get_device_id(desc.get_queue()); + return function_tables[libkey].create_commit_sycl_dr(desc); } } // namespace detail } // namespace dft } // namespace mkl -} // namespace oneapi +} // namespace oneapi \ No newline at end of file diff --git a/src/dft/function_table.hpp b/src/dft/function_table.hpp index aa7039496..d075cb2a8 100644 --- a/src/dft/function_table.hpp +++ b/src/dft/function_table.hpp @@ -35,7 +35,18 @@ typedef struct { int version; - oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl)(sycl::queue queue, oneapi::mkl::dft::dft_values values); + oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl_fz)( + oneapi::mkl::dft::descriptor& desc); + oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl_dz)( + oneapi::mkl::dft::descriptor& desc); + oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl_fr)( + oneapi::mkl::dft::descriptor& desc); + oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl_dr)( + oneapi::mkl::dft::descriptor& desc); } dft_function_table_t; #endif //_DFT_FUNCTION_TABLE_HPP_ From 6de0cdd010898d87e116a1660b8f34536f2848f8 Mon Sep 17 00:00:00 2001 From: "Anant, Srivastava" Date: Tue, 22 Nov 2022 21:34:00 -0800 Subject: [PATCH 15/21] revert testing --- .../rng/compile_time_dispatching/uniform_usm_mklcpu_curand.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/rng/compile_time_dispatching/uniform_usm_mklcpu_curand.cpp b/examples/rng/compile_time_dispatching/uniform_usm_mklcpu_curand.cpp index 7eb3ce83a..cdfd6c765 100644 --- a/examples/rng/compile_time_dispatching/uniform_usm_mklcpu_curand.cpp +++ b/examples/rng/compile_time_dispatching/uniform_usm_mklcpu_curand.cpp @@ -98,7 +98,6 @@ void run_uniform_example(const sycl::device& cpu_dev, const sycl::device& gpu_de // preparation on CPU device and GPU device sycl::queue cpu_queue(cpu_dev, cpu_exception_handler); sycl::queue gpu_queue(gpu_dev, gpu_exception_handler); - oneapi::mkl::rng::default_engine test_engine(cpu_queue, seed); oneapi::mkl::rng::default_engine cpu_engine( oneapi::mkl::backend_selector{ cpu_queue }, seed); oneapi::mkl::rng::default_engine gpu_engine( From 9385c865fb5ffa606dcf291533b908b51c8f36dc Mon Sep 17 00:00:00 2001 From: "Anant, Srivastava" Date: Wed, 23 Nov 2022 02:22:54 -0800 Subject: [PATCH 16/21] fix linking issues; update backend table --- examples/dft/CMakeLists.txt | 12 +-- .../run_time_dispatching/complex_fwd_usm.cpp | 2 +- include/oneapi/mkl/detail/backends_table.hpp | 9 +- src/dft/CMakeLists.txt | 2 +- src/dft/backends/mklcpu/CMakeLists.txt | 2 +- src/dft/backends/mklcpu/commit.cpp | 100 ++++++++++-------- src/dft/descriptor.cpp | 5 +- 7 files changed, 74 insertions(+), 58 deletions(-) diff --git a/examples/dft/CMakeLists.txt b/examples/dft/CMakeLists.txt index e43bea36d..06bd70859 100644 --- a/examples/dft/CMakeLists.txt +++ b/examples/dft/CMakeLists.txt @@ -17,15 +17,11 @@ # SPDX-License-Identifier: Apache-2.0 #=============================================================================== -# Note: compile-time example uses both MKLCPU and CURAND backends, therefore -# cmake in the sub-directory will only build it if CURAND backend is enabled -add_subdirectory(compile_time_dispatching) - # Note: compile-time example uses both MKLCPU and CUSOLVER backends, therefore # cmake in the sub-directory will only build it if CUSOLVER backend is enabled -# add_subdirectory(compile_time_dispatching) +add_subdirectory(compile_time_dispatching) # runtime compilation is only possible with dynamic libraries -# if (BUILD_SHARED_LIBS) -# add_subdirectory(run_time_dispatching) -# endif() +if (BUILD_SHARED_LIBS) + add_subdirectory(run_time_dispatching) +endif() diff --git a/examples/dft/run_time_dispatching/complex_fwd_usm.cpp b/examples/dft/run_time_dispatching/complex_fwd_usm.cpp index 5b44d3442..d504fdb0d 100644 --- a/examples/dft/run_time_dispatching/complex_fwd_usm.cpp +++ b/examples/dft/run_time_dispatching/complex_fwd_usm.cpp @@ -105,7 +105,7 @@ int main(int argc, char** argv) { print_example_banner(); try { - sycl::device my_dev((sycl::default_selector())); + sycl::device my_dev((sycl::default_selector_v)); if (my_dev.is_gpu()) { std::cout << "Running DFT complex forward example on GPU device" << std::endl; diff --git a/include/oneapi/mkl/detail/backends_table.hpp b/include/oneapi/mkl/detail/backends_table.hpp index c51f2589b..9432093e2 100644 --- a/include/oneapi/mkl/detail/backends_table.hpp +++ b/include/oneapi/mkl/detail/backends_table.hpp @@ -74,7 +74,13 @@ static std::map>> libraries = } } } }, { domain::dft, - { { device::intelgpu, + { { device::x86cpu, + { +#ifdef ENABLE_MKLCPU_BACKEND + LIB_NAME("dft_mklcpu") +#endif + } }, + { device::intelgpu, { #ifdef ENABLE_MKLGPU_BACKEND LIB_NAME("dft_mklgpu") @@ -130,6 +136,7 @@ static std::map>> libraries = static std::map table_names = { { domain::blas, "mkl_blas_table" }, { domain::lapack, "mkl_lapack_table" }, + { domain::dft, "mkl_dft_table" }, { domain::rng, "mkl_rng_table" } }; } //namespace mkl diff --git a/src/dft/CMakeLists.txt b/src/dft/CMakeLists.txt index dc72a444d..7b7e3816c 100644 --- a/src/dft/CMakeLists.txt +++ b/src/dft/CMakeLists.txt @@ -32,7 +32,7 @@ target_include_directories(onemkl_dft $ ) -target_compile_options(onemkl_dft PRIVATE ${ONEMKL_BUILD_COPT}) +target_compile_options(onemkl_dft PRIVATE ${ONEMKL_BUILD_COPT} -DBUILD_RUN) set_target_properties(onemkl_dft PROPERTIES POSITION_INDEPENDENT_CODE ON diff --git a/src/dft/backends/mklcpu/CMakeLists.txt b/src/dft/backends/mklcpu/CMakeLists.txt index 266e8c0f3..a9f3647a4 100644 --- a/src/dft/backends/mklcpu/CMakeLists.txt +++ b/src/dft/backends/mklcpu/CMakeLists.txt @@ -39,7 +39,7 @@ target_include_directories(${LIB_OBJ} ${MKL_INCLUDE} ) -target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT}) +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT} -DBUILD_COMP) if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) add_sycl_to_target(TARGET ${LIB_OBJ} SOURCES ${SOURCES}) endif() diff --git a/src/dft/backends/mklcpu/commit.cpp b/src/dft/backends/mklcpu/commit.cpp index 65a1b34ae..54f88dbbd 100644 --- a/src/dft/backends/mklcpu/commit.cpp +++ b/src/dft/backends/mklcpu/commit.cpp @@ -38,98 +38,108 @@ namespace mkl { namespace dft { namespace mklcpu { +template class commit_derived_impl : public oneapi::mkl::dft::detail::commit_impl { public: commit_derived_impl(sycl::queue queue, dft_values config_values) : oneapi::mkl::dft::detail::commit_impl(queue), status(-1) { - if (config_values.rank == 1) { - status = DftiCreateDescriptor(&handle, get_precision(config_values.precision), - get_domain(config_values.domain), config_values.rank, - config_values.dimensions[0]); - } else { - status = DftiCreateDescriptor(&handle, get_precision(config_values.precision), - get_domain(config_values.domain), config_values.rank, - config_values.dimensions.data()); + status = DftiCreateDescriptor(&handle, get_precision(prec), get_domain(dom), + config_values.rank, config_values.dimensions[0]); + } + else { + status = DftiCreateDescriptor(&handle, get_precision(prec), get_domain(dom), + config_values.rank, config_values.dimensions.data()); } - if(status != DFTI_NO_ERROR) { + if (status != DFTI_NO_ERROR) { throw oneapi::mkl::exception("dft", "commit", "DftiCreateDescriptor failed"); } set_value(handle, config_values); status = DftiCommitDescriptor(handle); - if(status != DFTI_NO_ERROR) { + if (status != DFTI_NO_ERROR) { throw oneapi::mkl::exception("dft", "commit", "DftiCommitDescriptor failed"); } } commit_derived_impl(const commit_derived_impl* other) - : oneapi::mkl::dft::detail::commit_impl(*other) { } + : oneapi::mkl::dft::detail::commit_impl(*other) {} - virtual ~commit_derived_impl() override { - DftiFreeDescriptor((DFTI_DESCRIPTOR_HANDLE *) &handle); + virtual ~commit_derived_impl() override { + DftiFreeDescriptor((DFTI_DESCRIPTOR_HANDLE*)&handle); } private: DFT_ERROR status; DFTI_DESCRIPTOR_HANDLE handle = nullptr; - constexpr DFTI_CONFIG_VALUE get_domain(oneapi::mkl::dft::domain dom) { - if (dom == oneapi::mkl::dft::domain::COMPLEX) { + + constexpr DFTI_CONFIG_VALUE get_domain(oneapi::mkl::dft::domain d) { + if (d == oneapi::mkl::dft::domain::COMPLEX) { return DFTI_COMPLEX; - } else { + } + else { return DFTI_REAL; } } - constexpr DFTI_CONFIG_VALUE get_precision(oneapi::mkl::dft::precision prec) { - if (prec == oneapi::mkl::dft::precision::SINGLE) { + constexpr DFTI_CONFIG_VALUE get_precision(oneapi::mkl::dft::precision p) { + if (p == oneapi::mkl::dft::precision::SINGLE) { return DFTI_SINGLE; - } else { + } + else { return DFTI_DOUBLE; } } void set_value(DFTI_DESCRIPTOR_HANDLE& descHandle, dft_values config) { - // TODO : add complex storage and workspace, fix error handling - status |= DftiSetValue(descHandle, DFTI_INPUT_STRIDES, config.input_strides.data()); - status |= DftiSetValue(descHandle, DFTI_OUTPUT_STRIDES, config.output_strides.data()); - status |= DftiSetValue(descHandle, DFTI_BACKWARD_SCALE, config.bwd_scale); - status |= DftiSetValue(descHandle, DFTI_FORWARD_SCALE, config.fwd_scale); - status |= DftiSetValue(descHandle, DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms); - status |= DftiSetValue(descHandle, DFTI_INPUT_DISTANCE, config.fwd_dist); - status |= DftiSetValue(descHandle, DFTI_OUTPUT_DISTANCE, config.bwd_dist); - status |= DftiSetValue(descHandle, DFTI_PLACEMENT, - (config.placement == oneapi::mkl::dft::config_value::INPLACE) - ? DFTI_INPLACE - : DFTI_NOT_INPLACE); - - if(status != DFTI_NO_ERROR) { + // TODO : add complex storage and workspace, fix error handling + status |= DftiSetValue(descHandle, DFTI_INPUT_STRIDES, config.input_strides.data()); + status |= DftiSetValue(descHandle, DFTI_OUTPUT_STRIDES, config.output_strides.data()); + status |= DftiSetValue(descHandle, DFTI_BACKWARD_SCALE, config.bwd_scale); + status |= DftiSetValue(descHandle, DFTI_FORWARD_SCALE, config.fwd_scale); + status |= DftiSetValue(descHandle, DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms); + status |= DftiSetValue(descHandle, DFTI_INPUT_DISTANCE, config.fwd_dist); + status |= DftiSetValue(descHandle, DFTI_OUTPUT_DISTANCE, config.bwd_dist); + status |= DftiSetValue(descHandle, DFTI_PLACEMENT, + (config.placement == oneapi::mkl::dft::config_value::INPLACE) + ? DFTI_INPLACE + : DFTI_NOT_INPLACE); + + if (status != DFTI_NO_ERROR) { throw oneapi::mkl::exception("dft", "commit", "DftiSetValue failed"); } } }; oneapi::mkl::dft::detail::commit_impl* create_commit( - oneapi::mkl::dft::descriptor &desc -) { - return new commit_derived_impl(desc.get_queue(), desc.get_values()); + oneapi::mkl::dft::descriptor& desc) { + return new commit_derived_impl(desc.get_queue(), + desc.get_values()); } oneapi::mkl::dft::detail::commit_impl* create_commit( - oneapi::mkl::dft::descriptor &desc -) { - return new commit_derived_impl(desc.get_queue(), desc.get_values()); + oneapi::mkl::dft::descriptor& desc) { + return new commit_derived_impl(desc.get_queue(), + desc.get_values()); } oneapi::mkl::dft::detail::commit_impl* create_commit( - oneapi::mkl::dft::descriptor &desc -) { - return new commit_derived_impl(desc.get_queue(), desc.get_values()); + oneapi::mkl::dft::descriptor& desc) { + return new commit_derived_impl(desc.get_queue(), + desc.get_values()); } oneapi::mkl::dft::detail::commit_impl* create_commit( - oneapi::mkl::dft::descriptor &desc -) { - return new commit_derived_impl(desc.get_queue(), desc.get_values()); + oneapi::mkl::dft::descriptor& desc) { + return new commit_derived_impl(desc.get_queue(), + desc.get_values()); } } // namespace mklcpu diff --git a/src/dft/descriptor.cpp b/src/dft/descriptor.cpp index 75b0763fc..805866342 100644 --- a/src/dft/descriptor.cpp +++ b/src/dft/descriptor.cpp @@ -31,12 +31,15 @@ namespace oneapi { namespace mkl { namespace dft { +#ifdef BUILD_RUN template void descriptor::commit(sycl::queue &queue) { queue_ = queue; pimpl_.reset(detail::create_commit(*this)); } +#endif +#ifdef BUILD_COMP #ifdef ENABLE_MKLCPU_BACKEND template void descriptor::commit(backend_selector selector) { @@ -52,6 +55,7 @@ void descriptor::commit(backend_selector selector) { // pimpl_.reset(mklgpu::create_commit(*this)); } #endif +#endif // impliment error class template @@ -59,7 +63,6 @@ void descriptor::set_value(config_param param, ...) { int err = 0; va_list vl; va_start(vl, param); - printf("oneapi interface set_value\n"); switch (param) { case config_param::INPUT_STRIDES: [[fallthrough]]; case config_param::OUTPUT_STRIDES: { From 96ac5de2d3282443992513858f3575530c062e43 Mon Sep 17 00:00:00 2001 From: "Anant, Srivastava" Date: Wed, 23 Nov 2022 02:41:44 -0800 Subject: [PATCH 17/21] revert some old comments --- .../complex_fwd_usm_mklcpu.cpp | 23 +- src/dft/backends/mklcpu/backward.cpp | 323 ++++++++--------- src/dft/backends/mklcpu/commit.cpp | 60 ++-- src/dft/backends/mklcpu/forward.cpp | 329 +++++++++--------- 4 files changed, 360 insertions(+), 375 deletions(-) diff --git a/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp b/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp index 486f31d19..bc7f149fa 100644 --- a/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp +++ b/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu.cpp @@ -77,23 +77,20 @@ void run_getrs_example(const sycl::device& cpu_device) { sycl::context cpu_context = cpu_queue.get_context(); sycl::event cpu_getrf_done; -double *x_usm = (double*) malloc_shared(N*2*sizeof(double), cpu_queue.get_device(), cpu_queue.get_context()); + double *x_usm = (double*) malloc_shared(N*2*sizeof(double), cpu_queue.get_device(), cpu_queue.get_context()); -// enabling -// 1. create descriptors -oneapi::mkl::dft::descriptor desc(N); + // enabling + // 1. create descriptors + oneapi::mkl::dft::descriptor desc(N); -// 2. variadic set_value -desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); + // 2. variadic set_value + desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); -// 3. commit_descriptor (compile_time CPU) -desc.commit(oneapi::mkl::backend_selector{ cpu_queue }); + // 3. commit_descriptor (compile_time CPU) + desc.commit(oneapi::mkl::backend_selector{ cpu_queue }); -// 4. commit_descriptor (run_time xPU) unusable from libonemkl_dft_mklcpu.so -// desc.commit(cpu_queue); - -// 5. compute_forward / compute_backward (CPU) -// oneapi::mkl::dft::compute_forward(desc, x_usm); + // 5. compute_forward / compute_backward (CPU) + // oneapi::mkl::dft::compute_forward(desc, x_usm); } // diff --git a/src/dft/backends/mklcpu/backward.cpp b/src/dft/backends/mklcpu/backward.cpp index 4cafbe549..9d226a3a0 100644 --- a/src/dft/backends/mklcpu/backward.cpp +++ b/src/dft/backends/mklcpu/backward.cpp @@ -26,6 +26,7 @@ #include "oneapi/mkl/types.hpp" #include "oneapi/mkl/dft/types.hpp" +#include "oneapi/mkl/dft/descriptor.hpp" #include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" namespace oneapi { @@ -33,174 +34,174 @@ namespace mkl { namespace dft { namespace mklcpu { -// void compute_backward_buffer_inplace_f(descriptor &desc, -// sycl::buffer, 1> &inout) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_backward_buffer_inplace_c(descriptor &desc, -// sycl::buffer, 1> &inout) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_backward_buffer_inplace_d(descriptor &desc, -// sycl::buffer, 1> &inout) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_backward_buffer_inplace_z(descriptor &desc, -// sycl::buffer, 1> &inout) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } +void compute_backward_buffer_inplace_f(descriptor &desc, + sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_inplace_c(descriptor &desc, + sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_inplace_d(descriptor &desc, + sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_inplace_z(descriptor &desc, + sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklcpu"); +} -// void compute_backward_buffer_inplace_split_f(descriptor &desc, -// sycl::buffer &inout_re, -// sycl::buffer &inout_im) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_backward_buffer_inplace_split_c(descriptor &desc, -// sycl::buffer &inout_re, -// sycl::buffer &inout_im) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_backward_buffer_inplace_split_d(descriptor &desc, -// sycl::buffer &inout_re, -// sycl::buffer &inout_im) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_backward_buffer_inplace_split_z(descriptor &desc, -// sycl::buffer &inout_re, -// sycl::buffer &inout_im) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } +void compute_backward_buffer_inplace_split_f(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_inplace_split_c(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_inplace_split_d(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_inplace_split_z(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} -// void compute_backward_buffer_outofplace_f(descriptor &desc, -// sycl::buffer, 1> &in, -// sycl::buffer &out) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_backward_buffer_outofplace_c(descriptor &desc, -// sycl::buffer, 1> &in, -// sycl::buffer, 1> &out) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_backward_buffer_outofplace_d(descriptor &desc, -// sycl::buffer, 1> &in, -// sycl::buffer &out) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_backward_buffer_outofplace_z(descriptor &desc, -// sycl::buffer, 1> &in, -// sycl::buffer, 1> &out) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } +void compute_backward_buffer_outofplace_f(descriptor &desc, + sycl::buffer, 1> &in, + sycl::buffer &out) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_outofplace_c(descriptor &desc, + sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_outofplace_d(descriptor &desc, + sycl::buffer, 1> &in, + sycl::buffer &out) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_outofplace_z(descriptor &desc, + sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { + throw std::runtime_error("Not implemented for mklcpu"); +} -// void compute_backward_buffer_outofplace_split_f(descriptor &desc, -// sycl::buffer &in_re, -// sycl::buffer &in_im, -// sycl::buffer &out_re, -// sycl::buffer &out_im) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_backward_buffer_outofplace_split_c( -// descriptor &desc, sycl::buffer &in_re, -// sycl::buffer &in_im, sycl::buffer &out_re, sycl::buffer &out_im) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_backward_buffer_outofplace_split_d(descriptor &desc, -// sycl::buffer &in_re, -// sycl::buffer &in_im, -// sycl::buffer &out_re, -// sycl::buffer &out_im) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_backward_buffer_outofplace_split_z( -// descriptor &desc, sycl::buffer &in_re, -// sycl::buffer &in_im, sycl::buffer &out_re, -// sycl::buffer &out_im) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } +void compute_backward_buffer_outofplace_split_f(descriptor &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_outofplace_split_c( + descriptor &desc, sycl::buffer &in_re, + sycl::buffer &in_im, sycl::buffer &out_re, sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_outofplace_split_d(descriptor &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_backward_buffer_outofplace_split_z( + descriptor &desc, sycl::buffer &in_re, + sycl::buffer &in_im, sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} -// sycl::event compute_backward_usm_inplace_f(descriptor &desc, -// std::complex *inout, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_backward_usm_inplace_c(descriptor &desc, -// std::complex *inout, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_backward_usm_inplace_d(descriptor &desc, -// std::complex *inout, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_backward_usm_inplace_z(descriptor &desc, -// std::complex *inout, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } +sycl::event compute_backward_usm_inplace_f(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_inplace_c(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_inplace_d(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_inplace_z(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} -// sycl::event compute_backward_usm_inplace_split_f(descriptor &desc, -// float *inout_re, float *inout_im, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_backward_usm_inplace_split_c( -// descriptor &desc, float *inout_re, float *inout_im, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_backward_usm_inplace_split_d(descriptor &desc, -// double *inout_re, double *inout_im, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_backward_usm_inplace_split_z( -// descriptor &desc, double *inout_re, double *inout_im, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } +sycl::event compute_backward_usm_inplace_split_f(descriptor &desc, + float *inout_re, float *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_inplace_split_c( + descriptor &desc, float *inout_re, float *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_inplace_split_d(descriptor &desc, + double *inout_re, double *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_inplace_split_z( + descriptor &desc, double *inout_re, double *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} -// sycl::event compute_backward_usm_outofplace_f(descriptor &desc, -// std::complex *in, float *out, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_backward_usm_outofplace_c(descriptor &desc, -// std::complex *in, std::complex *out, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_backward_usm_outofplace_d(descriptor &desc, -// std::complex *in, double *out, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_backward_usm_outofplace_z(descriptor &desc, -// std::complex *in, std::complex *out, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } +sycl::event compute_backward_usm_outofplace_f(descriptor &desc, + std::complex *in, float *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_outofplace_c(descriptor &desc, + std::complex *in, std::complex *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_outofplace_d(descriptor &desc, + std::complex *in, double *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_outofplace_z(descriptor &desc, + std::complex *in, std::complex *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} -// sycl::event compute_backward_usm_outofplace_split_f( -// descriptor &desc, float *in_re, float *in_im, float *out_re, -// float *out_im, const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_backward_usm_outofplace_split_c( -// descriptor &desc, float *in_re, float *in_im, float *out_re, -// float *out_im, const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_backward_usm_outofplace_split_d( -// descriptor &desc, double *in_re, double *in_im, double *out_re, -// double *out_im, const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_backward_usm_outofplace_split_z( -// descriptor &desc, double *in_re, double *in_im, -// double *out_re, double *out_im, const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } +sycl::event compute_backward_usm_outofplace_split_f( + descriptor &desc, float *in_re, float *in_im, float *out_re, + float *out_im, const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_outofplace_split_c( + descriptor &desc, float *in_re, float *in_im, float *out_re, + float *out_im, const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_outofplace_split_d( + descriptor &desc, double *in_re, double *in_im, double *out_re, + double *out_im, const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_backward_usm_outofplace_split_z( + descriptor &desc, double *in_re, double *in_im, + double *out_re, double *out_im, const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} } // namespace mklcpu } // namespace dft diff --git a/src/dft/backends/mklcpu/commit.cpp b/src/dft/backends/mklcpu/commit.cpp index 54f88dbbd..974e72588 100644 --- a/src/dft/backends/mklcpu/commit.cpp +++ b/src/dft/backends/mklcpu/commit.cpp @@ -38,11 +38,11 @@ namespace mkl { namespace dft { namespace mklcpu { -template -class commit_derived_impl : public oneapi::mkl::dft::detail::commit_impl { +template +class commit_derived_impl : public detail::commit_impl { public: commit_derived_impl(sycl::queue queue, dft_values config_values) - : oneapi::mkl::dft::detail::commit_impl(queue), + : detail::commit_impl(queue), status(-1) { if (config_values.rank == 1) { status = DftiCreateDescriptor(&handle, get_precision(prec), get_domain(dom), @@ -64,8 +64,7 @@ class commit_derived_impl : public oneapi::mkl::dft::detail::commit_impl { } } - commit_derived_impl(const commit_derived_impl* other) - : oneapi::mkl::dft::detail::commit_impl(*other) {} + commit_derived_impl(const commit_derived_impl* other) : detail::commit_impl(*other) {} virtual ~commit_derived_impl() override { DftiFreeDescriptor((DFTI_DESCRIPTOR_HANDLE*)&handle); @@ -75,8 +74,8 @@ class commit_derived_impl : public oneapi::mkl::dft::detail::commit_impl { DFT_ERROR status; DFTI_DESCRIPTOR_HANDLE handle = nullptr; - constexpr DFTI_CONFIG_VALUE get_domain(oneapi::mkl::dft::domain d) { - if (d == oneapi::mkl::dft::domain::COMPLEX) { + constexpr DFTI_CONFIG_VALUE get_domain(domain d) { + if (d == domain::COMPLEX) { return DFTI_COMPLEX; } else { @@ -84,8 +83,8 @@ class commit_derived_impl : public oneapi::mkl::dft::detail::commit_impl { } } - constexpr DFTI_CONFIG_VALUE get_precision(oneapi::mkl::dft::precision p) { - if (p == oneapi::mkl::dft::precision::SINGLE) { + constexpr DFTI_CONFIG_VALUE get_precision(precision p) { + if (p == precision::SINGLE) { return DFTI_SINGLE; } else { @@ -102,10 +101,9 @@ class commit_derived_impl : public oneapi::mkl::dft::detail::commit_impl { status |= DftiSetValue(descHandle, DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms); status |= DftiSetValue(descHandle, DFTI_INPUT_DISTANCE, config.fwd_dist); status |= DftiSetValue(descHandle, DFTI_OUTPUT_DISTANCE, config.bwd_dist); - status |= DftiSetValue(descHandle, DFTI_PLACEMENT, - (config.placement == oneapi::mkl::dft::config_value::INPLACE) - ? DFTI_INPLACE - : DFTI_NOT_INPLACE); + status |= DftiSetValue( + descHandle, DFTI_PLACEMENT, + (config.placement == config_value::INPLACE) ? DFTI_INPLACE : DFTI_NOT_INPLACE); if (status != DFTI_NO_ERROR) { throw oneapi::mkl::exception("dft", "commit", "DftiSetValue failed"); @@ -113,33 +111,21 @@ class commit_derived_impl : public oneapi::mkl::dft::detail::commit_impl { } }; -oneapi::mkl::dft::detail::commit_impl* create_commit( - oneapi::mkl::dft::descriptor& desc) { - return new commit_derived_impl(desc.get_queue(), - desc.get_values()); +detail::commit_impl* create_commit(descriptor& desc) { + return new commit_derived_impl(desc.get_queue(), + desc.get_values()); } -oneapi::mkl::dft::detail::commit_impl* create_commit( - oneapi::mkl::dft::descriptor& desc) { - return new commit_derived_impl(desc.get_queue(), - desc.get_values()); +detail::commit_impl* create_commit(descriptor& desc) { + return new commit_derived_impl(desc.get_queue(), + desc.get_values()); } -oneapi::mkl::dft::detail::commit_impl* create_commit( - oneapi::mkl::dft::descriptor& desc) { - return new commit_derived_impl(desc.get_queue(), - desc.get_values()); +detail::commit_impl* create_commit(descriptor& desc) { + return new commit_derived_impl(desc.get_queue(), + desc.get_values()); } -oneapi::mkl::dft::detail::commit_impl* create_commit( - oneapi::mkl::dft::descriptor& desc) { - return new commit_derived_impl(desc.get_queue(), - desc.get_values()); +detail::commit_impl* create_commit(descriptor& desc) { + return new commit_derived_impl(desc.get_queue(), + desc.get_values()); } } // namespace mklcpu diff --git a/src/dft/backends/mklcpu/forward.cpp b/src/dft/backends/mklcpu/forward.cpp index 634177398..1f35ca3df 100644 --- a/src/dft/backends/mklcpu/forward.cpp +++ b/src/dft/backends/mklcpu/forward.cpp @@ -26,6 +26,7 @@ #include "oneapi/mkl/types.hpp" #include "oneapi/mkl/dft/types.hpp" +#include "oneapi/mkl/dft/descriptor.hpp" #include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" namespace oneapi { @@ -33,177 +34,177 @@ namespace mkl { namespace dft { namespace mklcpu { -// void compute_forward_buffer_inplace_f(descriptor &desc, -// sycl::buffer, 1> &inout) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_forward_buffer_inplace_c(descriptor &desc, -// sycl::buffer, 1> &inout) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_forward_buffer_inplace_d(descriptor &desc, -// sycl::buffer, 1> &inout) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_forward_buffer_inplace_z(descriptor &desc, -// sycl::buffer, 1> &inout) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } +void compute_forward_buffer_inplace_f(descriptor &desc, + sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_inplace_c(descriptor &desc, + sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_inplace_d(descriptor &desc, + sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_inplace_z(descriptor &desc, + sycl::buffer, 1> &inout) { + throw std::runtime_error("Not implemented for mklcpu"); +} -// void compute_forward_buffer_inplace_split_f(descriptor &desc, -// sycl::buffer &inout_re, -// sycl::buffer &inout_im) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_forward_buffer_inplace_split_c(descriptor &desc, -// sycl::buffer &inout_re, -// sycl::buffer &inout_im) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_forward_buffer_inplace_split_d(descriptor &desc, -// sycl::buffer &inout_re, -// sycl::buffer &inout_im) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_forward_buffer_inplace_split_z(descriptor &desc, -// sycl::buffer &inout_re, -// sycl::buffer &inout_im) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } +void compute_forward_buffer_inplace_split_f(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_inplace_split_c(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_inplace_split_d(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_inplace_split_z(descriptor &desc, + sycl::buffer &inout_re, + sycl::buffer &inout_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} -// void compute_forward_buffer_outofplace_f(descriptor &desc, -// sycl::buffer &in, -// sycl::buffer, 1> &out) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_forward_buffer_outofplace_c(descriptor &desc, -// sycl::buffer, 1> &in, -// sycl::buffer, 1> &out) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_forward_buffer_outofplace_d(descriptor &desc, -// sycl::buffer &in, -// sycl::buffer, 1> &out) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_forward_buffer_outofplace_z(descriptor &desc, -// sycl::buffer, 1> &in, -// sycl::buffer, 1> &out) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } +void compute_forward_buffer_outofplace_f(descriptor &desc, + sycl::buffer &in, + sycl::buffer, 1> &out) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_outofplace_c(descriptor &desc, + sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_outofplace_d(descriptor &desc, + sycl::buffer &in, + sycl::buffer, 1> &out) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_outofplace_z(descriptor &desc, + sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { + throw std::runtime_error("Not implemented for mklcpu"); +} -// void compute_forward_buffer_outofplace_split_f(descriptor &desc, -// sycl::buffer &in_re, -// sycl::buffer &in_im, -// sycl::buffer &out_re, -// sycl::buffer &out_im) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_forward_buffer_outofplace_split_c(descriptor &desc, -// sycl::buffer &in_re, -// sycl::buffer &in_im, -// sycl::buffer &out_re, -// sycl::buffer &out_im) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_forward_buffer_outofplace_split_d(descriptor &desc, -// sycl::buffer &in_re, -// sycl::buffer &in_im, -// sycl::buffer &out_re, -// sycl::buffer &out_im) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// void compute_forward_buffer_outofplace_split_z(descriptor &desc, -// sycl::buffer &in_re, -// sycl::buffer &in_im, -// sycl::buffer &out_re, -// sycl::buffer &out_im) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } +void compute_forward_buffer_outofplace_split_f(descriptor &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_outofplace_split_c(descriptor &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_outofplace_split_d(descriptor &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} +void compute_forward_buffer_outofplace_split_z(descriptor &desc, + sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { + throw std::runtime_error("Not implemented for mklcpu"); +} -// sycl::event compute_forward_usm_inplace_f(descriptor &desc, -// std::complex *inout, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_forward_usm_inplace_c(descriptor &desc, -// std::complex *inout, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_forward_usm_inplace_d(descriptor &desc, -// std::complex *inout, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_forward_usm_inplace_z(descriptor &desc, -// std::complex *inout, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } +sycl::event compute_forward_usm_inplace_f(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_inplace_c(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_inplace_d(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_inplace_z(descriptor &desc, + std::complex *inout, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} -// sycl::event compute_forward_usm_inplace_split_f(descriptor &desc, -// float *inout_re, float *inout_im, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_forward_usm_inplace_split_c( -// descriptor &desc, float *inout_re, float *inout_im, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_forward_usm_inplace_split_d(descriptor &desc, -// double *inout_re, double *inout_im, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_forward_usm_inplace_split_z( -// descriptor &desc, double *inout_re, double *inout_im, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } +sycl::event compute_forward_usm_inplace_split_f(descriptor &desc, + float *inout_re, float *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_inplace_split_c( + descriptor &desc, float *inout_re, float *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_inplace_split_d(descriptor &desc, + double *inout_re, double *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_inplace_split_z( + descriptor &desc, double *inout_re, double *inout_im, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} -// sycl::event compute_forward_usm_outofplace_f(descriptor &desc, -// float *in, std::complex *out, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_forward_usm_outofplace_c(descriptor &desc, -// std::complex *in, std::complex *out, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_forward_usm_outofplace_d(descriptor &desc, -// double *in, std::complex *out, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_forward_usm_outofplace_z(descriptor &desc, -// std::complex *in, std::complex *out, -// const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } +sycl::event compute_forward_usm_outofplace_f(descriptor &desc, + float *in, std::complex *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_outofplace_c(descriptor &desc, + std::complex *in, std::complex *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_outofplace_d(descriptor &desc, + double *in, std::complex *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_outofplace_z(descriptor &desc, + std::complex *in, std::complex *out, + const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} -// sycl::event compute_forward_usm_outofplace_split_f( -// descriptor &desc, float *in_re, float *in_im, float *out_re, -// float *out_im, const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_forward_usm_outofplace_split_c( -// descriptor &desc, float *in_re, float *in_im, float *out_re, -// float *out_im, const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_forward_usm_outofplace_split_d( -// descriptor &desc, double *in_re, double *in_im, double *out_re, -// double *out_im, const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } -// sycl::event compute_forward_usm_outofplace_split_z( -// descriptor &desc, double *in_re, double *in_im, -// double *out_re, double *out_im, const std::vector &dependencies) { -// throw std::runtime_error("Not implemented for mklcpu"); -// } +sycl::event compute_forward_usm_outofplace_split_f( + descriptor &desc, float *in_re, float *in_im, float *out_re, + float *out_im, const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_outofplace_split_c( + descriptor &desc, float *in_re, float *in_im, float *out_re, + float *out_im, const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_outofplace_split_d( + descriptor &desc, double *in_re, double *in_im, double *out_re, + double *out_im, const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} +sycl::event compute_forward_usm_outofplace_split_z( + descriptor &desc, double *in_re, double *in_im, + double *out_re, double *out_im, const std::vector &dependencies) { + throw std::runtime_error("Not implemented for mklcpu"); +} } // namespace mklcpu } // namespace dft From 04ed345d100eeb1722be7cdc73c994b920ac1415 Mon Sep 17 00:00:00 2001 From: "Anant, Srivastava" Date: Wed, 30 Nov 2022 03:17:23 -0800 Subject: [PATCH 18/21] further comments --- include/oneapi/mkl/dft/descriptor.hpp | 2 -- include/oneapi/mkl/dft/detail/commit_impl.hpp | 6 ---- include/oneapi/mkl/dft/detail/dft_loader.hpp | 1 - .../dft/detail/mklcpu/onemkl_dft_mklcpu.hpp | 4 +-- .../dft/detail/mklgpu/onemkl_dft_mklgpu.hpp | 2 -- include/oneapi/mkl/dft/types.hpp | 17 +++++----- src/dft/backends/mklcpu/commit.cpp | 34 ++++++++++++------- src/dft/descriptor.cpp | 15 +++----- 8 files changed, 35 insertions(+), 46 deletions(-) diff --git a/include/oneapi/mkl/dft/descriptor.hpp b/include/oneapi/mkl/dft/descriptor.hpp index f9c734813..586d8b7db 100644 --- a/include/oneapi/mkl/dft/descriptor.hpp +++ b/include/oneapi/mkl/dft/descriptor.hpp @@ -68,10 +68,8 @@ class descriptor { sycl::queue queue_; std::int64_t rank_; - std::vector dimensions_; // descriptor configuration values_ and structs - void* handle_; oneapi::mkl::dft::dft_values values_; }; diff --git a/include/oneapi/mkl/dft/detail/commit_impl.hpp b/include/oneapi/mkl/dft/detail/commit_impl.hpp index 06c88084a..93c201f8b 100644 --- a/include/oneapi/mkl/dft/detail/commit_impl.hpp +++ b/include/oneapi/mkl/dft/detail/commit_impl.hpp @@ -27,12 +27,6 @@ #include #endif -#include "oneapi/mkl/detail/export.hpp" -#include "oneapi/mkl/detail/get_device_id.hpp" -#include "oneapi/mkl/dft/types.hpp" - -#include "oneapi/mkl/types.hpp" - namespace oneapi { namespace mkl { namespace dft { diff --git a/include/oneapi/mkl/dft/detail/dft_loader.hpp b/include/oneapi/mkl/dft/detail/dft_loader.hpp index 170292eff..d0d33682e 100644 --- a/include/oneapi/mkl/dft/detail/dft_loader.hpp +++ b/include/oneapi/mkl/dft/detail/dft_loader.hpp @@ -28,7 +28,6 @@ #endif #include "oneapi/mkl/detail/export.hpp" -#include "oneapi/mkl/detail/get_device_id.hpp" #include "oneapi/mkl/types.hpp" #include "oneapi/mkl/dft/types.hpp" diff --git a/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp b/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp index 53fb764c5..cc53376b0 100644 --- a/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp +++ b/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp @@ -28,11 +28,9 @@ #endif #include "oneapi/mkl/detail/export.hpp" -#include "oneapi/mkl/detail/get_device_id.hpp" #include "oneapi/mkl/types.hpp" #include "oneapi/mkl/dft/types.hpp" -#include "oneapi/mkl/dft/detail/commit_impl.hpp" #include "oneapi/mkl/dft/descriptor.hpp" namespace oneapi { @@ -61,4 +59,4 @@ ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit( } // namespace mkl } // namespace oneapi -#endif // _ONEMKL_DFT_MKLCPU_HPP_ \ No newline at end of file +#endif // _ONEMKL_DFT_MKLCPU_HPP_ diff --git a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp index 569efefb3..0fbda4aae 100644 --- a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp +++ b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp @@ -28,11 +28,9 @@ #endif #include "oneapi/mkl/detail/export.hpp" -#include "oneapi/mkl/detail/get_device_id.hpp" #include "oneapi/mkl/types.hpp" #include "oneapi/mkl/dft/types.hpp" -#include "oneapi/mkl/dft/detail/commit_impl.hpp" #include "oneapi/mkl/dft/descriptor.hpp" namespace oneapi { diff --git a/include/oneapi/mkl/dft/types.hpp b/include/oneapi/mkl/dft/types.hpp index c3debd94d..b811278dc 100644 --- a/include/oneapi/mkl/dft/types.hpp +++ b/include/oneapi/mkl/dft/types.hpp @@ -33,6 +33,8 @@ namespace dft { typedef int DFT_ERROR; +#define DFT_NOTSET -1 + enum class precision { SINGLE, DOUBLE }; enum class domain { REAL, COMPLEX }; enum class config_param { @@ -87,19 +89,14 @@ enum class config_value { // Allow/avoid certain usages ALLOW, AVOID, - NONE, - + NONE, + WORKSPACE_INTERNAL, + WORKSPACE_EXTERNAL, // for config_param::PACKED_FORMAT for storing conjugate-even finite sequence in real containers CCE_FORMAT }; -static std::unordered_map prec_map{ { precision::SINGLE, "SINGLE" }, - { precision::DOUBLE, "DOUBLE" } }; - -static std::unordered_map dom_map{ { domain::REAL, "REAL" }, - { domain::COMPLEX, "COMPLEX" } }; - struct dft_values { std::vector input_strides; std::vector output_strides; @@ -111,14 +108,16 @@ struct dft_values { config_value placement; config_value complex_storage; config_value conj_even_storage; + config_value workspace; std::vector dimensions; std::int64_t rank; domain domain; precision precision; }; + } // namespace dft } // namespace mkl } // namespace oneapi -#endif //_ONEMKL_TYPES_HPP_ \ No newline at end of file +#endif //_ONEMKL_TYPES_HPP_ diff --git a/src/dft/backends/mklcpu/commit.cpp b/src/dft/backends/mklcpu/commit.cpp index 974e72588..850b93174 100644 --- a/src/dft/backends/mklcpu/commit.cpp +++ b/src/dft/backends/mklcpu/commit.cpp @@ -43,7 +43,7 @@ class commit_derived_impl : public detail::commit_impl { public: commit_derived_impl(sycl::queue queue, dft_values config_values) : detail::commit_impl(queue), - status(-1) { + status(DFT_NOTSET) { if (config_values.rank == 1) { status = DftiCreateDescriptor(&handle, get_precision(prec), get_domain(dom), config_values.rank, config_values.dimensions[0]); @@ -64,8 +64,6 @@ class commit_derived_impl : public detail::commit_impl { } } - commit_derived_impl(const commit_derived_impl* other) : detail::commit_impl(*other) {} - virtual ~commit_derived_impl() override { DftiFreeDescriptor((DFTI_DESCRIPTOR_HANDLE*)&handle); } @@ -92,21 +90,31 @@ class commit_derived_impl : public detail::commit_impl { } } + template + DFT_ERROR set_value_item(DFTI_DESCRIPTOR_HANDLE hand, enum DFTI_CONFIG_PARAM name, Args... args) { + DFT_ERROR value_err = DFT_NOTSET; + value_err = DftiSetValue(hand, name, args...); + if (value_err != DFTI_NO_ERROR) { + throw oneapi::mkl::exception("dft", "set_value_item", std::to_string(name)); + } + + return value_err; + } + void set_value(DFTI_DESCRIPTOR_HANDLE& descHandle, dft_values config) { - // TODO : add complex storage and workspace, fix error handling - status |= DftiSetValue(descHandle, DFTI_INPUT_STRIDES, config.input_strides.data()); - status |= DftiSetValue(descHandle, DFTI_OUTPUT_STRIDES, config.output_strides.data()); - status |= DftiSetValue(descHandle, DFTI_BACKWARD_SCALE, config.bwd_scale); - status |= DftiSetValue(descHandle, DFTI_FORWARD_SCALE, config.fwd_scale); - status |= DftiSetValue(descHandle, DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms); - status |= DftiSetValue(descHandle, DFTI_INPUT_DISTANCE, config.fwd_dist); - status |= DftiSetValue(descHandle, DFTI_OUTPUT_DISTANCE, config.bwd_dist); - status |= DftiSetValue( + status |= set_value_item(descHandle, DFTI_INPUT_STRIDES, config.input_strides.data()); + status |= set_value_item(descHandle, DFTI_OUTPUT_STRIDES, config.output_strides.data()); + status |= set_value_item(descHandle, DFTI_BACKWARD_SCALE, config.bwd_scale); + status |= set_value_item(descHandle, DFTI_FORWARD_SCALE, config.fwd_scale); + status |= set_value_item(descHandle, DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms); + status |= set_value_item(descHandle, DFTI_INPUT_DISTANCE, config.fwd_dist); + status |= set_value_item(descHandle, DFTI_OUTPUT_DISTANCE, config.bwd_dist); + status |= set_value_item( descHandle, DFTI_PLACEMENT, (config.placement == config_value::INPLACE) ? DFTI_INPLACE : DFTI_NOT_INPLACE); if (status != DFTI_NO_ERROR) { - throw oneapi::mkl::exception("dft", "commit", "DftiSetValue failed"); + throw oneapi::mkl::exception("dft", "set_value", "failed"); } } }; diff --git a/src/dft/descriptor.cpp b/src/dft/descriptor.cpp index 805866342..986eb86f0 100644 --- a/src/dft/descriptor.cpp +++ b/src/dft/descriptor.cpp @@ -57,7 +57,6 @@ void descriptor::commit(backend_selector selector) { #endif #endif -// impliment error class template void descriptor::set_value(config_param param, ...) { int err = 0; @@ -94,13 +93,11 @@ void descriptor::set_value(config_param param, ...) { } template descriptor::descriptor(std::vector dimensions) - : dimensions_(dimensions), - handle_(nullptr), - rank_(dimensions.size()) { + : rank_(dimensions.size()) { // Compute default strides. std::vector defaultStrides(rank_, 1); for (int i = rank_ - 1; i < 0; --i) { - defaultStrides[i] = defaultStrides[i - 1] * dimensions_[i]; + defaultStrides[i] = defaultStrides[i - 1] * dimensions[i]; } defaultStrides[0] = 0; values_.input_strides = defaultStrides; @@ -113,7 +110,7 @@ descriptor::descriptor(std::vector dimensions) values_.placement = config_value::INPLACE; values_.complex_storage = config_value::COMPLEX_COMPLEX; values_.conj_even_storage = config_value::COMPLEX_COMPLEX; - values_.dimensions = dimensions_; + values_.dimensions = dimensions; values_.rank = rank_; values_.domain = dom; values_.precision = prec; @@ -124,9 +121,7 @@ descriptor::descriptor(std::int64_t length) : descriptor(std::vector{ length }) {} template -descriptor::~descriptor() { - // call DftiFreeDescriptor -} +descriptor::~descriptor() { } template void descriptor::get_value(config_param param, ...) { @@ -146,4 +141,4 @@ template class descriptor; } //namespace dft } //namespace mkl -} //namespace oneapi \ No newline at end of file +} //namespace oneapi From 820118e99c6b5df1c8afa0fab9f2517686252593 Mon Sep 17 00:00:00 2001 From: "Anant, Srivastava" Date: Mon, 5 Dec 2022 10:02:39 -0800 Subject: [PATCH 19/21] refactor descriptor source --- src/dft/CMakeLists.txt | 4 +-- src/dft/backends/descriptor.cpp | 41 +++++++++++++++++++++ src/dft/backends/mklcpu/CMakeLists.txt | 2 +- src/dft/backends/mklcpu/commit.cpp | 20 +++++------ src/dft/backends/mklcpu/descriptor.cpp | 42 ++++++++++++++++++++++ src/dft/backends/mklgpu/CMakeLists.txt | 2 +- src/dft/backends/mklgpu/descriptor.cpp | 42 ++++++++++++++++++++++ src/dft/{descriptor.cpp => descriptor.cxx} | 42 +++------------------- src/dft/dft_loader.cpp | 2 +- 9 files changed, 142 insertions(+), 55 deletions(-) create mode 100644 src/dft/backends/descriptor.cpp create mode 100644 src/dft/backends/mklcpu/descriptor.cpp create mode 100644 src/dft/backends/mklgpu/descriptor.cpp rename src/dft/{descriptor.cpp => descriptor.cxx} (79%) diff --git a/src/dft/CMakeLists.txt b/src/dft/CMakeLists.txt index 7b7e3816c..65c966bbf 100644 --- a/src/dft/CMakeLists.txt +++ b/src/dft/CMakeLists.txt @@ -23,7 +23,7 @@ add_subdirectory(backends) # Recipe for DFT loader object if(BUILD_SHARED_LIBS) add_library(onemkl_dft OBJECT) -target_sources(onemkl_dft PRIVATE descriptor.cpp dft_loader.cpp) +target_sources(onemkl_dft PRIVATE backends/descriptor.cpp dft_loader.cpp) target_include_directories(onemkl_dft PRIVATE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src @@ -38,7 +38,7 @@ set_target_properties(onemkl_dft PROPERTIES POSITION_INDEPENDENT_CODE ON ) if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) - add_sycl_to_target(TARGET onemkl_dft SOURCES descriptor.cpp dft_loader.cpp) + add_sycl_to_target(TARGET onemkl_dft SOURCES backends/descriptor.cxx dft_loader.cpp) else() target_link_libraries(onemkl_dft PUBLIC ONEMKL::SYCL::SYCL) endif() diff --git a/src/dft/backends/descriptor.cpp b/src/dft/backends/descriptor.cpp new file mode 100644 index 000000000..e79ef2230 --- /dev/null +++ b/src/dft/backends/descriptor.cpp @@ -0,0 +1,41 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/dft/descriptor.hpp" +#include "oneapi/mkl/dft/detail/dft_loader.hpp" + +#include "../descriptor.cxx" + +namespace oneapi { +namespace mkl { +namespace dft { + +template +void descriptor::commit(sycl::queue &queue) { + queue_ = queue; + pimpl_.reset(detail::create_commit(*this)); +} +template void descriptor::commit(sycl::queue &); +template void descriptor::commit(sycl::queue &); +template void descriptor::commit(sycl::queue &); +template void descriptor::commit(sycl::queue &); + +} //namespace dft +} //namespace mkl +} //namespace oneapi diff --git a/src/dft/backends/mklcpu/CMakeLists.txt b/src/dft/backends/mklcpu/CMakeLists.txt index a9f3647a4..52979f286 100644 --- a/src/dft/backends/mklcpu/CMakeLists.txt +++ b/src/dft/backends/mklcpu/CMakeLists.txt @@ -26,7 +26,7 @@ find_package(MKL REQUIRED) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT commit.cpp - ../../descriptor.cpp + descriptor.cpp forward.cpp backward.cpp $<$: mkl_dft_cpu_wrappers.cpp> diff --git a/src/dft/backends/mklcpu/commit.cpp b/src/dft/backends/mklcpu/commit.cpp index 850b93174..44433f9a5 100644 --- a/src/dft/backends/mklcpu/commit.cpp +++ b/src/dft/backends/mklcpu/commit.cpp @@ -102,20 +102,16 @@ class commit_derived_impl : public detail::commit_impl { } void set_value(DFTI_DESCRIPTOR_HANDLE& descHandle, dft_values config) { - status |= set_value_item(descHandle, DFTI_INPUT_STRIDES, config.input_strides.data()); - status |= set_value_item(descHandle, DFTI_OUTPUT_STRIDES, config.output_strides.data()); - status |= set_value_item(descHandle, DFTI_BACKWARD_SCALE, config.bwd_scale); - status |= set_value_item(descHandle, DFTI_FORWARD_SCALE, config.fwd_scale); - status |= set_value_item(descHandle, DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms); - status |= set_value_item(descHandle, DFTI_INPUT_DISTANCE, config.fwd_dist); - status |= set_value_item(descHandle, DFTI_OUTPUT_DISTANCE, config.bwd_dist); - status |= set_value_item( + set_value_item(descHandle, DFTI_INPUT_STRIDES, config.input_strides.data()); + set_value_item(descHandle, DFTI_OUTPUT_STRIDES, config.output_strides.data()); + set_value_item(descHandle, DFTI_BACKWARD_SCALE, config.bwd_scale); + set_value_item(descHandle, DFTI_FORWARD_SCALE, config.fwd_scale); + set_value_item(descHandle, DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms); + set_value_item(descHandle, DFTI_INPUT_DISTANCE, config.fwd_dist); + set_value_item(descHandle, DFTI_OUTPUT_DISTANCE, config.bwd_dist); + set_value_item( descHandle, DFTI_PLACEMENT, (config.placement == config_value::INPLACE) ? DFTI_INPLACE : DFTI_NOT_INPLACE); - - if (status != DFTI_NO_ERROR) { - throw oneapi::mkl::exception("dft", "set_value", "failed"); - } } }; diff --git a/src/dft/backends/mklcpu/descriptor.cpp b/src/dft/backends/mklcpu/descriptor.cpp new file mode 100644 index 000000000..c912bf0f7 --- /dev/null +++ b/src/dft/backends/mklcpu/descriptor.cpp @@ -0,0 +1,42 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/dft/descriptor.hpp" +#include "../../descriptor.cxx" + +#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { + +template +void descriptor::commit(backend_selector selector) { + queue_ = selector.get_queue(); + pimpl_.reset(mklcpu::create_commit(*this)); +} + +template void descriptor::commit(backend_selector); +template void descriptor::commit(backend_selector); +template void descriptor::commit(backend_selector); +template void descriptor::commit(backend_selector); + +} //namespace dft +} //namespace mkl +} //namespace oneapi diff --git a/src/dft/backends/mklgpu/CMakeLists.txt b/src/dft/backends/mklgpu/CMakeLists.txt index ca9cb9d09..5e0738b19 100644 --- a/src/dft/backends/mklgpu/CMakeLists.txt +++ b/src/dft/backends/mklgpu/CMakeLists.txt @@ -24,7 +24,7 @@ find_package(MKL REQUIRED) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT - ../../descriptor.cpp + descriptor.cpp commit.cpp forward.cpp backward.cpp diff --git a/src/dft/backends/mklgpu/descriptor.cpp b/src/dft/backends/mklgpu/descriptor.cpp new file mode 100644 index 000000000..6bebfdef4 --- /dev/null +++ b/src/dft/backends/mklgpu/descriptor.cpp @@ -0,0 +1,42 @@ +/******************************************************************************* +* Copyright 2022 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/dft/descriptor.hpp" +#include "../../descriptor.cxx" + +#include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { + +template +void descriptor::commit(backend_selector selector) { + queue_ = selector.get_queue(); + pimpl_.reset(mklgpu::create_commit(*this)); +} + +template void descriptor::commit(backend_selector); +template void descriptor::commit(backend_selector); +template void descriptor::commit(backend_selector); +template void descriptor::commit(backend_selector); + +} //namespace dft +} //namespace mkl +} //namespace oneapi diff --git a/src/dft/descriptor.cpp b/src/dft/descriptor.cxx similarity index 79% rename from src/dft/descriptor.cpp rename to src/dft/descriptor.cxx index 986eb86f0..312805e96 100644 --- a/src/dft/descriptor.cpp +++ b/src/dft/descriptor.cxx @@ -19,44 +19,10 @@ #include "oneapi/mkl/dft/descriptor.hpp" -#ifdef ENABLE_MKLCPU_BACKEND -#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" -#endif -#ifdef ENABLE_MKLGPU_BACKEND -#include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp" -#endif -#include "oneapi/mkl/dft/detail/dft_loader.hpp" - namespace oneapi { namespace mkl { namespace dft { -#ifdef BUILD_RUN -template -void descriptor::commit(sycl::queue &queue) { - queue_ = queue; - pimpl_.reset(detail::create_commit(*this)); -} -#endif - -#ifdef BUILD_COMP -#ifdef ENABLE_MKLCPU_BACKEND -template -void descriptor::commit(backend_selector selector) { - queue_ = selector.get_queue(); - pimpl_.reset(mklcpu::create_commit(*this)); -} -#endif - -#ifdef ENABLE_MKLGPU_BACKEND -template -void descriptor::commit(backend_selector selector) { - queue_ = selector.get_queue(); - // pimpl_.reset(mklgpu::create_commit(*this)); -} -#endif -#endif - template void descriptor::set_value(config_param param, ...) { int err = 0; @@ -72,7 +38,8 @@ void descriptor::set_value(config_param param, ...) { std::copy(strides, strides + rank_ + 1, std::back_inserter(values_.input_strides)); if (param == config_param::OUTPUT_STRIDES) std::copy(strides, strides + rank_ + 1, std::back_inserter(values_.output_strides)); - } break; + break; + } case config_param::FORWARD_SCALE: values_.fwd_scale = va_arg(vl, double); break; case config_param::BACKWARD_SCALE: values_.bwd_scale = va_arg(vl, double); break; case config_param::NUMBER_OF_TRANSFORMS: @@ -92,8 +59,7 @@ void descriptor::set_value(config_param param, ...) { va_end(vl); } template -descriptor::descriptor(std::vector dimensions) - : rank_(dimensions.size()) { +descriptor::descriptor(std::vector dimensions) : rank_(dimensions.size()) { // Compute default strides. std::vector defaultStrides(rank_, 1); for (int i = rank_ - 1; i < 0; --i) { @@ -121,7 +87,7 @@ descriptor::descriptor(std::int64_t length) : descriptor(std::vector{ length }) {} template -descriptor::~descriptor() { } +descriptor::~descriptor() {} template void descriptor::get_value(config_param param, ...) { diff --git a/src/dft/dft_loader.cpp b/src/dft/dft_loader.cpp index 84523e533..296c051fc 100644 --- a/src/dft/dft_loader.cpp +++ b/src/dft/dft_loader.cpp @@ -58,4 +58,4 @@ commit_impl* create_commit(oneapi::mkl::dft::descriptor Date: Mon, 5 Dec 2022 11:12:46 -0800 Subject: [PATCH 20/21] efficient resource handle --- include/oneapi/mkl/dft/descriptor.hpp | 2 ++ src/dft/descriptor.cxx | 32 ++++++++++++++++++--------- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/include/oneapi/mkl/dft/descriptor.hpp b/include/oneapi/mkl/dft/descriptor.hpp index 586d8b7db..0d2941b0e 100644 --- a/include/oneapi/mkl/dft/descriptor.hpp +++ b/include/oneapi/mkl/dft/descriptor.hpp @@ -68,6 +68,8 @@ class descriptor { sycl::queue queue_; std::int64_t rank_; + std::vector dimensions_; + // descriptor configuration values_ and structs oneapi::mkl::dft::dft_values values_; diff --git a/src/dft/descriptor.cxx b/src/dft/descriptor.cxx index 312805e96..2dd3d0a78 100644 --- a/src/dft/descriptor.cxx +++ b/src/dft/descriptor.cxx @@ -39,15 +39,25 @@ void descriptor::set_value(config_param param, ...) { if (param == config_param::OUTPUT_STRIDES) std::copy(strides, strides + rank_ + 1, std::back_inserter(values_.output_strides)); break; - } - case config_param::FORWARD_SCALE: values_.fwd_scale = va_arg(vl, double); break; - case config_param::BACKWARD_SCALE: values_.bwd_scale = va_arg(vl, double); break; + } + case config_param::FORWARD_SCALE: + values_.fwd_scale = va_arg(vl, double); + break; + case config_param::BACKWARD_SCALE: + values_.bwd_scale = va_arg(vl, double); + break; case config_param::NUMBER_OF_TRANSFORMS: values_.number_of_transforms = va_arg(vl, int64_t); break; - case config_param::FWD_DISTANCE: values_.fwd_dist = va_arg(vl, int64_t); break; - case config_param::BWD_DISTANCE: values_.bwd_dist = va_arg(vl, int64_t); break; - case config_param::PLACEMENT: values_.placement = va_arg(vl, config_value); break; + case config_param::FWD_DISTANCE: + values_.fwd_dist = va_arg(vl, int64_t); + break; + case config_param::BWD_DISTANCE: + values_.bwd_dist = va_arg(vl, int64_t); + break; + case config_param::PLACEMENT: + values_.placement = va_arg(vl, config_value); + break; case config_param::COMPLEX_STORAGE: values_.complex_storage = va_arg(vl, config_value); break; @@ -59,15 +69,17 @@ void descriptor::set_value(config_param param, ...) { va_end(vl); } template -descriptor::descriptor(std::vector dimensions) : rank_(dimensions.size()) { +descriptor::descriptor(std::vector dimensions) + : dimensions_(std::move(dimensions)), + rank_(dimensions.size()) { // Compute default strides. std::vector defaultStrides(rank_, 1); for (int i = rank_ - 1; i < 0; --i) { - defaultStrides[i] = defaultStrides[i - 1] * dimensions[i]; + defaultStrides[i] = defaultStrides[i - 1] * dimensions_[i]; } defaultStrides[0] = 0; values_.input_strides = defaultStrides; - values_.output_strides = defaultStrides; + values_.output_strides = std::move(defaultStrides); values_.bwd_scale = 1.0; values_.fwd_scale = 1.0; values_.number_of_transforms = 1; @@ -76,7 +88,7 @@ descriptor::descriptor(std::vector dimensions) : rank_( values_.placement = config_value::INPLACE; values_.complex_storage = config_value::COMPLEX_COMPLEX; values_.conj_even_storage = config_value::COMPLEX_COMPLEX; - values_.dimensions = dimensions; + values_.dimensions = dimensions_; values_.rank = rank_; values_.domain = dom; values_.precision = prec; From aa22374b2296e54167d0a0030e10f60a153dd4a7 Mon Sep 17 00:00:00 2001 From: "Anant, Srivastava" Date: Tue, 6 Dec 2022 01:13:54 -0800 Subject: [PATCH 21/21] create_commit explicit instance --- .../dft/detail/mklcpu/onemkl_dft_mklcpu.hpp | 16 ++--------- .../dft/detail/mklgpu/onemkl_dft_mklgpu.hpp | 16 ++--------- src/dft/backends/mklcpu/commit.cpp | 28 ++++++++----------- 3 files changed, 15 insertions(+), 45 deletions(-) diff --git a/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp b/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp index cc53376b0..005591836 100644 --- a/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp +++ b/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp @@ -38,21 +38,9 @@ namespace mkl { namespace dft { namespace mklcpu { +template ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit( - oneapi::mkl::dft::descriptor& desc); - -ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit( - oneapi::mkl::dft::descriptor& desc); - -ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit( - oneapi::mkl::dft::descriptor& desc); - -ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit( - oneapi::mkl::dft::descriptor& desc); + oneapi::mkl::dft::descriptor& desc); } // namespace mklcpu } // namespace dft diff --git a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp index 0fbda4aae..6c49dc833 100644 --- a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp +++ b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp @@ -38,21 +38,9 @@ namespace mkl { namespace dft { namespace mklgpu { +template ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit( - oneapi::mkl::dft::descriptor& desc); - -ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit( - oneapi::mkl::dft::descriptor& desc); - -ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit( - oneapi::mkl::dft::descriptor& desc); - -ONEMKL_EXPORT oneapi::mkl::dft::detail::commit_impl* create_commit( - oneapi::mkl::dft::descriptor& desc); + oneapi::mkl::dft::descriptor& desc); } // namespace mklgpu } // namespace dft diff --git a/src/dft/backends/mklcpu/commit.cpp b/src/dft/backends/mklcpu/commit.cpp index 44433f9a5..425d94e43 100644 --- a/src/dft/backends/mklcpu/commit.cpp +++ b/src/dft/backends/mklcpu/commit.cpp @@ -90,8 +90,9 @@ class commit_derived_impl : public detail::commit_impl { } } - template - DFT_ERROR set_value_item(DFTI_DESCRIPTOR_HANDLE hand, enum DFTI_CONFIG_PARAM name, Args... args) { + template + DFT_ERROR set_value_item(DFTI_DESCRIPTOR_HANDLE hand, enum DFTI_CONFIG_PARAM name, + Args... args) { DFT_ERROR value_err = DFT_NOTSET; value_err = DftiSetValue(hand, name, args...); if (value_err != DFTI_NO_ERROR) { @@ -115,23 +116,16 @@ class commit_derived_impl : public detail::commit_impl { } }; -detail::commit_impl* create_commit(descriptor& desc) { - return new commit_derived_impl(desc.get_queue(), - desc.get_values()); -} -detail::commit_impl* create_commit(descriptor& desc) { - return new commit_derived_impl(desc.get_queue(), - desc.get_values()); -} -detail::commit_impl* create_commit(descriptor& desc) { - return new commit_derived_impl(desc.get_queue(), - desc.get_values()); -} -detail::commit_impl* create_commit(descriptor& desc) { - return new commit_derived_impl(desc.get_queue(), - desc.get_values()); +template +detail::commit_impl* create_commit(descriptor& desc) { + return new commit_derived_impl(desc.get_queue(), desc.get_values()); } +template detail::commit_impl* create_commit(descriptor&); +template detail::commit_impl* create_commit(descriptor&); +template detail::commit_impl* create_commit(descriptor&); +template detail::commit_impl* create_commit(descriptor&); + } // namespace mklcpu } // namespace dft } // namespace mkl