From 3fcbaee738982ab6df9df883dbf0fdc63302862e Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Thu, 25 Sep 2025 08:48:11 -0700 Subject: [PATCH 1/7] [MLIR] Add sincos operation to math dialect Now that `sincos` is a supported intrinsic in the LLVM dialect (https://github.com/llvm/llvm-project/pull/160561) we are able to add the corresponding operation in the math dialect. We have several benchmarks that use sine and cosine in hot-loops, and saving some calculations by performing sine and cosine together can benefit performance. We would like to have a way to represent sincos in the math dialect. Two parts I'm unsure about: * What do we think of the assembly format? `math.sincos %floatlike : f32 -> f32, f32` With a custom assembly format we could omit the `->` and everything after, but I couldn't get the ODS to do that. Open to suggestions. * I implement `getShapeForUnroll()` here, but where is the best place to test the unroller interfaces? I'll keep poking around after sending this out for review. --- mlir/include/mlir/Dialect/Math/IR/MathOps.td | 38 ++++++++ .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 93 ++++++++++++++++++- mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 33 +++++++ mlir/lib/Dialect/Math/IR/MathOps.cpp | 22 +++++ .../Conversion/GPUToNVVM/gpu-to-nvvm.mlir | 39 ++++++++ .../Conversion/MathToLLVM/math-to-llvm.mlir | 10 ++ mlir/test/Dialect/Math/ops.mlir | 12 +++ 7 files changed, 246 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td index cfd8c4b8f11f7..a7e79f2efd4c5 100644 --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -510,6 +510,44 @@ def Math_SinhOp : Math_FloatUnaryOp<"sinh"> { let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// SinCosOp +//===----------------------------------------------------------------------===// + +def Math_SincosOp : Math_Op<"sincos", + [SameOperandsAndResultShape, + DeclareOpInterfaceMethods]> { + let summary = "sine and cosine of the specified value"; + let description = [{ + The `sincos` operation computes both the sine and cosine of a given value + simultaneously. It takes one operand of floating point type (i.e., scalar, + tensor or vector) and returns two results of the same type. This operation + can be more efficient than computing sine and cosine separately when both + values are needed. + + Example: + + ```mlir + // Scalar sine and cosine values. + %sin, %cos = math.sincos %input : f64 `->` f64, f64 + ``` + }]; + + let arguments = (ins FloatLike:$operand, + DefaultValuedAttr:$fastmath); + let results = (outs FloatLike:$sin, FloatLike:$cos); + + let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)? + attr-dict `:` type($operand) `->` type($sin) `,` type($cos) }]; + + let extraClassDeclaration = [{ + std::optional> getShapeForUnroll(); + }]; + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // CountLeadingZerosOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index a95263bb55f69..16d765f2b2561 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -436,7 +436,7 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) { LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp, - LLVM::SqrtOp>(); + LLVM::SincosOp, LLVM::SqrtOp>(); // TODO: Remove once we support replacing non-root ops. target.addLegalOp(); @@ -466,6 +466,94 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) { }); } +// Custom lowering for math.sincos to __nv_sincosf/__nv_sincos libdevice calls +struct SincosOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = adaptor.getOperand(); + Type inputType = input.getType(); + auto convertedInput = maybeExt(input, rewriter); + auto computeType = convertedInput.getType(); + + StringRef sincosFunc; + if (isa(computeType)) { + const arith::FastMathFlags flag = op.getFastmath(); + const bool useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag); + sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf"; + } else if (isa(computeType)) { + sincosFunc = "__nv_sincos"; + } else { + return rewriter.notifyMatchFailure(op, "unsupported operand type for sincos"); + } + + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); + + Value sinPtr, cosPtr; + { + OpBuilder::InsertionGuard guard(rewriter); + auto *scope = op->getParentWithTrait(); + assert(scope && "Expected op to be inside automatic allocation scope"); + rewriter.setInsertionPointToStart(&scope->getRegion(0).front()); + auto one = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1)); + sinPtr = rewriter.create(loc, ptrType, computeType, one, 0); + cosPtr = rewriter.create(loc, ptrType, computeType, one, 0); + } + + createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr, op); + + auto sinResult = rewriter.create(loc, computeType, sinPtr); + auto cosResult = rewriter.create(loc, computeType, cosPtr); + + rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter), + maybeTrunc(cosResult, inputType, rewriter)}); + return success(); + } + +private: + Value maybeExt(Value operand, PatternRewriter &rewriter) const { + if (isa(operand.getType())) { + return rewriter.create(operand.getLoc(), Float32Type::get(rewriter.getContext()), operand); + } + return operand; + } + + Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const { + if (operand.getType() != type) + return rewriter.create(operand.getLoc(), type, operand); + return operand; + } + + void createSincosCall(ConversionPatternRewriter &rewriter, Location loc, + StringRef funcName, Value input, Value sinPtr, Value cosPtr, + Operation *op) const { + auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext()); + auto ptrType = sinPtr.getType(); + + SmallVector operandTypes = {input.getType(), ptrType, ptrType}; + auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes); + + auto funcAttr = StringAttr::get(op->getContext(), funcName); + auto funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); + + if (!funcOp) { + auto parentFunc = op->getParentOfType(); + assert(parentFunc && "expected there to be a parent function"); + OpBuilder b(parentFunc); + + auto globalloc = loc->findInstanceOfOrUnknown(); + funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType); + } + + SmallVector callOperands = {input, sinPtr, cosPtr}; + rewriter.create(loc, funcOp, callOperands); + } +}; + template static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, @@ -589,6 +677,9 @@ void mlir::populateLibDeviceConversionPatterns( "__nv_tan", "__nv_fast_tanf"); populateOpPatterns(converter, patterns, benefit, "__nv_tanhf", "__nv_tanh"); + + // Custom pattern for sincos since it returns two values + patterns.add(converter, benefit); } void mlir::populateGpuToNVVMConversionPatterns( diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index 853f45498ac52..73a003ef4e6c1 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -121,6 +121,38 @@ using CountTrailingZerosOpLowering = LLVM::CountTrailingZerosOp>; using AbsIOpLowering = IntOpWithFlagLowering; +// A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops. +struct SincosOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const auto &typeConverter = *this->getTypeConverter(); + auto loc = op.getLoc(); + auto operandType = adaptor.getOperand().getType(); + auto llvmOperandType = typeConverter.convertType(operandType); + auto sinType = typeConverter.convertType(op.getSin().getType()); + auto cosType = typeConverter.convertType(op.getCos().getType()); + if (!llvmOperandType || !sinType || !cosType) + return failure(); + + ConvertFastMath attrs(op); + + auto structType = LLVM::LLVMStructType::getLiteral( + rewriter.getContext(), {llvmOperandType, llvmOperandType}); + + auto sincosOp = rewriter.create( + loc, structType, adaptor.getOperand(), attrs.getAttrs()); + + auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0); + auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1); + + rewriter.replaceOp(op, {sinValue, cosValue}); + return success(); + } +}; + // A `expm1` is converted into `exp - 1`. struct ExpM1OpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -393,6 +425,7 @@ void mlir::populateMathToLLVMConversionPatterns( RoundEvenOpLowering, RoundOpLowering, RsqrtOpLowering, + SincosOpLowering, SinOpLowering, SinhOpLowering, ASinOpLowering, diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp index a21631cbf8510..f0bf62770d4cc 100644 --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -284,6 +284,28 @@ OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) { }); } +//===----------------------------------------------------------------------===// +// SinCosOp verifier and getShapeForUnroll +//===----------------------------------------------------------------------===// + +LogicalResult math::SincosOp::verify() { + Type operandType = getOperand().getType(); + Type sinType = getSin().getType(); + Type cosType = getCos().getType(); + + if (operandType != sinType || operandType != cosType) { + return emitOpError("result types must match operand type"); + } + + return success(); +} + +std::optional> math::SincosOp::getShapeForUnroll() { + if (auto vt = mlir::dyn_cast_or_null(getOperand().getType())) + return llvm::to_vector<4>(vt.getShape()); + return std::nullopt; +} + //===----------------------------------------------------------------------===// // CountLeadingZerosOp folder //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir index ef06af3ad3163..cdefc4d6098c7 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -1109,3 +1109,42 @@ gpu.module @test_module_55 { func.return %result32, %result64 : f32, f64 } } + +gpu.module @test_module_56 { + // CHECK: gpu.module @test_module_56 + + // CHECK-DAG: llvm.func @__nv_sincosf(f32, !llvm.ptr, !llvm.ptr) + // CHECK-DAG: llvm.func @__nv_sincos(f64, !llvm.ptr, !llvm.ptr) + + // CHECK-LABEL: func @gpu_sincos + // CHECK-SAME: %[[ARG_f16:.*]]: f16, %[[ARG_f32:.*]]: f32, %[[ARG_f64:.*]]: f64 + func.func @gpu_sincos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f16, f32, f32, f64, f64) { + // CHECK-COUNT-6: llvm.alloca + // CHECK: %[[ARG_f16_ext:.*]] = llvm.fpext %[[ARG_f16]] : f16 to f32 + // CHECK: llvm.call @__nv_sincosf(%[[ARG_f16_ext]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> () + // CHECK-COUNT-2: llvm.fptrunc + // CHECK: llvm.call @__nv_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> () + // CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> () + %sin16, %cos16 = math.sincos %arg_f16 : f16 -> f16, f16 + %sin32, %cos32 = math.sincos %arg_f32 : f32 -> f32, f32 + %sin64, %cos64 = math.sincos %arg_f64 : f64 -> f64, f64 + func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64 + } + + // CHECK: llvm.func @__nv_fast_sincosf(f32, !llvm.ptr, !llvm.ptr) + + // CHECK-LABEL: func @gpu_sincos_fastmath + // CHECK-SAME: %[[ARG_f16:.*]]: f16, %[[ARG_f32:.*]]: f32, %[[ARG_f64:.*]]: f64 + func.func @gpu_sincos_fastmath(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f16, f32, f32, f64, f64) { + // CHECK-COUNT-6: llvm.alloca + // CHECK: %[[ARG_f16_ext:.*]] = llvm.fpext %[[ARG_f16]] : f16 to f32 + // CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f16_ext]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> () + // CHECK-COUNT-2: llvm.fptrunc + // CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> () + // CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> () + %sin16, %cos16 = math.sincos %arg_f16 fastmath : f16 -> f16, f16 + %sin32, %cos32 = math.sincos %arg_f32 fastmath : f32 -> f32, f32 + %sin64, %cos64 = math.sincos %arg_f64 fastmath : f64 -> f64, f64 + func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64 + } +} diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir index f4541220fe4d2..9030ba9c93e55 100644 --- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir +++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir @@ -230,6 +230,16 @@ func.func @trigonometrics(%arg0: f32) { // ----- +// CHECK-LABEL: func @sincos +// CHECK-SAME: [[ARG0:%.+]]: f32 +func.func @sincos(%arg0: f32) { + // CHECK: llvm.intr.sincos([[ARG0]]) : (f32) -> !llvm.struct<(f32, f32)> + %0:2 = math.sincos %arg0 : f32 -> f32, f32 + func.return +} + +// ----- + // CHECK-LABEL: func @inverse_trigonometrics // CHECK-SAME: [[ARG0:%.+]]: f32 func.func @inverse_trigonometrics(%arg0: f32) { diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir index cb10fc4397ffc..5d3a8a6d87bed 100644 --- a/mlir/test/Dialect/Math/ops.mlir +++ b/mlir/test/Dialect/Math/ops.mlir @@ -62,6 +62,18 @@ func.func @sin(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { return } +// CHECK-LABEL: func @sincos( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func.func @sincos(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.sincos %[[F]] : f32 + %0:2 = math.sincos %f : f32 -> f32, f32 + // CHECK: %{{.*}} = math.sincos %[[V]] : vector<4xf32> + %1:2 = math.sincos %v : vector<4xf32> -> vector<4xf32>, vector<4xf32> + // CHECK: %{{.*}} = math.sincos %[[T]] : tensor<4x4x?xf32> + %2:2 = math.sincos %t : tensor<4x4x?xf32> -> tensor<4x4x?xf32>, tensor<4x4x?xf32> + return +} + // CHECK-LABEL: func @erf( // CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) func.func @erf(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { From 6e2b34c11fd688312bcab520d4e74b5e12e10ae9 Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Thu, 25 Sep 2025 13:24:24 -0700 Subject: [PATCH 2/7] Formatting --- .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 40 +++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 16d765f2b2561..2c0a3305518e1 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -482,29 +482,35 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern { StringRef sincosFunc; if (isa(computeType)) { const arith::FastMathFlags flag = op.getFastmath(); - const bool useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag); + const bool useApprox = + ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag); sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf"; } else if (isa(computeType)) { sincosFunc = "__nv_sincos"; } else { - return rewriter.notifyMatchFailure(op, "unsupported operand type for sincos"); + return rewriter.notifyMatchFailure(op, + "unsupported operand type for sincos"); } auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - + Value sinPtr, cosPtr; { OpBuilder::InsertionGuard guard(rewriter); - auto *scope = op->getParentWithTrait(); + auto *scope = + op->getParentWithTrait(); assert(scope && "Expected op to be inside automatic allocation scope"); rewriter.setInsertionPointToStart(&scope->getRegion(0).front()); auto one = rewriter.create( loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1)); - sinPtr = rewriter.create(loc, ptrType, computeType, one, 0); - cosPtr = rewriter.create(loc, ptrType, computeType, one, 0); + sinPtr = + rewriter.create(loc, ptrType, computeType, one, 0); + cosPtr = + rewriter.create(loc, ptrType, computeType, one, 0); } - createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr, op); + createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr, + op); auto sinResult = rewriter.create(loc, computeType, sinPtr); auto cosResult = rewriter.create(loc, computeType, cosPtr); @@ -517,7 +523,8 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern { private: Value maybeExt(Value operand, PatternRewriter &rewriter) const { if (isa(operand.getType())) { - return rewriter.create(operand.getLoc(), Float32Type::get(rewriter.getContext()), operand); + return rewriter.create( + operand.getLoc(), Float32Type::get(rewriter.getContext()), operand); } return operand; } @@ -529,26 +536,27 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern { } void createSincosCall(ConversionPatternRewriter &rewriter, Location loc, - StringRef funcName, Value input, Value sinPtr, Value cosPtr, - Operation *op) const { + StringRef funcName, Value input, Value sinPtr, + Value cosPtr, Operation *op) const { auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext()); auto ptrType = sinPtr.getType(); - + SmallVector operandTypes = {input.getType(), ptrType, ptrType}; auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes); - + auto funcAttr = StringAttr::get(op->getContext(), funcName); - auto funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); - + auto funcOp = + SymbolTable::lookupNearestSymbolFrom(op, funcAttr); + if (!funcOp) { auto parentFunc = op->getParentOfType(); assert(parentFunc && "expected there to be a parent function"); OpBuilder b(parentFunc); - + auto globalloc = loc->findInstanceOfOrUnknown(); funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType); } - + SmallVector callOperands = {input, sinPtr, cosPtr}; rewriter.create(loc, funcOp, callOperands); } From dfea012c6c10386620c341eba82a690af926e969 Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Thu, 25 Sep 2025 13:26:32 -0700 Subject: [PATCH 3/7] Remove needless comment --- mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 2c0a3305518e1..2b46a01c3b0e5 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -466,7 +466,6 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) { }); } -// Custom lowering for math.sincos to __nv_sincosf/__nv_sincos libdevice calls struct SincosOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; From 1a24ecca6e955b70c09f11bccc4ee1bc6b41a1fc Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Thu, 25 Sep 2025 13:43:30 -0700 Subject: [PATCH 4/7] Remove braces on single-line if --- mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 3 +-- mlir/lib/Dialect/Math/IR/MathOps.cpp | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 2b46a01c3b0e5..f8f2104d2bd6a 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -521,10 +521,9 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern { private: Value maybeExt(Value operand, PatternRewriter &rewriter) const { - if (isa(operand.getType())) { + if (isa(operand.getType())) return rewriter.create( operand.getLoc(), Float32Type::get(rewriter.getContext()), operand); - } return operand; } diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp index f0bf62770d4cc..0de5636c27c3f 100644 --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -293,9 +293,8 @@ LogicalResult math::SincosOp::verify() { Type sinType = getSin().getType(); Type cosType = getCos().getType(); - if (operandType != sinType || operandType != cosType) { + if (operandType != sinType || operandType != cosType) return emitOpError("result types must match operand type"); - } return success(); } From 145610739e99193970accde6f4e9596eb7fe4f3b Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Thu, 25 Sep 2025 14:18:40 -0700 Subject: [PATCH 5/7] Refine assembly format --- mlir/include/mlir/Dialect/Math/IR/MathOps.td | 7 ++++--- mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir | 12 ++++++------ mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir | 2 +- mlir/test/Dialect/Math/ops.mlir | 6 +++--- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td index a7e79f2efd4c5..b4212056694e9 100644 --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -516,7 +516,8 @@ def Math_SinhOp : Math_FloatUnaryOp<"sinh"> { def Math_SincosOp : Math_Op<"sincos", [SameOperandsAndResultShape, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + AllTypesMatch<["operand", "sin", "cos"]>]> { let summary = "sine and cosine of the specified value"; let description = [{ The `sincos` operation computes both the sine and cosine of a given value @@ -529,7 +530,7 @@ def Math_SincosOp : Math_Op<"sincos", ```mlir // Scalar sine and cosine values. - %sin, %cos = math.sincos %input : f64 `->` f64, f64 + %sin, %cos = math.sincos %input : f64 ``` }]; @@ -539,7 +540,7 @@ def Math_SincosOp : Math_Op<"sincos", let results = (outs FloatLike:$sin, FloatLike:$cos); let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)? - attr-dict `:` type($operand) `->` type($sin) `,` type($cos) }]; + attr-dict `:` type($operand) }]; let extraClassDeclaration = [{ std::optional> getShapeForUnroll(); diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir index cdefc4d6098c7..a4b5dde8a2187 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -1125,9 +1125,9 @@ gpu.module @test_module_56 { // CHECK-COUNT-2: llvm.fptrunc // CHECK: llvm.call @__nv_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> () // CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> () - %sin16, %cos16 = math.sincos %arg_f16 : f16 -> f16, f16 - %sin32, %cos32 = math.sincos %arg_f32 : f32 -> f32, f32 - %sin64, %cos64 = math.sincos %arg_f64 : f64 -> f64, f64 + %sin16, %cos16 = math.sincos %arg_f16 : f16 + %sin32, %cos32 = math.sincos %arg_f32 : f32 + %sin64, %cos64 = math.sincos %arg_f64 : f64 func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64 } @@ -1142,9 +1142,9 @@ gpu.module @test_module_56 { // CHECK-COUNT-2: llvm.fptrunc // CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> () // CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> () - %sin16, %cos16 = math.sincos %arg_f16 fastmath : f16 -> f16, f16 - %sin32, %cos32 = math.sincos %arg_f32 fastmath : f32 -> f32, f32 - %sin64, %cos64 = math.sincos %arg_f64 fastmath : f64 -> f64, f64 + %sin16, %cos16 = math.sincos %arg_f16 fastmath : f16 + %sin32, %cos32 = math.sincos %arg_f32 fastmath : f32 + %sin64, %cos64 = math.sincos %arg_f64 fastmath : f64 func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64 } } diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir index 9030ba9c93e55..f7d27120d4207 100644 --- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir +++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir @@ -234,7 +234,7 @@ func.func @trigonometrics(%arg0: f32) { // CHECK-SAME: [[ARG0:%.+]]: f32 func.func @sincos(%arg0: f32) { // CHECK: llvm.intr.sincos([[ARG0]]) : (f32) -> !llvm.struct<(f32, f32)> - %0:2 = math.sincos %arg0 : f32 -> f32, f32 + %0:2 = math.sincos %arg0 : f32 func.return } diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir index 5d3a8a6d87bed..f085d1c62ea86 100644 --- a/mlir/test/Dialect/Math/ops.mlir +++ b/mlir/test/Dialect/Math/ops.mlir @@ -66,11 +66,11 @@ func.func @sin(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { // CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) func.func @sincos(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { // CHECK: %{{.*}} = math.sincos %[[F]] : f32 - %0:2 = math.sincos %f : f32 -> f32, f32 + %0:2 = math.sincos %f : f32 // CHECK: %{{.*}} = math.sincos %[[V]] : vector<4xf32> - %1:2 = math.sincos %v : vector<4xf32> -> vector<4xf32>, vector<4xf32> + %1:2 = math.sincos %v : vector<4xf32> // CHECK: %{{.*}} = math.sincos %[[T]] : tensor<4x4x?xf32> - %2:2 = math.sincos %t : tensor<4x4x?xf32> -> tensor<4x4x?xf32>, tensor<4x4x?xf32> + %2:2 = math.sincos %t : tensor<4x4x?xf32> return } From faa84a92b9baa9a1ed4a184abc84b4db44b329c4 Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Thu, 25 Sep 2025 17:22:51 -0700 Subject: [PATCH 6/7] Remove custom verifier; clean up FMF handling --- mlir/include/mlir/Dialect/Math/IR/MathOps.td | 2 -- .../Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 2 +- mlir/lib/Dialect/Math/IR/MathOps.cpp | 13 +------------ 3 files changed, 2 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td index b4212056694e9..af65af6fedec6 100644 --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -545,8 +545,6 @@ def Math_SincosOp : Math_Op<"sincos", let extraClassDeclaration = [{ std::optional> getShapeForUnroll(); }]; - - let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index f8f2104d2bd6a..852c50c965f11 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -482,7 +482,7 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern { if (isa(computeType)) { const arith::FastMathFlags flag = op.getFastmath(); const bool useApprox = - ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag); + mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn); sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf"; } else if (isa(computeType)) { sincosFunc = "__nv_sincos"; diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp index 0de5636c27c3f..ca2792dd177e5 100644 --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -285,20 +285,9 @@ OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) { } //===----------------------------------------------------------------------===// -// SinCosOp verifier and getShapeForUnroll +// SinCosOp getShapeForUnroll //===----------------------------------------------------------------------===// -LogicalResult math::SincosOp::verify() { - Type operandType = getOperand().getType(); - Type sinType = getSin().getType(); - Type cosType = getCos().getType(); - - if (operandType != sinType || operandType != cosType) - return emitOpError("result types must match operand type"); - - return success(); -} - std::optional> math::SincosOp::getShapeForUnroll() { if (auto vt = mlir::dyn_cast_or_null(getOperand().getType())) return llvm::to_vector<4>(vt.getShape()); From 90ef640e7dd16880d700220c2ce41666325281b8 Mon Sep 17 00:00:00 2001 From: Asher Mancinelli Date: Tue, 30 Sep 2025 07:23:09 -0700 Subject: [PATCH 7/7] Spell out types; use dyn_cast on non-nullable --- mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 12 ++++++------ mlir/lib/Dialect/Math/IR/MathOps.cpp | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index 73a003ef4e6c1..229e40e2061cb 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -128,12 +128,12 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - const auto &typeConverter = *this->getTypeConverter(); - auto loc = op.getLoc(); - auto operandType = adaptor.getOperand().getType(); - auto llvmOperandType = typeConverter.convertType(operandType); - auto sinType = typeConverter.convertType(op.getSin().getType()); - auto cosType = typeConverter.convertType(op.getCos().getType()); + const LLVMTypeConverter &typeConverter = *this->getTypeConverter(); + mlir::Location loc = op.getLoc(); + mlir::Type operandType = adaptor.getOperand().getType(); + mlir::Type llvmOperandType = typeConverter.convertType(operandType); + mlir::Type sinType = typeConverter.convertType(op.getSin().getType()); + mlir::Type cosType = typeConverter.convertType(op.getCos().getType()); if (!llvmOperandType || !sinType || !cosType) return failure(); diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp index ca2792dd177e5..bbeef0f6ee9e5 100644 --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -289,7 +289,7 @@ OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// std::optional> math::SincosOp::getShapeForUnroll() { - if (auto vt = mlir::dyn_cast_or_null(getOperand().getType())) + if (auto vt = mlir::dyn_cast(getOperand().getType())) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; }