Skip to content
37 changes: 37 additions & 0 deletions mlir/include/mlir/Dialect/Math/IR/MathOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,43 @@ def Math_SinhOp : Math_FloatUnaryOp<"sinh"> {
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
// SinCosOp
//===----------------------------------------------------------------------===//

def Math_SincosOp : Math_Op<"sincos",
[SameOperandsAndResultShape,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
AllTypesMatch<["operand", "sin", "cos"]>]> {
let summary = "sine and cosine of the specified value";
let description = [{
The `sincos` operation computes both the sine and cosine of a given value
simultaneously. It takes one operand of floating point type (i.e., scalar,
tensor or vector) and returns two results of the same type. This operation
can be more efficient than computing sine and cosine separately when both
values are needed.

Example:

```mlir
// Scalar sine and cosine values.
%sin, %cos = math.sincos %input : f64
```
}];

let arguments = (ins FloatLike:$operand,
DefaultValuedAttr<Arith_FastMathAttr,
"::mlir::arith::FastMathFlags::none">:$fastmath);
let results = (outs FloatLike:$sin, FloatLike:$cos);

let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)?
attr-dict `:` type($operand) }];

let extraClassDeclaration = [{
std::optional<SmallVector<int64_t, 4>> getShapeForUnroll();
}];
}

//===----------------------------------------------------------------------===//
// CountLeadingZerosOp
//===----------------------------------------------------------------------===//
Expand Down
99 changes: 98 additions & 1 deletion mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
LLVM::SqrtOp>();
LLVM::SincosOp, LLVM::SqrtOp>();

// TODO: Remove once we support replacing non-root ops.
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
Expand Down Expand Up @@ -466,6 +466,100 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
});
}

struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = adaptor.getOperand();
Type inputType = input.getType();
auto convertedInput = maybeExt(input, rewriter);
auto computeType = convertedInput.getType();

StringRef sincosFunc;
if (isa<Float32Type>(computeType)) {
const arith::FastMathFlags flag = op.getFastmath();
const bool useApprox =
mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
} else if (isa<Float64Type>(computeType)) {
sincosFunc = "__nv_sincos";
} else {
return rewriter.notifyMatchFailure(op,
"unsupported operand type for sincos");
}

auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());

Value sinPtr, cosPtr;
{
OpBuilder::InsertionGuard guard(rewriter);
auto *scope =
op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
assert(scope && "Expected op to be inside automatic allocation scope");
rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
auto one = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
sinPtr =
rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
cosPtr =
rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
}

createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
op);

auto sinResult = rewriter.create<LLVM::LoadOp>(loc, computeType, sinPtr);
auto cosResult = rewriter.create<LLVM::LoadOp>(loc, computeType, cosPtr);

rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
maybeTrunc(cosResult, inputType, rewriter)});
return success();
}

private:
Value maybeExt(Value operand, PatternRewriter &rewriter) const {
if (isa<Float16Type, BFloat16Type>(operand.getType()))
return rewriter.create<LLVM::FPExtOp>(
operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
return operand;
}

Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
if (operand.getType() != type)
return rewriter.create<LLVM::FPTruncOp>(operand.getLoc(), type, operand);
return operand;
}

void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
StringRef funcName, Value input, Value sinPtr,
Value cosPtr, Operation *op) const {
auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
auto ptrType = sinPtr.getType();

SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);

auto funcAttr = StringAttr::get(op->getContext(), funcName);
auto funcOp =
SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);

if (!funcOp) {
auto parentFunc = op->getParentOfType<FunctionOpInterface>();
assert(parentFunc && "expected there to be a parent function");
OpBuilder b(parentFunc);

auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
}

SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
rewriter.create<LLVM::CallOp>(loc, funcOp, callOperands);
}
};

template <typename OpTy>
static void populateOpPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
Expand Down Expand Up @@ -589,6 +683,9 @@ void mlir::populateLibDeviceConversionPatterns(
"__nv_tan", "__nv_fast_tanf");
populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
"__nv_tanh");

// Custom pattern for sincos since it returns two values
patterns.add<SincosOpLowering>(converter, benefit);
}

void mlir::populateGpuToNVVMConversionPatterns(
Expand Down
33 changes: 33 additions & 0 deletions mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,38 @@ using CountTrailingZerosOpLowering =
LLVM::CountTrailingZerosOp>;
using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;

// A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops.
struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
mlir::Location loc = op.getLoc();
mlir::Type operandType = adaptor.getOperand().getType();
mlir::Type llvmOperandType = typeConverter.convertType(operandType);
mlir::Type sinType = typeConverter.convertType(op.getSin().getType());
mlir::Type cosType = typeConverter.convertType(op.getCos().getType());
if (!llvmOperandType || !sinType || !cosType)
return failure();

ConvertFastMath<math::SincosOp, LLVM::SincosOp> attrs(op);

auto structType = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), {llvmOperandType, llvmOperandType});

auto sincosOp = rewriter.create<LLVM::SincosOp>(
loc, structType, adaptor.getOperand(), attrs.getAttrs());

auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);

rewriter.replaceOp(op, {sinValue, cosValue});
return success();
}
};

// A `expm1` is converted into `exp - 1`.
struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
Expand Down Expand Up @@ -393,6 +425,7 @@ void mlir::populateMathToLLVMConversionPatterns(
RoundEvenOpLowering,
RoundOpLowering,
RsqrtOpLowering,
SincosOpLowering,
SinOpLowering,
SinhOpLowering,
ASinOpLowering,
Expand Down
10 changes: 10 additions & 0 deletions mlir/lib/Dialect/Math/IR/MathOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,16 @@ OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
});
}

//===----------------------------------------------------------------------===//
// SinCosOp getShapeForUnroll
//===----------------------------------------------------------------------===//

std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
if (auto vt = mlir::dyn_cast<VectorType>(getOperand().getType()))
return llvm::to_vector<4>(vt.getShape());
return std::nullopt;
}

//===----------------------------------------------------------------------===//
// CountLeadingZerosOp folder
//===----------------------------------------------------------------------===//
Expand Down
39 changes: 39 additions & 0 deletions mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1109,3 +1109,42 @@ gpu.module @test_module_55 {
func.return %result32, %result64 : f32, f64
}
}

gpu.module @test_module_56 {
// CHECK: gpu.module @test_module_56

// CHECK-DAG: llvm.func @__nv_sincosf(f32, !llvm.ptr, !llvm.ptr)
// CHECK-DAG: llvm.func @__nv_sincos(f64, !llvm.ptr, !llvm.ptr)

// CHECK-LABEL: func @gpu_sincos
// CHECK-SAME: %[[ARG_f16:.*]]: f16, %[[ARG_f32:.*]]: f32, %[[ARG_f64:.*]]: f64
func.func @gpu_sincos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f16, f32, f32, f64, f64) {
// CHECK-COUNT-6: llvm.alloca
// CHECK: %[[ARG_f16_ext:.*]] = llvm.fpext %[[ARG_f16]] : f16 to f32
// CHECK: llvm.call @__nv_sincosf(%[[ARG_f16_ext]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
// CHECK-COUNT-2: llvm.fptrunc
// CHECK: llvm.call @__nv_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
// CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> ()
%sin16, %cos16 = math.sincos %arg_f16 : f16
%sin32, %cos32 = math.sincos %arg_f32 : f32
%sin64, %cos64 = math.sincos %arg_f64 : f64
func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64
}

// CHECK: llvm.func @__nv_fast_sincosf(f32, !llvm.ptr, !llvm.ptr)

// CHECK-LABEL: func @gpu_sincos_fastmath
// CHECK-SAME: %[[ARG_f16:.*]]: f16, %[[ARG_f32:.*]]: f32, %[[ARG_f64:.*]]: f64
func.func @gpu_sincos_fastmath(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f16, f32, f32, f64, f64) {
// CHECK-COUNT-6: llvm.alloca
// CHECK: %[[ARG_f16_ext:.*]] = llvm.fpext %[[ARG_f16]] : f16 to f32
// CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f16_ext]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
// CHECK-COUNT-2: llvm.fptrunc
// CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
// CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> ()
%sin16, %cos16 = math.sincos %arg_f16 fastmath<afn> : f16
%sin32, %cos32 = math.sincos %arg_f32 fastmath<afn> : f32
%sin64, %cos64 = math.sincos %arg_f64 fastmath<afn> : f64
func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64
}
}
10 changes: 10 additions & 0 deletions mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,16 @@ func.func @trigonometrics(%arg0: f32) {

// -----

// CHECK-LABEL: func @sincos
// CHECK-SAME: [[ARG0:%.+]]: f32
func.func @sincos(%arg0: f32) {
// CHECK: llvm.intr.sincos([[ARG0]]) : (f32) -> !llvm.struct<(f32, f32)>
%0:2 = math.sincos %arg0 : f32
func.return
}

// -----

// CHECK-LABEL: func @inverse_trigonometrics
// CHECK-SAME: [[ARG0:%.+]]: f32
func.func @inverse_trigonometrics(%arg0: f32) {
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Dialect/Math/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ func.func @sin(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
return
}

// CHECK-LABEL: func @sincos(
// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
func.func @sincos(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
// CHECK: %{{.*}} = math.sincos %[[F]] : f32
%0:2 = math.sincos %f : f32
// CHECK: %{{.*}} = math.sincos %[[V]] : vector<4xf32>
%1:2 = math.sincos %v : vector<4xf32>
// CHECK: %{{.*}} = math.sincos %[[T]] : tensor<4x4x?xf32>
%2:2 = math.sincos %t : tensor<4x4x?xf32>
return
}

// CHECK-LABEL: func @erf(
// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
func.func @erf(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
Expand Down