Skip to content

[water] Add FX importer handlers for arithmetic, unary, and attention-specific ops#1054

Merged
martin-luecke merged 4 commits intomainfrom
users/martin/attention_op_handlers
Mar 6, 2026
Merged

[water] Add FX importer handlers for arithmetic, unary, and attention-specific ops#1054
martin-luecke merged 4 commits intomainfrom
users/martin/attention_op_handlers

Conversation

@martin-luecke
Copy link
Contributor

This adds MLIR-to-FX import handlers for the ops needed by attention kernels: binary arithmetic (add, sub, mul, max, min, select), unary (exp2, reciprocal), cast, permute, self_index, apply_expr, extract, and reshape.

The apply_expr handler reconstructs the original sympy expression lambda from the MLIR WaveExprListAttr and combinator attribute, folding the affine map results back through the appropriate sympy constructor.
As the reconstructed lambda is a new Python object, the existing trace equivalence checker cannot compare it by identity. A _check_callable_equivalent helper is added that evaluates both lambdas with fresh symbolic inputs and verifies the resulting sympy expressions are equivalent via simplify(a - b) == 0. This allows roundtrip tests to confirm that ApplyExpr nodes carry semantically identical expressions despite being distinct function objects.

@martin-luecke martin-luecke force-pushed the users/martin/attention_op_handlers branch from c60315d to 2bcc7ac Compare March 5, 2026 21:38
@martin-luecke martin-luecke reopened this Mar 5, 2026
@martin-luecke martin-luecke requested a review from ftynse March 5, 2026 22:09
Copy link
Contributor

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing major

Comment on lines +515 to +547
def _assert_pre_decompose_roundtrip(kernel, subs: dict, label: str) -> None:
"""Compile through graph passes (stopping before decompose_reduce_ops),
emit MLIR, roundtrip, and assert equivalence.

Runs the full pipeline up to but not including `decompose_reduce_ops`
(which introduces `wave.shuffle` ops not yet supported by the
importer).
"""
options = WaveCompileOptions(subs=subs, compile_to_mlir=True)
with IndexingContext() as idxc:
idxc.set_subs(options.subs)
kernel.initialize_wave_constraints()
kernel.initialize_symbolic_constraints()
kernel.initialize_workgroup_constraints()
trace = kernel._trace(location_capture_config=options.location_capture_config)
graph_passes = build_graph_passes(kernel, trace, options)
for p in graph_passes:
if p.__name__ == decompose_reduce_ops.__name__:
break
p()

mlir_text, diagnostics, _ = emitter.emit_wave_dialect(
trace, kernel.constraints, options
)
errors = error_diagnostics(diagnostics)
assert errors == [], f"[{label}] unexpected emit errors: {errors}"

fx_trace, fx_constraints, fx_options, fx_diags = emitter.mlir_to_fx(mlir_text)
errors = error_diagnostics(fx_diags)
assert errors == [], f"[{label}] unexpected import errors: {errors}"

assert_traces_equivalent(trace, fx_trace, subs=options.subs)
print(f" {label}: OK")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we generalize this with the function for pre-canonical forms and take the name of the before which to stop as an argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generalized this with the _assert_roundtrip above

Base automatically changed from users/martin/multi_operand_reduce to main March 6, 2026 17:41
…-specific ops

Signed-off-by: Martin Lücke <martin.luecke@amd.com>
Signed-off-by: Martin Lücke <martin.luecke@amd.com>
@martin-luecke martin-luecke force-pushed the users/martin/attention_op_handlers branch from 6444ba1 to 7a93129 Compare March 6, 2026 17:47
Signed-off-by: Martin Lücke <martin.luecke@amd.com>
Signed-off-by: Martin Lücke <martin.luecke@amd.com>
@martin-luecke martin-luecke merged commit 0e1d128 into main Mar 6, 2026
17 checks passed
@martin-luecke martin-luecke deleted the users/martin/attention_op_handlers branch March 6, 2026 18:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants