diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 50fca564b5b64..02b61bd989368 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -1520,20 +1520,12 @@ class TanPattern : public SPIRVToLLVMConversion { if (!dstType) return rewriter.notifyMatchFailure(tanOp, "type conversion failed"); - Location loc = tanOp.getLoc(); - Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand()); - Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand()); - rewriter.replaceOpWithNewOp(tanOp, dstType, sin, cos); + rewriter.replaceOpWithNewOp(tanOp, dstType, + adaptor.getOperands()); return success(); } }; -/// Convert `spirv.Tanh` to -/// -/// exp(2x) - 1 -/// ----------- -/// exp(2x) + 1 -/// class TanhPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; @@ -1546,18 +1538,8 @@ class TanhPattern : public SPIRVToLLVMConversion { if (!dstType) return rewriter.notifyMatchFailure(tanhOp, "type conversion failed"); - Location loc = tanhOp.getLoc(); - Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); - Value multiplied = - LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand()); - Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied); - Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); - Value numerator = - LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one); - Value denominator = - LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one); - rewriter.replaceOpWithNewOp(tanhOp, dstType, numerator, - denominator); + rewriter.replaceOpWithNewOp(tanhOp, dstType, + adaptor.getOperands()); return success(); } }; diff --git a/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir index e1936e2fd8abe..b17e1c40cb9a7 100644 --- a/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/gl-ops-to-llvm.mlir @@ -162,9 +162,7 @@ spirv.func @sqrt(%arg0: f32, %arg1: vector<3xf16>) "None" { // CHECK-LABEL: @tan spirv.func @tan(%arg0: f32) "None" { - // CHECK: %[[SIN:.*]] = llvm.intr.sin(%{{.*}}) : (f32) -> f32 - // CHECK: %[[COS:.*]] = llvm.intr.cos(%{{.*}}) : (f32) -> f32 - // CHECK: llvm.fdiv %[[SIN]], %[[COS]] : f32 + // CHECK: llvm.intr.tan(%{{.*}}) : (f32) -> f32 %0 = spirv.GL.Tan %arg0 : f32 spirv.Return } @@ -175,13 +173,7 @@ spirv.func @tan(%arg0: f32) "None" { // CHECK-LABEL: @tanh spirv.func @tanh(%arg0: f32) "None" { - // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2.000000e+00 : f32) : f32 - // CHECK: %[[X2:.*]] = llvm.fmul %[[TWO]], %{{.*}} : f32 - // CHECK: %[[EXP:.*]] = llvm.intr.exp(%[[X2]]) : (f32) -> f32 - // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 - // CHECK: %[[T0:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : f32 - // CHECK: %[[T1:.*]] = llvm.fadd %[[EXP]], %[[ONE]] : f32 - // CHECK: llvm.fdiv %[[T0]], %[[T1]] : f32 + // CHECK: llvm.intr.tanh(%{{.*}}) : (f32) -> f32 %0 = spirv.GL.Tanh %arg0 : f32 spirv.Return }