-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR] Add sincos op to math dialect #160772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Add sincos op to math dialect #160772
Conversation
Now that `sincos` is a supported intrinsic in the LLVM dialect (llvm#160561) we are able to add the corresponding operation in the math dialect. We have several benchmarks that use sine and cosine in hot-loops, and saving some calculations by performing sine and cosine together can benefit performance. We would like to have a way to represent sincos in the math dialect. Two parts I'm unsure about: * What do we think of the assembly format? `math.sincos %floatlike : f32 -> f32, f32` With a custom assembly format we could omit the `->` and everything after, but I couldn't get the ODS to do that. Open to suggestions. * I implement `getShapeForUnroll()` here, but where is the best place to test the unroller interfaces? I'll keep poking around after sending this out for review.
…mlir-math-sincos
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Asher Mancinelli (ashermancinelli) ChangesNow that We have several benchmarks that use sine and cosine in hot-loops, and saving some calculations by performing them together can benefit performance. We would like to have a way to represent sincos in the math dialect. Parts I'm unsure about:
I will add more reviewers once you both think this looks okay. Thanks in advance! Full diff: https://github.com/llvm/llvm-project/pull/160772.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index cfd8c4b8f11f7..a7e79f2efd4c5 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -510,6 +510,44 @@ def Math_SinhOp : Math_FloatUnaryOp<"sinh"> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// SinCosOp
+//===----------------------------------------------------------------------===//
+
+def Math_SincosOp : Math_Op<"sincos",
+ [SameOperandsAndResultShape,
+ DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+ let summary = "sine and cosine of the specified value";
+ let description = [{
+ The `sincos` operation computes both the sine and cosine of a given value
+ simultaneously. It takes one operand of floating point type (i.e., scalar,
+ tensor or vector) and returns two results of the same type. This operation
+ can be more efficient than computing sine and cosine separately when both
+ values are needed.
+
+ Example:
+
+ ```mlir
+ // Scalar sine and cosine values.
+ %sin, %cos = math.sincos %input : f64 `->` f64, f64
+ ```
+ }];
+
+ let arguments = (ins FloatLike:$operand,
+ DefaultValuedAttr<Arith_FastMathAttr,
+ "::mlir::arith::FastMathFlags::none">:$fastmath);
+ let results = (outs FloatLike:$sin, FloatLike:$cos);
+
+ let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)?
+ attr-dict `:` type($operand) `->` type($sin) `,` type($cos) }];
+
+ let extraClassDeclaration = [{
+ std::optional<SmallVector<int64_t, 4>> getShapeForUnroll();
+ }];
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// CountLeadingZerosOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index a95263bb55f69..2b46a01c3b0e5 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -436,7 +436,7 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
- LLVM::SqrtOp>();
+ LLVM::SincosOp, LLVM::SqrtOp>();
// TODO: Remove once we support replacing non-root ops.
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
@@ -466,6 +466,101 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
});
}
+struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
+ using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value input = adaptor.getOperand();
+ Type inputType = input.getType();
+ auto convertedInput = maybeExt(input, rewriter);
+ auto computeType = convertedInput.getType();
+
+ StringRef sincosFunc;
+ if (isa<Float32Type>(computeType)) {
+ const arith::FastMathFlags flag = op.getFastmath();
+ const bool useApprox =
+ ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag);
+ sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
+ } else if (isa<Float64Type>(computeType)) {
+ sincosFunc = "__nv_sincos";
+ } else {
+ return rewriter.notifyMatchFailure(op,
+ "unsupported operand type for sincos");
+ }
+
+ auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+
+ Value sinPtr, cosPtr;
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto *scope =
+ op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
+ assert(scope && "Expected op to be inside automatic allocation scope");
+ rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
+ auto one = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
+ sinPtr =
+ rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
+ cosPtr =
+ rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
+ }
+
+ createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
+ op);
+
+ auto sinResult = rewriter.create<LLVM::LoadOp>(loc, computeType, sinPtr);
+ auto cosResult = rewriter.create<LLVM::LoadOp>(loc, computeType, cosPtr);
+
+ rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
+ maybeTrunc(cosResult, inputType, rewriter)});
+ return success();
+ }
+
+private:
+ Value maybeExt(Value operand, PatternRewriter &rewriter) const {
+ if (isa<Float16Type, BFloat16Type>(operand.getType())) {
+ return rewriter.create<LLVM::FPExtOp>(
+ operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
+ }
+ return operand;
+ }
+
+ Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
+ if (operand.getType() != type)
+ return rewriter.create<LLVM::FPTruncOp>(operand.getLoc(), type, operand);
+ return operand;
+ }
+
+ void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
+ StringRef funcName, Value input, Value sinPtr,
+ Value cosPtr, Operation *op) const {
+ auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
+ auto ptrType = sinPtr.getType();
+
+ SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
+ auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
+
+ auto funcAttr = StringAttr::get(op->getContext(), funcName);
+ auto funcOp =
+ SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
+
+ if (!funcOp) {
+ auto parentFunc = op->getParentOfType<FunctionOpInterface>();
+ assert(parentFunc && "expected there to be a parent function");
+ OpBuilder b(parentFunc);
+
+ auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
+ funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
+ }
+
+ SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
+ rewriter.create<LLVM::CallOp>(loc, funcOp, callOperands);
+ }
+};
+
template <typename OpTy>
static void populateOpPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
@@ -589,6 +684,9 @@ void mlir::populateLibDeviceConversionPatterns(
"__nv_tan", "__nv_fast_tanf");
populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
"__nv_tanh");
+
+ // Custom pattern for sincos since it returns two values
+ patterns.add<SincosOpLowering>(converter, benefit);
}
void mlir::populateGpuToNVVMConversionPatterns(
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 853f45498ac52..73a003ef4e6c1 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -121,6 +121,38 @@ using CountTrailingZerosOpLowering =
LLVM::CountTrailingZerosOp>;
using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
+// A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops.
+struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
+ using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ const auto &typeConverter = *this->getTypeConverter();
+ auto loc = op.getLoc();
+ auto operandType = adaptor.getOperand().getType();
+ auto llvmOperandType = typeConverter.convertType(operandType);
+ auto sinType = typeConverter.convertType(op.getSin().getType());
+ auto cosType = typeConverter.convertType(op.getCos().getType());
+ if (!llvmOperandType || !sinType || !cosType)
+ return failure();
+
+ ConvertFastMath<math::SincosOp, LLVM::SincosOp> attrs(op);
+
+ auto structType = LLVM::LLVMStructType::getLiteral(
+ rewriter.getContext(), {llvmOperandType, llvmOperandType});
+
+ auto sincosOp = rewriter.create<LLVM::SincosOp>(
+ loc, structType, adaptor.getOperand(), attrs.getAttrs());
+
+ auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
+ auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
+
+ rewriter.replaceOp(op, {sinValue, cosValue});
+ return success();
+ }
+};
+
// A `expm1` is converted into `exp - 1`.
struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
@@ -393,6 +425,7 @@ void mlir::populateMathToLLVMConversionPatterns(
RoundEvenOpLowering,
RoundOpLowering,
RsqrtOpLowering,
+ SincosOpLowering,
SinOpLowering,
SinhOpLowering,
ASinOpLowering,
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index a21631cbf8510..f0bf62770d4cc 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -284,6 +284,28 @@ OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
});
}
+//===----------------------------------------------------------------------===//
+// SinCosOp verifier and getShapeForUnroll
+//===----------------------------------------------------------------------===//
+
+LogicalResult math::SincosOp::verify() {
+ Type operandType = getOperand().getType();
+ Type sinType = getSin().getType();
+ Type cosType = getCos().getType();
+
+ if (operandType != sinType || operandType != cosType) {
+ return emitOpError("result types must match operand type");
+ }
+
+ return success();
+}
+
+std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
+ if (auto vt = mlir::dyn_cast_or_null<VectorType>(getOperand().getType()))
+ return llvm::to_vector<4>(vt.getShape());
+ return std::nullopt;
+}
+
//===----------------------------------------------------------------------===//
// CountLeadingZerosOp folder
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index ef06af3ad3163..cdefc4d6098c7 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1109,3 +1109,42 @@ gpu.module @test_module_55 {
func.return %result32, %result64 : f32, f64
}
}
+
+gpu.module @test_module_56 {
+ // CHECK: gpu.module @test_module_56
+
+ // CHECK-DAG: llvm.func @__nv_sincosf(f32, !llvm.ptr, !llvm.ptr)
+ // CHECK-DAG: llvm.func @__nv_sincos(f64, !llvm.ptr, !llvm.ptr)
+
+ // CHECK-LABEL: func @gpu_sincos
+ // CHECK-SAME: %[[ARG_f16:.*]]: f16, %[[ARG_f32:.*]]: f32, %[[ARG_f64:.*]]: f64
+ func.func @gpu_sincos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f16, f32, f32, f64, f64) {
+ // CHECK-COUNT-6: llvm.alloca
+ // CHECK: %[[ARG_f16_ext:.*]] = llvm.fpext %[[ARG_f16]] : f16 to f32
+ // CHECK: llvm.call @__nv_sincosf(%[[ARG_f16_ext]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
+ // CHECK-COUNT-2: llvm.fptrunc
+ // CHECK: llvm.call @__nv_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
+ // CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> ()
+ %sin16, %cos16 = math.sincos %arg_f16 : f16 -> f16, f16
+ %sin32, %cos32 = math.sincos %arg_f32 : f32 -> f32, f32
+ %sin64, %cos64 = math.sincos %arg_f64 : f64 -> f64, f64
+ func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64
+ }
+
+ // CHECK: llvm.func @__nv_fast_sincosf(f32, !llvm.ptr, !llvm.ptr)
+
+ // CHECK-LABEL: func @gpu_sincos_fastmath
+ // CHECK-SAME: %[[ARG_f16:.*]]: f16, %[[ARG_f32:.*]]: f32, %[[ARG_f64:.*]]: f64
+ func.func @gpu_sincos_fastmath(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f16, f32, f32, f64, f64) {
+ // CHECK-COUNT-6: llvm.alloca
+ // CHECK: %[[ARG_f16_ext:.*]] = llvm.fpext %[[ARG_f16]] : f16 to f32
+ // CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f16_ext]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
+ // CHECK-COUNT-2: llvm.fptrunc
+ // CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
+ // CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> ()
+ %sin16, %cos16 = math.sincos %arg_f16 fastmath<afn> : f16 -> f16, f16
+ %sin32, %cos32 = math.sincos %arg_f32 fastmath<afn> : f32 -> f32, f32
+ %sin64, %cos64 = math.sincos %arg_f64 fastmath<afn> : f64 -> f64, f64
+ func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64
+ }
+}
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index f4541220fe4d2..9030ba9c93e55 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -230,6 +230,16 @@ func.func @trigonometrics(%arg0: f32) {
// -----
+// CHECK-LABEL: func @sincos
+// CHECK-SAME: [[ARG0:%.+]]: f32
+func.func @sincos(%arg0: f32) {
+ // CHECK: llvm.intr.sincos([[ARG0]]) : (f32) -> !llvm.struct<(f32, f32)>
+ %0:2 = math.sincos %arg0 : f32 -> f32, f32
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @inverse_trigonometrics
// CHECK-SAME: [[ARG0:%.+]]: f32
func.func @inverse_trigonometrics(%arg0: f32) {
diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index cb10fc4397ffc..5d3a8a6d87bed 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -62,6 +62,18 @@ func.func @sin(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
return
}
+// CHECK-LABEL: func @sincos(
+// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
+func.func @sincos(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
+ // CHECK: %{{.*}} = math.sincos %[[F]] : f32
+ %0:2 = math.sincos %f : f32 -> f32, f32
+ // CHECK: %{{.*}} = math.sincos %[[V]] : vector<4xf32>
+ %1:2 = math.sincos %v : vector<4xf32> -> vector<4xf32>, vector<4xf32>
+ // CHECK: %{{.*}} = math.sincos %[[T]] : tensor<4x4x?xf32>
+ %2:2 = math.sincos %t : tensor<4x4x?xf32> -> tensor<4x4x?xf32>, tensor<4x4x?xf32>
+ return
+}
+
// CHECK-LABEL: func @erf(
// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
func.func @erf(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
|
I also have some computations where computing sincos together is faster, so +1 on adding this to math dialect. A common example of this is RoPE Embeddings: https://arxiv.org/abs/2104.09864 |
Overall it looks good to me. I think the assembly format looks ok. |
I think you can do |
That works, thank you for the suggestion! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, Asher! LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't know much about NVVM, everything else LGTM.
We see performance improvements from using sincos to reuse calculations in hot loops that compute sin() and cos() on the same operand. Add a pass to identify sin() and cos() calls in the same block with the same operand and fast-math flags, and fuse them into a sincos op. Follow-up to: * llvm#160561 * llvm#160772
We see performance improvements from using sincos to reuse calculations in hot loops that compute sin() and cos() on the same operand. Add a pass to identify sin() and cos() calls in the same block with the same operand and fast-math flags, and fuse them into a sincos op. Follow-up to: * llvm#160561 * llvm#160772
We see performance improvements from using sincos to reuse calculations in hot loops that compute sin() and cos() of the same operand. Add a pass to identify sin() and cos() calls in the same block with the same operand and fast-math flags, and fuse them into a sincos op. Follow-up to: * #160561 * #160772
Now that `sincos` is a supported intrinsic in the LLVM dialect (llvm#160561) we are able to add the corresponding operation in the math dialect and add conversion patterns for LLVM and NVVM. We have several benchmarks that use sine and cosine in hot-loops, and saving some calculations by performing them together can benefit performance. We would like to have a way to represent sincos in the math dialect.
We see performance improvements from using sincos to reuse calculations in hot loops that compute sin() and cos() of the same operand. Add a pass to identify sin() and cos() calls in the same block with the same operand and fast-math flags, and fuse them into a sincos op. Follow-up to: * llvm#160561 * llvm#160772
Now that
sincos
is a supported intrinsic in the LLVM dialect (#160561) we are able to add the corresponding operation in the math dialect and add conversion patters for LLVM and NVVM.We have several benchmarks that use sine and cosine in hot-loops, and saving some calculations by performing them together can benefit performance. We would like to have a way to represent sincos in the math dialect.
Parts I'm unsure about:
allocBuffers()
inVectorToSCF.cpp
by assuming the op exists in an AutomaticAllocationScope, and I don't check that the region's first block exists. We could assert more eagerly, but I'm not sure it matters.math.sincos %f : f32 -> f32, f32
? I know we could omit the->
and everything after with a custom assembly format, but I couldn't get the ODS to do that. Open to suggestions.getShapeForUnroll()
here, but where is the best place to test the unroller interfaces? I'll keep poking around after sending this out for review.I will add more reviewers once you both think this looks okay. Thanks in advance!