Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3275,6 +3275,39 @@ def CIR_InlineAsmOp : CIR_Op<"asm", [RecursiveMemoryEffects]> {
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// SqrtOp
//===----------------------------------------------------------------------===//

def CIR_SqrtOp : CIR_Op<"sqrt", [Pure]> {
let summary = "Floating-point square root";

let description = [{
The `cir.sqrt` operation computes the element-wise square root of its input.

The input must be either:
• a floating-point scalar type, or
• a vector whose element type is floating-point.

The result type must match the input type exactly.

Examples:
// scalar
%r = cir.sqrt %x : !cir.fp64

// vector
%v = cir.sqrt %vec : !cir.vector<!cir.fp32 x 4>
}];

// input and output types: float or vector-of-float
let arguments = (ins CIR_AnyFloatOrVecOfFloatType:$input);
let results = (outs CIR_AnyFloatOrVecOfFloatType:$result);

let assemblyFormat = [{
$input `:` type($input) attr-dict
}];
}

//===----------------------------------------------------------------------===//
// UnreachableOp
//===----------------------------------------------------------------------===//
Expand Down
14 changes: 13 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -781,9 +781,21 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
case X86::BI__builtin_ia32_sqrtsh_round_mask:
case X86::BI__builtin_ia32_sqrtsd_round_mask:
case X86::BI__builtin_ia32_sqrtss_round_mask:
errorNYI("masked round sqrt builtins");
return {};
case X86::BI__builtin_ia32_sqrtpd256:
case X86::BI__builtin_ia32_sqrtpd:
case X86::BI__builtin_ia32_sqrtps256:
case X86::BI__builtin_ia32_sqrtps:
case X86::BI__builtin_ia32_sqrtph256:
case X86::BI__builtin_ia32_sqrtph:
case X86::BI__builtin_ia32_sqrtph512:
case X86::BI__builtin_ia32_sqrtps512:
case X86::BI__builtin_ia32_sqrtpd512:
case X86::BI__builtin_ia32_sqrtpd512: {
mlir::Location loc = getLoc(expr->getExprLoc());
mlir::Value arg = ops[0];
return cir::SqrtOp::create(builder, loc, arg.getType(), arg).getResult();
}
case X86::BI__builtin_ia32_pmuludq128:
case X86::BI__builtin_ia32_pmuludq256:
case X86::BI__builtin_ia32_pmuludq512:
Expand Down
93 changes: 92 additions & 1 deletion clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//====- LowerToLLVM.cpp - Lowering from CIR to LLVMIR ---------------------===//
//====- LowerToLLVM.cpp - Lowering from CIR to LLVMIR ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down Expand Up @@ -30,6 +30,7 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Transforms/DialectConversion.h"
#include "clang/Basic/LLVM.h"
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"
Expand All @@ -44,6 +45,96 @@

using namespace cir;
using namespace llvm;
using namespace mlir;

static std::string getLLVMIntrinsicNameForType(Type llvmTy) {
std::string s;
{
llvm::raw_string_ostream os(s);
llvm::Type *unused = nullptr;
os << llvmTy;
}
if (auto vecTy = llvmTy.dyn_cast<LLVM::LLVMType>()) {
}
return s;
}

// Actual lowering
LogicalResult CIRToLLVMSqrtOpLowering::matchAndRewrite(
cir::SqrtOp op, typename cir::SqrtOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {

Location loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();

Type cirResTy = op.getResult().getType();
Type llvmResTy = getTypeConverter()->convertType(cirResTy);
if (!llvmResTy)
return op.emitOpError(
"expected LLVM dialect result type for cir.sqrt lowering");

Value operand = adaptor.getInput();
Value llvmOperand = operand;
if (operand.getType() != llvmResTy) {
llvmOperand = rewriter.create<LLVM::BitcastOp>(loc, llvmResTy, operand);
}

// Build the llvm.sqrt.* intrinsic name depending on scalar vs vector result
std::string intrinsicName = "llvm.sqrt.";
std::string suffix;

// If the CIR result type is a vector, include the 'vN' part in the suffix.
if (auto vec = cirResTy.dyn_cast<cir::VectorType>()) {
Type elt = vec.getElementType();
if (auto f = elt.dyn_cast<cir::FloatType>()) {
unsigned width = f.getWidth();
unsigned n = vec.getNumElements();
if (width == 32)
suffix = "v" + std::to_string(n) + "f32";
else if (width == 64)
suffix = "v" + std::to_string(n) + "f64";
else if (width == 16)
suffix = "v" + std::to_string(n) + "f16";
else
return op.emitOpError("unsupported float width for sqrt");
} else {
return op.emitOpError("vector element must be floating point for sqrt");
}
} else if (auto f = cirResTy.dyn_cast<cir::FloatType>()) {
// Scalar float
unsigned width = f.getWidth();
if (width == 32)
suffix = "f32";
else if (width == 64)
suffix = "f64";
else if (width == 16)
suffix = "f16";
else
return op.emitOpError("unsupported float width for sqrt");
} else {
return op.emitOpError("unsupported type for cir.sqrt lowering");
}

intrinsicName += suffix;

// Ensure the llvm intrinsic function exists at module scope. Insert it at
// the start of the module body using an insertion guard.
ModuleOp module = op->getParentOfType<ModuleOp>();
if (!module.lookupSymbol<LLVM::LLVMFuncOp>(intrinsicName)) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmResTy, {llvmResTy},
/*isVarArg=*/false);
rewriter.create<LLVM::LLVMFuncOp>(loc, intrinsicName, llvmFnType);
}

// Create the call and replace cir.sqrt
auto callee = SymbolRefAttr::get(ctx, intrinsicName);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, llvmResTy, callee,
ArrayRef<Value>{llvmOperand});

return mlir::success();
}

namespace cir {
namespace direct {
Expand Down
14 changes: 14 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,25 @@
#ifndef CLANG_CIR_LOWERTOLLVM_H
#define CLANG_CIR_LOWERTOLLVM_H

#include "mlir/Conversion/PatternRewriter.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Transforms/DialectConversion.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"

namespace cir {
class SqrtOp;
}

class CIRToLLVMSqrtOpLowering : public mlir::OpConversionPattern<cir::SqrtOp> {
public:
using mlir::OpConversionPattern<cir::SqrtOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::SqrtOp op, typename cir::SqrtOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override;
};

namespace cir {

namespace direct {
Expand Down
67 changes: 67 additions & 0 deletions clang/test/CIR/CodeGen/X86/cir-sqrt-builtins.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Test for x86 sqrt builtins (sqrtps, sqrtpd, sqrtph, etc.)
// RUN: %clang_cc1 -fclangir -triple x86_64-unknown-linux-gnu -target-feature +avx512fp16 -emit-cir %s -o - | FileCheck %s

#include <immintrin.h>

// Test __builtin_ia32_sqrtps - single precision vector sqrt (128-bit)
__m128 test_sqrtps(__m128 x) {
return __builtin_ia32_sqrtps(x);
}
// CHECK-LABEL: cir.func @test_sqrtps
// CHECK: cir.sqrt

// Test __builtin_ia32_sqrtps256 - single precision vector sqrt (256-bit)
__m256 test_sqrtps256(__m256 x) {
return __builtin_ia32_sqrtps256(x);
}
// CHECK-LABEL: cir.func @test_sqrtps256
// CHECK: cir.sqrt

// Test __builtin_ia32_sqrtps512 - single precision vector sqrt (512-bit)
__m512 test_sqrtps512(__m512 x) {
return __builtin_ia32_sqrtps512(x);
}
// CHECK-LABEL: cir.func @test_sqrtps512
// CHECK: cir.sqrt

// Test __builtin_ia32_sqrtpd - double precision vector sqrt (128-bit)
__m128d test_sqrtpd(__m128d x) {
return __builtin_ia32_sqrtpd(x);
}
// CHECK-LABEL: cir.func @test_sqrtpd
// CHECK: cir.sqrt

// Test __builtin_ia32_sqrtpd256 - double precision vector sqrt (256-bit)
__m256d test_sqrtpd256(__m256d x) {
return __builtin_ia32_sqrtpd256(x);
}
// CHECK-LABEL: cir.func @test_sqrtpd256
// CHECK: cir.sqrt

// Test __builtin_ia32_sqrtpd512 - double precision vector sqrt (512-bit)
__m512d test_sqrtpd512(__m512d x) {
return __builtin_ia32_sqrtpd512(x);
}
// CHECK-LABEL: cir.func @test_sqrtpd512
// CHECK: cir.sqrt

// Test __builtin_ia32_sqrtph - half precision vector sqrt (128-bit)
__m128h test_sqrtph(__m128h x) {
return __builtin_ia32_sqrtph(x);
}
// CHECK-LABEL: cir.func @test_sqrtph
// CHECK: cir.sqrt

// Test __builtin_ia32_sqrtph256 - half precision vector sqrt (256-bit)
__m256h test_sqrtph256(__m256h x) {
return __builtin_ia32_sqrtph256(x);
}
// CHECK-LABEL: cir.func @test_sqrtph256
// CHECK: cir.sqrt

// Test __builtin_ia32_sqrtph512 - half precision vector sqrt (512-bit)
__m512h test_sqrtph512(__m512h x) {
return __builtin_ia32_sqrtph512(x);
}
// CHECK-LABEL: cir.func @test_sqrtph512
// CHECK: cir.sqrt