Skip to content

Commit

Permalink
First cl to implement a pass to unstack loop operands.
Browse files Browse the repository at this point in the history
This pass implements unstacking for loop operands. Generally speaking, unstacking is the act of breaking a rank n tensor into n smaller n-1 rank tensors without changing the semantics of the program. There are different patterns that can benefit from unstacking. This pass aims to implement such patterns. The patterns implemented are not exhaustive by any means. There are more patterns to be added.
The pass is not added to the compiler yet.

PiperOrigin-RevId: 638785310
  • Loading branch information
fhoushmand authored and copybara-github committed May 30, 2024
1 parent 8e0cd17 commit 8ffe362
Show file tree
Hide file tree
Showing 5 changed files with 1,239 additions and 1 deletion.
38 changes: 38 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3248,6 +3248,44 @@ xla_cc_test(
],
)

cc_library(
name = "hlo_unstacker",
srcs = ["hlo_unstacker.cc"],
hdrs = ["hlo_unstacker.h"],
deps = [
":hlo_creation_utils",
":hlo_pass",
":pattern_matcher",
":tuple_util",
":while_loop_unroller",
"//xla:shape_util",
"//xla:util",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

xla_cc_test(
name = "hlo_unstacker_test",
srcs = ["hlo_unstacker_test.cc"],
tags = ["requires-net:external"],
deps = [
":hlo_unstacker",
"//xla/hlo/ir:hlo",
"//xla/tests:hlo_test_base",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
],
)

cc_library(
name = "while_loop_unroller",
srcs = ["while_loop_unroller.cc"],
Expand Down
Loading

0 comments on commit 8ffe362

Please sign in to comment.