diff --git a/WORKSPACE b/WORKSPACE index 41522a914547..4ac6d7c079a5 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -7,10 +7,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") # and update the sha256 with the result. http_archive( name = "xla", - sha256 = "4ec16aff3862c5a243db956ce558d7a62eb79f5e20747b0e80802a3b0d12e419", - strip_prefix = "xla-12de6ec958419b57be248d0acd2d9f757e71748c", + sha256 = "5aefcdcffec86005ef4c9ebb1220ab8f6d7389a49274e290ab685ef55d6fd954", + strip_prefix = "xla-d2da05dc41a0fc583505d0ad2e9a40779aee9f90", urls = [ - "https://github.com/openxla/xla/archive/12de6ec958419b57be248d0acd2d9f757e71748c.tar.gz", + "https://github.com/openxla/xla/archive/d2da05dc41a0fc583505d0ad2e9a40779aee9f90.tar.gz", ], ) diff --git a/jax/BUILD b/jax/BUILD index 66cdc4439d74..f84e167ccde6 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -569,6 +569,7 @@ pytype_strict_library( srcs = ["_src/tpu_custom_call.py"], visibility = [":internal"], deps = [ + ":config", ":core", ":jax", "//jax/_src/lib", diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 24f365c3c394..5eb6e34f92e3 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -112,10 +112,10 @@ def _xla_gc_callback(*args): import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error -try: + +if version >= (0, 4, 14): import jaxlib.tpu_mosaic as tpu_mosaic # pytype: disable=import-error -except ImportError: - # TODO(sharadmv): Remove this when minimum jaxlib version >= 0.4.14 +else: # Jaxlib doesn't contain Mosaic bindings tpu_mosaic = None # type: ignore diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 805e56ed9f3e..3655e7f67d2d 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -24,15 +24,15 @@ import io from typing import Any, Callable -from absl import flags import jax from jax import core from jax.interpreters import mlir from jax.interpreters import xla -from mlir import ir -from mlir.dialects import mhlo -from mlir.dialects import stablehlo -from mlir.passmanager import PassManager +from jax._src.config import config +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import mhlo +from jaxlib.mlir.dialects import stablehlo +from jaxlib.mlir.passmanager import PassManager from jax._src.lib import tpu_mosaic import numpy as np @@ -43,8 +43,10 @@ apply_vector_layout = tpu_mosaic.apply_vector_layout infer_memref_layout = tpu_mosaic.infer_memref_layout -_ALLOW_HLO = flags.DEFINE_bool( - "jax_mosaic_allow_hlo", False, "Allow hlo dialect in mosaic" +config.define_bool_state( + name="jax_mosaic_allow_hlo", + default=False, + help="Allow hlo dialects in Mosaic", ) tpu_custom_call_p = core.Primitive("tpu_custom_call") @@ -179,7 +181,7 @@ def _lower_tpu_kernel(module: ir.Module, hardware_generation: int) -> ir.Module: module.operation.get_asm(binary=True, enable_debug_info=True) ) - if _ALLOW_HLO.value: + if config.jax_mosaic_allow_hlo: # Run hlo dialect conversion: hlo -> linalg -> vector. pipeline = [ "hlo-legalize-to-arithmetic", diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 6258c9c23057..93edfb3fea9d 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -59,12 +59,14 @@ py_library_providing_imports_info( "//jaxlib/mlir:chlo_dialect", "//jaxlib/mlir:func_dialect", "//jaxlib/mlir:ir", + "//jaxlib/mlir:memref_dialect", "//jaxlib/mlir:mhlo_dialect", "//jaxlib/mlir:ml_program_dialect", "//jaxlib/mlir:pass_manager", + "//jaxlib/mlir:scf_dialect", "//jaxlib/mlir:sparse_tensor_dialect", "//jaxlib/mlir:stablehlo_dialect", - # placeholder for Mosaic target + "//jaxlib/mosaic", ], ) diff --git a/jaxlib/mlir/BUILD.bazel b/jaxlib/mlir/BUILD.bazel index 1849d7acc23e..b7abed018dfa 100644 --- a/jaxlib/mlir/BUILD.bazel +++ b/jaxlib/mlir/BUILD.bazel @@ -61,6 +61,71 @@ symlink_inputs( ], ) +symlink_inputs( + name = "vector_dialect", + rule = py_library, + symlinked_inputs = {"srcs": {"dialects": [ + "@llvm-project//mlir/python:VectorOpsPyFiles", + ]}}, + deps = [ + ":core", + ":ir", + ":mlir", + ], +) + +symlink_inputs( + name = "math_dialect", + rule = py_library, + symlinked_inputs = {"srcs": {"dialects": [ + "@llvm-project//mlir/python:MathOpsPyFiles", + ]}}, + deps = [ + ":core", + ":ir", + ":mlir", + ], +) + +symlink_inputs( + name = "arithmetic_dialect", + rule = py_library, + symlinked_inputs = {"srcs": {"dialects": [ + "@llvm-project//mlir/python:ArithOpsPyFiles", + ]}}, + deps = [ + ":core", + ":ir", + ":mlir", + ], +) + +symlink_inputs( + name = "memref_dialect", + rule = py_library, + symlinked_inputs = {"srcs": {"dialects": [ + "@llvm-project//mlir/python:MemRefOpsPyFiles", + ]}}, + deps = [ + ":core", + ":ir", + ":mlir", + ], +) + +symlink_inputs( + name = "scf_dialect", + rule = py_library, + symlinked_inputs = {"srcs": {"dialects": [ + "@llvm-project//mlir/python:SCFPyFiles", + ]}}, + deps = [ + ":core", + ":ir", + ":mlir", + ], +) + symlink_inputs( name = "ml_program_dialect", rule = py_library, diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index e6a079a9dfa7..a271acdb0f87 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -106,9 +106,14 @@ py_extension( copts = COPTS, linkopts = LINKOPTS, deps = [ + ":jax_dialects_capi_headers", ":jaxlib_mlir_capi_shared_library", + "@llvm-project//mlir:CAPIArithHeaders", "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:CAPIMathHeaders", + "@llvm-project//mlir:CAPIMemRefHeaders", "@llvm-project//mlir:CAPITransformsHeaders", + "@llvm-project//mlir:CAPIVectorHeaders", "@llvm-project//mlir:MLIRBindingsPythonHeaders", "@local_config_python//:headers", "@pybind11", @@ -193,11 +198,36 @@ cc_library( }), ) +cc_library( + name = "jax_dialects_capi", + srcs = ["jax_dialects.cc"], + hdrs = ["jax_dialects.h"], + deps = [ + "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:SCFDialect", + ], + alwayslink = 1, +) + +cc_library( + name = "jax_dialects_capi_headers", + hdrs = ["jax_dialects.h"], + deps = [ + "@llvm-project//mlir:CAPIIRHeaders", + ], +) + cc_library( name = "jaxlib_mlir_capi_objects", deps = [ + ":jax_dialects_capi", + "//jaxlib/mosaic:tpu_dialect_capi_objects", + "@llvm-project//mlir:CAPIArithObjects", + "@llvm-project//mlir:CAPIMathObjects", + "@llvm-project//mlir:CAPIMemRefObjects", "@llvm-project//mlir:CAPISparseTensorObjects", "@llvm-project//mlir:CAPITransformsObjects", + "@llvm-project//mlir:CAPIVectorObjects", "@llvm-project//mlir:MLIRBindingsPythonCAPIObjects", "@stablehlo//:chlo_capi_objects", "@stablehlo//:stablehlo_capi_objects", diff --git a/jaxlib/mlir/_mlir_libs/_site_initialize_0.cc b/jaxlib/mlir/_mlir_libs/_site_initialize_0.cc index fa2e23f1ae0a..31aea97040bc 100644 --- a/jaxlib/mlir/_mlir_libs/_site_initialize_0.cc +++ b/jaxlib/mlir/_mlir_libs/_site_initialize_0.cc @@ -1,18 +1,31 @@ // Registers MLIR dialects used by JAX. // This module is called by mlir/__init__.py during initialization. +#include "mlir-c/Dialect/Arith.h" #include "mlir-c/Dialect/Func.h" +#include "mlir-c/Dialect/Math.h" +#include "mlir-c/Dialect/MemRef.h" +#include "mlir-c/Dialect/Vector.h" #include "mlir-c/Transforms.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#include "jaxlib/mlir/_mlir_libs/jax_dialects.h" + +#define REGISTER_DIALECT(name) \ + MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ + mlirDialectHandleInsertDialect(name##_dialect, registry) PYBIND11_MODULE(_site_initialize_0, m) { m.doc() = "Registers MLIR dialects used by JAX."; m.def("register_dialects", [](MlirDialectRegistry registry) { - MlirDialectHandle func_dialect = mlirGetDialectHandle__func__(); - mlirDialectHandleInsertDialect(func_dialect, registry); - + REGISTER_DIALECT(arith); + REGISTER_DIALECT(func); + REGISTER_DIALECT(math); + REGISTER_DIALECT(memref); + REGISTER_DIALECT(scf); + REGISTER_DIALECT(vector); + mlirRegisterTransformsPasses(); // Transforms used by JAX. mlirRegisterTransformsStripDebugInfo(); }); -} \ No newline at end of file +} diff --git a/jaxlib/mlir/_mlir_libs/jax_dialects.cc b/jaxlib/mlir/_mlir_libs/jax_dialects.cc new file mode 100644 index 000000000000..b082352f6bd7 --- /dev/null +++ b/jaxlib/mlir/_mlir_libs/jax_dialects.cc @@ -0,0 +1,25 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mlir/_mlir_libs/jax_dialects.h" + +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +extern "C" { + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SCF, scf, mlir::scf::SCFDialect) + +} \ No newline at end of file diff --git a/jaxlib/mlir/_mlir_libs/jax_dialects.h b/jaxlib/mlir/_mlir_libs/jax_dialects.h new file mode 100644 index 000000000000..7e060a784c1f --- /dev/null +++ b/jaxlib/mlir/_mlir_libs/jax_dialects.h @@ -0,0 +1,32 @@ +/* Copyright 2023 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAX_DIALECTS_H +#define JAX_DIALECTS_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SCF, scf); + +#ifdef __cplusplus +} +#endif + +#endif // JAX_DIALECTS_H \ No newline at end of file diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index e18fca9dd365..3fa7112f7b20 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -45,7 +45,7 @@ cc_library( "dialect/tpu/layout.h", "dialect/tpu/tpu_dialect.h", ], - compatible_with = ["//buildenv/target:libtpu"], + # compatible with libtpu deps = [ ":tpu_inc_gen", "@llvm-project//llvm:Support", @@ -72,7 +72,7 @@ cc_library( gentbl_cc_library( name = "tpu_inc_gen", - compatible_with = ["//buildenv/target:libtpu"], + # compatible with libtpu tbl_outs = [ ( ["-gen-op-decls"], @@ -146,7 +146,7 @@ td_library( srcs = [ "dialect/tpu/tpu.td", ], - compatible_with = ["//buildenv/target:libtpu"], + # compatible with libtpu deps = [ "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:ControlFlowInterfacesTdFiles", diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 2014ca9102ed..3e10481dbcfc 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -60,20 +60,15 @@ using ImplicitDim = VectorLayout::ImplicitDim; static constexpr int kLayoutLog = 10; -template class Print { public: - explicit Print(T* t) : payload_(t) {} - T* payload_; + explicit Print(Operation* t) : payload_(t) {} + Operation* payload_; private: - friend std::ostream &operator<<(std::ostream &, Print &&); + friend std::ostream &operator<<(std::ostream &, Print); }; -template -Print(T& t) -> Print; - -template -std::ostream &operator<<(std::ostream &os, Print p) { +std::ostream &operator<<(std::ostream &os, Print p) { std::string s; llvm::raw_string_ostream tmp_os(s); p.payload_->print(tmp_os); diff --git a/jaxlib/mosaic/python/BUILD b/jaxlib/mosaic/python/BUILD index 4f83a3ef61ec..998993ad2f21 100644 --- a/jaxlib/mosaic/python/BUILD +++ b/jaxlib/mosaic/python/BUILD @@ -39,7 +39,7 @@ genrule( name = "tpu_python_gen", srcs = ["_tpu_gen_raw.py"], outs = ["_tpu_gen.py"], - cmd = "cat $(location _tpu_gen_raw.py) | sed -e 's/^from \\./from mlir\\.dialects\\./g' > $@", + cmd = "cat $(location _tpu_gen_raw.py) | sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' > $@", ) py_library( @@ -60,9 +60,9 @@ pybind_extension( name = "_tpu_ext", srcs = ["tpu_ext.cc"], deps = [ - "//jaxlib/mosaic:tpu_dialect_capi", + "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library", + "//jaxlib/mosaic:tpu_dialect_capi_headers", "@llvm-project//mlir:CAPIIR", - "@llvm-project//mlir:IR", "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", "@pybind11", ], @@ -83,6 +83,7 @@ py_library( "//jaxlib/mlir:func_dialect", "//jaxlib/mlir:ir", "//jaxlib/mlir:math_dialect", + "//jaxlib/mlir:memref_dialect", "//jaxlib/mlir:pass_manager", "//jaxlib/mlir:scf_dialect", "//jaxlib/mlir:stablehlo_dialect", diff --git a/jaxlib/mosaic/python/apply_vector_layout.py b/jaxlib/mosaic/python/apply_vector_layout.py index 69cb96bccbaf..e62306723468 100644 --- a/jaxlib/mosaic/python/apply_vector_layout.py +++ b/jaxlib/mosaic/python/apply_vector_layout.py @@ -27,14 +27,14 @@ import functools import math import re -from typing import Any, Callable, Literal, Union, overload - -from mlir import ir -from mlir.dialects import arith -from mlir.dialects import func -from mlir.dialects import math as math_dialect -from mlir.dialects import scf -from mlir.dialects import vector +from typing import Any, Callable, Literal, Optional, Union, overload + +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import func +from jaxlib.mlir.dialects import math as math_dialect +from jaxlib.mlir.dialects import scf +from jaxlib.mlir.dialects import vector import numpy as np from . import infer_memref_layout @@ -72,7 +72,7 @@ def __bool__(self): REPLICATED = Replicated.REPLICATED -Offset = int | Literal[REPLICATED] +Offset = Union[int, Literal[REPLICATED]] class ImplicitDim(enum.IntEnum): @@ -176,7 +176,7 @@ class VectorLayout: bitwidth: int offsets: tuple[Offset, Offset] # Replication applies only within a tile. tiling: tuple[int, int] - implicit_dim: ImplicitDim | None + implicit_dim: Optional[ImplicitDim] def __post_init__(self): # TODO(b/275751535): Allow more bitwidths. @@ -289,7 +289,7 @@ def tile_array_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: raise AssertionError(f"Invalid implicit dim: {self.implicit_dim}") def generalizes(self, other: "VectorLayout", - shape: tuple[int, ...] | None = None) -> bool: + shape: Optional[tuple[int, ...]] = None) -> bool: """Returns True if the other layout is a special case of this one. In here, other is considered "a special case" when the set of vector @@ -345,7 +345,7 @@ def generalizes(self, other: "VectorLayout", return True def equivalent_to(self, other: "VectorLayout", - shape: tuple[int, ...] | None = None) -> bool: + shape: Optional[tuple[int, ...]] = None) -> bool: """Returns True if the two layouts are equivalent. That is, when all potential vector entries where the value can be stored @@ -804,7 +804,7 @@ def get_sublane_mask(self) -> ir.Attribute: return ir.DenseBoolArrayAttr.get(mask) -Layout = VectorLayout | None +Layout = Optional[VectorLayout] PATTERN = re.compile( r'#tpu.vpad<"([0-9]+),{([*0-9]+),([*0-9]+)},\(([0-9]+),([0-9]+)\)(,-1|,-2)?">' @@ -1113,17 +1113,17 @@ class RewriteContext: func: func.FuncOp hardware_generation: int - def erase(self, op: ir.Operation | ir.OpView): + def erase(self, op: Union[ir.Operation, ir.OpView]): if isinstance(op, ir.OpView): op = op.operation op.erase() - def replace(self, old: ir.Operation | ir.OpView, new: ValueLike): + def replace(self, old: Union[ir.Operation, ir.OpView], new: ValueLike): self.replace_all_uses_with(old, new) self.erase(old) def replace_all_uses_with( - self, old: ir.Operation | ir.OpView, new: ValueLike + self, old: Union[ir.Operation, ir.OpView], new: ValueLike ): if isinstance(new, (ir.Operation, ir.OpView)): new = new.results @@ -2649,7 +2649,7 @@ def type_bitwidth(ty: ir.Type) -> int: raise NotImplementedError(ty) -def get_constant(ty: ir.Type, value: int | float) -> ir.Attribute: +def get_constant(ty: ir.Type, value: Union[int, float]) -> ir.Attribute: if ir.IntegerType.isinstance(ty): return ir.IntegerAttr.get(ty, value) elif ty == ir.IndexType.get(): diff --git a/jaxlib/mosaic/python/infer_memref_layout.py b/jaxlib/mosaic/python/infer_memref_layout.py index 7e49b4910567..4217bc466c06 100644 --- a/jaxlib/mosaic/python/infer_memref_layout.py +++ b/jaxlib/mosaic/python/infer_memref_layout.py @@ -15,8 +15,8 @@ """Inference for memref layout and memory space.""" # mypy: ignore-errors -from mlir import ir -from mlir.dialects import func +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import func import numpy as np from . import tpu diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 0e7370847689..1bd4ddd69946 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -88,6 +88,9 @@ def has_ext_modules(self): 'cpu/*', 'cuda/*', 'cuda/nvvm/libdevice/libdevice*', + 'mosaic/*.py', + 'mosaic/python/*.py', + 'mosaic/python/*.so', 'mlir/*.py', 'mlir/dialects/*.py', 'mlir/_mlir_libs/*.dll', diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index b685de813cc1..54721db03a43 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -84,6 +84,16 @@ def copy_file(src_file, dst_dir, dst_filename=None, from_runfiles=True): shutil.copy(src_file, dst_file) +def patch_copy_mlir_import(src_file, dst_dir): + src_file = r.Rlocation(src_file) + src_filename = os.path.basename(src_file) + with open(src_file) as f: + src = f.read() + + with open(os.path.join(dst_dir, src_filename), 'w') as f: + replaced = re.sub(r'^from mlir(\..*)? import (.*)', r'from jaxlib.mlir\1 import \2', src, flags=re.MULTILINE) + f.write(replaced) + _XLA_EXTENSION_STUBS = [ "__init__.pyi", "jax_jit.pyi", @@ -169,6 +179,7 @@ def prepare_wheel(sources_path): copy_to_jaxlib("__main__/jaxlib/gpu_triton.py") copy_to_jaxlib("__main__/jaxlib/gpu_solver.py") copy_to_jaxlib("__main__/jaxlib/gpu_sparse.py") + copy_to_jaxlib("__main__/jaxlib/tpu_mosaic.py") copy_to_jaxlib("__main__/jaxlib/version.py") copy_to_jaxlib("__main__/jaxlib/xla_client.py") copy_to_jaxlib(f"__main__/jaxlib/xla_extension.{pyext}") @@ -200,6 +211,17 @@ def prepare_wheel(sources_path): if exists(f"__main__/jaxlib/rocm/_sparse.{pyext}"): copy_file(f"__main__/jaxlib/rocm/_sparse.{pyext}", dst_dir=rocm_dir) + mosaic_dir = os.path.join(jaxlib_dir, "mosaic") + mosaic_python_dir = os.path.join(mosaic_dir, "python") + os.makedirs(mosaic_dir) + os.makedirs(mosaic_python_dir) + copy_to_jaxlib("__main__/jaxlib/mosaic/python/apply_vector_layout.py", dst_dir=mosaic_python_dir) + copy_to_jaxlib("__main__/jaxlib/mosaic/python/infer_memref_layout.py", dst_dir=mosaic_python_dir) + copy_to_jaxlib("__main__/jaxlib/mosaic/python/tpu.py", dst_dir=mosaic_python_dir) + copy_file(f"__main__/jaxlib/mosaic/python/_tpu_ext.{pyext}", dst_dir=mosaic_python_dir) + copy_file("__main__/jaxlib/mosaic/python/_tpu_ops_ext.py", dst_dir=mosaic_python_dir) + # TODO (sharadmv,skyewm): can we avoid patching this file? + patch_copy_mlir_import("__main__/jaxlib/mosaic/python/_tpu_gen.py", dst_dir=mosaic_python_dir) mlir_dir = os.path.join(jaxlib_dir, "mlir") mlir_dialects_dir = os.path.join(jaxlib_dir, "mlir", "dialects") @@ -223,6 +245,19 @@ def prepare_wheel(sources_path): copy_file("__main__/jaxlib/mlir/dialects/sparse_tensor.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/builtin.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/chlo.py", dst_dir=mlir_dialects_dir) + copy_file("__main__/jaxlib/mlir/dialects/arith.py", dst_dir=mlir_dialects_dir) + copy_file("__main__/jaxlib/mlir/dialects/_arith_ops_gen.py", dst_dir=mlir_dialects_dir) + copy_file("__main__/jaxlib/mlir/dialects/_arith_ops_ext.py", dst_dir=mlir_dialects_dir) + copy_file("__main__/jaxlib/mlir/dialects/math.py", dst_dir=mlir_dialects_dir) + copy_file("__main__/jaxlib/mlir/dialects/_math_ops_gen.py", dst_dir=mlir_dialects_dir) + copy_file("__main__/jaxlib/mlir/dialects/memref.py", dst_dir=mlir_dialects_dir) + copy_file("__main__/jaxlib/mlir/dialects/_memref_ops_gen.py", dst_dir=mlir_dialects_dir) + copy_file("__main__/jaxlib/mlir/dialects/_memref_ops_ext.py", dst_dir=mlir_dialects_dir) + copy_file("__main__/jaxlib/mlir/dialects/scf.py", dst_dir=mlir_dialects_dir) + copy_file("__main__/jaxlib/mlir/dialects/_scf_ops_gen.py", dst_dir=mlir_dialects_dir) + copy_file("__main__/jaxlib/mlir/dialects/_scf_ops_ext.py", dst_dir=mlir_dialects_dir) + copy_file("__main__/jaxlib/mlir/dialects/vector.py", dst_dir=mlir_dialects_dir) + copy_file("__main__/jaxlib/mlir/dialects/_vector_ops_gen.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/mhlo.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/stablehlo.py", dst_dir=mlir_dialects_dir) copy_file("__main__/jaxlib/mlir/dialects/func.py", dst_dir=mlir_dialects_dir)