-
Notifications
You must be signed in to change notification settings - Fork 112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ShardingAttr #619
Comments
I've been working on a prototype in MHLO, which I thought would be nice to have My proposal is that I implement this initially in MHLO so that op Right now we don't have sections in the spec for attributes. Here's what AttributesshardingSemanticsThe Attribute TypesShardingAttr
Properties
Constraints
EnumsShardingType Enum
|
Thank you for your work! First, I'll provide some high-level feedback, and in subsequent comments I'll follow up with some low-level comments on the prose. High-level feedbackOverall, as far as prioritization goes, my recommendation would be to prioritize: 1) speccing ShardingAttr, 2) implementing it in StableHLO, 3) migrating StableHLO users, 4) porting this to MHLO. The rationale for deprioritizing MHLO work is that our main focus right now is on shipping StableHLO v1.0, and MHLO is not on a critical path for this. MHLO will benefit from the work on StableHLO (via having a comprehensive spec and eventually via code sharing), but that is not a P0. This is consistent with how we've been doing development of many features over the last few quarters, including e.g. shape inference and verification - first in StableHLO, then in MHLO. The rationale for prioritizing speccing ShardingAttr is that prototyping ShardingAttr is taking a lot of time, so it doesn't look like the biggest bang for the buck. The spec is direly needed, but as JAX experience shows, a readable string attribute has already been a huge improvement, so the added value of having something more structured is lower than the added value of having a spec. As far as speccing goes, in StableHLO, we're typically looking for something more comprehensive than how we documented MHLO, with the goal for the specification alone being sufficient to develop a conformant implementation. From that perspective, prose like "The sharding strategy determines how the computation for the operation is distributed across multiple devices" is alright, but it needs to be accompanied with a description of what it means for an operation to be distributed across multiple devices. As a foundation for that, we already have the "Parallel execution" section which sketches a formalism for programs distributed across multiple devices. It looks like a spec for ShardingAttr would benefit from a "Sharded ops" subsection in that section which explains how ops with a Speccing this will raise many questions, e.g. what it means for an operation to be executed on a single device - does it mean that the |
(icymi, I just made some minor changes/improvements to the prose, I think while you were writing the above comment) |
Agreed. But I think this description will belong in the section on ShardingAttr under Attribute Types rather than the section on sharding under Attributes. I'm also planning to include examples that make it clear how an operation is sharded.
These are actually far less ambiguous than that appear at first blush! I have more to write for this initial design to be a spec, but note that (and this sentence should be included above in the final version) the only case in which we have multiple "instances" of a tensor are the REPLICATED case. In all other cases, we have a single tensor that is sharded (say, along the batch dim) across many devices. So for example, I'll need to update this draft to spell that out, and explain how the tile_assignment_dimensions in brackets shard a tile. Great point about the "parallel execution" section --- I'll need to align the langauge here which makes reference to device IDs to be compatible with the language earlier in the spec about |
It will be more natural to merge
|
Sharding on StableHLO/MHLO ops is represented as serialized proto string, e.g.
mhlo.sharding = "\08\03\1A\02\02\01\22\02\00\01"
. We should spec it and model it using a dedicate data structure.The text was updated successfully, but these errors were encountered: