Skip to content

Commit

Permalink
[RFC] StableHLO v1.0 Opset Deprecations & Cleanups (#2283)
Browse files Browse the repository at this point in the history
A proposal to remove redundant operations from StableHLO before
long-term compatibility guarantees go into place.

High level summary:
- Remove `CreateTokenOp`, `TraceOp`, `BroadcastOp`, `DotOp`,
`UnaryEinsumOp`, `RealDynamicSliceOp`.
- Enhance `DynamicSliceOp`.
- Move `CrossReplicaSumOp` to CHLO.
- Hopefully remove/move to CHLO (need feedback) `MapOp`, `RngOp`,
`EinsumOp`, `TorchIndexSelectOp`, `GetTupleElementOp`, `tuple` and `tuple` type.

OpenXLA Discuss post:
https://groups.google.com/a/openxla.org/g/openxla-discuss/c/sBAkvnd2bcA

Related tickets: #2176, #3
  • Loading branch information
GleasonK committed May 13, 2024
1 parent 9b38205 commit 06bcb0d
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 21 deletions.
105 changes: 84 additions & 21 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,26 +316,6 @@ inputs/outputs and a signature. The name consists of the `stablehlo.` prefix and
a **mnemonic** which uniquely identifies one of the supported ops. See below for
a comprehensive list of all supported ops.

At the moment, StableHLO programs in the wild sometimes contain operations that
are not described in this document. In the future, we are planning to either
absorb these operations into the StableHLO opset or prohibit them from appearing
in StableHLO programs. In the meanwhile, here is the list of these operations:

* `builtin.module`, `func.func`, `func.call` and `func.return`
([#425](https://github.com/openxla/stablehlo/issues/425)).
* `chlo` operations ([#602](https://github.com/openxla/stablehlo/issues/602)).
* "Not in HLO" category of StableHLO operations - they were initially part of
the StableHLO opset but have been later deemed to not fit it well:
`broadcast`, `create_token`, `cross-replica-sum`, `dot`, `einsum`,
`torch_index_select`, `unary_einsum`
([#3](https://github.com/openxla/stablehlo/issues/3)).
* "Dynamism" category of StableHLO operations - they were bootstrapped from
MHLO,and we are in the process of speccing them: `real_dynamic_slice`,
`set_dimension_size`.
([#8](https://github.com/openxla/stablehlo/issues/8)).
* Shape computations, including `arith`, `shape` and `tensor` operations
([#8](https://github.com/openxla/stablehlo/issues/8)).

```ebnf
OpInputs ::= OpInputValues OpInputFuncs OpInputAttrs
OpInputValues ::= '(' [OpInputValue {',' OpInputValue}] ')'
Expand Down Expand Up @@ -3668,6 +3648,11 @@ component of the type. The element-type could be anything.

### get_tuple_element

> Note: Per [StableHLO v1.0 Cleanup #2283](https://github.com/openxla/stablehlo/pull/2283),
> this op is being explored for deprecation as it appears to be unused by both
> frameworks and compilers. As such, it has limited compatibility guarantees
> (6 months).
#### Semantics

Extracts element at `index` position of the `operand` tuple and produces a
Expand Down Expand Up @@ -3695,7 +3680,6 @@ Extracts element at `index` position of the `operand` tuple and produces a

```mlir
// %operand: ([1.0, 2.0], (3))
%result = "stablehlo.get_tuple_element"(%operand) {
index = 0 : i32
} : (tuple<tensor<2xf32>, tuple<tensor<i32>>>) -> tensor<2xf32>
// %result: [1.0, 2.0]
Expand Down Expand Up @@ -4039,6 +4023,11 @@ Performs element-wise logistic operation on `operand` tensor and produces a

### map

> Note: Per [StableHLO v1.0 Cleanup #2283](https://github.com/openxla/stablehlo/pull/2283),
> this op is being explored for deprecation as it appears to be unused by both
> frameworks and compilers. As such, it has limited compatibility guarantees
> (6 months).
#### Semantics

Applies a map function `computation` to `inputs` along the `dimensions` and
Expand Down Expand Up @@ -5170,6 +5159,11 @@ and produces a `result` tensor. More formally,

### rng

> Note: Per [StableHLO v1.0 Cleanup #2283](https://github.com/openxla/stablehlo/pull/2283),
> this op is being explored for deprecation as it appears to be unused by both
> frameworks and compilers. As such, it has limited compatibility guarantees
> (6 months).
#### Semantics

Generates random numbers using the `rng_distribution` algorithm and produces a
Expand Down Expand Up @@ -6337,6 +6331,11 @@ unit_diagonal, transpose_a), a, b, type(result))`.

### tuple

> Note: Per [StableHLO v1.0 Cleanup #2283](https://github.com/openxla/stablehlo/pull/2283),
> this op is being explored for deprecation as it appears to be unused by both
> frameworks and compilers. As such, it has limited compatibility guarantees
> (6 months).
#### Semantics

Produces a `result` tuple from values `val`.
Expand Down Expand Up @@ -6559,6 +6558,70 @@ tensor. Depending on the element type, does the following:

&nbsp;[More Examples](https://github.com/openxla/stablehlo/tree/main/stablehlo/tests/interpret/xor.mlir)

## Dialect Interop

At the moment, StableHLO programs in the wild sometimes contain operations that
are not defined by StableHLO.

### Module, Function, Call and Return

StableHLO uses upstream MLIR operations for ModuleOp, FuncOp, CallOp, and
ReturnOp. This was done for better interop with existing MLIR machinery, as many
useful passes are written targeting FuncOp and ModuleOp, and many compilation
pipelines expect these ops to be present. Full compatibility guarantees are
applied to these ops. If anything ever changes about these ops in an
incompatible way (i.e. removal), StableHLO equivalents will be added to preserve
compatibility.

### CHLO

The CHLO opset contains higher level operations that decompose to StableHLO.
Currently there are no compatibility guarantees for CHLO. For compatibility
guarantees, the [chlo-legalize-to-stablehlo pass](https://github.com/openxla/stablehlo/blob/12fd0a9e7b3c6f3dea3defc513870c962e62726d/stablehlo/transforms/Passes.td#L119)
must be used prior to serialization.

### Shape Operations

It is a common use case in the community to use certain operations from core
MLIR dialects in dynamic StableHLO programs to perform shape computations.
Most commonly, these include [`shape` dialect](https://mlir.llvm.org/docs/Dialects/ShapeDialect/)
ops like `shape_of` or `num_elements`, [`tensor` dialect](https://mlir.llvm.org/docs/Dialects/TensorOps/)
ops like `dim` or `from_elements`, and the builtin `index` type.

The [Dynamism RFC > O2](https://github.com/openxla/stablehlo/blob/main/rfcs/20230704-dynamism-101.md#o2)
denotes these as out of scope, however some support for `index` types is
included for interop purposes. There are no compatibility guarantees for these
ops or types. The [shape-legalize-to-stablehlo](https://github.com/openxla/stablehlo/blob/12fd0a9e7b3c6f3dea3defc513870c962e62726d/stablehlo/transforms/Passes.td#L136)
pass can be used to convert these operations to fully supported StableHLO ops.

## Deprecated Operations

There are several StableHLO operations that were inherited from
[MHLO](https://github.com/openxla/xla/blob/d63deb9250b9c212445290bd08c6effb5b6d0a2b/xla/mlir_hlo/mhlo/IR/hlo_ops.td)
which are deprecated and on the way out of StableHLO. The full details on these
removals can be found in the [StableHLO v1.0 Cleanup #2283](https://github.com/openxla/stablehlo/pull/2283).

These operations fall into a few categories:

* "Not in HLO" category of StableHLO operations - they were initially part of
the StableHLO opset but have been later deemed to not fit it well:
`broadcast`, `create_token`, `cross-replica-sum`, `dot`, `einsum`,
`torch_index_select`, `unary_einsum`
([#3](https://github.com/openxla/stablehlo/issues/3)).
* Unused ops - These operations may have been useful at some point, but the ops
were either underdeveloped, or the pipelines using these ops have been
refactored to not require them anymore. This includes `map`, `tuple`,
`get_tuple_element`, and `rng`.

Some of these ops can be removed easily given that they can be expressed using
existing ops (`broadcast`, `create_token`, `cross-replica-sum`, `dot`,
`unary_einsum`) and will be removed after the existing compatibilty window
passes (6 months). Others are still being explored for removal (`einsum`,
`get_tuple_element`, `map`, `rng` `torch_index_select`, `tuple`). Pending
community feedback, these ops will either be removed, or added to the spec with
full support. Until these ops futures are known, they are only guaranteed 6
months of compatibility.

## Execution

### Sequential execution
Expand Down
113 changes: 113 additions & 0 deletions rfcs/20240503-opset-deprecations.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# RFC: StableHLO v1.0 Opset Deprecations & Cleanups

Author: gleasonk<br/>
Last Modified: 5/3/24<br/>
Status: In review<br/>

## Background

This doc covers a list of opset cleanups that we want to do for StableHLO v1.0.
Most of these ops were never spec’ed and therefore have no formal compatibility
guarantees, per the [Unspecced Features compatibility exemption][compat-out-of-scope],
however we can provide some backward / forward compatibility for most of them.

This doc will propose futures for the ops intentionally omitted from the spec,
including:

- [“Not in HLO” ops][not-in-HLO] ([#3](https://github.com/openxla/stablehlo/issues/3)):
`broadcast`, `create_token`, `cross-replica-sum`, `dot`, `einsum`,
`torch_index_select`, `unary_einsum`, `trace` ([#604](https://github.com/openxla/stablehlo/issues/604)).
- [Dynamism RFC P4](https://github.com/openxla/stablehlo/blob/main/rfcs/20230704-dynamism-101.md#p4)
opset updates: `real_dynamic_slice` vs `dynamic_slice`.
- Potentially unused ops like stateful `rng` [#597](https://github.com/openxla/stablehlo/issues/597)
and `map`.
- Tuple Ops and type, including `get_tuple_element` and `tuple` op, along with
`tuple` type support in `custom_call` ([#598](https://github.com/openxla/stablehlo/issues/598)).

In general (unless the op is unused and can be trivially deleted), the
deprecation steps will be as follows:

1. Migrate framework uses of redundant ops.
1. Block serialization of deprecated ops once frameworks migrated.
1. Migrate uses of the ops to the supported StableHLO op (add builder methods).
1. Change VHLO legalization to upgrade to the supported op for compatibility.
1. Remove the redundant StableHLO op.
1. Remove redundant op from VHLO after 6 months.

## Proposed Opset Changes

### P0: Delete `CreateTokenOp` and `TraceOp`

These ops are both unused as far as we can tell. They can be trivially deleted.

### P1: Deprecate `BroadcastOp`, `DotOp`, `UnaryEinsumOp`

These ops are all a trivial subset of features of another op. I.e. BroadcastOp
can be represented using BroadcastInDim, DotOp with DotGeneralOp, UnaryEinsum
with [`einsum` lowering][einsum-lowering].
These ops will follow the formal deprecation process listed above.

Helper methods can be added to the support op for compatibility, something like:
`isLhsBroadcast`, `isSimpleDot`, `isUnaryEinsum`.

### P2: Deprecate `RealDynamicSliceOp`, Enhance `DynamicSliceOp`

In terms of naming `stablehlo.dynamic_slice` is more in-model than
`real_dynamic_slice`. However in terms of functionality, per
[Dynamism RFC P4](https://github.com/openxla/stablehlo/blob/main/rfcs/20230704-dynamism-101.md#p4)
the behavior of `real_dynamic_slice` is correct. We propose to enhance
`dynamic_slice_op` to have an identical feature set as `real_dynamic_slice`, and
deprecate `real_dynamic_slice`. This change will be done with full
forward and backward compatibility.

One could make the argument that `dot` is a more proper name than `dot_general`,
and I'm happy to go down that route, but it will likely cause a good deal of
code churn in community repos. Interested in feedback here.

### P3: Move `CrossReplicaSumOp` to CHLO

The `cross-replica-sum` op (hyphens not a typo), is just sugar for an
`all-reduce` op. Even in the XlaBuilder's [xla::CrossReplicaSum][CRS]
implementation this op is decomposed into an all reduce. We could just remove
this op, and eventually we may, but we propose to move it to CHLO in the short
term since frameworks map to this op, and this will keep the refactoring fairly
trivial.

### P4: Deprecate `MapOp`, `RngOp`, `EinsumOp` `TorchIndexSelectOp`, Tuple support

**Feedback Requested:** These opset changes are pending community feedback.

These are all ops that seem to have very limited use in StableHLO. It would be
great to remove them all or move them to CHLO, as opposed to providing long term
compatibility on ops that aren't needed.

In the interim, we only plan to guarantee the existing 6 month compatibility
guarantees until these ops' futures are more clearly known.

- **MapOp** is unused as far as we can tell, including in HLO. Its uses tend to
be just for a region to mimic a composite, which is no longer needed after the
addition of the `stablehlo.composite` op. This op likely can be removed.
- **RngOp** is stateful, and there is a better alternative in
`RngBitGeneratorOp`. More work needs to be done to determine if all uses of this
op can be safely migrated to the alternative.
- **EinsumOp** can likely be moved to CHLO, the [xla::Einsum][einsum] method is
similarly a decomposition. It is unclear how necessary this abstraction is for
linalg lowerings though.
- **TorchIndexSelectOp** can also likely be moved to CHLO. There is an existing
[lowering to `gather`][torch-index-select] which can be used for a
decomposition. However, similar to `einsum`, it is unclear how necessary this
abstraction is to the community.
- **Tuple Support** includes `get_tuple_element` and `tuple` ops, along with,
support for `tuple` type in `custom_call` ([#598](https://github.com/openxla/stablehlo/issues/598)).
The use of tuples in MLIR is limited, and these are mostly kept around for
interop with XLA and other dialects.

Interested in feedback on any of the above proposals, or ideas for how to keep
these changes from being too invasive to community projects!

[compat-out-of-scope]: https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md#out-of-scope
[not-in-HLO]: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#:~:text=%22Not%20in%20HLO%22,-category%20of%20StableHLO
[CRS]: https://github.com/openxla/xla/blob/6cc24d8548094b3fc94dacc569fc6959227ae28b/xla/client/xla_builder.cc#L3619
[einsum]: https://github.com/openxla/xla/blob/8371ea90202d9ca1cb1148237a1a1ef3620b354a/xla/client/lib/matrix.cc#L386
[einsum-lowering]: https://github.com/openxla/xla/blob/6cc24d8548094b3fc94dacc569fc6959227ae28b/xla/mlir_hlo/mhlo/IR/mhlo_canonicalize.td#L30
[torch-index-select]: https://github.com/openxla/xla/blob/8371ea90202d9ca1cb1148237a1a1ef3620b354a/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc#L45

0 comments on commit 06bcb0d

Please sign in to comment.