Skip to content

Commit

Permalink
[XLA] Add option to unconditionally simplify reduce(transpose(x)) and…
Browse files Browse the repository at this point in the history
… reduce(reshape(x)).

Previously we were careful about this.  We would do it only if the result would
be that the transpose in question went away entirely.

The problem with this approach is that it blocks other optimizations, like
transpose-mover, from operating on the transpose.  Those optimizations can't
move transpose "through" reduce, so they need algsimp to do this first.

PiperOrigin-RevId: 520451644
  • Loading branch information
Justin Lebar authored and Copybara-Service committed Mar 29, 2023
1 parent 1797b37 commit 4195a90
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 13 deletions.
63 changes: 50 additions & 13 deletions xla/service/algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5757,10 +5757,11 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
// field if the output of the reduce is a vector or scalar. Higher ranked
// result may require a transpose of the output.
if (arg->opcode() == HloOpcode::kTranspose &&
(reduce->shape().rank() < 2 || arg->user_count() == 1 ||
absl::c_all_of(arg->users(), [](HloInstruction* use) {
return use->opcode() == HloOpcode::kReduce;
}))) {
(options_.unconditionally_simplify_reduce_of_transpose_or_reshape() ||
(reduce->shape().rank() < 2 || arg->user_count() == 1 ||
absl::c_all_of(arg->users(), [](HloInstruction* use) {
return use->opcode() == HloOpcode::kReduce;
})))) {
auto transpose_dimensions = arg->dimensions();
std::vector<int64_t> new_reduce_dimensions;
new_reduce_dimensions.reserve(dimensions.size());
Expand Down Expand Up @@ -5830,24 +5831,35 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
new_dimensions, function));
}

// A reshape that collapses multiple dimensions into a dimension being
// reduced can just reduce all of those dimensions instead of doing a
// collapsing reshape before a reduction.
// Handle two cases of reduce(reshape(x)).
//
// 1. The reshape collapses/expands only dimensions that are being reduced.
// In this case we can just reduce those dimensions and skip the reshape.
// 2. The reshape collapses/expands only dimensions that are *not* being
// reduced. In this case we can do the reshape after the reduce. This is
// beneficial because the reduce will now operate on less data.
if (options_.enable_reduce_of_reshape() &&
arg->opcode() == HloOpcode::kReshape) {
std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(),
arg->shape());
std::vector<bool> arg_dim_in_output(arg->shape().rank(), true);
std::vector<bool> arg_dim_unmodified(arg->shape().rank(), false);

// True for those dimensions of the reduce input that are not reduced, false
// for the dims that are reduced.
absl::InlinedVector<bool, 8> arg_dim_in_output(arg->shape().rank(), true);
for (auto dim : dimensions) {
arg_dim_in_output[dim] = false;
}
for (auto dim_pair : unmodified_dims) {
arg_dim_unmodified[dim_pair.second] = true;

// True for those dimensions of the reduce input that are unmodified by the
// reshape.
absl::InlinedVector<bool, 8> arg_dim_unmodified(arg->shape().rank(), false);
for (auto [input_idx, output_idx] : unmodified_dims) {
arg_dim_unmodified[output_idx] = true;
}
// The goal is to verify that all dimensions that are not removed in the
// reduce are unmodified by the reshape. For example:

// Case 1: Check whether all dimensions that are not removed in the reduce
// are unmodified by the reshape. For example:
// reduce(reshape([A,B*C], a[A,B,C]),[1]) = reduce(a[A, B, C], [1, 2])
bool can_move_reshape_into_reduce = true;
for (int64_t i = 0; i < arg_dim_in_output.size(); ++i) {
Expand All @@ -5874,7 +5886,32 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
reduce_result_shape, arg->mutable_operand(0), init_value,
new_reduce_dimensions, function));
}

// Case 2: Check whether the reshape only modifies non-reduction dimensions.
// Equivalently, the reduction dimensions are all preserved by the reshape.
if ((arg->user_count() == 1 ||
options_.unconditionally_simplify_reduce_of_transpose_or_reshape()) &&
absl::c_all_of(dimensions,
[&](int64_t dim) { return arg_dim_unmodified[dim]; })) {
absl::InlinedVector<int64_t, 8> new_reduce_dims;
for (auto dim : dimensions) {
auto matching_dim_it = absl::c_find_if(
unmodified_dims,
[&](const auto& dim_pair) { return dim_pair.second == dim; });
CHECK(matching_dim_it != unmodified_dims.end());
new_reduce_dims.push_back(matching_dim_it->first);
}

TF_ASSIGN_OR_RETURN(
HloInstruction * new_reduce,
MakeReduceHlo(arg->mutable_operand(0), init_value, new_reduce_dims,
reduce->to_apply(), &reduce->metadata()));
TF_ASSIGN_OR_RETURN(HloInstruction * new_reshape,
MakeReshapeHlo(reduce->shape(), new_reduce));
return ReplaceInstruction(reduce, new_reshape);
}
}

// Convert Reduce(concat({a,b,...})) to
// map(reduce(a),map(reduce(b),...,))
//
Expand Down
12 changes: 12 additions & 0 deletions xla/service/algebraic_simplifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,17 @@ class AlgebraicSimplifierOptions {

bool enable_sink_broadcast() const { return enable_sink_broadcast_; }

// If true, always simplify reduce(transpose(x)) and reduce(reshape(x)), even
// if the transpose/reshape has multiple users. This can be beneficial
// on platforms where the extra transpose/reshape isn't as expensive as
// the optimization benefits brought about by simplifying the graph.
bool unconditionally_simplify_reduce_of_transpose_or_reshape() const {
return unconditionally_simplify_reduce_of_transpose_or_reshape_;
}
void set_unconditionally_simplify_reduce_of_transpose_or_reshape(bool val) {
unconditionally_simplify_reduce_of_transpose_or_reshape_ = val;
}

// If true, min(x, NaN) = NaN. If false, min(x, NaN) = x.
//
// TODO(b/209827141): Remove this and make minmax_propagate_nan uncondtionally
Expand Down Expand Up @@ -220,6 +231,7 @@ class AlgebraicSimplifierOptions {
bool enable_reduce_of_reshape_{true};
bool enable_negative_padding_replacement_{true};
bool enable_sink_broadcast_{true};
bool unconditionally_simplify_reduce_of_transpose_or_reshape_{false};
int64_t very_small_gather_size_{4};
bool minmax_propagate_nan_{true};
Metadata metadata_;
Expand Down
81 changes: 81 additions & 0 deletions xla/service/algebraic_simplifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,87 @@ TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) {
EXPECT_EQ(root->dimensions(), std::vector<int64_t>({0, 2, 3}));
}

TEST_F(AlgebraicSimplifierTest, ReduceOfMergeNoncontractingDims) {
const char* kModuleStr = R"(
HloModule m
add_f32 {
p0 = f32[] parameter(0)
p1 = f32[] parameter(1)
ROOT r = f32[] add(p0, p1)
}
ENTRY test {
p = f32[3,5,7] parameter(0)
reshape = f32[15,7] reshape(p)
ROOT reduce = f32[15] reduce(reshape, f32[] constant(0)), dimensions={1}, to_apply=add_f32
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
AlgebraicSimplifierOptions options = default_options_;
options.set_unconditionally_simplify_reduce_of_transpose_or_reshape(true);
ASSERT_TRUE(AlgebraicSimplifier(options).Run(m.get()).value());
EXPECT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::Reshape(m::Reduce()
.WithShape(F32, {3, 5})
.WithPredicate([](const HloInstruction* instr) {
return instr->dimensions() ==
std::vector<int64_t>({2});
}))
.WithShape(F32, {15})));
}

TEST_F(AlgebraicSimplifierTest, ReduceOfSplitNoncontractingDims) {
const char* kModuleStr = R"(
HloModule m
add_f32 {
p0 = f32[] parameter(0)
p1 = f32[] parameter(1)
ROOT r = f32[] add(p0, p1)
}
ENTRY test {
p = f32[3,35] parameter(0)
reshape = f32[3,5,7] reshape(p)
ROOT reduce = f32[5,7] reduce(reshape, f32[] constant(0)), dimensions={0}, to_apply=add_f32
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
AlgebraicSimplifierOptions options = default_options_;
options.set_unconditionally_simplify_reduce_of_transpose_or_reshape(true);
ASSERT_TRUE(AlgebraicSimplifier(options).Run(m.get()).value());
EXPECT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::Reshape(m::Reduce().WithShape(F32, {35}).WithPredicate(
[](const HloInstruction* instr) {
return instr->dimensions() ==
std::vector<int64_t>({0});
}))
.WithShape(F32, {5, 7})));
}

TEST_F(AlgebraicSimplifierTest,
ReduceOfReshapeOfContractingAndNoncontractingDims) {
const char* kModuleStr = R"(
HloModule m
add_f32 {
p0 = f32[] parameter(0)
p1 = f32[] parameter(1)
ROOT r = f32[] add(p0, p1)
}
ENTRY test {
ROOT reduce = f32[8] reduce(
f32[8,4] reshape(f32[32] parameter(0)), f32[] constant(0)),
dimensions={1}, to_apply=add_f32
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
AlgebraicSimplifierOptions options = default_options_;
options.set_unconditionally_simplify_reduce_of_transpose_or_reshape(true);
ASSERT_FALSE(AlgebraicSimplifier(options).Run(m.get()).value());
}

// Test that Const + A is canonicalized to A + Const.
TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) {
auto m = CreateNewVerifiedModule();
Expand Down
8 changes: 8 additions & 0 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,14 @@ Status GpuCompiler::OptimizeHloModule(
layout_insensitive_algsimp_opts.set_minmax_propagate_nan(
!debug_options.xla_gpu_enable_fast_min_max());

// Always simplify reduce(transpose(x)) and reduce(reshape(x)), even when
// the transpose/reshape has multiple users. This helps int8 models, which
// tend to have lots of transpose+reshape's (converting between NCHW and
// NCHW_VECT_C). Without this, those reshape+transposes can get materialized
// out, which is really bad for perf.
layout_insensitive_algsimp_opts
.set_unconditionally_simplify_reduce_of_transpose_or_reshape(true);

if (gpu_target_config.platform_name == "ROCM") {
layout_insensitive_algsimp_opts.set_enable_conv_operand_swap(false);
}
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/nvptx_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization(

AlgebraicSimplifierOptions algsimp_options;
algsimp_options.set_enable_conv_operand_swap(false);
algsimp_options.set_unconditionally_simplify_reduce_of_transpose_or_reshape(
true);
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(algsimp_options);

// CudnnSimplifyPadding gets rid of some padding introduced by
Expand Down

0 comments on commit 4195a90

Please sign in to comment.