Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,16 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
callingConv, sideEffect, extraFnAttr);
}

cir::CallOp createCallOp(mlir::Location loc, cir::IFuncOp callee,
mlir::ValueRange operands = mlir::ValueRange(),
cir::CallingConv callingConv = cir::CallingConv::C,
cir::SideEffect sideEffect = cir::SideEffect::All,
cir::ExtraFuncAttributesAttr extraFnAttr = {}) {
return createCallOp(loc, mlir::SymbolRefAttr::get(callee),
callee.getFunctionType().getReturnType(), operands,
callingConv, sideEffect, extraFnAttr);
}

cir::CallOp
createIndirectCallOp(mlir::Location loc, mlir::Value ind_target,
cir::FuncType fn_type,
Expand Down Expand Up @@ -775,6 +785,17 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
callingConv, sideEffect, extraFnAttr);
}

cir::CallOp
createTryCallOp(mlir::Location loc, cir::IFuncOp callee,
mlir::ValueRange operands,
cir::CallingConv callingConv = cir::CallingConv::C,
cir::SideEffect sideEffect = cir::SideEffect::All,
cir::ExtraFuncAttributesAttr extraFnAttr = {}) {
return createTryCallOp(loc, mlir::SymbolRefAttr::get(callee),
callee.getFunctionType().getReturnType(), operands,
callingConv, sideEffect, extraFnAttr);
}

cir::CallOp
createIndirectTryCallOp(mlir::Location loc, mlir::Value ind_target,
cir::FuncType fn_type, mlir::ValueRange operands,
Expand Down
47 changes: 47 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2596,6 +2596,53 @@ def CIR_GetGlobalOp
}];
}

def CIR_IFuncOp : CIR_Op<"func.ifunc", [Symbol]> {
let summary = "Indirect function (ifunc) declaration";
let description = [{
The `cir.func.ifunc` operation declares an indirect function, which allows
runtime selection of function implementations based on CPU features or other
runtime conditions. The actual function to call is determined by a resolver
function at runtime.

The resolver function must return a pointer to a function with the same
signature as the ifunc. The resolver typically inspects CPU features or
other runtime conditions to select the appropriate implementation.

This corresponds to the GNU indirect function attribute:
`__attribute__((ifunc("resolver")))`

Example:
```mlir
// Resolver function that returns a function pointer
cir.func internal @resolve_foo() -> !cir.ptr<!cir.func<i32 ()>> {
...
cir.return %impl : !cir.ptr<!cir.func<i32 ()>>
}

// IFunc declaration
cir.func.ifunc @foo resolver(@resolve_foo) : !cir.func<i32 ()>

// Usage
cir.func @use_foo() {
%result = cir.call @foo() : () -> i32
cir.return
}
```
}];

let arguments = (ins SymbolNameAttr:$sym_name,
CIR_VisibilityAttr:$global_visibility,
TypeAttrOf<CIR_FuncType>:$function_type, FlatSymbolRefAttr:$resolver,
DefaultValuedAttr<CIR_GlobalLinkageKind,
"GlobalLinkageKind::ExternalLinkage">:$linkage,
OptionalAttr<StrAttr>:$sym_visibility);

let assemblyFormat = [{
$sym_name `resolver` `(` $resolver `)` attr-dict `:`
$function_type
}];
}

//===----------------------------------------------------------------------===//
// GuardedInitOp
//===----------------------------------------------------------------------===//
Expand Down
27 changes: 19 additions & 8 deletions clang/lib/CIR/CodeGen/CIRGenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ void CIRGenModule::constructAttributeList(
static cir::CIRCallOpInterface
emitCallLikeOp(CIRGenFunction &CGF, mlir::Location callLoc,
cir::FuncType indirectFuncTy, mlir::Value indirectFuncVal,
cir::FuncOp directFuncOp,
mlir::Operation *directCalleeOp,
SmallVectorImpl<mlir::Value> &CIRCallArgs, bool isInvoke,
cir::CallingConv callingConv, cir::SideEffect sideEffect,
cir::ExtraFuncAttributesAttr extraFnAttrs) {
Expand Down Expand Up @@ -378,9 +378,13 @@ emitCallLikeOp(CIRGenFunction &CGF, mlir::Location callLoc,
callOpWithExceptions = builder.createIndirectTryCallOp(
callLoc, indirectFuncVal, indirectFuncTy, CIRCallArgs, callingConv,
sideEffect);
} else if (auto funcOp = mlir::dyn_cast<cir::FuncOp>(directCalleeOp)) {
callOpWithExceptions = builder.createTryCallOp(
callLoc, funcOp, CIRCallArgs, callingConv, sideEffect);
} else {
auto ifuncOp = mlir::cast<cir::IFuncOp>(directCalleeOp);
callOpWithExceptions = builder.createTryCallOp(
callLoc, directFuncOp, CIRCallArgs, callingConv, sideEffect);
callLoc, ifuncOp, CIRCallArgs, callingConv, sideEffect);
}
callOpWithExceptions->setAttr("extra_attrs", extraFnAttrs);
CGF.mayThrow = true;
Expand All @@ -405,7 +409,12 @@ emitCallLikeOp(CIRGenFunction &CGF, mlir::Location callLoc,
callLoc, indirectFuncVal, indirectFuncTy, CIRCallArgs,
cir::CallingConv::C, sideEffect, extraFnAttrs);
}
return builder.createCallOp(callLoc, directFuncOp, CIRCallArgs, callingConv,
if (auto funcOp = mlir::dyn_cast<cir::FuncOp>(directCalleeOp)) {
return builder.createCallOp(callLoc, funcOp, CIRCallArgs, callingConv,
sideEffect, extraFnAttrs);
}
auto ifuncOp = mlir::cast<cir::IFuncOp>(directCalleeOp);
return builder.createCallOp(callLoc, ifuncOp, CIRCallArgs, callingConv,
sideEffect, extraFnAttrs);
}

Expand Down Expand Up @@ -620,19 +629,21 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &CallInfo,
cir::CIRCallOpInterface theCall = [&]() {
cir::FuncType indirectFuncTy;
mlir::Value indirectFuncVal;
cir::FuncOp directFuncOp;
mlir::Operation *directCalleeOp = nullptr;

if (auto fnOp = dyn_cast<cir::FuncOp>(CalleePtr)) {
directFuncOp = fnOp;
directCalleeOp = fnOp;
} else if (auto ifuncOp = dyn_cast<cir::IFuncOp>(CalleePtr)) {
directCalleeOp = ifuncOp;
} else if (auto getGlobalOp = dyn_cast<cir::GetGlobalOp>(CalleePtr)) {
// FIXME(cir): This peephole optimization to avoids indirect calls for
// builtins. This should be fixed in the builting declaration instead by
// not emitting an unecessary get_global in the first place.
auto *globalOp = mlir::SymbolTable::lookupSymbolIn(CGM.getModule(),
getGlobalOp.getName());
assert(getGlobalOp && "undefined global function");
directFuncOp = llvm::dyn_cast<cir::FuncOp>(globalOp);
assert(directFuncOp && "operation is not a function");
directCalleeOp = llvm::dyn_cast<cir::FuncOp>(globalOp);
assert(directCalleeOp && "operation is not a function");
} else {
[[maybe_unused]] auto resultTypes = CalleePtr->getResultTypes();
[[maybe_unused]] auto FuncPtrTy =
Expand All @@ -648,7 +659,7 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &CallInfo,
Attrs.getDictionary(&getMLIRContext()));

cir::CIRCallOpInterface callLikeOp = emitCallLikeOp(
*this, callLoc, indirectFuncTy, indirectFuncVal, directFuncOp,
*this, callLoc, indirectFuncTy, indirectFuncVal, directCalleeOp,
CIRCallArgs, isInvoke, callingConv, sideEffect, extraFnAttrs);

if (E)
Expand Down
9 changes: 9 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,15 @@ static CIRGenCallee emitDirectCallee(CIRGenModule &CGM, GlobalDecl GD) {
return CIRGenCallee::forBuiltin(builtinID, FD);
}

// Handle ifunc specially - get the IFuncOp directly
if (FD->hasAttr<IFuncAttr>()) {
llvm::StringRef mangledName = CGM.getMangledName(GD);
mlir::Operation *ifuncOp = CGM.getGlobalValue(mangledName);
assert(ifuncOp && isa<cir::IFuncOp>(ifuncOp) &&
"Expected IFuncOp for ifunc");
return CIRGenCallee::forDirect(ifuncOp, GD);
}

mlir::Operation *CalleePtr = emitFunctionDeclPointer(CGM, GD);

if ((CGM.getLangOpts().HIP || CGM.getLangOpts().CUDA) &&
Expand Down
85 changes: 84 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,10 @@ void CIRGenModule::emitGlobal(GlobalDecl gd) {

const auto *global = cast<ValueDecl>(gd.getDecl());

assert(!global->hasAttr<IFuncAttr>() && "NYI");
// IFunc like an alias whose value is resolved at runtime by calling resolver.
if (global->hasAttr<IFuncAttr>())
return emitIFuncDefinition(gd);

assert(!global->hasAttr<CPUDispatchAttr>() && "NYI");

if (langOpts.CUDA || langOpts.HIP) {
Expand Down Expand Up @@ -722,6 +725,77 @@ void CIRGenModule::emitGlobal(GlobalDecl gd) {
}
}

void CIRGenModule::emitIFuncDefinition(GlobalDecl globalDecl) {
const auto *d = cast<FunctionDecl>(globalDecl.getDecl());
const IFuncAttr *ifa = d->getAttr<IFuncAttr>();
assert(ifa && "Not an ifunc?");

llvm::StringRef mangledName = getMangledName(globalDecl);

if (ifa->getResolver() == mangledName) {
getDiags().Report(ifa->getLocation(), diag::err_cyclic_alias) << 1;
return;
}

// Get function type for the ifunc.
mlir::Type declTy = getTypes().convertTypeForMem(d->getType());
auto funcTy = mlir::dyn_cast<cir::FuncType>(declTy);
assert(funcTy && "IFunc must have function type");

// The resolver might not be visited yet. Create a forward declaration for it.
mlir::Type resolverRetTy = builder.getPointerTo(funcTy);
auto resolverFuncTy =
cir::FuncType::get(llvm::ArrayRef<mlir::Type>{}, resolverRetTy);

// Ensure the resolver function is created.
GetOrCreateCIRFunction(ifa->getResolver(), resolverFuncTy, GlobalDecl(),
/*ForVTable=*/false);

mlir::OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(theModule.getBody());

// Report an error if some definition overrides ifunc.
mlir::Operation *entry = getGlobalValue(mangledName);
if (entry) {
// Check if this is a non-declaration (an actual definition).
bool isDeclaration = false;
if (auto func = mlir::dyn_cast<cir::FuncOp>(entry))
isDeclaration = func.isDeclaration();

if (!isDeclaration) {
GlobalDecl otherGd;
if (lookupRepresentativeDecl(mangledName, otherGd) &&
DiagnosedConflictingDefinitions.insert(globalDecl).second) {
getDiags().Report(d->getLocation(), diag::err_duplicate_mangled_name)
<< mangledName;
getDiags().Report(otherGd.getDecl()->getLocation(),
diag::note_previous_definition);
}
return;
}

// This is just a forward declaration, remove it.
if (auto func = mlir::dyn_cast<cir::FuncOp>(entry)) {
func.erase();
}
}

// Get linkage
GVALinkage linkage = astContext.GetGVALinkageForFunction(d);
cir::GlobalLinkageKind cirLinkage =
getCIRLinkageForDeclarator(d, linkage, /*IsConstantVariable=*/false);

// Get visibility
cir::VisibilityAttr visibilityAttr = getGlobalVisibilityAttrFromDecl(d);
cir::VisibilityKind visibilityKind = visibilityAttr.getValue();

auto ifuncOp = builder.create<cir::IFuncOp>(
theModule.getLoc(), mangledName, visibilityKind, funcTy,
ifa->getResolver(), cirLinkage, /*sym_visibility=*/mlir::StringAttr{});

setCommonAttributes(globalDecl, ifuncOp);
}

void CIRGenModule::emitGlobalFunctionDefinition(GlobalDecl gd,
mlir::Operation *op) {
auto const *d = cast<FunctionDecl>(gd.getDecl());
Expand Down Expand Up @@ -2432,6 +2506,10 @@ void CIRGenModule::ReplaceUsesOfNonProtoTypeWithRealFunction(
// Replace type
getGlobalOp.getAddr().setType(
cir::PointerType::get(newFn.getFunctionType()));
} else if (auto ifuncOp = dyn_cast<cir::IFuncOp>(use.getUser())) {
// IFuncOp references the resolver function by symbol.
// The symbol reference doesn't need updating - it's name-based.
// The resolver's signature is validated when the IFuncOp is created.
} else {
llvm_unreachable("NIY");
}
Expand Down Expand Up @@ -3139,6 +3217,11 @@ cir::FuncOp CIRGenModule::GetOrCreateCIRFunction(
// Lookup the entry, lazily creating it if necessary.
mlir::Operation *entry = getGlobalValue(mangledName);
if (entry) {
// If this is an ifunc, we can't create a FuncOp for it. Just return nullptr
// and let the caller handle calling through the ifunc.
if (isa<cir::IFuncOp>(entry))
return nullptr;

assert(isa<cir::FuncOp>(entry) &&
"not implemented, only supports FuncOp for now");

Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/CodeGen/CIRGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@ class CIRGenModule : public CIRGenTypeCache {
cir::FuncOp func);

void emitGlobalDefinition(clang::GlobalDecl D, mlir::Operation *Op = nullptr);
void emitIFuncDefinition(clang::GlobalDecl globalDecl);
void emitGlobalFunctionDefinition(clang::GlobalDecl D, mlir::Operation *Op);
void emitGlobalVarDefinition(const clang::VarDecl *D,
bool IsTentative = false);
Expand Down
20 changes: 15 additions & 5 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3134,17 +3134,26 @@ verifyCallCommInSymbolUses(Operation *op, SymbolTableCollection &symbolTable) {
if (!fnAttr)
return success();

// Look up the callee - it can be either a FuncOp or an IFuncOp
cir::FuncOp fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
if (!fn)
cir::IFuncOp ifn =
symbolTable.lookupNearestSymbolFrom<cir::IFuncOp>(op, fnAttr);

if (!fn && !ifn)
return op->emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid function";

auto callIf = dyn_cast<cir::CIRCallOpInterface>(op);
assert(callIf && "expected CIR call interface to be always available");

// Get function type from either FuncOp or IFuncOp
cir::FuncType fnType = fn ? fn.getFunctionType() : ifn.getFunctionType();

// Verify that the operand and result types match the callee. Note that
// argument-checking is disabled for functions without a prototype.
auto fnType = fn.getFunctionType();
if (!fn.getNoProto()) {
// IFuncs are always considered to have a prototype.
bool hasProto = ifn || !fn.getNoProto();
if (hasProto) {
unsigned numCallOperands = callIf.getNumArgOperands();
unsigned numFnOpOperands = fnType.getNumInputs();

Expand All @@ -3161,8 +3170,9 @@ verifyCallCommInSymbolUses(Operation *op, SymbolTableCollection &symbolTable) {
<< op->getOperand(i).getType() << " for operand number " << i;
}

// Calling convention must match.
if (callIf.getCallingConv() != fn.getCallingConv())
// Calling convention must match (only check for FuncOp; IFuncOp uses the
// type's convention)
if (fn && callIf.getCallingConv() != fn.getCallingConv())
return op->emitOpError("calling convention mismatch: expected ")
<< stringifyCallingConv(fn.getCallingConv()) << ", but provided "
<< stringifyCallingConv(callIf.getCallingConv());
Expand Down
Loading