Skip to content

Conversation

@froystig
Copy link
Member

@froystig froystig commented Sep 24, 2021

This is an initial take on #7733 corresponding to jax.jit

Example usage:

>>> import jax; from jax import numpy as jnp
>>> def f(x): return jnp.sqrt(x ** 2) + 1.
>>> f_jit = jax.jit(f)

>>> f_low = f_jit.lower(1.)
>>> print(f_low._xla_computation().as_hlo_text())
HloModule jit_f__1.8

ENTRY jit_f__1.8 {
  constant.2 = pred[] constant(false)
  parameter.1 = f32[] parameter(0)
  multiply.3 = f32[] multiply(parameter.1, parameter.1)
  sqrt.4 = f32[] sqrt(multiply.3)
  constant.5 = f32[] constant(1)
  add.6 = f32[] add(sqrt.4, constant.5)
  ROOT tuple.7 = (f32[]) tuple(add.6)
}

>>> f_exe = f_low.compile()
>>> print(f_exe._xla_executable().hlo_modules()[0].to_string())
HloModule jit_f__1.8

%fused_computation (param_0.2: f32[]) -> f32[] {
  %param_0.2 = f32[] parameter(0)
  %abs.1 = f32[] abs(f32[] %param_0.2), metadata={op_type="sqrt" op_name="jit(f)/sqrt" source_file="<ipython-input-2-f5f7990cdb84>" source_line=2}
  %constant.0 = f32[] constant(1), metadata={op_type="add" op_name="jit(f)/add" source_file="<ipython-input-2-f5f7990cdb84>" source_line=2}
  ROOT %add.0 = f32[] add(f32[] %abs.1, f32[] %constant.0), metadata={op_type="add" op_name="jit(f)/add" source_file="<ipython-input-2-f5f7990cdb84>" source_line=2}
}

ENTRY %jit_f__1.8 (parameter.1: f32[]) -> (f32[]) {
  %parameter.1 = f32[] parameter(0)
  %fusion = f32[] fusion(f32[] %parameter.1), kind=kLoop, calls=%fused_computation, metadata={op_type="add" op_name="jit(f)/add" source_file="<ipython-input-2-f5f7990cdb84>" source_line=2}
  ROOT %tuple.7 = (f32[]) tuple(f32[] %fusion)
}

>>> f_exe(3.)
DeviceArray(4., dtype=float32, weak_type=True)

>>> f_exe([3.])
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-7-048b0cc42154> in <module>
----> 1 f_exe([3.])

~/ws/jax/jax/_src/api.py in __call__(self, *args, **kwargs)
    540     if in_tree != self.in_tree:
    541       # TODO(frostig): provide more info about the source function
--> 542       raise TypeError(
    543           f'function compiled for {self.in_tree}, called with {in_tree}')
    544     out_flat = self.computation.call(*args_flat)

TypeError: function compiled for PyTreeDef(((*,), {})), called with PyTreeDef((([*],), {}))

Note the compiler optimization between lowering and compilation, from XLA.

We have yet to determine exactly how to expose (functions of) the lowered computation and compiled executable, so for now the example above reaches for the underlying computation and executable directly, via a hidden method (leading underscore).

@froystig froystig self-assigned this Sep 24, 2021
@google-cla google-cla bot added the cla: yes label Sep 24, 2021
Copy link
Member

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this have more caching? I would expect calling jit(f).lower(*args).compile() to cause cache hits when I jit(f)(*args) later.

@froystig
Copy link
Member Author

Yep, that's part of the plan. I might scope this PR to a cache-free version of things first, to get the refactoring and basic functionality in place.

@froystig froystig marked this pull request as ready for review September 30, 2021 02:02
@froystig froystig added the pull ready Ready for copybara import and testing label Sep 30, 2021
Copy link
Member

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK this is only supposed to be a first step and to that end LGTM. I left some comments that you might want to address, but it's fine by me to do it later.

out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
name=flat_fun.__name__, donated_invars=donated_invars, inline=inline)
return tree_unflatten(out_tree(), out)
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider returning a dict you later splat as kwargs? that's what I do with xmap and pjit to avoid messing up the order of returns/binders

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also called it _infer_params, which I think is a better description for this function than _prepare_jit, but I don't have a super strong preference

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla: yes pull ready Ready for copybara import and testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants