diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td index cfd8c4b8f11f7..af65af6fedec6 100644 --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -510,6 +510,43 @@ def Math_SinhOp : Math_FloatUnaryOp<"sinh"> { let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// SinCosOp +//===----------------------------------------------------------------------===// + +def Math_SincosOp : Math_Op<"sincos", + [SameOperandsAndResultShape, + DeclareOpInterfaceMethods, + AllTypesMatch<["operand", "sin", "cos"]>]> { + let summary = "sine and cosine of the specified value"; + let description = [{ + The `sincos` operation computes both the sine and cosine of a given value + simultaneously. It takes one operand of floating point type (i.e., scalar, + tensor or vector) and returns two results of the same type. This operation + can be more efficient than computing sine and cosine separately when both + values are needed. + + Example: + + ```mlir + // Scalar sine and cosine values. + %sin, %cos = math.sincos %input : f64 + ``` + }]; + + let arguments = (ins FloatLike:$operand, + DefaultValuedAttr:$fastmath); + let results = (outs FloatLike:$sin, FloatLike:$cos); + + let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)? + attr-dict `:` type($operand) }]; + + let extraClassDeclaration = [{ + std::optional> getShapeForUnroll(); + }]; +} + //===----------------------------------------------------------------------===// // CountLeadingZerosOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index a95263bb55f69..852c50c965f11 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -436,7 +436,7 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) { LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp, - LLVM::SqrtOp>(); + LLVM::SincosOp, LLVM::SqrtOp>(); // TODO: Remove once we support replacing non-root ops. target.addLegalOp(); @@ -466,6 +466,100 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) { }); } +struct SincosOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = adaptor.getOperand(); + Type inputType = input.getType(); + auto convertedInput = maybeExt(input, rewriter); + auto computeType = convertedInput.getType(); + + StringRef sincosFunc; + if (isa(computeType)) { + const arith::FastMathFlags flag = op.getFastmath(); + const bool useApprox = + mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn); + sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf"; + } else if (isa(computeType)) { + sincosFunc = "__nv_sincos"; + } else { + return rewriter.notifyMatchFailure(op, + "unsupported operand type for sincos"); + } + + auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); + + Value sinPtr, cosPtr; + { + OpBuilder::InsertionGuard guard(rewriter); + auto *scope = + op->getParentWithTrait(); + assert(scope && "Expected op to be inside automatic allocation scope"); + rewriter.setInsertionPointToStart(&scope->getRegion(0).front()); + auto one = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1)); + sinPtr = + rewriter.create(loc, ptrType, computeType, one, 0); + cosPtr = + rewriter.create(loc, ptrType, computeType, one, 0); + } + + createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr, + op); + + auto sinResult = rewriter.create(loc, computeType, sinPtr); + auto cosResult = rewriter.create(loc, computeType, cosPtr); + + rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter), + maybeTrunc(cosResult, inputType, rewriter)}); + return success(); + } + +private: + Value maybeExt(Value operand, PatternRewriter &rewriter) const { + if (isa(operand.getType())) + return rewriter.create( + operand.getLoc(), Float32Type::get(rewriter.getContext()), operand); + return operand; + } + + Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const { + if (operand.getType() != type) + return rewriter.create(operand.getLoc(), type, operand); + return operand; + } + + void createSincosCall(ConversionPatternRewriter &rewriter, Location loc, + StringRef funcName, Value input, Value sinPtr, + Value cosPtr, Operation *op) const { + auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext()); + auto ptrType = sinPtr.getType(); + + SmallVector operandTypes = {input.getType(), ptrType, ptrType}; + auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes); + + auto funcAttr = StringAttr::get(op->getContext(), funcName); + auto funcOp = + SymbolTable::lookupNearestSymbolFrom(op, funcAttr); + + if (!funcOp) { + auto parentFunc = op->getParentOfType(); + assert(parentFunc && "expected there to be a parent function"); + OpBuilder b(parentFunc); + + auto globalloc = loc->findInstanceOfOrUnknown(); + funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType); + } + + SmallVector callOperands = {input, sinPtr, cosPtr}; + rewriter.create(loc, funcOp, callOperands); + } +}; + template static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, @@ -589,6 +683,9 @@ void mlir::populateLibDeviceConversionPatterns( "__nv_tan", "__nv_fast_tanf"); populateOpPatterns(converter, patterns, benefit, "__nv_tanhf", "__nv_tanh"); + + // Custom pattern for sincos since it returns two values + patterns.add(converter, benefit); } void mlir::populateGpuToNVVMConversionPatterns( diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index 853f45498ac52..229e40e2061cb 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -121,6 +121,38 @@ using CountTrailingZerosOpLowering = LLVM::CountTrailingZerosOp>; using AbsIOpLowering = IntOpWithFlagLowering; +// A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops. +struct SincosOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const LLVMTypeConverter &typeConverter = *this->getTypeConverter(); + mlir::Location loc = op.getLoc(); + mlir::Type operandType = adaptor.getOperand().getType(); + mlir::Type llvmOperandType = typeConverter.convertType(operandType); + mlir::Type sinType = typeConverter.convertType(op.getSin().getType()); + mlir::Type cosType = typeConverter.convertType(op.getCos().getType()); + if (!llvmOperandType || !sinType || !cosType) + return failure(); + + ConvertFastMath attrs(op); + + auto structType = LLVM::LLVMStructType::getLiteral( + rewriter.getContext(), {llvmOperandType, llvmOperandType}); + + auto sincosOp = rewriter.create( + loc, structType, adaptor.getOperand(), attrs.getAttrs()); + + auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0); + auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1); + + rewriter.replaceOp(op, {sinValue, cosValue}); + return success(); + } +}; + // A `expm1` is converted into `exp - 1`. struct ExpM1OpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -393,6 +425,7 @@ void mlir::populateMathToLLVMConversionPatterns( RoundEvenOpLowering, RoundOpLowering, RsqrtOpLowering, + SincosOpLowering, SinOpLowering, SinhOpLowering, ASinOpLowering, diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp index a21631cbf8510..bbeef0f6ee9e5 100644 --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -284,6 +284,16 @@ OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) { }); } +//===----------------------------------------------------------------------===// +// SinCosOp getShapeForUnroll +//===----------------------------------------------------------------------===// + +std::optional> math::SincosOp::getShapeForUnroll() { + if (auto vt = mlir::dyn_cast(getOperand().getType())) + return llvm::to_vector<4>(vt.getShape()); + return std::nullopt; +} + //===----------------------------------------------------------------------===// // CountLeadingZerosOp folder //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir index ef06af3ad3163..a4b5dde8a2187 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -1109,3 +1109,42 @@ gpu.module @test_module_55 { func.return %result32, %result64 : f32, f64 } } + +gpu.module @test_module_56 { + // CHECK: gpu.module @test_module_56 + + // CHECK-DAG: llvm.func @__nv_sincosf(f32, !llvm.ptr, !llvm.ptr) + // CHECK-DAG: llvm.func @__nv_sincos(f64, !llvm.ptr, !llvm.ptr) + + // CHECK-LABEL: func @gpu_sincos + // CHECK-SAME: %[[ARG_f16:.*]]: f16, %[[ARG_f32:.*]]: f32, %[[ARG_f64:.*]]: f64 + func.func @gpu_sincos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f16, f32, f32, f64, f64) { + // CHECK-COUNT-6: llvm.alloca + // CHECK: %[[ARG_f16_ext:.*]] = llvm.fpext %[[ARG_f16]] : f16 to f32 + // CHECK: llvm.call @__nv_sincosf(%[[ARG_f16_ext]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> () + // CHECK-COUNT-2: llvm.fptrunc + // CHECK: llvm.call @__nv_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> () + // CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> () + %sin16, %cos16 = math.sincos %arg_f16 : f16 + %sin32, %cos32 = math.sincos %arg_f32 : f32 + %sin64, %cos64 = math.sincos %arg_f64 : f64 + func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64 + } + + // CHECK: llvm.func @__nv_fast_sincosf(f32, !llvm.ptr, !llvm.ptr) + + // CHECK-LABEL: func @gpu_sincos_fastmath + // CHECK-SAME: %[[ARG_f16:.*]]: f16, %[[ARG_f32:.*]]: f32, %[[ARG_f64:.*]]: f64 + func.func @gpu_sincos_fastmath(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f16, f32, f32, f64, f64) { + // CHECK-COUNT-6: llvm.alloca + // CHECK: %[[ARG_f16_ext:.*]] = llvm.fpext %[[ARG_f16]] : f16 to f32 + // CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f16_ext]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> () + // CHECK-COUNT-2: llvm.fptrunc + // CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> () + // CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> () + %sin16, %cos16 = math.sincos %arg_f16 fastmath : f16 + %sin32, %cos32 = math.sincos %arg_f32 fastmath : f32 + %sin64, %cos64 = math.sincos %arg_f64 fastmath : f64 + func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64 + } +} diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir index f4541220fe4d2..f7d27120d4207 100644 --- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir +++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir @@ -230,6 +230,16 @@ func.func @trigonometrics(%arg0: f32) { // ----- +// CHECK-LABEL: func @sincos +// CHECK-SAME: [[ARG0:%.+]]: f32 +func.func @sincos(%arg0: f32) { + // CHECK: llvm.intr.sincos([[ARG0]]) : (f32) -> !llvm.struct<(f32, f32)> + %0:2 = math.sincos %arg0 : f32 + func.return +} + +// ----- + // CHECK-LABEL: func @inverse_trigonometrics // CHECK-SAME: [[ARG0:%.+]]: f32 func.func @inverse_trigonometrics(%arg0: f32) { diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir index cb10fc4397ffc..f085d1c62ea86 100644 --- a/mlir/test/Dialect/Math/ops.mlir +++ b/mlir/test/Dialect/Math/ops.mlir @@ -62,6 +62,18 @@ func.func @sin(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { return } +// CHECK-LABEL: func @sincos( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func.func @sincos(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.sincos %[[F]] : f32 + %0:2 = math.sincos %f : f32 + // CHECK: %{{.*}} = math.sincos %[[V]] : vector<4xf32> + %1:2 = math.sincos %v : vector<4xf32> + // CHECK: %{{.*}} = math.sincos %[[T]] : tensor<4x4x?xf32> + %2:2 = math.sincos %t : tensor<4x4x?xf32> + return +} + // CHECK-LABEL: func @erf( // CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) func.func @erf(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {