Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a helper function for building custom call lowering rules #21484

Merged
merged 1 commit into from
Jun 6, 2024

Conversation

dfm
Copy link
Member

@dfm dfm commented May 29, 2024

This function provides sensible defaults for custom call lowering rules with the goal of reducing the amount of boilerplate required for implementing custom calls. While all of the behavior can be overridden, this will be particularly useful when (a) the inputs and outputs are in row-major order because then the layouts and types can be evaluated from the context avals, and/or (b) the custom call supports api_version=4 attributes provided via a backend_config dictionary.

There are a number of lowering rules in the core library that could be rewritten using this helper, but I think it probably wouldn't make sense to port those as part of this PR.

For now the tests only check that lowering succeeds for parameters/attributes of different types, but it might be useful to check that the output is appropriate. I haven't figured out a reasonable way to do that yet - thoughts would be appreciated!

@superbobry — I'd love to hear your feedback. Since this is closely based on code you wrote previously, I added you as a co-author on the commit, but I'm happy to remove that if you'd prefer.

@dfm dfm requested a review from superbobry May 29, 2024 13:57
@dfm dfm self-assigned this May 29, 2024
@dfm dfm added the pull ready Ready for copybara import and testing label May 29, 2024
jax/__init__.py Outdated Show resolved Hide resolved
jax/_src/custom_call.py Outdated Show resolved Hide resolved
jax/_src/custom_call.py Outdated Show resolved Hide resolved
jax/_src/custom_call.py Outdated Show resolved Hide resolved
jax/_src/custom_call.py Outdated Show resolved Hide resolved

def _ir_attribute(obj: Any) -> ir.Attribute:
# TODO(dfm): Similar functions exist in Pallas and Mosaic GPU. Perhaps these
# could be consolidated into mlir or similar.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm inclined to keep this here for now to keep this PR compact, then we could move this function later if it would be useful to consolidate.

jax/_src/custom_call.py Outdated Show resolved Hide resolved
jax/_src/custom_call.py Outdated Show resolved Hide resolved
@dfm dfm force-pushed the custom-call-lowering branch 4 times, most recently from 768c4a4 to 1044921 Compare June 3, 2024 00:23
jax/_src/extend/ffi.py Show resolved Hide resolved
jax/_src/custom_call.py Outdated Show resolved Hide resolved
jax/_src/extend/ffi.py Outdated Show resolved Hide resolved
jax/_src/extend/ffi.py Outdated Show resolved Hide resolved
jax/_src/extend/ffi.py Outdated Show resolved Hide resolved
jax/_src/extend/ffi.py Outdated Show resolved Hide resolved
kwargs["backend_config"] = dict(
backend_config or {}, **{k: _ir_attribute(v) for k, v in params.items()})

# TODO(dfm): This is a common pattern for supporting dynamic shapes, but I
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is really necessary here. CC @gnecula.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just removed this. We can revisit in the future if it's a feature that many users would find useful.

@dfm
Copy link
Member Author

dfm commented Jun 3, 2024

Thanks for the reviews, @superbobry! I think this is getting close to ready. Here are the main updates:

  1. After conversations with @superbobry and @froystig, I agreed that it makes sense for this to live somewhere under jax.extend. In particular, I think it actually makes the most sense for this to be in jax.extend.ffi, and then we can focus specifically on supporting the new XLA FFI interface, rather than general custom calls. With this in mind, I've renamed the function to jax.extend.ffi.ffi_lowering.
  2. I updated the docstring to omit any discussion of primitives (or binding) for now. As we settle on a public API for primitives, this could be updated.

There's one inline question about supporting dynamic shapes out of the box, but I'm not very familiar with that use case. Otherwise, I think this is good to go!

@dfm dfm force-pushed the custom-call-lowering branch 2 times, most recently from 1c4cef6 to 687f594 Compare June 6, 2024 15:32
This function provides sensible defaults for custom call lowering rules
with the goal of reducing the amount of boilerplate required for
implementing custom calls.

Co-authored-by: Sergei Lebedev <slebedev@google.com>
@dfm
Copy link
Member Author

dfm commented Jun 6, 2024

@superbobry — Can you take another look at this when you get a chance? I'm happy with it as it stands now: I think it's a useful feature (it would remove a lot of boilerplate from #21645, for example), but I don't think we're overpromising. Let me know what you think. Thanks!

@superbobry
Copy link
Member

Looks great, ship it!

@copybara-service copybara-service bot merged commit 90c83bb into google:main Jun 6, 2024
12 of 13 checks passed

Args:
call_target_name: The name of the custom call target.
operand_layouts: A sequence of layouts (dimension orders) for each operand.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW I don't think this (layouts) is fully supported today, it's easy to add strides to capture buffer layout, but no one asked for it so far and we don't have any examples

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezhulenev — Oh interesting - good to know! So, if I understand correctly, the custom call always receive buffers in row-major(?) order regardless of the layout values passed to mlir?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants