-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
_site_initialize_0.cc
31 lines (27 loc) · 1.01 KB
/
_site_initialize_0.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// Registers MLIR dialects used by JAX.
// This module is called by mlir/__init__.py during initialization.
#include "mlir-c/Dialect/Arith.h"
#include "mlir-c/Dialect/Func.h"
#include "mlir-c/Dialect/Math.h"
#include "mlir-c/Dialect/MemRef.h"
#include "mlir-c/Dialect/Vector.h"
#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "jaxlib/mlir/_mlir_libs/jax_dialects.h"
#define REGISTER_DIALECT(name) \
MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \
mlirDialectHandleInsertDialect(name##_dialect, registry)
PYBIND11_MODULE(_site_initialize_0, m) {
m.doc() = "Registers MLIR dialects used by JAX.";
m.def("register_dialects", [](MlirDialectRegistry registry) {
REGISTER_DIALECT(arith);
REGISTER_DIALECT(func);
REGISTER_DIALECT(math);
REGISTER_DIALECT(memref);
REGISTER_DIALECT(scf);
REGISTER_DIALECT(vector);
mlirRegisterTransformsPasses();
// Transforms used by JAX.
mlirRegisterTransformsStripDebugInfo();
});
}