diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 5f5fab6f12300..2dc71c68f8a94 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -3275,6 +3275,39 @@ def CIR_InlineAsmOp : CIR_Op<"asm", [RecursiveMemoryEffects]> { let hasCustomAssemblyFormat = 1; } +//===----------------------------------------------------------------------===// +// SqrtOp +//===----------------------------------------------------------------------===// + +def CIR_SqrtOp : CIR_Op<"sqrt", [Pure]> { + let summary = "Floating-point square root"; + + let description = [{ + The `cir.sqrt` operation computes the element-wise square root of its input. + + The input must be either: + • a floating-point scalar type, or + • a vector whose element type is floating-point. + + The result type must match the input type exactly. + + Examples: + // scalar + %r = cir.sqrt %x : !cir.fp64 + + // vector + %v = cir.sqrt %vec : !cir.vector + }]; + + // input and output types: float or vector-of-float + let arguments = (ins CIR_AnyFloatOrVecOfFloatType:$input); + let results = (outs CIR_AnyFloatOrVecOfFloatType:$result); + + let assemblyFormat = [{ + $input `:` type($input) attr-dict + }]; +} + //===----------------------------------------------------------------------===// // UnreachableOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp index 0e43345bad6f1..f8a139ec7a8e0 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp @@ -781,9 +781,21 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID, case X86::BI__builtin_ia32_sqrtsh_round_mask: case X86::BI__builtin_ia32_sqrtsd_round_mask: case X86::BI__builtin_ia32_sqrtss_round_mask: + errorNYI("masked round sqrt builtins"); + return {}; + case X86::BI__builtin_ia32_sqrtpd256: + case X86::BI__builtin_ia32_sqrtpd: + case X86::BI__builtin_ia32_sqrtps256: + case X86::BI__builtin_ia32_sqrtps: + case X86::BI__builtin_ia32_sqrtph256: + case X86::BI__builtin_ia32_sqrtph: case X86::BI__builtin_ia32_sqrtph512: case X86::BI__builtin_ia32_sqrtps512: - case X86::BI__builtin_ia32_sqrtpd512: + case X86::BI__builtin_ia32_sqrtpd512: { + mlir::Location loc = getLoc(expr->getExprLoc()); + mlir::Value arg = ops[0]; + return cir::SqrtOp::create(builder, loc, arg.getType(), arg).getResult(); + } case X86::BI__builtin_ia32_pmuludq128: case X86::BI__builtin_ia32_pmuludq256: case X86::BI__builtin_ia32_pmuludq512: diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 0c34d87734c3e..5514a4cd0876d 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1,4 +1,4 @@ -//====- LowerToLLVM.cpp - Lowering from CIR to LLVMIR ---------------------===// +//====- LowerToLLVM.cpp - Lowering from CIR to LLVMIR ---------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -30,6 +30,7 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Transforms/DialectConversion.h" +#include "clang/Basic/LLVM.h" #include "clang/CIR/Dialect/IR/CIRAttrs.h" #include "clang/CIR/Dialect/IR/CIRDialect.h" #include "clang/CIR/Dialect/IR/CIRTypes.h" @@ -44,6 +45,96 @@ using namespace cir; using namespace llvm; +using namespace mlir; + +static std::string getLLVMIntrinsicNameForType(Type llvmTy) { + std::string s; + { + llvm::raw_string_ostream os(s); + llvm::Type *unused = nullptr; + os << llvmTy; + } + if (auto vecTy = llvmTy.dyn_cast()) { + } + return s; +} + +// Actual lowering +LogicalResult CIRToLLVMSqrtOpLowering::matchAndRewrite( + cir::SqrtOp op, typename cir::SqrtOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + Type cirResTy = op.getResult().getType(); + Type llvmResTy = getTypeConverter()->convertType(cirResTy); + if (!llvmResTy) + return op.emitOpError( + "expected LLVM dialect result type for cir.sqrt lowering"); + + Value operand = adaptor.getInput(); + Value llvmOperand = operand; + if (operand.getType() != llvmResTy) { + llvmOperand = rewriter.create(loc, llvmResTy, operand); + } + + // Build the llvm.sqrt.* intrinsic name depending on scalar vs vector result + std::string intrinsicName = "llvm.sqrt."; + std::string suffix; + + // If the CIR result type is a vector, include the 'vN' part in the suffix. + if (auto vec = cirResTy.dyn_cast()) { + Type elt = vec.getElementType(); + if (auto f = elt.dyn_cast()) { + unsigned width = f.getWidth(); + unsigned n = vec.getNumElements(); + if (width == 32) + suffix = "v" + std::to_string(n) + "f32"; + else if (width == 64) + suffix = "v" + std::to_string(n) + "f64"; + else if (width == 16) + suffix = "v" + std::to_string(n) + "f16"; + else + return op.emitOpError("unsupported float width for sqrt"); + } else { + return op.emitOpError("vector element must be floating point for sqrt"); + } + } else if (auto f = cirResTy.dyn_cast()) { + // Scalar float + unsigned width = f.getWidth(); + if (width == 32) + suffix = "f32"; + else if (width == 64) + suffix = "f64"; + else if (width == 16) + suffix = "f16"; + else + return op.emitOpError("unsupported float width for sqrt"); + } else { + return op.emitOpError("unsupported type for cir.sqrt lowering"); + } + + intrinsicName += suffix; + + // Ensure the llvm intrinsic function exists at module scope. Insert it at + // the start of the module body using an insertion guard. + ModuleOp module = op->getParentOfType(); + if (!module.lookupSymbol(intrinsicName)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmResTy, {llvmResTy}, + /*isVarArg=*/false); + rewriter.create(loc, intrinsicName, llvmFnType); + } + + // Create the call and replace cir.sqrt + auto callee = SymbolRefAttr::get(ctx, intrinsicName); + rewriter.replaceOpWithNewOp(op, llvmResTy, callee, + ArrayRef{llvmOperand}); + + return mlir::success(); +} namespace cir { namespace direct { diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index 0591de545b81d..be6a380372efe 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -12,11 +12,25 @@ #ifndef CLANG_CIR_LOWERTOLLVM_H #define CLANG_CIR_LOWERTOLLVM_H +#include "mlir/Conversion/PatternRewriter.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Transforms/DialectConversion.h" #include "clang/CIR/Dialect/IR/CIRDialect.h" +namespace cir { +class SqrtOp; +} + +class CIRToLLVMSqrtOpLowering : public mlir::OpConversionPattern { +public: + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::SqrtOp op, typename cir::SqrtOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override; +}; + namespace cir { namespace direct { diff --git a/clang/test/CIR/CodeGen/X86/cir-sqrt-builtins.c b/clang/test/CIR/CodeGen/X86/cir-sqrt-builtins.c new file mode 100644 index 0000000000000..ef5cb954e3efe --- /dev/null +++ b/clang/test/CIR/CodeGen/X86/cir-sqrt-builtins.c @@ -0,0 +1,67 @@ +// Test for x86 sqrt builtins (sqrtps, sqrtpd, sqrtph, etc.) +// RUN: %clang_cc1 -fclangir -triple x86_64-unknown-linux-gnu -target-feature +avx512fp16 -emit-cir %s -o - | FileCheck %s + +#include + +// Test __builtin_ia32_sqrtps - single precision vector sqrt (128-bit) +__m128 test_sqrtps(__m128 x) { + return __builtin_ia32_sqrtps(x); +} +// CHECK-LABEL: cir.func @test_sqrtps +// CHECK: cir.sqrt + +// Test __builtin_ia32_sqrtps256 - single precision vector sqrt (256-bit) +__m256 test_sqrtps256(__m256 x) { + return __builtin_ia32_sqrtps256(x); +} +// CHECK-LABEL: cir.func @test_sqrtps256 +// CHECK: cir.sqrt + +// Test __builtin_ia32_sqrtps512 - single precision vector sqrt (512-bit) +__m512 test_sqrtps512(__m512 x) { + return __builtin_ia32_sqrtps512(x); +} +// CHECK-LABEL: cir.func @test_sqrtps512 +// CHECK: cir.sqrt + +// Test __builtin_ia32_sqrtpd - double precision vector sqrt (128-bit) +__m128d test_sqrtpd(__m128d x) { + return __builtin_ia32_sqrtpd(x); +} +// CHECK-LABEL: cir.func @test_sqrtpd +// CHECK: cir.sqrt + +// Test __builtin_ia32_sqrtpd256 - double precision vector sqrt (256-bit) +__m256d test_sqrtpd256(__m256d x) { + return __builtin_ia32_sqrtpd256(x); +} +// CHECK-LABEL: cir.func @test_sqrtpd256 +// CHECK: cir.sqrt + +// Test __builtin_ia32_sqrtpd512 - double precision vector sqrt (512-bit) +__m512d test_sqrtpd512(__m512d x) { + return __builtin_ia32_sqrtpd512(x); +} +// CHECK-LABEL: cir.func @test_sqrtpd512 +// CHECK: cir.sqrt + +// Test __builtin_ia32_sqrtph - half precision vector sqrt (128-bit) +__m128h test_sqrtph(__m128h x) { + return __builtin_ia32_sqrtph(x); +} +// CHECK-LABEL: cir.func @test_sqrtph +// CHECK: cir.sqrt + +// Test __builtin_ia32_sqrtph256 - half precision vector sqrt (256-bit) +__m256h test_sqrtph256(__m256h x) { + return __builtin_ia32_sqrtph256(x); +} +// CHECK-LABEL: cir.func @test_sqrtph256 +// CHECK: cir.sqrt + +// Test __builtin_ia32_sqrtph512 - half precision vector sqrt (512-bit) +__m512h test_sqrtph512(__m512h x) { + return __builtin_ia32_sqrtph512(x); +} +// CHECK-LABEL: cir.func @test_sqrtph512 +// CHECK: cir.sqrt \ No newline at end of file