Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLIR emitters: Vectorize column reductions. #12941

Merged
1 commit merged into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,6 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
Expand All @@ -880,7 +879,9 @@ cc_library(
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:VectorDialect",
],
)

Expand Down
3 changes: 3 additions & 0 deletions xla/service/gpu/fusions/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ cc_library(
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:VectorDialect",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
Expand Down Expand Up @@ -214,6 +215,7 @@ cc_library(
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:ToLLVMIRTranslation",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorDialect",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
Expand Down Expand Up @@ -334,6 +336,7 @@ cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:VectorToLLVM",
"@llvm-project//mlir:VectorTransforms",
],
)
Expand Down
42 changes: 40 additions & 2 deletions xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@ limitations under the License.
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/Dialect/Vector/IR/VectorOps.h" // from @llvm-project
#include "mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/IRMapping.h" // from @llvm-project
#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
Expand Down Expand Up @@ -1297,10 +1299,25 @@ SmallVector<Value> EmitLoopNest(
mlir::function_ref<SmallVector<Value>(ValueRange /*iter_args*/,
ValueRange /*dim_values*/,
ValueRange /*symbol_values*/)>
create_body) {
create_body,
bool vectorize) {
SmallVector<Value, 4> lbs, ubs, steps;
GetLoopBoundsFromIndexingMap(b, indexing_map, &lbs, &ubs, &steps);

SmallVector<Value, 4> vector_inits;
if (vectorize) {
CHECK_EQ(indexing_map.GetSymbolBounds().back().lower, 0);
int vector_size = indexing_map.GetSymbolBounds().back().upper + 1;
vector_inits = iter_args_inits;
for (auto& init : vector_inits) {
if (!mlir::isa<mlir::ShapedType>(init.getType())) {
auto vector_ty = mlir::VectorType::get({vector_size}, init.getType());
init = b.create<mlir::vector::SplatOp>(vector_ty, init);
}
}
iter_args_inits = vector_inits;
}

scf::LoopNest loop_nest = scf::buildLoopNest(
b, b.getLoc(), lbs, ubs, steps, iter_args_inits,
[&](OpBuilder& nested_builder, Location loc, ValueRange symbol_values,
Expand All @@ -1313,7 +1330,28 @@ SmallVector<Value> EmitLoopNest(
[&](OpBuilder& then_builder, Location then_loc) -> void {
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(then_builder.getInsertionBlock());
auto results = create_body(iter_args, dim_values, symbol_values);
SmallVector<Value, 4> results;
if (vectorize) {
SmallVector<Value, 4> vector_args;
vector_args = iter_args;
// Extract the vector elements.
for (auto& init : vector_args) {
if (mlir::isa<mlir::VectorType>(init.getType())) {
init = b.create<mlir::vector::ExtractOp>(
init, symbol_values.back());
}
}
results = create_body(vector_args, dim_values, symbol_values);
// Insert the results.
for (auto [index, init] : llvm::enumerate(iter_args)) {
if (mlir::isa<mlir::VectorType>(init.getType())) {
results[index] = b.create<mlir::vector::InsertOp>(
results[index], iter_args[index], symbol_values.back());
}
}
} else {
results = create_body(iter_args, dim_values, symbol_values);
}
b.create<scf::YieldOp>(results);
},
[&](OpBuilder& else_b, Location else_loc) {
Expand Down
15 changes: 14 additions & 1 deletion xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,26 @@ mlir::Value CheckConstraints(const IndexingMap& map, mlir::ValueRange dims,

// Emits a loop nest over the entire domain of the indexing_map at a point
// `dim_values`.
// If `vectorize` is set, the loop essentially turns into multiple independent
// loops, and the results of all the loops are returned as a vector. The last
// symbol dimension is used as the vectorized dimension.
// If `vectorize` is set:
// - the body will still be called with scalars and should return scalars.
// - the loop for the last symbol in `indexing_map` will be vectorized
// - the symbol range should be [0, 2] or [0, 4] for vectorization to work.
// [0, 1] is supported and will have no effect. The lower bound must be 0.
// - all scalar results of `EmitLoopNest` will become vectors instead. Scalar
// inits will be initialized with a vector splat. Passing a vector init is
// supported.
// - Tensor arguments and results are unaffected.
llvm::SmallVector<mlir::Value> EmitLoopNest(
mlir::ImplicitLocOpBuilder& b, mlir::ValueRange dim_values,
mlir::ValueRange iter_args_inits, const IndexingMap& indexing_map,
mlir::function_ref<llvm::SmallVector<mlir::Value>(
mlir::ValueRange iter_args, mlir::ValueRange dim_values,
mlir::ValueRange symbol_values)>
create_body);
create_body,
bool vectorize = false);

// Same as EmitLoopNest, but the body building function can return an error
// which gets returned from EmitLoopNestWithStatus.
Expand Down
3 changes: 3 additions & 0 deletions xla/service/gpu/fusions/mlir/lower_tensors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,9 @@ struct RewriteNonScalarConstants
mlir::LogicalResult matchAndRewrite(
mlir::arith::ConstantOp op,
mlir::PatternRewriter& rewriter) const override {
if (mlir::isa<mlir::VectorType>(op.getType())) {
return rewriter.notifyMatchFailure(op, "the op is a vector constant");
}
auto shaped_ty = mlir::dyn_cast<mlir::ShapedType>(op.getValue().getType());
// We only need to rewrite non-scalar constants.
if (!shaped_ty || shaped_ty.getNumElements() < 2) {
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/fusions/mlir/lower_to_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/Conversion/LLVMCommon/TypeConverter.h" // from @llvm-project
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" // from @llvm-project
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project
#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/Dialect/Arith/Transforms/Passes.h" // from @llvm-project
#include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project
Expand Down Expand Up @@ -67,6 +68,7 @@ class LowerToLLVMPass : public impl::LowerToLLVMPassBase<LowerToLLVMPass> {
patterns);
mlir::populateGpuToNVVMConversionPatterns(type_converter, patterns);
mlir::populateFuncToLLVMConversionPatterns(type_converter, patterns);
mlir::populateVectorToLLVMConversionPatterns(type_converter, patterns);
mlir::cf::populateControlFlowToLLVMConversionPatterns(type_converter,
patterns);
mlir::populateComplexToLLVMConversionPatterns(type_converter, patterns);
Expand Down
9 changes: 6 additions & 3 deletions xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ limitations under the License.
#include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project
#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/Dialect/Vector/IR/VectorOps.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
Expand Down Expand Up @@ -361,7 +362,8 @@ MlirFusionEmitterBase::CreateMLIRModule(
mlir::arith::ArithDialect, mlir::cf::ControlFlowDialect,
mlir::math::MathDialect, mlir::scf::SCFDialect,
mlir::mhlo::MhloDialect, mlir::gpu::GPUDialect,
mlir::NVVM::NVVMDialect, xla::gpu::XlaGpuDialect>();
mlir::vector::VectorDialect, mlir::NVVM::NVVMDialect,
xla::gpu::XlaGpuDialect>();
mlir::DialectRegistry registry;
mlir::func::registerInlinerExtension(registry);
mlir::registerBuiltinDialectTranslation(registry);
Expand Down Expand Up @@ -456,9 +458,10 @@ SmallVector<Value> MlirFusionEmitterBase::EmitThreadLoopNest(
const IndexingMap& indexing_map,
const std::function<
SmallVector<Value>(ValueRange outputs_tensors, ValueRange dim_values,
ValueRange symbol_values)>& create_body) const {
ValueRange symbol_values)>& create_body,
bool vectorize) const {
return mlir_converter::EmitLoopNest(b, EmitThreadAndBlockIds(b), outputs,
indexing_map, create_body);
indexing_map, create_body, vectorize);
}

absl::Status MlirFusionEmitterBase::EmitMlir(
Expand Down
5 changes: 4 additions & 1 deletion xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,15 @@ class MlirFusionEmitterBase : public KernelFusionInterface {
// the symbol 0 as the outermost loop. The indices of the map's dimensions and
// symbols are passed to the lambda separately. The return values of the
// function are the updated outputs.
// For the meaning of `vectorize`, see the documentation of `EmitLoopNest` in
// elemental_hlo_to_mlir.h.
llvm::SmallVector<mlir::Value> EmitThreadLoopNest(
mlir::ImplicitLocOpBuilder& b, mlir::ValueRange outputs,
const IndexingMap& indexing_map,
const std::function<llvm::SmallVector<mlir::Value>(
mlir::ValueRange outputs, mlir::ValueRange dim_values,
mlir::ValueRange symbol_values)>& create_body) const;
mlir::ValueRange symbol_values)>& create_body,
bool vectorize = false) const;

mlir::Value EmitBlockId(mlir::ImplicitLocOpBuilder& builder, int dim) const;
mlir::Value EmitThreadId(mlir::ImplicitLocOpBuilder& builder, int dim) const;
Expand Down
13 changes: 13 additions & 0 deletions xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,19 @@ module {

// -----

module {
func.func @vector_constant() -> vector<2xindex> {
%c1 = arith.constant dense<[1, 2]> : vector<2xindex>
func.return %c1 : vector<2xindex>
}
}

// vector constants should not be rewritten.
// CHECK: @vector_constant
// CHECK-NEXT: arith.constant

// -----

module {
func.func @complex_tensor_insert(
%arg0: tensor<10xcomplex<f32>>) -> tensor<10xcomplex<f32>> {
Expand Down
86 changes: 56 additions & 30 deletions xla/service/gpu/fusions/reduction_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,37 +70,43 @@ int RowReductionGetRowsPerWarp(int reduced_dimension_size) {

int GetVectorSize(const HloFusionAnalysis& analysis,
const ReductionDimensions& reduction_dimensions,
int num_threads, Vector3 reduction_tiling) {
if (!reduction_dimensions.is_row_reduction) {
int num_threads, Vector3 reduction_tiling, bool for_mlir) {
// If the minor dimension is not divisible by 2, we can't currently vectorize.
int64_t minor_dim = reduction_dimensions.dimensions.back();
if (minor_dim % 2 != 0) {
return 1;
}

constexpr int kRowMinorReduced =
ReductionDimensions::kRowMinorReducedDimension;
if (reduction_dimensions.dimensions[kRowMinorReduced] % 2 != 0 ||
MayPreventVectorization(analysis.fusion())) {
// Only enable vectorization if all threads will still have work.
if (num_threads * 2 > minor_dim) {
return 1;
}

// Enabling vectorization if (number_threads * vector_size) is <=
// minor_reduced_dimension otherwise exist threads not doing any work.
if (num_threads * 2 > reduction_dimensions.dimensions[kRowMinorReduced]) {
if (MayPreventVectorization(analysis.fusion())) {
return 1;
}

const auto* cuda_cc = std::get_if<se::CudaComputeCapability>(
&analysis.device_info().gpu_compute_capability());
if (cuda_cc == nullptr) return 1;
if (cuda_cc->IsAtLeast(se::CudaComputeCapability::VOLTA)) return 2;
if (cuda_cc->IsAtLeast(se::CudaComputeCapability::PASCAL_)) {
return analysis.input_output_info().smallest_input_dtype_bits <= 32 &&
reduction_dimensions.dimensions[kRowMinorReduced] %
(reduction_tiling[kRowMinorReduced] * num_threads) ==
0
? 2
: 1;
if (reduction_dimensions.is_row_reduction) {
constexpr int kRowMinorReduced =
ReductionDimensions::kRowMinorReducedDimension;

const auto* cuda_cc = std::get_if<se::CudaComputeCapability>(
&analysis.device_info().gpu_compute_capability());
if (cuda_cc == nullptr) return 1;
if (cuda_cc->IsAtLeast(se::CudaComputeCapability::VOLTA)) return 2;
if (cuda_cc->IsAtLeast(se::CudaComputeCapability::PASCAL_)) {
return analysis.input_output_info().smallest_input_dtype_bits <= 32 &&
reduction_dimensions.dimensions[kRowMinorReduced] %
(reduction_tiling[kRowMinorReduced] *
num_threads) ==
0
? 2
: 1;
}
return 1;
}
return 1;

return for_mlir ? 2 : 1;
}

ReductionGroups GroupDisjointReductions(const HloFusionAnalysis& analysis,
Expand Down Expand Up @@ -301,14 +307,14 @@ ReductionInfo ReductionInfo::Create(const HloFusionAnalysis& analysis,
}

int vector_size = GetVectorSize(analysis, reduction_dimensions, num_threads_x,
reduction_tiling);
reduction_tiling, for_mlir);

absl::InlinedVector<int64_t, 4> num_threads{1, num_threads_y, num_threads_x};
absl::InlinedVector<int64_t, 4> tiled_shape{shape[0], shape[1],
shape[2] / vector_size};
absl::InlinedVector<int64_t, 4> tile_per_thread{
reduction_tiling[0], reduction_tiling[1],
reduction_tiling[2] / vector_size};
std::max<int64_t>(reduction_tiling[2] / vector_size, 1)};
if (for_mlir) {
// The indexing map simplifier does not currently handle this correctly,
// leading to loop bounds that are too large.
Expand All @@ -332,12 +338,27 @@ ReductionInfo ReductionInfo::Create(const HloFusionAnalysis& analysis,
// uses the thread ID as the coordinate.
tile_per_thread[2] = 1;
}
if (vector_size != 1) {
if (vector_size != 1 ||
(for_mlir && !reduction_dimensions.is_row_reduction)) {
num_threads.push_back(1); // The vector dimension is a loop.
tiled_shape.push_back(vector_size);
tile_per_thread.push_back(vector_size);
}

// The MLIR emitter treats the last tiled dimension as the number of parallel
// independent reductions per thread (to use vectorized loads). This is only
// needed for column reductions: row reductions can use vectorized loads for
// the same reduction.
// row reduction: [[a, b], [c, d]] -> [a + b, c + d]
// column reduction: [[a, b], [c, d]] -> [a + c, b + d]
// In both cases [a, b] are loaded together, but only in the column reduction
// they contribute to different result elements.
if (for_mlir && reduction_dimensions.is_row_reduction) {
num_threads.push_back(1);
tiled_shape.push_back(1);
tile_per_thread.push_back(1);
}

Tiling tiling(tiled_shape, tile_per_thread, num_threads,
/*loops_to_unroll=*/{false, false, true, false});
bool reduction_is_race_free = ReductionIsRaceFree(
Expand Down Expand Up @@ -404,13 +425,18 @@ std::optional<IndexingMap> ReductionInfo::ComputeThreadIdToOutputIndexing(
physical_shape, ctx));
}

mlir::SmallVector<mlir::AffineExpr> projected_dims{
block_offsets.getResult(kColMajorKept),
block_offsets.getResult(kColMinorKept) + thread_ids[kColReduced]};
std::vector<RangeVar> range_vars;
if (thread_ids.size() == 4) {
int vector_size = tiling_.GetThreadTileSize().back();
range_vars.push_back({0, vector_size - 1});
projected_dims.push_back(mlir::getAffineSymbolExpr(0, ctx));
}
IndexingMap projected_index(
mlir::AffineMap::get(
6, 0,
{block_offsets.getResult(kColMajorKept),
block_offsets.getResult(kColMinorKept) + thread_ids[kColReduced]},
ctx),
dimension_ranges, /*range_vars=*/{}, /*rt_vars=*/{});
mlir::AffineMap::get(6, range_vars.size(), projected_dims, ctx),
dimension_ranges, range_vars, /*rt_vars=*/{});

projected_index.AddConstraint(
mlir::getAffineDimExpr(
Expand Down
11 changes: 6 additions & 5 deletions xla/service/gpu/fusions/reduction_base_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -556,22 +556,23 @@ TEST_F(ReductionTest, MlirColumnReduction) {
EXPECT_THAT(
fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(),
MatchIndexingString(R"(
(d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> (
d3 floordiv 48,
(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> (
d3 floordiv 24,
d0 floordiv 32 + s1 * 32,
(d3 mod 48) * 32 + d0 mod 32
((d3 mod 24) * 32 + d0 mod 32) * 2 + s3
)
domain:
d0 in [0, 1023]
d1 in [0, 0]
d2 in [0, 0]
d3 in [0, 9215]
d3 in [0, 4607]
d4 in [0, 0]
d5 in [0, 0]
s0 in [0, 0]
s1 in [0, 1]
s2 in [0, 0]
(d3 mod 48) * 32 + d0 mod 32 in [0, 1535]
s3 in [0, 1]
(d3 mod 24) * 32 + d0 mod 32 in [0, 767]
d0 floordiv 32 + s1 * 32 in [0, 63]
)"));
}
Expand Down
Loading
Loading