Skip to content

Commit

Permalink
Remove the monkey patch in jax2tf by moving the function to mlir.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 539266562
  • Loading branch information
yashk2810 authored and jax authors committed Jun 10, 2023
1 parent c287b2a commit 1a7336d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 28 deletions.
17 changes: 17 additions & 0 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import dialects
from jax._src.lib.mlir import ir
Expand Down Expand Up @@ -1960,3 +1961,19 @@ def custom_call(
operands = list(operands) + list(result_shapes)

return hlo.CustomCallOp.build_generic(results=out_types, operands=operands, attributes=attributes)


def refine_polymorphic_shapes(module: ir.Module) -> ir.Module:
"""Refine the polymorphic shapes inside a module.
Given a module with static input shapes, but using dynamic shapes due to
shape polymorphism, run shape refinement to resolve all the dynamic shapes.
"""
if xc.mlir_api_version < 50:
raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12")

refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
module_to_bytecode(module))
context = make_ir_context()
with context:
return ir.Module.parse(refined_module_str)
8 changes: 1 addition & 7 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,6 @@ def batched_device_put(aval: core.ShapedArray,
return xc.batched_device_put(aval, sharding, xs, devices, committed) # type: ignore


def refine_shape_polymorphism(module: ir.Module) -> ir.Module:
# In order to avoid depending on jax2tf/jax_export.py we will monkey patch
# this from jax_export to refine the polymorphic shapes in the module.
raise NotImplementedError("Compiling modules with shape polymorphism")


# NOTE(skye): we could refactor to generate _multi_slice parameters directly
# from the input ShardingSpec, rather than the indices. However, this would
# require duplicating the ordering logic of spec_to_indices, which is more
Expand Down Expand Up @@ -2631,7 +2625,7 @@ def from_hlo(name: str,
compiler_options=None
) -> MeshExecutable:
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
hlo = refine_shape_polymorphism(hlo)
hlo = mlir.refine_polymorphic_shapes(hlo)
compiler_options_keys = tuple(
compiler_options.keys()) if compiler_options is not None else None
compiler_options_values = tuple(
Expand Down
21 changes: 0 additions & 21 deletions jax/experimental/jax2tf/jax_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import jax
from jax import sharding
from jax.lib import xla_client as xc

from jax._src import core
from jax._src import dispatch
Expand All @@ -38,7 +37,6 @@
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import xla_extension
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge as xb
Expand Down Expand Up @@ -854,22 +852,3 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra
mlir.register_lowering(call_exported_p,
functools.partial(_call_exported_lowering, platform=_p),
platform=_p)


def _refine_polymorphic_shapes(module: ir.Module) -> ir.Module:
"""Refine the polymorphic shapes inside a module.
Given a module with static input shapes, but using dynamic shapes due to
shape polymorphism, run shape refinement to resolve all the dynamic shapes.
"""
if xc.mlir_api_version < 50:
raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12")

refined_module_str = xla_extension.mlir.refine_polymorphic_shapes(
mlir.module_to_bytecode(module)
)
context = mlir.make_ir_context()
with context:
return ir.Module.parse(refined_module_str)

pxla.refine_shape_polymorphism = _refine_polymorphic_shapes

0 comments on commit 1a7336d

Please sign in to comment.