-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a helper function for building custom call lowering rules #21484
Conversation
jax/_src/custom_call.py
Outdated
|
||
def _ir_attribute(obj: Any) -> ir.Attribute: | ||
# TODO(dfm): Similar functions exist in Pallas and Mosaic GPU. Perhaps these | ||
# could be consolidated into mlir or similar. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm inclined to keep this here for now to keep this PR compact, then we could move this function later if it would be useful to consolidate.
768c4a4
to
1044921
Compare
jax/_src/extend/ffi.py
Outdated
kwargs["backend_config"] = dict( | ||
backend_config or {}, **{k: _ir_attribute(v) for k, v in params.items()}) | ||
|
||
# TODO(dfm): This is a common pattern for supporting dynamic shapes, but I |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is really necessary here. CC @gnecula.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just removed this. We can revisit in the future if it's a feature that many users would find useful.
Thanks for the reviews, @superbobry! I think this is getting close to ready. Here are the main updates:
There's one inline question about supporting dynamic shapes out of the box, but I'm not very familiar with that use case. Otherwise, I think this is good to go! |
1c4cef6
to
687f594
Compare
This function provides sensible defaults for custom call lowering rules with the goal of reducing the amount of boilerplate required for implementing custom calls. Co-authored-by: Sergei Lebedev <slebedev@google.com>
@superbobry — Can you take another look at this when you get a chance? I'm happy with it as it stands now: I think it's a useful feature (it would remove a lot of boilerplate from #21645, for example), but I don't think we're overpromising. Let me know what you think. Thanks! |
Looks great, ship it! |
|
||
Args: | ||
call_target_name: The name of the custom call target. | ||
operand_layouts: A sequence of layouts (dimension orders) for each operand. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW I don't think this (layouts) is fully supported today, it's easy to add strides to capture buffer layout, but no one asked for it so far and we don't have any examples
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezhulenev — Oh interesting - good to know! So, if I understand correctly, the custom call always receive buffers in row-major(?) order regardless of the layout
values passed to mlir?
This function provides sensible defaults for custom call lowering rules with the goal of reducing the amount of boilerplate required for implementing custom calls. While all of the behavior can be overridden, this will be particularly useful when (a) the inputs and outputs are in row-major order because then the layouts and types can be evaluated from the context avals, and/or (b) the custom call supports
api_version=4
attributes provided via abackend_config
dictionary.There are a number of lowering rules in the core library that could be rewritten using this helper, but I think it probably wouldn't make sense to port those as part of this PR.
For now the tests only check that lowering succeeds for parameters/attributes of different types, but it might be useful to check that the output is appropriate. I haven't figured out a reasonable way to do that yet - thoughts would be appreciated!
@superbobry — I'd love to hear your feedback. Since this is closely based on code you wrote previously, I added you as a co-author on the commit, but I'm happy to remove that if you'd prefer.