Skip to content

Commit

Permalink
Fixes for PocketFFT->ducc migration.
Browse files Browse the repository at this point in the history
* Rename modules from pocketfft to ducc.
* Fix up strides at their generation point rather than where they are
  consumed.
  • Loading branch information
hawkinsp committed Aug 26, 2022
1 parent 024ae47 commit b63801b
Show file tree
Hide file tree
Showing 13 changed files with 126 additions and 126 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Expand Up @@ -27,8 +27,8 @@ http_archive(
# path = "/path/to/tensorflow",
# )

load("//third_party/pocketfft:workspace.bzl", pocketfft = "repo")
pocketfft()
load("//third_party/ducc:workspace.bzl", ducc = "repo")
ducc()

# Initialize TensorFlow's external dependencies.
load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3")
Expand Down
4 changes: 2 additions & 2 deletions build/build_wheel.py
Expand Up @@ -171,8 +171,8 @@ def prepare_wheel(sources_path):
copy_to_jaxlib("__main__/jaxlib/lapack.py")
copy_to_jaxlib(f"__main__/jaxlib/_lapack.{pyext}")
copy_to_jaxlib("__main__/jaxlib/mhlo_helpers.py")
copy_to_jaxlib(f"__main__/jaxlib/_pocketfft.{pyext}")
copy_to_jaxlib("__main__/jaxlib/pocketfft.py")
copy_to_jaxlib(f"__main__/jaxlib/_ducc_fft.{pyext}")
copy_to_jaxlib("__main__/jaxlib/ducc_fft.py")
copy_to_jaxlib("__main__/jaxlib/gpu_prng.py")
copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py")
copy_to_jaxlib("__main__/jaxlib/gpu_solver.py")
Expand Down
12 changes: 9 additions & 3 deletions jax/_src/lax/fft.py
Expand Up @@ -31,6 +31,9 @@
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib import mlir_api_version
from jax._src.lib import xla_client
# TODO(phawkins): remove pocketfft references when the minimum jaxlib version
# is 0.3.17 or newer.
from jax._src.lib import ducc_fft
from jax._src.lib import pocketfft
from jax._src.numpy.util import _promote_dtypes_complex, _promote_dtypes_inexact

Expand Down Expand Up @@ -123,8 +126,12 @@ def _fft_lowering(ctx, x, *, fft_type, fft_lengths):

def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths):
x_aval, = ctx.avals_in
return [pocketfft.pocketfft_mhlo(x, x_aval.dtype, fft_type=fft_type,
if ducc_fft:
return [ducc_fft.ducc_fft_mhlo(x, x_aval.dtype, fft_type=fft_type,
fft_lengths=fft_lengths)]
else:
return [pocketfft.pocketfft_mhlo(x, x_aval.dtype, fft_type=fft_type,
fft_lengths=fft_lengths)]


def _naive_rfft(x, fft_lengths):
Expand Down Expand Up @@ -188,5 +195,4 @@ def _fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths):
mlir.register_lowering(fft_p, _fft_lowering)
ad.deflinear2(fft_p, _fft_transpose_rule)
batching.primitive_batchers[fft_p] = _fft_batching_rule
if pocketfft:
mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')
mlir.register_lowering(fft_p, _fft_lowering_cpu, platform='cpu')
12 changes: 11 additions & 1 deletion jax/_src/lib/__init__.py
Expand Up @@ -99,7 +99,17 @@ def _parse_version(v: str) -> Tuple[int, ...]:

import jaxlib.xla_client as xla_client
import jaxlib.lapack as lapack
import jaxlib.pocketfft as pocketfft

# TODO(phawkins): remove pocketfft references when the minimum jaxlib version
# is 0.3.17 or newer.
try:
import jaxlib.pocketfft as pocketfft
except ImportError:
pocketfft = None
try:
import jaxlib.ducc_fft as ducc_fft
except ImportError:
ducc_fft = None

xla_extension = xla_client._xla
pytree = xla_client._xla.pytree
Expand Down
36 changes: 18 additions & 18 deletions jaxlib/BUILD
Expand Up @@ -29,21 +29,21 @@ package(default_visibility = ["//:__subpackages__"])
py_library(
name = "jaxlib",
srcs = [
"ducc_fft.py",
"gpu_linalg.py",
"gpu_prng.py",
"gpu_solver.py",
"gpu_sparse.py",
"init.py",
"lapack.py",
"mhlo_helpers.py",
"pocketfft.py",
":version",
":xla_client",
],
data = [":xla_extension"],
deps = [
":_ducc_fft",
":_lapack",
":_pocketfft",
":cpu_feature_guard",
"//jaxlib/mlir",
"//jaxlib/mlir:builtin_dialect",
Expand Down Expand Up @@ -148,8 +148,8 @@ cc_library(
srcs = ["lapack_kernels.cc"],
hdrs = ["lapack_kernels.h"],
deps = [
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@com_google_absl//absl/base:dynamic_annotations",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
],
)

Expand All @@ -176,40 +176,40 @@ pybind_extension(
],
)

# PocketFFT
# DUCC (CPU FFTs)

flatbuffer_cc_library(
name = "pocketfft_flatbuffers_cc",
srcs = ["pocketfft.fbs"],
name = "ducc_fft_flatbuffers_cc",
srcs = ["ducc_fft.fbs"],
)

cc_library(
name = "pocketfft_kernels",
srcs = ["pocketfft_kernels.cc"],
hdrs = ["pocketfft_kernels.h"],
name = "ducc_fft_kernels",
srcs = ["ducc_fft_kernels.cc"],
hdrs = ["ducc_fft_kernels.h"],
copts = ["-fexceptions"], # PocketFFT may throw.
features = ["-use_header_modules"],
deps = [
":pocketfft_flatbuffers_cc",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
":ducc_fft_flatbuffers_cc",
"@ducc",
"@flatbuffers//:runtime_cc",
"@pocketfft",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
],
)

pybind_extension(
name = "_pocketfft",
srcs = ["pocketfft.cc"],
name = "_ducc_fft",
srcs = ["ducc_fft.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_pocketfft",
module_name = "_ducc_fft",
deps = [
":ducc_fft_flatbuffers_cc",
":ducc_fft_kernels",
":kernel_pybind11_helpers",
":pocketfft_flatbuffers_cc",
":pocketfft_kernels",
"@flatbuffers//:runtime_cc",
"@pybind11",
],
Expand All @@ -220,9 +220,9 @@ cc_library(
srcs = ["cpu_kernels.cc"],
visibility = ["//visibility:public"],
deps = [
":ducc_kernels",
":lapack_kernels",
":lapack_kernels_using_lapack",
":pocketfft_kernels",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry",
],
alwayslink = 1,
Expand Down
40 changes: 20 additions & 20 deletions jaxlib/pocketfft.cc → jaxlib/ducc_fft.cc
Expand Up @@ -16,53 +16,53 @@ limitations under the License.
#include <complex>
#include <vector>

#include "jaxlib/kernel_pybind11_helpers.h"
#include "jaxlib/pocketfft_generated.h"
#include "jaxlib/pocketfft_kernels.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "jaxlib/ducc_fft_generated.h"
#include "jaxlib/ducc_fft_kernels.h"
#include "jaxlib/kernel_pybind11_helpers.h"

namespace py = pybind11;

namespace jax {
namespace {

py::bytes BuildPocketFftDescriptor(const std::vector<uint64_t>& shape,
bool is_double, int fft_type,
const std::vector<uint64_t>& fft_lengths,
const std::vector<uint64_t>& strides_in,
const std::vector<uint64_t>& strides_out,
const std::vector<uint32_t>& axes,
bool forward, double scale) {
PocketFftDescriptorT descriptor;
py::bytes BuildDuccFftDescriptor(const std::vector<uint64_t> &shape,
bool is_double, int fft_type,
const std::vector<uint64_t> &fft_lengths,
const std::vector<uint64_t> &strides_in,
const std::vector<uint64_t> &strides_out,
const std::vector<uint32_t> &axes,
bool forward, double scale) {
DuccFftDescriptorT descriptor;
descriptor.shape = shape;
descriptor.fft_type = static_cast<PocketFftType>(fft_type);
descriptor.fft_type = static_cast<DuccFftType>(fft_type);
descriptor.dtype =
is_double ? PocketFftDtype_COMPLEX128 : PocketFftDtype_COMPLEX64;
is_double ? DuccFftDtype_COMPLEX128 : DuccFftDtype_COMPLEX64;
descriptor.strides_in = strides_in;
descriptor.strides_out = strides_out;
descriptor.axes = axes;
descriptor.forward = forward;
descriptor.scale = scale;
flatbuffers::FlatBufferBuilder fbb;
fbb.Finish(PocketFftDescriptor::Pack(fbb, &descriptor));
return py::bytes(reinterpret_cast<char*>(fbb.GetBufferPointer()),
fbb.Finish(DuccFftDescriptor::Pack(fbb, &descriptor));
return py::bytes(reinterpret_cast<char *>(fbb.GetBufferPointer()),
fbb.GetSize());
}

py::dict Registrations() {
pybind11::dict dict;
dict["pocketfft"] = EncapsulateFunction(PocketFft);
dict["ducc_fft"] = EncapsulateFunction(DuccFft);
return dict;
}

PYBIND11_MODULE(_pocketfft, m) {
PYBIND11_MODULE(_ducc_fft, m) {
m.def("registrations", &Registrations);
m.def("pocketfft_descriptor", &BuildPocketFftDescriptor, py::arg("shape"),
m.def("ducc_fft_descriptor", &BuildDuccFftDescriptor, py::arg("shape"),
py::arg("is_double"), py::arg("fft_type"), py::arg("fft_lengths"),
py::arg("strides_in"), py::arg("strides_out"), py::arg("axes"),
py::arg("forward"), py::arg("scale"));
}

} // namespace
} // namespace jax
} // namespace
} // namespace jax
12 changes: 6 additions & 6 deletions jaxlib/pocketfft.fbs → jaxlib/ducc_fft.fbs
Expand Up @@ -15,20 +15,20 @@ limitations under the License.

namespace jax;

enum PocketFftDtype : byte {
enum DuccFftDtype : byte {
COMPLEX64 = 0,
COMPLEX128 = 1,
}

enum PocketFftType : byte {
enum DuccFftType : byte {
C2C = 0,
C2R = 1,
R2C = 2,
}

table PocketFftDescriptor {
dtype:PocketFftDtype;
fft_type:PocketFftType;
table DuccFftDescriptor {
dtype:DuccFftDtype;
fft_type:DuccFftType;
shape:[uint64];
strides_in:[uint64];
strides_out:[uint64];
Expand All @@ -37,4 +37,4 @@ table PocketFftDescriptor {
scale:double;
}

root_type PocketFftDescriptor;
root_type DuccFftDescriptor;
28 changes: 14 additions & 14 deletions jaxlib/pocketfft.py → jaxlib/ducc_fft.py
Expand Up @@ -19,12 +19,12 @@


from .mhlo_helpers import custom_call
from . import _pocketfft
from . import _ducc_fft
import numpy as np

from jaxlib import xla_client

for _name, _value in _pocketfft.registrations().items():
for _name, _value in _ducc_fft.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="cpu")

FftType = xla_client.FftType
Expand All @@ -34,7 +34,7 @@
_C2R = 1
_R2C = 2

def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
def _ducc_fft_descriptor(shape: List[int], dtype, fft_type: FftType,
fft_lengths: List[int]) -> bytes:
n = len(shape)
assert len(fft_lengths) >= 1
Expand All @@ -44,7 +44,7 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
forward = fft_type in (FftType.FFT, FftType.RFFT)
is_double = np.finfo(dtype).dtype == np.float64
if fft_type == FftType.RFFT:
pocketfft_type = _R2C
ducc_fft_type = _R2C

assert dtype in (np.float32, np.float64), dtype
out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128)
Expand All @@ -54,7 +54,7 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
out_shape[-1] = out_shape[-1] // 2 + 1

elif fft_type == FftType.IRFFT:
pocketfft_type = _C2R
ducc_fft_type = _C2R
assert np.issubdtype(dtype, np.complexfloating), dtype

out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64)
Expand All @@ -64,7 +64,7 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
out_shape[-1] = fft_lengths[-1]
assert (out_shape[-1] // 2 + 1) == shape[-1]
else:
pocketfft_type = _C2C
ducc_fft_type = _C2C

assert np.issubdtype(dtype, np.complexfloating), dtype
out_dtype = dtype
Expand All @@ -79,24 +79,24 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
# Builds a PocketFftDescriptor flatbuffer. This descriptor is passed to the
# C++ kernel to describe the FFT to perform.
strides_in = []
stride = dtype.itemsize
stride = 1
for d in reversed(shape):
strides_in.append(stride)
stride *= d

strides_out = []
stride = out_dtype.itemsize
stride = 1
for d in reversed(out_shape):
strides_out.append(stride)
stride *= d

axes = [n - len(fft_lengths) + d for d in range(len(fft_lengths))]

scale = 1. if forward else (1. / np.prod(fft_lengths))
descriptor = _pocketfft.pocketfft_descriptor(
descriptor = _ducc_fft.ducc_fft_descriptor(
shape=shape if fft_type != FftType.IRFFT else out_shape,
is_double=is_double,
fft_type=pocketfft_type,
fft_type=ducc_fft_type,
fft_lengths=fft_lengths,
strides_in=list(reversed(strides_in)),
strides_out=list(reversed(strides_out)),
Expand All @@ -107,13 +107,13 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
return descriptor, out_dtype, out_shape


def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
"""PocketFFT kernel for CPU."""
def ducc_fft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
"""DUCC FFT kernel for CPU."""
a_type = ir.RankedTensorType(a.type)
n = len(a_type.shape)

fft_lengths = list(fft_lengths)
descriptor_bytes, out_dtype, out_shape = _pocketfft_descriptor(
descriptor_bytes, out_dtype, out_shape = _ducc_fft_descriptor(
list(a_type.shape), dtype, fft_type, fft_lengths)

if out_dtype == np.float32:
Expand Down Expand Up @@ -141,7 +141,7 @@ def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
layout = tuple(range(n - 1, -1, -1))
return custom_call(
"pocketfft",
"ducc_fft",
[ir.RankedTensorType.get(out_shape, out_type)],
[descriptor, a],
operand_layouts=[[0], layout],
Expand Down

0 comments on commit b63801b

Please sign in to comment.