Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ShardingAttr #619

Open
burmako opened this issue Nov 25, 2022 · 5 comments
Open

Add ShardingAttr #619

burmako opened this issue Nov 25, 2022 · 5 comments
Assignees
Labels

Comments

@burmako
Copy link
Contributor

burmako commented Nov 25, 2022

Sharding on StableHLO/MHLO ops is represented as serialized proto string, e.g. mhlo.sharding = "\08\03\1A\02\02\01\22\02\00\01". We should spec it and model it using a dedicate data structure.

@atondwal
Copy link
Contributor

atondwal commented Mar 25, 2023

I've been working on a prototype in MHLO, which I thought would be nice to have
to try this out before I send a PR to the spec, but it's trickier than
expected. I'd like to get some feedback on the design before then, since I've
already committed a lot of engineering time to the prototype, and want to get
some early feedback as I'm commit even more time to it. cc @mikedelorimier who was asking about this yesterday.

My proposal is that I implement this initially in MHLO so that op sharding can take
either this sharding attr, or a string (as before). And then I'll add it in stablehlo in a form that only takes the attr.

Right now we don't have sections in the spec for attributes. Here's what
I've been working off of:

Attributes

sharding

Semantics

The stablehlo.sharding attribute is used to specify the sharding strategy for
an operation --- specified as a ShardingAttr --- used to distribute the
operation across multiple devices.

Attribute Types

ShardingAttr

  • REPLICATED: The operation is replicated across all devices and
    executed independently on each device. No other
    fields are used in this case, except metadata.
  • MAXIMAL: The operation is executed entirely on a single device,
    the first device in the tile_assignment_devices list.
  • TUPLE: The tuple_shardings
    field contains a list of ShardingAttr instances, one per leaf node
    in the tuple shape. No other fields are used in this case, except
    metadata.
  • OTHER: The operation is sharded using the specified tile shape and
    device assignment. The tile_shape field defines the shape of the
    sharded tile, and tile_assignment_dimensions specifies the shape
    of the tile assignment tensor. The tile_assignment_devices field
    contains a flattened list of device IDs for the tile assignment.
    • If replicate_on_last_tile_dim is true, the data is sharded
      according to other dimensions of the tile_assignment, but
      replicated across devices along the last dimension.
    • The last_tile_dims field is a list of ShardingType values
      representing the sharding type of each subgroup in the
      tile_assignment_dimensions. If this is shorter than the rank
      of tile_shape, it applies to the trailing dimensions.
  • MANUAL: The operation is manually sharded, which means that the
    shapes are already partitioned, and the partitioner should not
    change this operation. This type is used when custom sharding is
    desired and cannot be achieved using the other sharding types.

Properties

Label Property C++ Type Description Constraints
(P1) type ShardingType The type of sharding.
(P2) tile_shape ArrayRef<int64_t> The shape of the sharded tile. (C1)
(P3) tuple_shardings ArrayRef A list of sub-shardings for a tuple type. (C2)
(P4) replicate_on_last_tile_dim Optional Indicates whether data is sharded according to other dimensions of tile_assignment, but replicated across devices along the last dimension. (C1)
(P5) metadata Optional The source location of this sharding.
(P6) tile_assignment_dimensions ArrayRef<int64_t> The shape of the tile assignment tensor. (C1),(C3)
(P7) tile_assignment_devices ArrayRef<int64_t> Flattened list of device IDs for tile assignment. (C1),(C3)
(P8) last_tile_dims ArrayRef A list of sharding types representing each subgroup.

Constraints

  • (C1) These values are only used when type == OTHER. Otherwise it is equal to {}.
  • (C2) tuple_shardings is only used for type == TUPLE. Otherwise it is equal to {}.
  • (C3) The product of tile_assignment_dimensions must equal the size of tile_assignment_devices.

Enums

ShardingType Enum

ShardingType is an enumeration attribute that represents the type of sharding for an operation. The possible values are:

  • REPLICATED: This sharding is replicated across all devices (all other fields are unused).
  • MAXIMAL: This sharding is maximal - one device runs the entire operation.
  • TUPLE: This sharding is a tuple - only the tuple_shardings field is valid.
  • OTHER: None of the above; tile_shape and tile_assignment are both used.
  • MANUAL: This op is manually sharded: the shapes are already partitioned, and the partitioner should not change this op.

@burmako
Copy link
Contributor Author

burmako commented Mar 25, 2023

Thank you for your work! First, I'll provide some high-level feedback, and in subsequent comments I'll follow up with some low-level comments on the prose.

High-level feedback

Overall, as far as prioritization goes, my recommendation would be to prioritize: 1) speccing ShardingAttr, 2) implementing it in StableHLO, 3) migrating StableHLO users, 4) porting this to MHLO.

The rationale for deprioritizing MHLO work is that our main focus right now is on shipping StableHLO v1.0, and MHLO is not on a critical path for this. MHLO will benefit from the work on StableHLO (via having a comprehensive spec and eventually via code sharing), but that is not a P0. This is consistent with how we've been doing development of many features over the last few quarters, including e.g. shape inference and verification - first in StableHLO, then in MHLO.

The rationale for prioritizing speccing ShardingAttr is that prototyping ShardingAttr is taking a lot of time, so it doesn't look like the biggest bang for the buck. The spec is direly needed, but as JAX experience shows, a readable string attribute has already been a huge improvement, so the added value of having something more structured is lower than the added value of having a spec.

As far as speccing goes, in StableHLO, we're typically looking for something more comprehensive than how we documented MHLO, with the goal for the specification alone being sufficient to develop a conformant implementation.

From that perspective, prose like "The sharding strategy determines how the computation for the operation is distributed across multiple devices" is alright, but it needs to be accompanied with a description of what it means for an operation to be distributed across multiple devices.

As a foundation for that, we already have the "Parallel execution" section which sketches a formalism for programs distributed across multiple devices. It looks like a spec for ShardingAttr would benefit from a "Sharded ops" subsection in that section which explains how ops with a sharding attribute are executed.

Speccing this will raise many questions, e.g. what it means for an operation to be executed on a single device - does it mean that the (0, 0) process receives all instances of the inputs and then runs the op on a somehow concatenated value? Figuring these things out has been one of the critical contributions of the StableHLO spec.

@atondwal
Copy link
Contributor

atondwal commented Mar 25, 2023

(icymi, I just made some minor changes/improvements to the prose, I think while you were writing the above comment)

@atondwal
Copy link
Contributor

From that perspective, prose like "The sharding strategy determines how the computation for the operation is distributed across multiple devices" is alright, but it needs to be accompanied with a description of what it means for an operation to be distributed across multiple devices.

Agreed. But I think this description will belong in the section on ShardingAttr under Attribute Types rather than the section on sharding under Attributes. I'm also planning to include examples that make it clear how an operation is sharded.

Speccing this will raise many questions, e.g. what it means for an operation to be executed on a single device - does it mean that the (0, 0) process receives all instances of the inputs and then runs the op on a somehow concatenated value? Figuring these things out has been one of the critical contributions of the StableHLO spec.

These are actually far less ambiguous than that appear at first blush! I have more to write for this initial design to be a spec, but note that (and this sentence should be included above in the final version) the only case in which we have multiple "instances" of a tensor are the REPLICATED case. In all other cases, we have a single tensor that is sharded (say, along the batch dim) across many devices. So for example, add : <3x3xf32>, <3x3xf32> -> <3x3xf32> sharded with maximal, tile_assigment=4, runs everything on device ID 4. Sharded with (XLA syntax) [3,1,1]0,1,2 means that the op is sharded across the first dimension into 3 shards of shape 1x3x3, each of which runs on the first 3 devices by ID.

I'll need to update this draft to spell that out, and explain how the tile_assignment_dimensions in brackets shard a tile.

Great point about the "parallel execution" section --- I'll need to align the langauge here which makes reference to device IDs to be compatible with the language earlier in the spec about (0,0)-style replica/partition devices. I think what XLA does here makes sense for us as well --- treat it as one might describe something like union { device_by_part[replicas][partitions]; devices_by_id[replicas*partitions] }

@burmako burmako moved this from Todo to In Progress in Frontend contract Apr 13, 2023
@burmako burmako moved this from In Progress to Todo in Frontend contract Apr 23, 2023
@sogartar
Copy link

sogartar commented May 5, 2023

It will be more natural to merge tile_assignment_dimensions and tile_assignment_devices into a single value. This will be a tensor value with the shape tile_assignment_dimensions inside.
Then you can have something like

tile_assignment_devices = [[0, 1], [2, 3]]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Todo
Development

Successfully merging a pull request may close this issue.

4 participants