Skip to content

Commit

Permalink
[CIR][Lowering] Lower vtable and type info (#264)
Browse files Browse the repository at this point in the history
Lowering Vtable and RTTI globals. Also lowering AddressPoint.

based on #259
  • Loading branch information
htyu authored and lanza committed Jan 29, 2024
1 parent 1bdbb9a commit 1363edd
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 25 deletions.
136 changes: 120 additions & 16 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
Expand Down Expand Up @@ -148,7 +149,41 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
return result;
}

// ArrayAttr visitor.
// VTableAttr visitor.
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
mlir::cir::VTableAttr vtableArr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
auto llvmTy = converter->convertType(vtableArr.getType());
auto loc = parentOp->getLoc();
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);

for (auto [idx, elt] : llvm::enumerate(vtableArr.getVtableData())) {
mlir::Value init = lowerCirAttrAsValue(parentOp, elt, rewriter, converter);
result = rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx);
}

return result;
}

// TypeInfoAttr visitor.
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
mlir::cir::TypeInfoAttr typeinfoArr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
auto llvmTy = converter->convertType(typeinfoArr.getType());
auto loc = parentOp->getLoc();
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);

for (auto [idx, elt] : llvm::enumerate(typeinfoArr.getData())) {
mlir::Value init = lowerCirAttrAsValue(parentOp, elt, rewriter, converter);
result = rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx);
}

return result;
}

// ConstArrayAttr visitor
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
mlir::cir::ConstArrayAttr constArr,
mlir::ConversionPatternRewriter &rewriter,
Expand Down Expand Up @@ -191,27 +226,47 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
auto module = parentOp->getParentOfType<mlir::ModuleOp>();
auto sourceSymbol = dyn_cast<mlir::LLVM::GlobalOp>(
mlir::SymbolTable::lookupSymbolIn(module, globalAttr.getSymbol()));
assert(sourceSymbol && "Unlowered GlobalOp");
auto loc = parentOp->getLoc();

auto addressOfOp = rewriter.create<mlir::LLVM::AddressOfOp>(
loc, mlir::LLVM::LLVMPointerType::get(parentOp->getContext()),
sourceSymbol.getSymName());
mlir::Type sourceType;
llvm::StringRef symName;
auto sourceSymbol =
mlir::SymbolTable::lookupSymbolIn(module, globalAttr.getSymbol());
if (auto llvmSymbol = dyn_cast<mlir::LLVM::GlobalOp>(sourceSymbol)) {
sourceType = llvmSymbol.getType();
symName = llvmSymbol.getSymName();
} else if (auto cirSymbol = dyn_cast<mlir::cir::GlobalOp>(sourceSymbol)) {
sourceType = converter->convertType(cirSymbol.getSymType());
symName = cirSymbol.getSymName();
} else {
llvm_unreachable("Unexpected GlobalOp type");
}

assert(!globalAttr.getIndices() && "TODO");
auto loc = parentOp->getLoc();
auto srcPtrType = mlir::LLVM::LLVMPointerType::get(parentOp->getContext());
mlir::Value addrOp =
rewriter.create<mlir::LLVM::AddressOfOp>(loc, srcPtrType, symName);

if (globalAttr.getIndices()) {
llvm::SmallVector<mlir::LLVM::GEPArg> Indices;
for (auto idx : globalAttr.getIndices()) {
auto intAttr = dyn_cast<mlir::cir::IntAttr>(idx);
assert(intAttr && "index must be integers");
Indices.push_back(intAttr.getSInt());
}
auto eltTy = converter->convertType(sourceType);
addrOp = rewriter.create<mlir::LLVM::GEPOp>(loc, srcPtrType, eltTy, addrOp,
Indices, true);
}

auto ptrTy = globalAttr.getType().dyn_cast<mlir::cir::PointerType>();
assert(ptrTy && "Expecting pointer type in GlobalViewAttr");
auto llvmEltTy = converter->convertType(ptrTy.getPointee());

if (llvmEltTy == sourceSymbol.getType())
return addressOfOp;
if (llvmEltTy == sourceType)
return addrOp;

auto llvmDstTy = converter->convertType(globalAttr.getType());
return rewriter.create<mlir::LLVM::BitcastOp>(parentOp->getLoc(), llvmDstTy,
addressOfOp.getResult());
addrOp);
}

/// Switches on the type of attribute and calls the appropriate conversion.
Expand All @@ -235,6 +290,10 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr,
return lowerCirAttrAsValue(parentOp, zeroAttr, rewriter, converter);
if (const auto globalAttr = attr.dyn_cast<mlir::cir::GlobalViewAttr>())
return lowerCirAttrAsValue(parentOp, globalAttr, rewriter, converter);
if (const auto vtableAttr = attr.dyn_cast<mlir::cir::VTableAttr>())
return lowerCirAttrAsValue(parentOp, vtableAttr, rewriter, converter);
if (const auto typeinfoAttr = attr.dyn_cast<mlir::cir::TypeInfoAttr>())
return lowerCirAttrAsValue(parentOp, typeinfoAttr, rewriter, converter);

llvm_unreachable("unhandled attribute type");
}
Expand Down Expand Up @@ -1324,8 +1383,8 @@ class CIRGlobalOpLowering

// Check for missing funcionalities.
if (!init.has_value()) {
rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>(op, llvmType, isConst,
linkage, symbol, mlir::Attribute());
rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>(
op, llvmType, isConst, linkage, symbol, mlir::Attribute());
return mlir::success();
}

Expand Down Expand Up @@ -1378,6 +1437,20 @@ class CIRGlobalOpLowering
rewriter.create<mlir::LLVM::ReturnOp>(
loc, lowerCirAttrAsValue(op, attr, rewriter, typeConverter));
return mlir::success();
} else if (const auto vtableAttr =
init.value().dyn_cast<mlir::cir::VTableAttr>()) {
setupRegionInitializedLLVMGlobalOp(op, rewriter);
rewriter.create<mlir::LLVM::ReturnOp>(
op->getLoc(),
lowerCirAttrAsValue(op, vtableAttr, rewriter, typeConverter));
return mlir::success();
} else if (const auto typeinfoAttr =
init.value().dyn_cast<mlir::cir::TypeInfoAttr>()) {
setupRegionInitializedLLVMGlobalOp(op, rewriter);
rewriter.create<mlir::LLVM::ReturnOp>(
op->getLoc(),
lowerCirAttrAsValue(op, typeinfoAttr, rewriter, typeConverter));
return mlir::success();
} else {
op.emitError() << "usupported initializer '" << init.value() << "'";
return mlir::failure();
Expand Down Expand Up @@ -1844,6 +1917,37 @@ class CIRFAbsOpLowering : public mlir::OpConversionPattern<mlir::cir::FAbsOp> {
}
};

class CIRVTableAddrPointOpLowering
: public mlir::OpConversionPattern<mlir::cir::VTableAddrPointOp> {
public:
using OpConversionPattern<mlir::cir::VTableAddrPointOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::VTableAddrPointOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
const auto *converter = getTypeConverter();
auto targetType = converter->convertType(op.getType());
mlir::Value symAddr = op.getSymAddr();

mlir::Type eltType;
if (!symAddr) {
auto module = op->getParentOfType<mlir::ModuleOp>();
auto symbol = dyn_cast<mlir::LLVM::GlobalOp>(
mlir::SymbolTable::lookupSymbolIn(module, op.getNameAttr()));
symAddr = rewriter.create<mlir::LLVM::AddressOfOp>(
op.getLoc(), mlir::LLVM::LLVMPointerType::get(getContext()),
*op.getName());
eltType = converter->convertType(symbol.getType());
}

auto offsets = llvm::SmallVector<mlir::LLVM::GEPArg>{
0, op.getVtableIndex(), op.getAddressPointIndex()};
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, targetType, eltType,
symAddr, offsets, true);
return mlir::success();
}
};

void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRReturnLowering>(patterns.getContext());
Expand All @@ -1857,7 +1961,7 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
CIRVAArgLowering, CIRBrOpLowering, CIRTernaryOpLowering,
CIRGetMemberOpLowering, CIRSwitchOpLowering,
CIRPtrDiffOpLowering, CIRCopyOpLowering, CIRMemCpyOpLowering,
CIRFAbsOpLowering>(
CIRFAbsOpLowering, CIRVTableAddrPointOpLowering>(
converter, patterns.getContext());
}

Expand Down
30 changes: 21 additions & 9 deletions clang/test/CIR/CodeGen/vbase.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -fclangir-enable -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s
// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR
// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -fclangir-enable -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=LLVM

struct A {
int a;
Expand All @@ -13,23 +15,33 @@ void ppp() { B b; }


// Vtable definition for B
// CHECK: cir.global linkonce_odr @_ZTV1B = #cir.vtable<{#cir.const_array<[#cir.ptr<12> : !cir.ptr<!u8i>, #cir.ptr<null> : !cir.ptr<!u8i>, #cir.global_view<@_ZTI1B> : !cir.ptr<!u8i>]> : !cir.array<!cir.ptr<!u8i> x 3>}>
// CIR: cir.global linkonce_odr @_ZTV1B = #cir.vtable<{#cir.const_array<[#cir.ptr<12> : !cir.ptr<!u8i>, #cir.ptr<null> : !cir.ptr<!u8i>, #cir.global_view<@_ZTI1B> : !cir.ptr<!u8i>]> : !cir.array<!cir.ptr<!u8i> x 3>}>

// VTT for B.
// CHECK: cir.global linkonce_odr @_ZTT1B = #cir.const_array<[#cir.global_view<@_ZTV1B, [#cir.int<0> : !s32i, #cir.int<0> : !s32i, #cir.int<3> : !s32i]> : !cir.ptr<!u8i>]> : !cir.array<!cir.ptr<!u8i> x 1>
// CIR: cir.global linkonce_odr @_ZTT1B = #cir.const_array<[#cir.global_view<@_ZTV1B, [#cir.int<0> : !s32i, #cir.int<0> : !s32i, #cir.int<3> : !s32i]> : !cir.ptr<!u8i>]> : !cir.array<!cir.ptr<!u8i> x 1>

// CHECK: cir.global "private" external @_ZTVN10__cxxabiv121__vmi_class_type_infoE
// CIR: cir.global "private" external @_ZTVN10__cxxabiv121__vmi_class_type_infoE

// Type info name for B
// CHECK: cir.global linkonce_odr @_ZTS1B = #cir.const_array<"1B" : !cir.array<!s8i x 2>> : !cir.array<!s8i x 2>
// CIR: cir.global linkonce_odr @_ZTS1B = #cir.const_array<"1B" : !cir.array<!s8i x 2>> : !cir.array<!s8i x 2>

// CHECK: cir.global "private" external @_ZTVN10__cxxabiv117__class_type_infoE : !cir.ptr<!cir.ptr<!u8i>>
// CIR: cir.global "private" external @_ZTVN10__cxxabiv117__class_type_infoE : !cir.ptr<!cir.ptr<!u8i>>

// Type info name for A
// CHECK: cir.global linkonce_odr @_ZTS1A = #cir.const_array<"1A" : !cir.array<!s8i x 2>> : !cir.array<!s8i x 2>
// CIR: cir.global linkonce_odr @_ZTS1A = #cir.const_array<"1A" : !cir.array<!s8i x 2>> : !cir.array<!s8i x 2>

// Type info A.
// CHECK: cir.global constant external @_ZTI1A = #cir.typeinfo<{#cir.global_view<@_ZTVN10__cxxabiv117__class_type_infoE, [#cir.int<2> : !s64i]> : !cir.ptr<!u8i>, #cir.global_view<@_ZTS1A> : !cir.ptr<!u8i>}>
// CIR: cir.global constant external @_ZTI1A = #cir.typeinfo<{#cir.global_view<@_ZTVN10__cxxabiv117__class_type_infoE, [#cir.int<2> : !s64i]> : !cir.ptr<!u8i>, #cir.global_view<@_ZTS1A> : !cir.ptr<!u8i>}>

// Type info B.
// CHECK: cir.global constant external @_ZTI1B = #cir.typeinfo<{#cir.global_view<@_ZTVN10__cxxabiv121__vmi_class_type_infoE, [#cir.int<2> : !s64i]> : !cir.ptr<!u8i>, #cir.global_view<@_ZTS1B> : !cir.ptr<!u8i>, #cir.int<0> : !u32i, #cir.int<1> : !u32i, #cir.global_view<@_ZTI1A> : !cir.ptr<!u8i>, #cir.int<-6141> : !s64i}>
// CIR: cir.global constant external @_ZTI1B = #cir.typeinfo<{#cir.global_view<@_ZTVN10__cxxabiv121__vmi_class_type_infoE, [#cir.int<2> : !s64i]> : !cir.ptr<!u8i>, #cir.global_view<@_ZTS1B> : !cir.ptr<!u8i>, #cir.int<0> : !u32i, #cir.int<1> : !u32i, #cir.global_view<@_ZTI1A> : !cir.ptr<!u8i>, #cir.int<-6141> : !s64i}>


// LLVM: @_ZTV1B = linkonce_odr global { [3 x ptr] } { [3 x ptr] [ptr inttoptr (i64 12 to ptr), ptr null, ptr @_ZTI1B] }
// LLVM: @_ZTT1B = linkonce_odr global [1 x ptr] [ptr getelementptr inbounds ({ [3 x ptr] }, ptr @_ZTV1B, i32 0, i32 0, i32 3)]
// LLVM: @_ZTVN10__cxxabiv121__vmi_class_type_infoE = external global ptr
// LLVM: @_ZTS1B = linkonce_odr global [2 x i8] c"1B"
// LLVM: @_ZTVN10__cxxabiv117__class_type_infoE = external global ptr
// LLVM: @_ZTS1A = linkonce_odr global [2 x i8] c"1A"
// LLVM: @_ZTI1A = constant { ptr, ptr } { ptr getelementptr inbounds (ptr, ptr @_ZTVN10__cxxabiv117__class_type_infoE, i32 2), ptr @_ZTS1A }
// LLVM: @_ZTI1B = constant { ptr, ptr, i32, i32, ptr, i64 } { ptr getelementptr inbounds (ptr, ptr @_ZTVN10__cxxabiv121__vmi_class_type_infoE, i32 2), ptr @_ZTS1B, i32 0, i32 1, ptr @_ZTI1A, i64 -6141 }

0 comments on commit 1363edd

Please sign in to comment.