diff --git a/WORKSPACE b/WORKSPACE index 49f8f8e6654d..f97610efe249 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -27,8 +27,8 @@ http_archive( # path = "/path/to/tensorflow", # ) -load("//third_party/pocketfft:workspace.bzl", pocketfft = "repo") -pocketfft() +load("//third_party/ducc:workspace.bzl", ducc = "repo") +ducc() # Initialize TensorFlow's external dependencies. load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3") diff --git a/build/build_wheel.py b/build/build_wheel.py index 721819548d17..adf478d0cd0f 100644 --- a/build/build_wheel.py +++ b/build/build_wheel.py @@ -171,8 +171,8 @@ def prepare_wheel(sources_path): copy_to_jaxlib("__main__/jaxlib/lapack.py") copy_to_jaxlib(f"__main__/jaxlib/_lapack.{pyext}") copy_to_jaxlib("__main__/jaxlib/mhlo_helpers.py") - copy_to_jaxlib(f"__main__/jaxlib/_pocketfft.{pyext}") - copy_to_jaxlib("__main__/jaxlib/pocketfft.py") + copy_to_jaxlib(f"__main__/jaxlib/_ducc_fft.{pyext}") + copy_to_jaxlib("__main__/jaxlib/ducc_fft.py") copy_to_jaxlib("__main__/jaxlib/gpu_prng.py") copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py") copy_to_jaxlib("__main__/jaxlib/gpu_solver.py") diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 6f43f9243003..5a28569b5a43 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -31,6 +31,9 @@ from jax._src.lib.mlir.dialects import mhlo from jax._src.lib import mlir_api_version from jax._src.lib import xla_client +# TODO(phawkins): remove pocketfft references when the minimum jaxlib version +# is 0.3.17 or newer. +from jax._src.lib import ducc_fft from jax._src.lib import pocketfft from jax._src.numpy.util import _promote_dtypes_complex, _promote_dtypes_inexact @@ -123,8 +126,12 @@ def _fft_lowering(ctx, x, *, fft_type, fft_lengths): def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths): x_aval, = ctx.avals_in - return [pocketfft.pocketfft_mhlo(x, x_aval.dtype, fft_type=fft_type, + if ducc_fft: + return [ducc_fft.ducc_fft_mhlo(x, x_aval.dtype, fft_type=fft_type, fft_lengths=fft_lengths)] + else: + return [pocketfft.pocketfft_mhlo(x, x_aval.dtype, fft_type=fft_type, + fft_lengths=fft_lengths)] def _naive_rfft(x, fft_lengths): @@ -188,5 +195,4 @@ def _fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths): mlir.register_lowering(fft_p, _fft_lowering) ad.deflinear2(fft_p, _fft_transpose_rule) batching.primitive_batchers[fft_p] = _fft_batching_rule -if pocketfft: - mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu') +mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu') diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index a086503af9a0..97a77ffd6909 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -99,7 +99,17 @@ def _parse_version(v: str) -> Tuple[int, ...]: import jaxlib.xla_client as xla_client import jaxlib.lapack as lapack -import jaxlib.pocketfft as pocketfft + +# TODO(phawkins): remove pocketfft references when the minimum jaxlib version +# is 0.3.17 or newer. +try: + import jaxlib.pocketfft as pocketfft +except ImportError: + pocketfft = None +try: + import jaxlib.ducc_fft as ducc_fft +except ImportError: + ducc_fft = None xla_extension = xla_client._xla pytree = xla_client._xla.pytree diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 066b353aaa83..41baecd52c98 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -29,6 +29,7 @@ package(default_visibility = ["//:__subpackages__"]) py_library( name = "jaxlib", srcs = [ + "ducc_fft.py", "gpu_linalg.py", "gpu_prng.py", "gpu_solver.py", @@ -36,14 +37,13 @@ py_library( "init.py", "lapack.py", "mhlo_helpers.py", - "pocketfft.py", ":version", ":xla_client", ], data = [":xla_extension"], deps = [ + ":_ducc_fft", ":_lapack", - ":_pocketfft", ":cpu_feature_guard", "//jaxlib/mlir", "//jaxlib/mlir:builtin_dialect", @@ -148,8 +148,8 @@ cc_library( srcs = ["lapack_kernels.cc"], hdrs = ["lapack_kernels.h"], deps = [ - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", "@com_google_absl//absl/base:dynamic_annotations", + "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", ], ) @@ -176,40 +176,40 @@ pybind_extension( ], ) -# PocketFFT +# DUCC (CPU FFTs) flatbuffer_cc_library( - name = "pocketfft_flatbuffers_cc", - srcs = ["pocketfft.fbs"], + name = "ducc_fft_flatbuffers_cc", + srcs = ["ducc_fft.fbs"], ) cc_library( - name = "pocketfft_kernels", - srcs = ["pocketfft_kernels.cc"], - hdrs = ["pocketfft_kernels.h"], + name = "ducc_fft_kernels", + srcs = ["ducc_fft_kernels.cc"], + hdrs = ["ducc_fft_kernels.h"], copts = ["-fexceptions"], # PocketFFT may throw. features = ["-use_header_modules"], deps = [ - ":pocketfft_flatbuffers_cc", - "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", + ":ducc_fft_flatbuffers_cc", + "@ducc", "@flatbuffers//:runtime_cc", - "@pocketfft", + "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status", ], ) pybind_extension( - name = "_pocketfft", - srcs = ["pocketfft.cc"], + name = "_ducc_fft", + srcs = ["ducc_fft.cc"], copts = [ "-fexceptions", "-fno-strict-aliasing", ], features = ["-use_header_modules"], - module_name = "_pocketfft", + module_name = "_ducc_fft", deps = [ + ":ducc_fft_flatbuffers_cc", + ":ducc_fft_kernels", ":kernel_pybind11_helpers", - ":pocketfft_flatbuffers_cc", - ":pocketfft_kernels", "@flatbuffers//:runtime_cc", "@pybind11", ], @@ -220,9 +220,9 @@ cc_library( srcs = ["cpu_kernels.cc"], visibility = ["//visibility:public"], deps = [ + ":ducc_kernels", ":lapack_kernels", ":lapack_kernels_using_lapack", - ":pocketfft_kernels", "@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry", ], alwayslink = 1, diff --git a/jaxlib/pocketfft.cc b/jaxlib/ducc_fft.cc similarity index 57% rename from jaxlib/pocketfft.cc rename to jaxlib/ducc_fft.cc index b8d09a61cf92..fef8d129aed6 100644 --- a/jaxlib/pocketfft.cc +++ b/jaxlib/ducc_fft.cc @@ -16,53 +16,53 @@ limitations under the License. #include #include -#include "jaxlib/kernel_pybind11_helpers.h" -#include "jaxlib/pocketfft_generated.h" -#include "jaxlib/pocketfft_kernels.h" #include "include/pybind11/pybind11.h" #include "include/pybind11/stl.h" +#include "jaxlib/ducc_fft_generated.h" +#include "jaxlib/ducc_fft_kernels.h" +#include "jaxlib/kernel_pybind11_helpers.h" namespace py = pybind11; namespace jax { namespace { -py::bytes BuildPocketFftDescriptor(const std::vector& shape, - bool is_double, int fft_type, - const std::vector& fft_lengths, - const std::vector& strides_in, - const std::vector& strides_out, - const std::vector& axes, - bool forward, double scale) { - PocketFftDescriptorT descriptor; +py::bytes BuildDuccFftDescriptor(const std::vector &shape, + bool is_double, int fft_type, + const std::vector &fft_lengths, + const std::vector &strides_in, + const std::vector &strides_out, + const std::vector &axes, + bool forward, double scale) { + DuccFftDescriptorT descriptor; descriptor.shape = shape; - descriptor.fft_type = static_cast(fft_type); + descriptor.fft_type = static_cast(fft_type); descriptor.dtype = - is_double ? PocketFftDtype_COMPLEX128 : PocketFftDtype_COMPLEX64; + is_double ? DuccFftDtype_COMPLEX128 : DuccFftDtype_COMPLEX64; descriptor.strides_in = strides_in; descriptor.strides_out = strides_out; descriptor.axes = axes; descriptor.forward = forward; descriptor.scale = scale; flatbuffers::FlatBufferBuilder fbb; - fbb.Finish(PocketFftDescriptor::Pack(fbb, &descriptor)); - return py::bytes(reinterpret_cast(fbb.GetBufferPointer()), + fbb.Finish(DuccFftDescriptor::Pack(fbb, &descriptor)); + return py::bytes(reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); } py::dict Registrations() { pybind11::dict dict; - dict["pocketfft"] = EncapsulateFunction(PocketFft); + dict["ducc_fft"] = EncapsulateFunction(DuccFft); return dict; } -PYBIND11_MODULE(_pocketfft, m) { +PYBIND11_MODULE(_ducc_fft, m) { m.def("registrations", &Registrations); - m.def("pocketfft_descriptor", &BuildPocketFftDescriptor, py::arg("shape"), + m.def("ducc_fft_descriptor", &BuildDuccFftDescriptor, py::arg("shape"), py::arg("is_double"), py::arg("fft_type"), py::arg("fft_lengths"), py::arg("strides_in"), py::arg("strides_out"), py::arg("axes"), py::arg("forward"), py::arg("scale")); } -} // namespace -} // namespace jax +} // namespace +} // namespace jax diff --git a/jaxlib/pocketfft.fbs b/jaxlib/ducc_fft.fbs similarity index 83% rename from jaxlib/pocketfft.fbs rename to jaxlib/ducc_fft.fbs index 8d3558eb52a6..2951d395483b 100644 --- a/jaxlib/pocketfft.fbs +++ b/jaxlib/ducc_fft.fbs @@ -15,20 +15,20 @@ limitations under the License. namespace jax; -enum PocketFftDtype : byte { +enum DuccFftDtype : byte { COMPLEX64 = 0, COMPLEX128 = 1, } -enum PocketFftType : byte { +enum DuccFftType : byte { C2C = 0, C2R = 1, R2C = 2, } -table PocketFftDescriptor { - dtype:PocketFftDtype; - fft_type:PocketFftType; +table DuccFftDescriptor { + dtype:DuccFftDtype; + fft_type:DuccFftType; shape:[uint64]; strides_in:[uint64]; strides_out:[uint64]; @@ -37,4 +37,4 @@ table PocketFftDescriptor { scale:double; } -root_type PocketFftDescriptor; +root_type DuccFftDescriptor; diff --git a/jaxlib/pocketfft.py b/jaxlib/ducc_fft.py similarity index 87% rename from jaxlib/pocketfft.py rename to jaxlib/ducc_fft.py index 5c1907f78f65..c9f76ab1d062 100644 --- a/jaxlib/pocketfft.py +++ b/jaxlib/ducc_fft.py @@ -19,12 +19,12 @@ from .mhlo_helpers import custom_call -from . import _pocketfft +from . import _ducc_fft import numpy as np from jaxlib import xla_client -for _name, _value in _pocketfft.registrations().items(): +for _name, _value in _ducc_fft.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="cpu") FftType = xla_client.FftType @@ -34,7 +34,7 @@ _C2R = 1 _R2C = 2 -def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType, +def _ducc_fft_descriptor(shape: List[int], dtype, fft_type: FftType, fft_lengths: List[int]) -> bytes: n = len(shape) assert len(fft_lengths) >= 1 @@ -44,7 +44,7 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType, forward = fft_type in (FftType.FFT, FftType.RFFT) is_double = np.finfo(dtype).dtype == np.float64 if fft_type == FftType.RFFT: - pocketfft_type = _R2C + ducc_fft_type = _R2C assert dtype in (np.float32, np.float64), dtype out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128) @@ -54,7 +54,7 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType, out_shape[-1] = out_shape[-1] // 2 + 1 elif fft_type == FftType.IRFFT: - pocketfft_type = _C2R + ducc_fft_type = _C2R assert np.issubdtype(dtype, np.complexfloating), dtype out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64) @@ -64,7 +64,7 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType, out_shape[-1] = fft_lengths[-1] assert (out_shape[-1] // 2 + 1) == shape[-1] else: - pocketfft_type = _C2C + ducc_fft_type = _C2C assert np.issubdtype(dtype, np.complexfloating), dtype out_dtype = dtype @@ -79,13 +79,13 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType, # Builds a PocketFftDescriptor flatbuffer. This descriptor is passed to the # C++ kernel to describe the FFT to perform. strides_in = [] - stride = dtype.itemsize + stride = 1 for d in reversed(shape): strides_in.append(stride) stride *= d strides_out = [] - stride = out_dtype.itemsize + stride = 1 for d in reversed(out_shape): strides_out.append(stride) stride *= d @@ -93,10 +93,10 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType, axes = [n - len(fft_lengths) + d for d in range(len(fft_lengths))] scale = 1. if forward else (1. / np.prod(fft_lengths)) - descriptor = _pocketfft.pocketfft_descriptor( + descriptor = _ducc_fft.ducc_fft_descriptor( shape=shape if fft_type != FftType.IRFFT else out_shape, is_double=is_double, - fft_type=pocketfft_type, + fft_type=ducc_fft_type, fft_lengths=fft_lengths, strides_in=list(reversed(strides_in)), strides_out=list(reversed(strides_out)), @@ -107,13 +107,13 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType, return descriptor, out_dtype, out_shape -def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]): - """PocketFFT kernel for CPU.""" +def ducc_fft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]): + """DUCC FFT kernel for CPU.""" a_type = ir.RankedTensorType(a.type) n = len(a_type.shape) fft_lengths = list(fft_lengths) - descriptor_bytes, out_dtype, out_shape = _pocketfft_descriptor( + descriptor_bytes, out_dtype, out_shape = _ducc_fft_descriptor( list(a_type.shape), dtype, fft_type, fft_lengths) if out_dtype == np.float32: @@ -141,7 +141,7 @@ def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]): np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type)) layout = tuple(range(n - 1, -1, -1)) return custom_call( - "pocketfft", + "ducc_fft", [ir.RankedTensorType.get(out_shape, out_type)], [descriptor, a], operand_layouts=[[0], layout], diff --git a/jaxlib/pocketfft_kernels.cc b/jaxlib/ducc_fft_kernels.cc similarity index 74% rename from jaxlib/pocketfft_kernels.cc rename to jaxlib/ducc_fft_kernels.cc index 7d2a967052c7..c5ce67785a46 100644 --- a/jaxlib/pocketfft_kernels.cc +++ b/jaxlib/ducc_fft_kernels.cc @@ -16,27 +16,17 @@ limitations under the License. #include #include "flatbuffers/flatbuffers.h" -#include "jaxlib/pocketfft_generated.h" -#include "pocketfft/src/ducc0/fft/fft.h" +#include "jaxlib/ducc_fft_generated.h" #include "tensorflow/compiler/xla/service/custom_call_status.h" +#include "third_party/ducc/src/ducc0/fft/fft.h" namespace jax { using shape_t = ducc0::fmav_info::shape_t; using stride_t = ducc0::fmav_info::stride_t; -void fixstrides(stride_t &str, size_t size) { - ptrdiff_t ssize = ptrdiff_t(size); - for (auto &s : str) { - auto tmp = s / ssize; - if (tmp * ssize != s) - throw "Bad stride"; - s = tmp; - } -} - -void PocketFft(void *out, void **in, XlaCustomCallStatus *) { - const PocketFftDescriptor *descriptor = GetPocketFftDescriptor(in[0]); +void DuccFft(void *out, void **in, XlaCustomCallStatus *) { + const DuccFftDescriptor *descriptor = GetDuccFftDescriptor(in[0]); shape_t shape(descriptor->shape()->begin(), descriptor->shape()->end()); stride_t stride_in(descriptor->strides_in()->begin(), descriptor->strides_in()->end()); @@ -45,10 +35,8 @@ void PocketFft(void *out, void **in, XlaCustomCallStatus *) { shape_t axes(descriptor->axes()->begin(), descriptor->axes()->end()); switch (descriptor->fft_type()) { - case PocketFftType_C2C: - if (descriptor->dtype() == PocketFftDtype_COMPLEX64) { - fixstrides(stride_in, sizeof(std::complex)); - fixstrides(stride_out, sizeof(std::complex)); + case DuccFftType_C2C: + if (descriptor->dtype() == DuccFftDtype_COMPLEX64) { ducc0::cfmav> m_in( reinterpret_cast *>(in[1]), shape, stride_in); ducc0::vfmav> m_out( @@ -56,8 +44,6 @@ void PocketFft(void *out, void **in, XlaCustomCallStatus *) { ducc0::c2c(m_in, m_out, axes, descriptor->forward(), static_cast(descriptor->scale())); } else { - fixstrides(stride_in, sizeof(std::complex)); - fixstrides(stride_out, sizeof(std::complex)); ducc0::cfmav> m_in( reinterpret_cast *>(in[1]), shape, stride_in); ducc0::vfmav> m_out( @@ -66,10 +52,8 @@ void PocketFft(void *out, void **in, XlaCustomCallStatus *) { static_cast(descriptor->scale())); } break; - case PocketFftType_C2R: - if (descriptor->dtype() == PocketFftDtype_COMPLEX64) { - fixstrides(stride_in, sizeof(std::complex)); - fixstrides(stride_out, sizeof(float)); + case DuccFftType_C2R: + if (descriptor->dtype() == DuccFftDtype_COMPLEX64) { auto shape_in = shape; shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1; ducc0::cfmav> m_in( @@ -79,8 +63,6 @@ void PocketFft(void *out, void **in, XlaCustomCallStatus *) { ducc0::c2r(m_in, m_out, axes, descriptor->forward(), static_cast(descriptor->scale())); } else { - fixstrides(stride_in, sizeof(std::complex)); - fixstrides(stride_out, sizeof(double)); auto shape_in = shape; shape_in[axes.back()] = shape_in[axes.back()] / 2 + 1; ducc0::cfmav> m_in( @@ -91,10 +73,8 @@ void PocketFft(void *out, void **in, XlaCustomCallStatus *) { static_cast(descriptor->scale())); } break; - case PocketFftType_R2C: - if (descriptor->dtype() == PocketFftDtype_COMPLEX64) { - fixstrides(stride_in, sizeof(float)); - fixstrides(stride_out, sizeof(std::complex)); + case DuccFftType_R2C: + if (descriptor->dtype() == DuccFftDtype_COMPLEX64) { auto shape_out = shape; shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1; ducc0::cfmav m_in(reinterpret_cast(in[1]), shape, @@ -104,8 +84,6 @@ void PocketFft(void *out, void **in, XlaCustomCallStatus *) { ducc0::r2c(m_in, m_out, axes, descriptor->forward(), static_cast(descriptor->scale())); } else { - fixstrides(stride_in, sizeof(double)); - fixstrides(stride_out, sizeof(std::complex)); auto shape_out = shape; shape_out[axes.back()] = shape_out[axes.back()] / 2 + 1; ducc0::cfmav m_in(reinterpret_cast(in[1]), shape, diff --git a/jaxlib/pocketfft_kernels.h b/jaxlib/ducc_fft_kernels.h similarity index 92% rename from jaxlib/pocketfft_kernels.h rename to jaxlib/ducc_fft_kernels.h index 7804ad1ded67..1760921b00fe 100644 --- a/jaxlib/pocketfft_kernels.h +++ b/jaxlib/ducc_fft_kernels.h @@ -17,6 +17,6 @@ limitations under the License. namespace jax { -void PocketFft(void* out, void** in, XlaCustomCallStatus*); +void DuccFft(void* out, void** in, XlaCustomCallStatus*); } // namespace jax diff --git a/third_party/ducc/BUILD.bazel b/third_party/ducc/BUILD.bazel new file mode 100644 index 000000000000..383c0acca67d --- /dev/null +++ b/third_party/ducc/BUILD.bazel @@ -0,0 +1,29 @@ +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "ducc", + srcs = [ + "src/ducc0/fft/fft1d.h", + "src/ducc0/infra/aligned_array.h", + "src/ducc0/infra/error_handling.h", + "src/ducc0/infra/mav.h", + "src/ducc0/infra/simd.h", + "src/ducc0/infra/threading.cc", + "src/ducc0/infra/threading.h", + "src/ducc0/infra/useful_macros.h", + "src/ducc0/math/cmplx.h", + "src/ducc0/math/unity_roots.h", + ], + hdrs = ["src/ducc0/fft/fft.h"], + copts = [ + "-fexceptions", + "-ffast-math", + ], + features = ["-use_header_modules"], + include_prefix = "third_party/ducc", + includes = [ + "src", + ], +) diff --git a/third_party/pocketfft/workspace.bzl b/third_party/ducc/workspace.bzl similarity index 88% rename from third_party/pocketfft/workspace.bzl rename to third_party/ducc/workspace.bzl index 9ce2c48e1562..6cb6910f1730 100644 --- a/third_party/pocketfft/workspace.bzl +++ b/third_party/ducc/workspace.bzl @@ -12,18 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Bazel workspace for PocketFFT.""" +"""Bazel workspace for DUCC (CPU FFTs).""" load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") def repo(): http_archive( - name = "pocketfft", + name = "ducc", strip_prefix = "ducc-356d619a4b5f6f8940d15913c14a043355ef23be", - sha256 = "d23eb2d06f03604867ad40af4fe92dec7cccc2c59f5119e9f01b35b973885c61, + sha256 = "d23eb2d06f03604867ad40af4fe92dec7cccc2c59f5119e9f01b35b973885c61", urls = [ "https://github.com/mreineck/ducc/archive/356d619a4b5f6f8940d15913c14a043355ef23be.tar.gz", "https://storage.googleapis.com/jax-releases/mirror/ducc/ducc-356d619a4b5f6f8940d15913c14a043355ef23be.tar.gz", ], - build_file = "@//third_party/pocketfft:BUILD.bazel", + build_file = "@//third_party/ducc:BUILD.bazel", ) diff --git a/third_party/pocketfft/BUILD.bazel b/third_party/pocketfft/BUILD.bazel deleted file mode 100644 index 29b8e2931327..000000000000 --- a/third_party/pocketfft/BUILD.bazel +++ /dev/null @@ -1,23 +0,0 @@ -licenses(["notice"]) - -package(default_visibility = ["//visibility:public"]) - -cc_library( - name = "pocketfft", - hdrs = ["src/ducc0/fft/fft.h"], - srcs = ["src/ducc0/fft/fft1d.h", - "src/ducc0/infra/aligned_array.h", - "src/ducc0/infra/error_handling.h", - "src/ducc0/infra/mav.h", - "src/ducc0/infra/simd.h", - "src/ducc0/infra/threading.h", - "src/ducc0/infra/threading.cc", - "src/ducc0/infra/useful_macros.h", - "src/ducc0/math/cmplx.h", - "src/ducc0/math/unity_roots.h", -], - copts = ["-fexceptions", "-ffast-math"], - features = ["-use_header_modules"], - include_prefix = "pocketfft", - includes = ["pocketfft", "src"], -)