-
Notifications
You must be signed in to change notification settings - Fork 2
/
pybind11_kernel_helpers.h
35 lines (26 loc) · 1.18 KB
/
pybind11_kernel_helpers.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
// This header extends kernel_helpers.h with the pybind11 specific interface to
// serializing descriptors. It also adds a pybind11 function for wrapping our
// custom calls in a Python capsule. This is separate from kernel_helpers so that
// the CUDA code itself doesn't include pybind11. I don't think that this is
// strictly necessary, but they do it in jaxlib, so let's do it here too.
#ifndef _JAX_FINUFFT_PYBIND11_KERNEL_HELPERS_H_
#define _JAX_FINUFFT_PYBIND11_KERNEL_HELPERS_H_
#include <pybind11/pybind11.h>
#include "kernel_helpers.h"
namespace jax_finufft {
template <typename T>
pybind11::bytes pack_descriptor(const T& descriptor) {
return pybind11::bytes(pack_descriptor_as_string(descriptor));
}
template <typename T>
pybind11::capsule encapsulate_function(T* fn) {
return pybind11::capsule(bit_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET");
}
template <typename T>
pybind11::bytes build_descriptor(T eps, int iflag, int64_t n_tot, int n_transf, int64_t n_j,
int64_t n_k_1, int64_t n_k_2, int64_t n_k_3) {
return pack_descriptor(
NufftDescriptor<T>{eps, iflag, n_tot, n_transf, n_j, {n_k_1, n_k_2, n_k_3}});
}
} // namespace jax_finufft
#endif