Skip to content

Commit

Permalink
[XLA] Introduce HloShardingV2 (2/N).
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 542481139
  • Loading branch information
cezheng authored and Copybara-Service committed Jun 22, 2023
1 parent 1b097b5 commit a781713
Show file tree
Hide file tree
Showing 19 changed files with 697 additions and 439 deletions.
4 changes: 2 additions & 2 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3969,8 +3969,8 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins,
return Undefined();
}

Array<int64_t> tile_assignment = strategy.output_sharding.tile_assignment();
tile_assignment.Reshape({cluster_env.total_devices_});
auto tile_assignment = strategy.output_sharding.tile_assignment().Reshape(
{cluster_env.total_devices_});
return HloSharding::Tile(std::move(tile_assignment));

} else {
Expand Down
10 changes: 5 additions & 5 deletions xla/hlo/experimental/auto_sharding/auto_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ HloSharding BroadcastSharding(const HloSharding& input_spec,
target_tile_assignment_dimensions.push_back(
input_spec.tile_assignment().dimensions().back());
}
Array<int64_t> new_tile_assignment = input_spec.tile_assignment();
new_tile_assignment.Reshape(target_tile_assignment_dimensions);
auto new_tile_assignment =
input_spec.tile_assignment().Reshape(target_tile_assignment_dimensions);

return input_spec.ReplicateOnLastTileDim()
? HloSharding::PartialTile(new_tile_assignment)
Expand Down Expand Up @@ -1182,11 +1182,11 @@ bool IsValidTileAssignment(const HloSharding& spec) {
}

// Check all tile dims
const Array<int64_t>& tile_assignment = spec.tile_assignment();
const auto& tile_assignment = spec.tile_assignment();
for (int i = 0; i < tile_assignment.num_dimensions(); i++) {
if (tile_assignment.dim(i) != 1) {
std::vector<int64_t> device_ids =
GetValuesAlongOneDim(tile_assignment, i).value();
GetValuesAlongOneDim(tile_assignment.array(), i).value();
auto status_or_delta = CheckArithmeticSequence(device_ids);
if (!status_or_delta.ok()) {
return false;
Expand Down Expand Up @@ -1243,7 +1243,7 @@ absl::StatusOr<std::vector<int64_t>> GetTensorDimToMeshDimNoCrash(
do {
auto transposed_mesh = Transpose(mesh, axes);
if (std::equal(transposed_mesh.begin(), transposed_mesh.end(),
spec.tile_assignment().begin())) {
spec.tile_assignment().array().begin())) {
found = true;
break;
}
Expand Down
1 change: 1 addition & 0 deletions xla/hlo/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ cc_library(
"hlo_sharding_metadata.h",
],
deps = [
":tile_assignment",
"//xla:array",
"//xla:comparison_util",
"//xla:literal",
Expand Down

0 comments on commit a781713

Please sign in to comment.