Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add num_replicas and num_partitions to ModuleOp #650

Open
sdasgup3 opened this issue Nov 29, 2022 · 1 comment
Open

Add num_replicas and num_partitions to ModuleOp #650

sdasgup3 opened this issue Nov 29, 2022 · 1 comment

Comments

@sdasgup3
Copy link
Member

sdasgup3 commented Nov 29, 2022

Request description

The statically know values num_replicas and num_partitions, being provided by HLOModuleConfig, helps to find the size of each process group to be employed in parallel execution. Currently these values are not exposed in StableHLO.

Having these values in StableHLO will enable

  1. type inference of CollectiveOps. For example, using num_replicas and num_partitions we can determine shard_count using GetSubgroupSize. With that we can say: dim(result, all_gather_dim) = shard_count * dim(operand, all_gather_dim).

  2. Populate empty replica_groups ref

Originated in #503 (comment) and
#503 (comment)

@sdasgup3 sdasgup3 added the Spec label Nov 29, 2022
@burmako burmako changed the title Expose num_replicas and num_partitions in StableHLO. Add num_replicas and num_partitions to ModuleOp Nov 29, 2022
burmako pushed a commit that referenced this issue Nov 30, 2022
fixes #462 

Address the followings:
1. Adds verification checks for AllGather w.r.t.
#498
2. fixes #491

A few points
- Type Inference is marked `infeasible` as the return type of the op
depends upon the
[shard_count](https://github.com/tensorflow/tensorflow/blob/20c6943d3cd7e07da162f7778a0af5d3776274b4/tensorflow/compiler/xla/service/hlo_verifier.cc#L452)
which [depends on the result
type](https://github.com/tensorflow/tensorflow/blob/20c6943d3cd7e07da162f7778a0af5d3776274b4/tensorflow/compiler/xla/service/hlo_verifier.cc#L426).
Note that the `shard_count` is a parameter in HLO spec, where as in MHLO
it is
[derived](https://github.com/tensorflow/tensorflow/blob/a1acd6a6466f58ed4197b7beaa2f3e0b6fcfc32a/tensorflow/compiler/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc#L766)
using result type before exporting to HLO.
- With (1) and (2), the Verifier is a `yes`. Note that we still do not
have the
[check](https://github.com/tensorflow/tensorflow/blob/20c6943d3cd7e07da162f7778a0af5d3776274b4/tensorflow/compiler/xla/service/hlo_verifier.cc#L436)
for `shard_count == subgroup_size`. The `subgroup_size` depends on
[module
configuration](https://github.com/tensorflow/tensorflow/blob/20c6943d3cd7e07da162f7778a0af5d3776274b4/tensorflow/compiler/xla/service/hlo_verifier.cc#L95)
which manages the settings and values which affect the compiled
executable outside of the HLO code itself and I am not sure if that info
is available in StableHLO IR.

upd:
- With #650, the type
inference should be made feasible. Marked it `revisit` until that is
fixed.
- The verifier should be `revisit` based on
#652
GleasonK pushed a commit to GleasonK/stablehlo that referenced this issue Dec 6, 2022
fixes openxla#462 

Address the followings:
1. Adds verification checks for AllGather w.r.t.
openxla#498
2. fixes openxla#491

A few points
- Type Inference is marked `infeasible` as the return type of the op
depends upon the
[shard_count](https://github.com/tensorflow/tensorflow/blob/20c6943d3cd7e07da162f7778a0af5d3776274b4/tensorflow/compiler/xla/service/hlo_verifier.cc#L452)
which [depends on the result
type](https://github.com/tensorflow/tensorflow/blob/20c6943d3cd7e07da162f7778a0af5d3776274b4/tensorflow/compiler/xla/service/hlo_verifier.cc#L426).
Note that the `shard_count` is a parameter in HLO spec, where as in MHLO
it is
[derived](https://github.com/tensorflow/tensorflow/blob/a1acd6a6466f58ed4197b7beaa2f3e0b6fcfc32a/tensorflow/compiler/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc#L766)
using result type before exporting to HLO.
- With (1) and (2), the Verifier is a `yes`. Note that we still do not
have the
[check](https://github.com/tensorflow/tensorflow/blob/20c6943d3cd7e07da162f7778a0af5d3776274b4/tensorflow/compiler/xla/service/hlo_verifier.cc#L436)
for `shard_count == subgroup_size`. The `subgroup_size` depends on
[module
configuration](https://github.com/tensorflow/tensorflow/blob/20c6943d3cd7e07da162f7778a0af5d3776274b4/tensorflow/compiler/xla/service/hlo_verifier.cc#L95)
which manages the settings and values which affect the compiled
executable outside of the HLO code itself and I am not sure if that info
is available in StableHLO IR.

upd:
- With openxla#650, the type
inference should be made feasible. Marked it `revisit` until that is
fixed.
- The verifier should be `revisit` based on
openxla#652
@julianwa julianwa moved this from Inbox to Needs Scheduling in (Deprecated) IREE Apr 11, 2023
@julianwa julianwa moved this from Needs Scheduling to Not Started in (Deprecated) IREE Apr 11, 2023
@burmako
Copy link
Contributor

burmako commented Apr 13, 2023

JAX has just added mhlo.num_replicas and mhlo.num_partitions to their lowering: jax-ml/jax#15586. It's great to know that this information is available during lowering 🎉

@burmako burmako moved this to Todo in Frontend contract Apr 23, 2023
@allieculp allieculp moved this from Not Started to Backlog in (Deprecated) IREE May 16, 2023
@burmako burmako assigned sdasgup3 and unassigned atondwal Aug 2, 2023
burmako pushed a commit that referenced this issue Aug 8, 2023
We have the following constraints in the spec:

```
(I1) `operand`: tensor.
(I2) `source_target_pairs`: 2-dimensional tensor constant of type `si64`.
(I3) `channel_id`: constant of type `si64`.
(C1) `dim(source_target_pairs, 1) = 2`.
(C2) `is_unique(source_target_pairs[:, 0])`.
(C3) `is_unique(source_target_pairs[:, 1])`.
(C4) `0 <= source_target_pairs < N`, where N is defined as:
* `num_replicas` if `cross_replica` is used.
* `num_partitions` if `cross_partition` is used.
(C5) `type(result) = type(operand)`.
```

These constraints will be comprehensively covered by the following
tests:

```
I1: a) `operand` is not a tensor. (Covered by ODS).
I2: a) `source_target_pairs` is not a 2-dimensional tensor constant of type `si64`.
I3: a) `channel_id` is not a constant of type `si64`. (Covered by ODS).
C1: a) `dim(source_target_pairs, 1) != 2`.
C2: a) `is_unique(source_target_pairs[:, 0]) = false`.
C3: a) `is_unique(source_target_pairs[:, 1]) = false`.
C4: a) `source_target_pairs < 0`.
C4: b) `source_target_pairs >= N`, where `N` is defined as:
* `num_replicas` if `cross_replica` is used.
* `num_partitions` if `cross_partition` is used.
C5: a) `type(result) != type(operand)`.
```

If we drop the "Covered by ODS" pieces, this will leave us with the
following test cases:

```
I2a: `source_target_pairs` is not a 2-dimensional tensor constant of type `si64`.
C1a: `dim(source_target_pairs, 1) != 2`.
C2a: `is_unique(source_target_pairs[:, 0]) = false`.
C3a: `is_unique(source_target_pairs[:, 1]) = false`.
C4a: `source_target_pairs < 0`.
C4b: `source_target_pairs >= N`, where `N` is defined as:
* `num_replicas` if `cross_replica` is used.
* `num_partitions` if `cross_partition` is used.
C5a: `type(result) != type(operand)`.
```

Notes:
* C4b verification is infeasible since `num_replicas` and
`num_partitions` are not known statically at the moment (see #650).

closes #1124
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Backlog
Status: Todo
Development

No branches or pull requests

3 participants