Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ add_mlir_doc(X86Vector X86Vector Dialects/ -gen-dialect-doc -dialect=x86vector)

add_mlir_interface(X86VectorInterfaces)
add_dependencies(MLIRX86VectorIncGen MLIRX86VectorInterfacesIncGen)

add_subdirectory(TransformOps)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set(LLVM_TARGET_DEFINITIONS X86VectorTransformOps.td)
mlir_tablegen(X86VectorTransformOps.h.inc -gen-op-decls)
mlir_tablegen(X86VectorTransformOps.cpp.inc -gen-op-defs)
add_mlir_dialect_tablegen_target(MLIRX86VectorTransformOpsIncGen)
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===- X86VectorTransformOps.h - X86Vector transform ops --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
#define MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H

#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"

//===----------------------------------------------------------------------===//
// X86Vector Transform Operations
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h.inc"

namespace mlir {
class DialectRegistry;

namespace x86vector {
void registerTransformDialectExtension(DialectRegistry &registry);

} // namespace x86vector
} // namespace mlir

#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMOPS_X86VECTORTRANSFORMOPS_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//===- X86VectorTransformOps.td - X86Vector transform ops --*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef X86VECTOR_TRANSFORM_OPS
#define X86VECTOR_TRANSFORM_OPS

include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/IR/RegionKindInterface.td"

def ApplyVectorContractToFMAPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.vector_contract_to_fma",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collect patterns to lower a F32 type vector.contract operation to a FMA.
}];

let assemblyFormat = "attr-dict";
}

def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.vector_contract_to_packed_type_dot_product",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collect patterns to lower a BF16/Int8 type vector.contract operation
to a BF16/Int8 dot-product.
}];

let assemblyFormat = "attr-dict";
}


#endif // X86VECTOR_TRANSFORM_OPS

12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/X86Vector/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ struct MaskHelper {
}
};

//===----------------------------------------------------------------------===//

// A set of patterns for specialized lowering of vector contraction
// operation to vector fused multiply and add (FMA) operation.
void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);

// A set of patterns for lowering 32-bit packed vector contraction operations
// to their corresponding packed-type dot-product operations, ultimately
// targeting the relevant x86 LLVM intrinsics (e.g., BF16 and Int8).
void populateVectorContractToPackedTypeDotProductPatterns(
RewritePatternSet &patterns);

//===----------------------------------------------------------------------===//
/// Helpers extracted from:
/// - clang/lib/Headers/avxintrin.h
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/X86Vector/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(TransformOps)
17 changes: 17 additions & 0 deletions mlir/lib/Dialect/X86Vector/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
add_mlir_dialect_library(MLIRX86VectorTransformOps
X86VectorTransformOps.cpp

DEPENDS
MLIRX86VectorTransformOpsIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMCommonConversion
MLIRLLVMDialect
MLIRVectorDialect
MLIRSideEffectInterfaces
MLIRTransformDialect
MLIRTransformDialectUtils
MLIRX86VectorDialect
MLIRX86VectorTransforms
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//===- X86VectorTransformOps.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"

#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/RegionKindInterface.h"

using namespace mlir;
using namespace mlir::x86vector;
using namespace mlir::transform;

void mlir::transform::ApplyVectorContractToFMAPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
x86vector::populateVectorContractToFMAPatterns(patterns);
}

void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
populatePatterns(RewritePatternSet &patterns) {
x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//

namespace {
class X86VectorTransformDialectExtension
: public transform::TransformDialectExtension<
X86VectorTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
X86VectorTransformDialectExtension)

X86VectorTransformDialectExtension() {
declareGeneratedDialect<x86vector::X86VectorDialect>();
declareGeneratedDialect<LLVM::LLVMDialect>();
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"
>();
}
};
} // namespace

#define GET_OP_CLASSES
#include "mlir/Dialect/X86Vector/TransformOps/X86VectorTransformOps.cpp.inc"

void mlir::x86vector::registerTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<X86VectorTransformDialectExtension>();
}
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
add_mlir_dialect_library(MLIRX86VectorTransforms
AVXTranspose.cpp
LegalizeForLLVMExport.cpp
VectorContractToFMA.cpp
VectorContractToPackedTypeDotProduct.cpp

LINK_LIBS PUBLIC
MLIRArithDialect
Expand Down
143 changes: 143 additions & 0 deletions mlir/lib/Dialect/X86Vector/Transforms/VectorContractToFMA.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
//===- VectorContractToFMA.cpp --------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using namespace mlir::vector;
using namespace mlir::x86vector;

namespace {

// Implements outer product contraction as a sequence of broadcast and
// FMA operations.
//
// For example - for F32 type:
// ```
// vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32>
// ```
// to
// ```
// vector.broadcast %lhs to <16xf32>
// vector.fma vector<16xf32>
// ```
struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {

if (contractOp.getKind() != vector::CombiningKind::ADD)
return rewriter.notifyMatchFailure(contractOp,
"Expects add combining kind.");

VectorType lhsTy = contractOp.getLhsType();
if (!lhsTy.getElementType().isF32())
return rewriter.notifyMatchFailure(contractOp,
"Only F32 lowering is supported.");

ArrayRef<int64_t> lhsShape = lhsTy.getShape();
llvm::SmallVector<int64_t> nonUnitDimLhs;
llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
[](int64_t dim) { return dim != 1; });

VectorType rhsTy = contractOp.getRhsType();
ArrayRef<int64_t> rhsShape = rhsTy.getShape();
llvm::SmallVector<int64_t> nonUnitDimRhs;
llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
[](int64_t dim) { return dim != 1; });

if (nonUnitDimLhs.size() > 0 && nonUnitDimRhs.size() > 0)
return rewriter.notifyMatchFailure(
contractOp, "Excepts unit dimensions for either LHS or RHS shape.");

if (nonUnitDimLhs.size() != 1 && nonUnitDimRhs.size() != 1)
return rewriter.notifyMatchFailure(
contractOp,
"Excepts a one non-unit A/B dimension for either LHS or RHS shape.");

VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
if (!accTy)
return rewriter.notifyMatchFailure(contractOp,
"Accmulator is not a vector type");

if (!accTy.getElementType().isF32())
return rewriter.notifyMatchFailure(contractOp,
"Accmulator should be F32 type.");

ArrayRef<int64_t> accShape = accTy.getShape();
llvm::SmallVector<int64_t> nonUnitDimAcc;
llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
[](int64_t dim) { return dim != 1; });
if (nonUnitDimAcc.size() != 1)
return rewriter.notifyMatchFailure(
contractOp, "A or B dimension should be non-unit.");

// Lowers vector.contract into a broadcast+FMA sequence.
auto loc = contractOp.getLoc();
auto castAcc = vector::ShapeCastOp::create(
rewriter, loc,
VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
contractOp.getAcc());

vector::FMAOp fma;

// Broadcast the unit-dimension LHS or RHS to match the vector length of the
// corresponding non-unit dimension on the other operand. For example,
// if LHS has type vector<1x1xf32> and RHS has type vector<1x16xf32>, we
// broadcast the LHS to vector<1x16xf32>. In the opposite case (non-unit
// dimension on the LHS), we broadcast the RHS instead.
if (nonUnitDimRhs.size() > 0) {
auto castLhs = vector::ShapeCastOp::create(
rewriter, loc, VectorType::get(1, lhsTy.getElementType()),
contractOp.getLhs());
auto castRhs = vector::ShapeCastOp::create(
rewriter, loc,
VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()),
contractOp.getRhs());
auto broadcastLhs = vector::BroadcastOp::create(
rewriter, loc, castRhs.getResult().getType(), castLhs);
fma =
vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc);
} else {
auto castLhs = vector::ShapeCastOp::create(
rewriter, loc,
VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()),
contractOp.getLhs());
auto castRhs = vector::ShapeCastOp::create(
rewriter, loc, VectorType::get(1, rhsTy.getElementType()),
contractOp.getRhs());
auto broadcastRhs = vector::BroadcastOp::create(
rewriter, loc, castLhs.getResult().getType(), castRhs);
fma =
vector::FMAOp::create(rewriter, loc, castLhs, broadcastRhs, castAcc);
}

auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
rewriter.replaceOp(contractOp, castFma);

return success();
}
};

} // namespace

void x86vector::populateVectorContractToFMAPatterns(
RewritePatternSet &patterns) {
patterns.add<VectorContractToFMA>(patterns.getContext());
}
Loading