diff --git a/clang/include/clang/CIR/ABIArgInfo.h b/clang/include/clang/CIR/ABIArgInfo.h index cd679961eb49f..28d831f9a0353 100644 --- a/clang/include/clang/CIR/ABIArgInfo.h +++ b/clang/include/clang/CIR/ABIArgInfo.h @@ -53,10 +53,10 @@ class ABIArgInfo { public: ABIArgInfo(Kind k = Direct) : directAttr{0, 0}, theKind(k) {} - static ABIArgInfo getDirect(mlir::Type ty = nullptr) { + static ABIArgInfo getDirect(mlir::Type ty = nullptr, unsigned offset = 0) { ABIArgInfo info(Direct); info.setCoerceToType(ty); - assert(!cir::MissingFeatures::abiArgInfo()); + info.directAttr.offset = offset; return info; } @@ -82,13 +82,10 @@ class ABIArgInfo { return false; } - bool canHaveCoerceToType() const { - assert(!cir::MissingFeatures::abiArgInfo()); - return isDirect(); - } + bool canHaveCoerceToType() const { return isDirect(); } unsigned getDirectOffset() const { - assert(!cir::MissingFeatures::abiArgInfo()); + assert(isDirect() && "ABIArgInfo offset is only valid for Direct"); return directAttr.offset; } diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp index f648eff375a77..36ce0a05390d8 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp @@ -15,6 +15,7 @@ #include "CIRGenCXXABI.h" #include "CIRGenFunction.h" #include "CIRGenFunctionInfo.h" +#include "TargetInfo.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Attributes.h" #include "clang/CIR/ABIArgInfo.h" @@ -42,6 +43,7 @@ CIRGenFunctionInfo *CIRGenFunctionInfo::create( fi->required = required; fi->numArgs = argTypes.size(); + fi->returnInfo = cir::ABIArgInfo::getDirect(); fi->getArgTypes()[0] = resultType; std::copy(argTypes.begin(), argTypes.end(), fi->argTypesBegin()); @@ -56,7 +58,10 @@ cir::FuncType CIRGenTypes::getFunctionType(GlobalDecl gd) { } cir::FuncType CIRGenTypes::getFunctionType(const CIRGenFunctionInfo &info) { + const cir::ABIArgInfo &retInfo = info.getReturnInfo(); mlir::Type resultType = convertType(info.getReturnType()); + if (retInfo.getDirectOffset() != 0 && retInfo.getCoerceToType()) + resultType = retInfo.getCoerceToType(); SmallVector argTypes; argTypes.reserve(info.getNumRequiredArgs()); @@ -1156,6 +1161,35 @@ CIRGenTypes::arrangeFreeFunctionType(CanQual fnpt) { fnpt->getExtInfo(), RequiredArgs(0)); } +static Address emitAddressAtOffset(CIRGenFunction &cgf, Address addr, + const cir::ABIArgInfo &info, + mlir::Location loc) { + if (!info.getDirectOffset()) + return addr; + mlir::Value offsetVal = + cgf.getBuilder().getUnsignedInt(loc, info.getDirectOffset(), 64); + Address byteAddr = + addr.withElementType(cgf.getBuilder(), cgf.getBuilder().getUInt8Ty()); + return byteAddr.withPointer( + cgf.getBuilder().createPtrStride(loc, byteAddr.getPointer(), offsetVal)); +} + +static void emitCoercedReturnStore(CIRGenFunction &cgf, mlir::Location loc, + mlir::Value retVal, QualType retTy, + Address destPtr, + const cir::ABIArgInfo &retInfo) { + if (isEmptyRecordForLayout(cgf.getContext(), retTy)) + return; + Address storePtr = emitAddressAtOffset(cgf, destPtr, retInfo, loc); + mlir::Type storeTy = retInfo.getCoerceToType(); + if (!storeTy) + storeTy = retVal.getType(); + if (retVal.getType() != storeTy) + cgf.cgm.errorNYI(loc, "coerced return value type mismatch"); + storePtr = storePtr.withElementType(cgf.getBuilder(), storeTy); + cgf.getBuilder().createStore(loc, retVal, storePtr); +} + RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo, const CIRGenCallee &callee, ReturnValueSlot returnValue, @@ -1354,8 +1388,13 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo, mlir::ResultRange results = theCall->getOpResults(); assert(results.size() <= 1 && "multiple returns from a call"); - SourceLocRAIIObject loc{*this, callLoc}; - emitAggregateStore(results[0], destPtr); + SourceLocRAIIObject locRAII{*this, callLoc}; + const cir::ABIArgInfo retInfo = funcInfo.getReturnInfo(); + if (retInfo.getDirectOffset() != 0) + emitCoercedReturnStore(*this, callLoc, results[0], retTy, destPtr, + retInfo); + else + emitAggregateStore(results[0], destPtr); return RValue::getAggregate(destPtr); } case cir::TEK_Scalar: { diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp index 52e7a9d3de412..5d788fe0cb69d 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp @@ -329,6 +329,19 @@ void CIRGenFunction::LexicalScope::cleanup() { } } +static Address emitReturnSlotAtOffset(CIRGenFunction &cgf, Address addr, + const cir::ABIArgInfo &info, + mlir::Location loc) { + if (!info.getDirectOffset()) + return addr; + mlir::Value offsetVal = + cgf.getBuilder().getUnsignedInt(loc, info.getDirectOffset(), 64); + Address byteAddr = + addr.withElementType(cgf.getBuilder(), cgf.getBuilder().getUInt8Ty()); + return byteAddr.withPointer( + cgf.getBuilder().createPtrStride(loc, byteAddr.getPointer(), offsetVal)); +} + cir::ReturnOp CIRGenFunction::LexicalScope::emitReturn(mlir::Location loc) { CIRGenBuilderTy &builder = cgf.getBuilder(); @@ -336,9 +349,20 @@ cir::ReturnOp CIRGenFunction::LexicalScope::emitReturn(mlir::Location loc) { assert(fn && "emitReturn from non-function"); if (!fn.getFunctionType().hasVoidReturn()) { - // Load the value from `__retval` and return it via the `cir.return` op. - auto value = cir::LoadOp::create( - builder, loc, fn.getFunctionType().getReturnType(), *cgf.fnRetAlloca); + const CIRGenFunctionInfo &fnInfo = + cgf.cgm.getTypes().arrangeGlobalDeclaration(cgf.curGD); + const cir::ABIArgInfo &retInfo = fnInfo.getReturnInfo(); + mlir::Type loadTy = fn.getFunctionType().getReturnType(); + Address loadAddr = cgf.returnValue; + if (retInfo.getDirectOffset() != 0) { + loadAddr = emitReturnSlotAtOffset(cgf, loadAddr, retInfo, loc); + if (retInfo.getCoerceToType()) { + loadTy = retInfo.getCoerceToType(); + loadAddr = loadAddr.withElementType(cgf.getBuilder(), loadTy); + } + } + auto value = + cir::LoadOp::create(builder, loc, loadTy, loadAddr.getPointer()); return cir::ReturnOp::create(builder, loc, llvm::ArrayRef(value.getResult())); } diff --git a/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h b/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h index d37a6149bcafa..4c1497aaa36be 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h @@ -95,6 +95,8 @@ class CIRGenFunctionInfo final RequiredArgs required; + cir::ABIArgInfo returnInfo; + unsigned numArgs; CanQualType *getArgTypes() { return getTrailingObjects(); } @@ -123,6 +125,7 @@ class CIRGenFunctionInfo final // Friending class TrailingObjects is apparantly not good enough for MSVC, so // these have to be public. friend class TrailingObjects; + friend class CIRGenTypes; using const_arg_iterator = const CanQualType *; using arg_iterator = CanQualType *; @@ -160,14 +163,7 @@ class CIRGenFunctionInfo final CanQualType getReturnType() const { return getArgTypes()[0]; } - cir::ABIArgInfo getReturnInfo() const { - assert(!cir::MissingFeatures::abiArgInfo()); - // TODO(cir): we currently just 'fake' this, but should calculate - // this/figure out what it means when we get our ABI info set correctly. - // For now, we leave this as a direct return. - - return cir::ABIArgInfo::getDirect(); - } + cir::ABIArgInfo getReturnInfo() const { return returnInfo; } const_arg_iterator argTypesBegin() const { return getArgTypes() + 1; } const_arg_iterator argTypesEnd() const { return getArgTypes() + 1 + numArgs; } diff --git a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp index 4777d8e429e34..2a897002ac207 100644 --- a/clang/lib/CIR/CodeGen/CIRGenStmt.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenStmt.cpp @@ -667,11 +667,26 @@ mlir::LogicalResult CIRGenFunction::emitReturnStmt(const ReturnStmt &s) { // directly. // TODO(cir): Eliminate this redundant load and the store above when we can. if (fnRetAlloca) { - // Load the value from `__retval` and return it via the `cir.return` op. - cir::AllocaOp retAlloca = - mlir::cast(fnRetAlloca->getDefiningOp()); - auto value = cir::LoadOp::create(builder, loc, retAlloca.getAllocaType(), - *fnRetAlloca); + const CIRGenFunctionInfo &fnInfo = + cgm.getTypes().arrangeGlobalDeclaration(curGD); + const cir::ABIArgInfo &retInfo = fnInfo.getReturnInfo(); + mlir::Type loadTy = + cast(curFn).getFunctionType().getReturnType(); + Address loadAddr = returnValue; + if (retInfo.getDirectOffset() != 0) { + mlir::Value offsetVal = + builder.getUnsignedInt(loc, retInfo.getDirectOffset(), 64); + Address byteAddr = + loadAddr.withElementType(builder, builder.getUInt8Ty()); + loadAddr = byteAddr.withPointer( + builder.createPtrStride(loc, byteAddr.getPointer(), offsetVal)); + if (retInfo.getCoerceToType()) { + loadTy = retInfo.getCoerceToType(); + loadAddr = loadAddr.withElementType(builder, loadTy); + } + } + auto value = + cir::LoadOp::create(builder, loc, loadTy, loadAddr.getPointer()); cir::ReturnOp::create(builder, loc, {value}); } else { @@ -1274,10 +1289,24 @@ void CIRGenFunction::emitReturnOfRValue(mlir::Location loc, RValue rv, // directly. // TODO(cir): Eliminate this redundant load and the store above when we can. // Load the value from `__retval` and return it via the `cir.return` op. - cir::AllocaOp retAlloca = - mlir::cast(fnRetAlloca->getDefiningOp()); - auto value = cir::LoadOp::create(builder, loc, retAlloca.getAllocaType(), - *fnRetAlloca); + const CIRGenFunctionInfo &fnInfo = + cgm.getTypes().arrangeGlobalDeclaration(curGD); + const cir::ABIArgInfo &retInfo = fnInfo.getReturnInfo(); + mlir::Type loadTy = + cast(curFn).getFunctionType().getReturnType(); + Address loadAddr = returnValue; + if (retInfo.getDirectOffset() != 0) { + mlir::Value offsetVal = + builder.getUnsignedInt(loc, retInfo.getDirectOffset(), 64); + Address byteAddr = loadAddr.withElementType(builder, builder.getUInt8Ty()); + loadAddr = byteAddr.withPointer( + builder.createPtrStride(loc, byteAddr.getPointer(), offsetVal)); + if (retInfo.getCoerceToType()) { + loadTy = retInfo.getCoerceToType(); + loadAddr = loadAddr.withElementType(builder, loadTy); + } + } + auto value = cir::LoadOp::create(builder, loc, loadTy, loadAddr.getPointer()); cir::ReturnOp::create(builder, loc, {value}); } diff --git a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp index 85b7e854abb7f..cae37eaff6465 100644 --- a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp @@ -11,7 +11,10 @@ #include "clang/Basic/TargetInfo.h" #include "clang/CIR/Dialect/IR/CIRTypes.h" +#include "llvm/TargetParser/Triple.h" + #include +#include using namespace clang; using namespace clang::CIRGen; @@ -711,6 +714,66 @@ bool CIRGenTypes::isZeroInitializable(const RecordDecl *rd) { return getCIRGenRecordLayout(rd).isZeroInitializable(); } +cir::ABIArgInfo CIRGenTypes::classifyCIRReturnType(CanQualType retTy) { + mlir::Type fullTy = convertType(retTy); + const llvm::Triple &triple = cgm.getTriple(); + if (triple.getArch() != llvm::Triple::x86_64 || + triple.getOS() != llvm::Triple::Linux) + return cir::ABIArgInfo::getDirect(fullTy); + + if (!isa(retTy)) + return cir::ABIArgInfo::getDirect(fullTy); + + const auto *recTy = cast(retTy); + const RecordDecl *rd = cast(recTy->getDecl()); + if (!rd->isCompleteDefinition()) + return cir::ABIArgInfo::getDirect(fullTy); + + CharUnits size = astContext.getTypeSizeInChars(retTy); + if (size != CharUnits::fromQuantity(16)) + return cir::ABIArgInfo::getDirect(fullTy); + + const ASTRecordLayout &layout = astContext.getASTRecordLayout(rd); + if (layout.getFieldCount() < 2) + return cir::ABIArgInfo::getDirect(fullTy); + + if (layout.getFieldOffset(1) != + static_cast(astContext.toBits(CharUnits::fromQuantity(8)))) + return cir::ABIArgInfo::getDirect(fullTy); + + auto hiFieldIt = rd->field_begin(); + std::advance(hiFieldIt, 1); + QualType hiTy = hiFieldIt->getType(); + if (astContext.getTypeSizeInChars(hiTy) != CharUnits::fromQuantity(8)) + return cir::ABIArgInfo::getDirect(fullTy); + + mlir::Type coerceTy; + if (hiTy->isIntegralOrEnumerationType() || hiTy->isPointerType()) + coerceTy = convertType(hiTy); + else if (const RecordDecl *hiRd = hiTy->getAsRecordDecl()) { + if (!hiRd->isCompleteDefinition()) + return cir::ABIArgInfo::getDirect(fullTy); + const ASTRecordLayout &hiLayout = astContext.getASTRecordLayout(hiRd); + if (hiLayout.getFieldCount() != 1) + return cir::ABIArgInfo::getDirect(fullTy); + QualType innerTy = hiRd->field_begin()->getType(); + if (!innerTy->isIntegralOrEnumerationType()) + return cir::ABIArgInfo::getDirect(fullTy); + coerceTy = convertType(innerTy); + } else { + return cir::ABIArgInfo::getDirect(fullTy); + } + + if (!mlir::isa(coerceTy)) + return cir::ABIArgInfo::getDirect(fullTy); + + auto intTy = mlir::cast(coerceTy); + if (intTy.getWidth() > 64) + return cir::ABIArgInfo::getDirect(fullTy); + + return cir::ABIArgInfo::getDirect(coerceTy, 8); +} + const CIRGenFunctionInfo &CIRGenTypes::arrangeCIRFunctionInfo( CanQualType returnType, bool isInstanceMethod, llvm::ArrayRef argTypes, FunctionType::ExtInfo info, @@ -739,6 +802,7 @@ const CIRGenFunctionInfo &CIRGenTypes::arrangeCIRFunctionInfo( // Construction the function info. We co-allocate the ArgInfos. fi = CIRGenFunctionInfo::create(info, isInstanceMethod, returnType, argTypes, required); + fi->returnInfo = classifyCIRReturnType(returnType); functionInfos.InsertNode(fi, insertPos); return *fi; diff --git a/clang/lib/CIR/CodeGen/CIRGenTypes.h b/clang/lib/CIR/CodeGen/CIRGenTypes.h index 15955c517f1f3..7aa785505ee1d 100644 --- a/clang/lib/CIR/CodeGen/CIRGenTypes.h +++ b/clang/lib/CIR/CodeGen/CIRGenTypes.h @@ -82,6 +82,8 @@ class CIRGenTypes { /// Heper for convertType. mlir::Type convertFunctionTypeInternal(clang::QualType ft); + cir::ABIArgInfo classifyCIRReturnType(clang::CanQualType retTy); + public: CIRGenTypes(CIRGenModule &cgm); ~CIRGenTypes(); diff --git a/clang/test/CIR/CodeGen/coerced-return-highpart.cpp b/clang/test/CIR/CodeGen/coerced-return-highpart.cpp new file mode 100644 index 0000000000000..6e1883d09559c --- /dev/null +++ b/clang/test/CIR/CodeGen/coerced-return-highpart.cpp @@ -0,0 +1,56 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir +// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t-cir.ll +// RUN: FileCheck --check-prefix=LLVM --input-file=%t-cir.ll %s +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll +// RUN: FileCheck --check-prefix=OGCG --input-file=%t.ll %s + +struct Pad8 { + char c[8]; +}; + +struct Val { + long v; +}; + +struct Ret { + Pad8 pad; + Val val; +}; + +Ret make(long x) { + Ret r{{0}, {x}}; + return r; +} + +long take(Ret r) { return r.val.v; } + +long caller() { + Ret tmp = make(99); + return take(tmp); +} + +// Coerced 16-byte struct return: only the high eightbyte (field at +8) is +// returned in a register; CIR stores the call result at offset 8. +// CIR: cir.func {{.*}} @_Z4makel(%{{.*}}: !s64i +// CIR-SAME: -> !s64i +// CIR: cir.const #cir.int<8> : !u64i +// CIR: cir.ptr_stride +// CIR: cir.load {{.*}} : !cir.ptr, !s64i +// CIR: cir.return {{.*}} : !s64i + +// CIR: cir.func {{.*}} @_Z6callerv() +// CIR: cir.call @_Z4makel +// CIR-SAME: -> !s64i +// CIR: cir.const #cir.int<8> : !u64i +// CIR: cir.ptr_stride +// CIR: cir.store {{.*}} : !s64i + +// LLVM: define {{.*}} @_Z4makel(i64 +// LLVM: define {{.*}} @_Z6callerv() +// LLVM: call i64 @_Z4makel(i64 +// LLVM: getelementptr i8, ptr %{{.*}}, i64 8 + +// OGCG: define {{.*}} @_Z6callerv() +// OGCG: call { i64, i64 } @_Z4makel(i64 +// OGCG: call noundef i64 @_Z4take3Ret(i64