Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] Improve memory bandwidth utilization of column reduction #11018

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
cd60cb7
optimize column reduction
lingzhi98 Mar 28, 2024
b8f3770
missing file
lingzhi98 Mar 28, 2024
d67859b
typo
lingzhi98 Mar 28, 2024
e364b0c
disable tiling adjustment for mlir reduction codegen and gpu performa…
lingzhi98 Mar 28, 2024
b33507c
format
lingzhi98 Mar 28, 2024
37474bf
fix reduction mlir test build error
lingzhi98 Mar 28, 2024
ad8fe37
refine
lingzhi98 Mar 29, 2024
90cbfa6
comments
lingzhi98 Mar 30, 2024
1a725c6
fix filecheck error
lingzhi98 Mar 30, 2024
a510ce7
reduce the number of syncs
lingzhi98 Mar 30, 2024
08169b7
remove unused variable
lingzhi98 Mar 30, 2024
1453d73
add tests
lingzhi98 Apr 12, 2024
0caba39
modify tests
lingzhi98 Apr 12, 2024
7b465eb
estimate compute time of column reduction fusion correctly
lingzhi98 Apr 16, 2024
4ac609d
modify comment
lingzhi98 Apr 16, 2024
b75bbba
estimate compute time correctly for kernels which use shared cache
lingzhi98 Apr 18, 2024
8c0e306
make block counts larger than core counts
lingzhi98 Apr 18, 2024
2605755
update reduction mlir emitter
lingzhi98 Apr 23, 2024
87bb413
modify comment
lingzhi98 Apr 23, 2024
6976ee6
correct the condition of using shared cache
lingzhi98 Apr 23, 2024
8c63413
update compute time calculation in indexing performance model
lingzhi98 Apr 23, 2024
45904ee
fix merge conflict
lingzhi98 Apr 23, 2024
4caf6d4
remove unused variable
lingzhi98 Apr 23, 2024
5e68018
conflict
lingzhi98 Apr 24, 2024
24f71c8
fix reduction base test
lingzhi98 Apr 24, 2024
c7e8ca2
format
lingzhi98 Apr 24, 2024
f4d7713
use alloca
lingzhi98 Apr 26, 2024
4fe58f2
rotate iter args
lingzhi98 May 1, 2024
6d08bc5
conflict
lingzhi98 May 1, 2024
a9a48d6
hlo change
lingzhi98 May 1, 2024
444f36c
fix
lingzhi98 May 1, 2024
9ac2a26
fix build error
lingzhi98 May 1, 2024
7f4200c
restore comments
lingzhi98 May 1, 2024
496be7f
remove unused variable
lingzhi98 May 1, 2024
954cbf1
format
lingzhi98 May 1, 2024
a166784
remove
lingzhi98 May 1, 2024
77ca02e
keep usage consistent
lingzhi98 May 1, 2024
5e71fba
conflict
lingzhi98 May 6, 2024
5a481e6
fix build error
lingzhi98 May 6, 2024
5463bbe
remove redundant comments
lingzhi98 May 6, 2024
bdba6a6
remove legacy emitter related change
lingzhi98 May 6, 2024
b1c53f2
Merge remote-tracking branch 'origin'
lingzhi98 May 6, 2024
a35ef76
move cost model change to another pull request
lingzhi98 May 7, 2024
c0cf49c
check thread indexing of vectorize column reduction for mlir reductio…
lingzhi98 May 7, 2024
b248c32
fix hlo filecheck error
lingzhi98 May 7, 2024
2554006
remove unnecessary constraint check
lingzhi98 May 8, 2024
d3d0679
add comments
lingzhi98 May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 124 additions & 14 deletions xla/service/gpu/fusions/reduction_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,57 @@ int RowReductionGetRowsPerWarp(int reduced_dimension_size) {
return WarpSize() / reduced_dimension_size;
}

int64_t ComputeColReductionActiveCore(Vector3 reduction_dimensions,
int64_t tile_y, int64_t num_threads_y,
int64_t num_threads_x, int vector_size) {
constexpr int kColMajorKept = ReductionDimensions::kColMajorKeptDimension;
constexpr int kColReduced = ReductionDimensions::kColReducedDimension;
constexpr int kColMinorKept = ReductionDimensions::kColMinorKeptDimension;
// The number of blocks is strongly related to the number of active cores.
// Make block counts larger than sm core counts can reach better memory
// bandwidth. We don't consider how many threads are scheduled in an sm core
// due to it is not as important as active core ratio for memory-bound kernels
// and hope to relax the vectorization restrictions of column reduction.
int64_t blocks_x = CeilOfRatio(reduction_dimensions[kColMinorKept],
num_threads_x * vector_size);
int64_t block_tile_y = num_threads_y * tile_y;
int64_t blocks_y =
CeilOfRatio(reduction_dimensions[kColReduced], block_tile_y);
int64_t blocks = reduction_dimensions[kColMajorKept] * blocks_x * blocks_y;
return blocks;
}

int GetVectorSize(const HloFusionAnalysis& analysis,
const ReductionDimensions& reduction_dimensions,
int num_threads, Vector3 reduction_tiling) {
int64_t num_threads_y, int64_t num_threads_x,
Vector3 reduction_tiling, bool for_mlir) {
if (MayPreventVectorization(analysis.fusion())) {
return 1;
}

constexpr int kColMinorKept = ReductionDimensions::kColMinorKeptDimension;
if (!reduction_dimensions.is_row_reduction) {
if (for_mlir) {
int vector_size = 2;
auto num_kept_minor = reduction_dimensions.dimensions[kColMinorKept];
// Check if the last dimension is divisible by (vector_size *
// num_threads_x).
if (num_kept_minor % (vector_size * num_threads_x) == 0) {
return vector_size;
}
}
return 1;
}

constexpr int kRowMinorReduced =
ReductionDimensions::kRowMinorReducedDimension;
if (reduction_dimensions.dimensions[kRowMinorReduced] % 2 != 0 ||
MayPreventVectorization(analysis.fusion())) {
if (reduction_dimensions.dimensions[kRowMinorReduced] % 2 != 0) {
return 1;
}

// Enabling vectorization if (number_threads * vector_size) is <=
// minor_reduced_dimension otherwise exist threads not doing any work.
if (num_threads * 2 > reduction_dimensions.dimensions[kRowMinorReduced]) {
if (num_threads_x * 2 > reduction_dimensions.dimensions[kRowMinorReduced]) {
return 1;
}

Expand All @@ -93,14 +127,62 @@ int GetVectorSize(const HloFusionAnalysis& analysis,
if (cuda_cc->IsAtLeast(se::CudaComputeCapability::PASCAL_)) {
return analysis.input_output_info().smallest_input_dtype_bits <= 32 &&
reduction_dimensions.dimensions[kRowMinorReduced] %
(reduction_tiling[kRowMinorReduced] * num_threads) ==
(reduction_tiling[kRowMinorReduced] *
num_threads_x) ==
0
? 2
: 1;
}
return 1;
}

std::tuple<Vector3, int, bool> AdjustColReductionTilingConfig(
const HloFusionAnalysis& analysis, Vector3 reduction_dimensions,
Vector3 reduction_tiling, int64_t num_threads_y, int64_t num_threads_x,
int vector_size) {
constexpr int kColReduced = ReductionDimensions::kColReducedDimension;
auto core_count = analysis.device_info().core_count();
constexpr int minimum_tile_size = 8;

auto actual_tile_size =
CeilOfRatio(reduction_dimensions[kColReduced], num_threads_y);
reduction_tiling[kColReduced] = actual_tile_size;
// Early return if all of the sm cores are active.
if (ComputeColReductionActiveCore(
reduction_dimensions, reduction_tiling[kColReduced], num_threads_y,
num_threads_x, vector_size) >= core_count) {
return {reduction_tiling, vector_size, false};
}

auto roots = analysis.fusion().GetRoots();
for (auto [root, hero] : llvm::zip(roots, analysis.fusion_heroes())) {
// Only adjust tile_y if hero is reduction and output element type is F32.
// F32 atomic is fast so that we can ignore the extra atomic overhead
// by adjusting tile_y to increase the parallelism of the kernel.
if (hero->opcode() == HloOpcode::kReduce) {
if (hero != (&root.instruction()) ||
hero->shape().element_type() != F32) {
// If we can not adjust tile_y but sm core active ratio is low, reset
// vector size as 1.
return {reduction_tiling, 1, false};
}
}
}

auto current_tile_size = actual_tile_size;
while (current_tile_size >= minimum_tile_size * 2) {
if (ComputeColReductionActiveCore(reduction_dimensions, current_tile_size,
num_threads_y, num_threads_x,
vector_size) > core_count)
break;
current_tile_size = current_tile_size / 2;
}
bool tile_size_decreased = current_tile_size != actual_tile_size;
reduction_tiling[kColReduced] = current_tile_size;
return {reduction_tiling, tile_size_decreased ? vector_size : 1,
tile_size_decreased};
}

ReductionGroups GroupDisjointReductions(const HloFusionAnalysis& analysis,
bool for_mlir) {
const int num_fusion_outputs = analysis.fusion_root_count();
Expand Down Expand Up @@ -230,6 +312,15 @@ int ReductionInfo::GetRowsPerWarp() const {
tiling_.GetShape()[ReductionDimensions::kRowMinorReducedDimension]);
}

int ReductionInfo::ElemsWritePerThread() const {
const auto& shape = tiling_.GetShape();
if (!IsRowReduction() &&
ReductionDimensions::kVectorizedDimension < shape.size()) {
return shape[ReductionDimensions::kVectorizedDimension];
}
return 1;
}

LaunchDimensions ReductionInfo::launch_dimensions() const {
size_t blocks_y = groups_.grouped_roots.size();
return {se::BlockDim(/*x=*/tiling_.GetNumBlocks(),
Expand Down Expand Up @@ -298,15 +389,24 @@ ReductionInfo ReductionInfo::Create(const HloFusionAnalysis& analysis,
}
}

int vector_size = GetVectorSize(analysis, reduction_dimensions, num_threads_x,
reduction_tiling);
int vector_size = GetVectorSize(analysis, reduction_dimensions, num_threads_y,
num_threads_x, reduction_tiling, for_mlir);
bool tile_size_decreased = false;
if (!reduction_dimensions.is_row_reduction) {
// Adjust tile_y and vector size for column reduction.
std::tie(reduction_tiling, vector_size, tile_size_decreased) =
AdjustColReductionTilingConfig(analysis, shape, reduction_tiling,
num_threads_y, num_threads_x,
vector_size);
}

absl::InlinedVector<int64_t, 4> num_threads{1, num_threads_y, num_threads_x};
absl::InlinedVector<int64_t, 4> tiled_shape{shape[0], shape[1],
shape[2] / vector_size};
absl::InlinedVector<int64_t, 4> tile_per_thread{
reduction_tiling[0], reduction_tiling[1],
reduction_tiling[2] / vector_size};
reduction_dimensions.is_row_reduction ? reduction_tiling[2] / vector_size
: reduction_tiling[2]};
if (rows_per_warp > 1) {
// If we produce more than one element per thread, that means the reduced
// dimension is small and it can't be tiled - we already have more threads
Expand All @@ -325,6 +425,8 @@ ReductionInfo ReductionInfo::Create(const HloFusionAnalysis& analysis,
/*loops_to_unroll=*/{false, false, true, false});
bool reduction_is_race_free = ReductionIsRaceFree(
hero_reduction->GetModule()->config(), reduction_dimensions);
// If tile_y is decreased, reduction is not race free.
reduction_is_race_free = reduction_is_race_free && !tile_size_decreased;
return ReductionInfo(analysis, tiling, reduction_dimensions.is_row_reduction,
reduction_is_race_free,
GroupDisjointReductions(analysis, for_mlir),
Expand Down Expand Up @@ -388,13 +490,21 @@ std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToOutputIndexing(
physical_shape, ctx));
}

llvm::SmallVector<mlir::AffineExpr> exprs(
{block_offsets.getResult(kColMajorKept),
block_offsets.getResult(kColMinorKept) + thread_ids[kColReduced]});
std::vector<RangeVar> symbol_ranges;
int symbols_count = 0;
int elems_write_per_thread = ElemsWritePerThread();
if (elems_write_per_thread > 1) {
exprs.push_back(getAffineSymbolExpr(0, ctx));
symbol_ranges.push_back({0, elems_write_per_thread});
symbols_count = 1;
}
IndexingMap projected_index(
mlir::AffineMap::get(
6, 0,
{block_offsets.getResult(kColMajorKept),
block_offsets.getResult(kColMinorKept) + thread_ids[kColReduced]},
ctx),
dimension_ranges, /*range_vars=*/{}, /*rt_vars=*/{});
mlir::AffineMap::get(6, symbols_count, exprs, ctx), dimension_ranges,
/*range_vars=*/symbol_ranges,
/*rt_vars=*/{});

projected_index.AddConstraint(
mlir::getAffineDimExpr(
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/fusions/reduction_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class ReductionInfo {
bool IsRowReduction() const { return is_row_reduction_; }
bool IsRaceFree() const { return is_race_free_; }
int GetRowsPerWarp() const;
int ElemsWritePerThread() const;

std::optional<IndexingMap> ComputeThreadIdToOutputIndexing(
int64_t root_index, mlir::MLIRContext* ctx) const;
Expand Down
59 changes: 58 additions & 1 deletion xla/service/gpu/fusions/reduction_base_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ TEST_F(ReductionTest, ThreadIndexingColumnReduction) {
domain:
d0 in [0, 1023] d1 in [0, 0] d2 in [0, 0]
d3 in [0, 99] d4 in [0, 0] d5 in [0, 0]
s0 in [0, 0] s1 in [0, 127] s2 in [0, 0]
s0 in [0, 0] s1 in [0, 1] s2 in [0, 0]
d0 floordiv 32 + s1 * 32 in [0, 63]
d0 mod 32 in [0, 31]
)"));
Expand All @@ -254,6 +254,63 @@ TEST_F(ReductionTest, ThreadIndexingColumnReduction) {
)"));
}

TEST_F(ReductionTest, ThreadIndexingVectorizeColumnReduction) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule module

add {
p0 = f32[] parameter(0)
p1 = f32[] parameter(1)
ROOT add = f32[] add(p0, p1)
}

fusion {
%input = f32[256,64,64] parameter(0)
%c0 = f32[] constant(0)
ROOT reduce = f32[256,64] reduce(%input, %c0), dimensions={1}, to_apply=add
}

ENTRY entry {
%input = f32[256,64,64] parameter(0)
ROOT %fusion = f32[256,64] fusion(%input), kind=kInput, calls=fusion
})")
.value();

auto* root = module->entry_computation()->root_instruction();
auto analysis = AnalyzeFusion(*root, device_info_);
FakeMlirReductionFusion fusion(analysis);
mlir::MLIRContext mlir_context;

EXPECT_THAT(
fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(),
MatchIndexingString(R"(
(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> (
d3,
d0 floordiv 32 + s1 * 32,
(d0 mod 32) * 2 + s3
)
domain:
d0 in [0, 1023] d1 in [0, 0] d2 in [0, 0]
d3 in [0, 255] d4 in [0, 0] d5 in [0, 0]
s0 in [0, 0] s1 in [0, 1] s2 in [0, 0] s3 in [0, 1]
d0 floordiv 32 + s1 * 32 in [0, 63]
d0 mod 32 in [0, 31]
)"));
EXPECT_THAT(
fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(),
MatchIndexingString(R"(
(d0, d1, d2, d3, d4, d5)[s0] -> (
d3,
(d0 floordiv 32) * 2 + s0
)
domain:
d0 in [0, 1023] d1 in [0, 0] d2 in [0, 0]
d3 in [0, 255] d4 in [0, 0] d5 in [0, 0]
s0 in [0, 1]
d0 mod 32 in [0, 0]
)"));
}

TEST_F(ReductionTest, ThreadIndexingOutputLayout) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule module
Expand Down
Loading
Loading