Skip to content

Commit

Permalink
Enable building jaxlib w/ Mosaic
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 551159246
  • Loading branch information
sharadmv authored and jax authors committed Jul 26, 2023
1 parent f66d3cf commit 3baa6e7
Show file tree
Hide file tree
Showing 17 changed files with 257 additions and 53 deletions.
6 changes: 3 additions & 3 deletions WORKSPACE
Expand Up @@ -7,10 +7,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
# and update the sha256 with the result.
http_archive(
name = "xla",
sha256 = "4ec16aff3862c5a243db956ce558d7a62eb79f5e20747b0e80802a3b0d12e419",
strip_prefix = "xla-12de6ec958419b57be248d0acd2d9f757e71748c",
sha256 = "5aefcdcffec86005ef4c9ebb1220ab8f6d7389a49274e290ab685ef55d6fd954",
strip_prefix = "xla-d2da05dc41a0fc583505d0ad2e9a40779aee9f90",
urls = [
"https://github.com/openxla/xla/archive/12de6ec958419b57be248d0acd2d9f757e71748c.tar.gz",
"https://github.com/openxla/xla/archive/d2da05dc41a0fc583505d0ad2e9a40779aee9f90.tar.gz",
],
)

Expand Down
1 change: 1 addition & 0 deletions jax/BUILD
Expand Up @@ -569,6 +569,7 @@ pytype_strict_library(
srcs = ["_src/tpu_custom_call.py"],
visibility = [":internal"],
deps = [
":config",
":core",
":jax",
"//jax/_src/lib",
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/lib/__init__.py
Expand Up @@ -112,10 +112,10 @@ def _xla_gc_callback(*args):

import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error
try:

if version >= (0, 4, 14):
import jaxlib.tpu_mosaic as tpu_mosaic # pytype: disable=import-error
except ImportError:
# TODO(sharadmv): Remove this when minimum jaxlib version >= 0.4.14
else:
# Jaxlib doesn't contain Mosaic bindings
tpu_mosaic = None # type: ignore

Expand Down
18 changes: 10 additions & 8 deletions jax/_src/tpu_custom_call.py
Expand Up @@ -24,15 +24,15 @@
import io
from typing import Any, Callable

from absl import flags
import jax
from jax import core
from jax.interpreters import mlir
from jax.interpreters import xla
from mlir import ir
from mlir.dialects import mhlo
from mlir.dialects import stablehlo
from mlir.passmanager import PassManager
from jax._src.config import config
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import mhlo
from jaxlib.mlir.dialects import stablehlo
from jaxlib.mlir.passmanager import PassManager
from jax._src.lib import tpu_mosaic
import numpy as np

Expand All @@ -43,8 +43,10 @@
apply_vector_layout = tpu_mosaic.apply_vector_layout
infer_memref_layout = tpu_mosaic.infer_memref_layout

_ALLOW_HLO = flags.DEFINE_bool(
"jax_mosaic_allow_hlo", False, "Allow hlo dialect in mosaic"
config.define_bool_state(
name="jax_mosaic_allow_hlo",
default=False,
help="Allow hlo dialects in Mosaic",
)

tpu_custom_call_p = core.Primitive("tpu_custom_call")
Expand Down Expand Up @@ -179,7 +181,7 @@ def _lower_tpu_kernel(module: ir.Module, hardware_generation: int) -> ir.Module:
module.operation.get_asm(binary=True, enable_debug_info=True)
)

if _ALLOW_HLO.value:
if config.jax_mosaic_allow_hlo:
# Run hlo dialect conversion: hlo -> linalg -> vector.
pipeline = [
"hlo-legalize-to-arithmetic",
Expand Down
4 changes: 3 additions & 1 deletion jaxlib/BUILD
Expand Up @@ -59,12 +59,14 @@ py_library_providing_imports_info(
"//jaxlib/mlir:chlo_dialect",
"//jaxlib/mlir:func_dialect",
"//jaxlib/mlir:ir",
"//jaxlib/mlir:memref_dialect",
"//jaxlib/mlir:mhlo_dialect",
"//jaxlib/mlir:ml_program_dialect",
"//jaxlib/mlir:pass_manager",
"//jaxlib/mlir:scf_dialect",
"//jaxlib/mlir:sparse_tensor_dialect",
"//jaxlib/mlir:stablehlo_dialect",
# placeholder for Mosaic target
"//jaxlib/mosaic",
],
)

Expand Down
65 changes: 65 additions & 0 deletions jaxlib/mlir/BUILD.bazel
Expand Up @@ -61,6 +61,71 @@ symlink_inputs(
],
)

symlink_inputs(
name = "vector_dialect",
rule = py_library,
symlinked_inputs = {"srcs": {"dialects": [
"@llvm-project//mlir/python:VectorOpsPyFiles",
]}},
deps = [
":core",
":ir",
":mlir",
],
)

symlink_inputs(
name = "math_dialect",
rule = py_library,
symlinked_inputs = {"srcs": {"dialects": [
"@llvm-project//mlir/python:MathOpsPyFiles",
]}},
deps = [
":core",
":ir",
":mlir",
],
)

symlink_inputs(
name = "arithmetic_dialect",
rule = py_library,
symlinked_inputs = {"srcs": {"dialects": [
"@llvm-project//mlir/python:ArithOpsPyFiles",
]}},
deps = [
":core",
":ir",
":mlir",
],
)

symlink_inputs(
name = "memref_dialect",
rule = py_library,
symlinked_inputs = {"srcs": {"dialects": [
"@llvm-project//mlir/python:MemRefOpsPyFiles",
]}},
deps = [
":core",
":ir",
":mlir",
],
)

symlink_inputs(
name = "scf_dialect",
rule = py_library,
symlinked_inputs = {"srcs": {"dialects": [
"@llvm-project//mlir/python:SCFPyFiles",
]}},
deps = [
":core",
":ir",
":mlir",
],
)

symlink_inputs(
name = "ml_program_dialect",
rule = py_library,
Expand Down
30 changes: 30 additions & 0 deletions jaxlib/mlir/_mlir_libs/BUILD.bazel
Expand Up @@ -106,9 +106,14 @@ py_extension(
copts = COPTS,
linkopts = LINKOPTS,
deps = [
":jax_dialects_capi_headers",
":jaxlib_mlir_capi_shared_library",
"@llvm-project//mlir:CAPIArithHeaders",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:CAPIMathHeaders",
"@llvm-project//mlir:CAPIMemRefHeaders",
"@llvm-project//mlir:CAPITransformsHeaders",
"@llvm-project//mlir:CAPIVectorHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeaders",
"@local_config_python//:headers",
"@pybind11",
Expand Down Expand Up @@ -193,11 +198,36 @@ cc_library(
}),
)

cc_library(
name = "jax_dialects_capi",
srcs = ["jax_dialects.cc"],
hdrs = ["jax_dialects.h"],
deps = [
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:SCFDialect",
],
alwayslink = 1,
)

cc_library(
name = "jax_dialects_capi_headers",
hdrs = ["jax_dialects.h"],
deps = [
"@llvm-project//mlir:CAPIIRHeaders",
],
)

cc_library(
name = "jaxlib_mlir_capi_objects",
deps = [
":jax_dialects_capi",
"//jaxlib/mosaic:tpu_dialect_capi_objects",
"@llvm-project//mlir:CAPIArithObjects",
"@llvm-project//mlir:CAPIMathObjects",
"@llvm-project//mlir:CAPIMemRefObjects",
"@llvm-project//mlir:CAPISparseTensorObjects",
"@llvm-project//mlir:CAPITransformsObjects",
"@llvm-project//mlir:CAPIVectorObjects",
"@llvm-project//mlir:MLIRBindingsPythonCAPIObjects",
"@stablehlo//:chlo_capi_objects",
"@stablehlo//:stablehlo_capi_objects",
Expand Down
21 changes: 17 additions & 4 deletions jaxlib/mlir/_mlir_libs/_site_initialize_0.cc
@@ -1,18 +1,31 @@
// Registers MLIR dialects used by JAX.
// This module is called by mlir/__init__.py during initialization.

#include "mlir-c/Dialect/Arith.h"
#include "mlir-c/Dialect/Func.h"
#include "mlir-c/Dialect/Math.h"
#include "mlir-c/Dialect/MemRef.h"
#include "mlir-c/Dialect/Vector.h"
#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "jaxlib/mlir/_mlir_libs/jax_dialects.h"

#define REGISTER_DIALECT(name) \
MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \
mlirDialectHandleInsertDialect(name##_dialect, registry)

PYBIND11_MODULE(_site_initialize_0, m) {
m.doc() = "Registers MLIR dialects used by JAX.";

m.def("register_dialects", [](MlirDialectRegistry registry) {
MlirDialectHandle func_dialect = mlirGetDialectHandle__func__();
mlirDialectHandleInsertDialect(func_dialect, registry);

REGISTER_DIALECT(arith);
REGISTER_DIALECT(func);
REGISTER_DIALECT(math);
REGISTER_DIALECT(memref);
REGISTER_DIALECT(scf);
REGISTER_DIALECT(vector);
mlirRegisterTransformsPasses();
// Transforms used by JAX.
mlirRegisterTransformsStripDebugInfo();
});
}
}
25 changes: 25 additions & 0 deletions jaxlib/mlir/_mlir_libs/jax_dialects.cc
@@ -0,0 +1,25 @@
/* Copyright 2023 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "jaxlib/mlir/_mlir_libs/jax_dialects.h"

#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/SCF/IR/SCF.h"

extern "C" {

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SCF, scf, mlir::scf::SCFDialect)

}
32 changes: 32 additions & 0 deletions jaxlib/mlir/_mlir_libs/jax_dialects.h
@@ -0,0 +1,32 @@
/* Copyright 2023 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef JAX_DIALECTS_H
#define JAX_DIALECTS_H

#include "mlir-c/IR.h"
#include "mlir-c/Support.h"

#ifdef __cplusplus
extern "C" {
#endif

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SCF, scf);

#ifdef __cplusplus
}
#endif

#endif // JAX_DIALECTS_H
6 changes: 3 additions & 3 deletions jaxlib/mosaic/BUILD
Expand Up @@ -45,7 +45,7 @@ cc_library(
"dialect/tpu/layout.h",
"dialect/tpu/tpu_dialect.h",
],
compatible_with = ["//buildenv/target:libtpu"],
# compatible with libtpu
deps = [
":tpu_inc_gen",
"@llvm-project//llvm:Support",
Expand All @@ -72,7 +72,7 @@ cc_library(

gentbl_cc_library(
name = "tpu_inc_gen",
compatible_with = ["//buildenv/target:libtpu"],
# compatible with libtpu
tbl_outs = [
(
["-gen-op-decls"],
Expand Down Expand Up @@ -146,7 +146,7 @@ td_library(
srcs = [
"dialect/tpu/tpu.td",
],
compatible_with = ["//buildenv/target:libtpu"],
# compatible with libtpu
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
Expand Down
13 changes: 4 additions & 9 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Expand Up @@ -60,20 +60,15 @@ using ImplicitDim = VectorLayout::ImplicitDim;

static constexpr int kLayoutLog = 10;

template<typename T>
class Print {
public:
explicit Print(T* t) : payload_(t) {}
T* payload_;
explicit Print(Operation* t) : payload_(t) {}
Operation* payload_;
private:
friend std::ostream &operator<<(std::ostream &, Print<T> &&);
friend std::ostream &operator<<(std::ostream &, Print);
};

template<typename T>
Print(T& t) -> Print<T>;

template<typename T>
std::ostream &operator<<(std::ostream &os, Print<T> p) {
std::ostream &operator<<(std::ostream &os, Print p) {
std::string s;
llvm::raw_string_ostream tmp_os(s);
p.payload_->print(tmp_os);
Expand Down
7 changes: 4 additions & 3 deletions jaxlib/mosaic/python/BUILD
Expand Up @@ -39,7 +39,7 @@ genrule(
name = "tpu_python_gen",
srcs = ["_tpu_gen_raw.py"],
outs = ["_tpu_gen.py"],
cmd = "cat $(location _tpu_gen_raw.py) | sed -e 's/^from \\./from mlir\\.dialects\\./g' > $@",
cmd = "cat $(location _tpu_gen_raw.py) | sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' > $@",
)

py_library(
Expand All @@ -60,9 +60,9 @@ pybind_extension(
name = "_tpu_ext",
srcs = ["tpu_ext.cc"],
deps = [
"//jaxlib/mosaic:tpu_dialect_capi",
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi_shared_library",
"//jaxlib/mosaic:tpu_dialect_capi_headers",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
"@pybind11",
],
Expand All @@ -83,6 +83,7 @@ py_library(
"//jaxlib/mlir:func_dialect",
"//jaxlib/mlir:ir",
"//jaxlib/mlir:math_dialect",
"//jaxlib/mlir:memref_dialect",
"//jaxlib/mlir:pass_manager",
"//jaxlib/mlir:scf_dialect",
"//jaxlib/mlir:stablehlo_dialect",
Expand Down

0 comments on commit 3baa6e7

Please sign in to comment.