Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][CIRGen] CIR generation for bitfields. Fixes #13 #233

Merged
merged 20 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,11 @@ class CIRGenBuilderTy : public mlir::OpBuilder {
return getConstInt(
loc, t, isSigned ? intVal.getSExtValue() : intVal.getZExtValue());
}
mlir::Value getConstAPInt(mlir::Location loc, mlir::Type typ,
const llvm::APInt &val) {
return create<mlir::cir::ConstantOp>(loc, typ,
getAttr<mlir::cir::IntAttr>(typ, val));
}
gitoleg marked this conversation as resolved.
Show resolved Hide resolved
mlir::cir::ConstantOp getBool(bool state, mlir::Location loc) {
return create<mlir::cir::ConstantOp>(loc, getBoolTy(),
getCIRBoolAttr(state));
Expand Down Expand Up @@ -677,6 +682,65 @@ class CIRGenBuilderTy : public mlir::OpBuilder {
mlir::cir::UnaryOpKind::Not, value);
}

mlir::Value createBinop(mlir::Value lhs, mlir::cir::BinOpKind kind,
const llvm::APInt &rhs) {
return create<mlir::cir::BinOp>(
lhs.getLoc(), lhs.getType(), kind, lhs,
getConstAPInt(lhs.getLoc(), lhs.getType(), rhs));
}

mlir::Value createBinop(mlir::Value lhs, mlir::cir::BinOpKind kind,
mlir::Value rhs) {
return create<mlir::cir::BinOp>(lhs.getLoc(), lhs.getType(), kind, lhs,
rhs);
}

mlir::Value createShift(mlir::Value lhs, const llvm::APInt &rhs,
bool isShiftLeft) {
return create<mlir::cir::ShiftOp>(
lhs.getLoc(), lhs.getType(), lhs,
getConstAPInt(lhs.getLoc(), lhs.getType(), rhs), isShiftLeft);
}

mlir::Value createShift(mlir::Value lhs, unsigned bits, bool isShiftLeft) {
auto width = lhs.getType().dyn_cast<mlir::cir::IntType>().getWidth();
auto shift = llvm::APInt(width, bits);
return createShift(lhs, shift, isShiftLeft);
}

mlir::Value createShiftLeft(mlir::Value lhs, unsigned bits) {
return createShift(lhs, bits, true);
}

mlir::Value createShiftRight(mlir::Value lhs, unsigned bits) {
return createShift(lhs, bits, false);
}

mlir::Value createLowBitsSet(mlir::Location loc, unsigned size,
unsigned bits) {
auto val = llvm::APInt::getLowBitsSet(size, bits);
auto typ = mlir::cir::IntType::get(getContext(), size, false);
return getConstAPInt(loc, typ, val);
}

mlir::Value createAnd(mlir::Value lhs, llvm::APInt rhs) {
auto val = getConstAPInt(lhs.getLoc(), lhs.getType(), rhs);
return createBinop(lhs, mlir::cir::BinOpKind::And, val);
}

mlir::Value createAnd(mlir::Value lhs, mlir::Value rhs) {
return createBinop(lhs, mlir::cir::BinOpKind::And, rhs);
gitoleg marked this conversation as resolved.
Show resolved Hide resolved
}

mlir::Value createOr(mlir::Value lhs, llvm::APInt rhs) {
auto val = getConstAPInt(lhs.getLoc(), lhs.getType(), rhs);
return createBinop(lhs, mlir::cir::BinOpKind::Or, val);
}

mlir::Value createOr(mlir::Value lhs, mlir::Value rhs) {
return createBinop(lhs, mlir::cir::BinOpKind::Or, rhs);
}

//===--------------------------------------------------------------------===//
// Cast/Conversion Operators
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -727,6 +791,5 @@ class CIRGenBuilderTy : public mlir::OpBuilder {
return createCast(mlir::cir::CastKind::bitcast, src, newTy);
}
};

} // namespace cir
#endif
229 changes: 217 additions & 12 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "clang/AST/GlobalDecl.h"
#include "clang/Basic/Builtins.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
Expand Down Expand Up @@ -126,6 +127,7 @@ static Address buildPointerWithAlignment(const Expr *E,
if (PtrTy->getPointeeType()->isVoidType())
break;
assert(!UnimplementedFeature::tbaa());

LValueBaseInfo InnerBaseInfo;
Address Addr = CGF.buildPointerWithAlignment(
CE->getSubExpr(), &InnerBaseInfo, IsKnownNonNull);
Expand Down Expand Up @@ -209,13 +211,78 @@ static Address buildPointerWithAlignment(const Expr *E,
return Address(CGF.buildScalarExpr(E), Align);
}

/// Helper method to check if the underlying ABI is AAPCS
static bool isAAPCS(const TargetInfo &TargetInfo) {
return TargetInfo.getABI().startswith("aapcs");
}

Address CIRGenFunction::getAddrOfField(LValue base, const FieldDecl *field,
unsigned index) {
if (index == 0)
return base.getAddress();

auto loc = getLoc(field->getLocation());
auto fieldType = convertType(field->getType());
auto fieldPtr =
mlir::cir::PointerType::get(getBuilder().getContext(), fieldType);
auto sea = getBuilder().createGetMember(
loc, fieldPtr, base.getPointer(), field->getName(), index);

return Address(sea, CharUnits::One());
}

static bool useVolatileForBitField(const CIRGenModule &cgm, LValue base,
const CIRGenBitFieldInfo &info,
const FieldDecl *field) {
return isAAPCS(cgm.getTarget()) && cgm.getCodeGenOpts().AAPCSBitfieldWidth &&
info.VolatileStorageSize != 0 &&
field->getType()
.withCVRQualifiers(base.getVRQualifiers())
.isVolatileQualified();
}

LValue CIRGenFunction::buildLValueForBitField(LValue base,
const FieldDecl *field) {

LValueBaseInfo BaseInfo = base.getBaseInfo();
const RecordDecl *rec = field->getParent();
auto &layout = CGM.getTypes().getCIRGenRecordLayout(field->getParent());
auto &info = layout.getBitFieldInfo(field);
auto useVolatile = useVolatileForBitField(CGM, base, info, field);
unsigned Idx = layout.getCIRFieldNo(field);

if (useVolatile ||
(IsInPreservedAIRegion ||
(getDebugInfo() && rec->hasAttr<BPFPreserveAccessIndexAttr>()))) {
llvm_unreachable("NYI");
}

Address Addr = getAddrOfField(base, field, Idx);

const unsigned SS = useVolatile ? info.VolatileStorageSize : info.StorageSize;

// Get the access type.
mlir::Type FieldIntTy = builder.getUIntNTy(SS);

auto loc = getLoc(field->getLocation());
if (Addr.getElementType() != FieldIntTy)
Addr = builder.createElementBitCast(loc, Addr, FieldIntTy);

QualType fieldType =
field->getType().withCVRQualifiers(base.getVRQualifiers());

assert(!UnimplementedFeature::tbaa() && "NYI TBAA for bit fields");
LValueBaseInfo FieldBaseInfo(BaseInfo.getAlignmentSource());
return LValue::MakeBitfield(Addr, info, fieldType, FieldBaseInfo);
}

LValue CIRGenFunction::buildLValueForField(LValue base,
const FieldDecl *field) {

LValueBaseInfo BaseInfo = base.getBaseInfo();

if (field->isBitField()) {
llvm_unreachable("NYI");
}
if (field->isBitField())
return buildLValueForBitField(base, field);

// Fields of may-alias structures are may-alais themselves.
// FIXME: this hould get propagated down through anonymous structs and unions.
Expand Down Expand Up @@ -518,12 +585,55 @@ void CIRGenFunction::buildStoreOfScalar(mlir::Value value, LValue lvalue,
/// method emits the address of the lvalue, then loads the result as an rvalue,
/// returning the rvalue.
RValue CIRGenFunction::buildLoadOfLValue(LValue LV, SourceLocation Loc) {
assert(LV.isSimple() && "not implemented");
assert(!LV.getType()->isFunctionType());
assert(!(LV.getType()->isConstantMatrixType()) && "not implemented");

// Everything needs a load.
return RValue::get(buildLoadOfScalar(LV, Loc));
if (LV.isBitField())
return buildLoadOfBitfieldLValue(LV, Loc);

if (LV.isSimple())
return RValue::get(buildLoadOfScalar(LV, Loc));
llvm_unreachable("NYI");
}

RValue CIRGenFunction::buildLoadOfBitfieldLValue(LValue LV,
SourceLocation Loc) {
const CIRGenBitFieldInfo &Info = LV.getBitFieldInfo();

// Get the output type.
mlir::Type ResLTy = convertType(LV.getType());
Address Ptr = LV.getBitFieldAddress();
mlir::Value Val = builder.createLoad(getLoc(Loc), Ptr);
auto ValWidth = Val.getType().cast<IntType>().getWidth();

bool UseVolatile = LV.isVolatileQualified() &&
Info.VolatileStorageSize != 0 && isAAPCS(CGM.getTarget());
const unsigned Offset = UseVolatile ? Info.VolatileOffset : Info.Offset;
const unsigned StorageSize =
UseVolatile ? Info.VolatileStorageSize : Info.StorageSize;

if (Info.IsSigned) {
assert(static_cast<unsigned>(Offset + Info.Size) <= StorageSize);

mlir::Type typ = builder.getSIntNTy(ValWidth);
Val = builder.createIntCast(Val, typ);

unsigned HighBits = StorageSize - Offset - Info.Size;
if (HighBits)
Val = builder.createShiftLeft(Val, HighBits);
if (Offset + HighBits)
Val = builder.createShiftRight(Val, Offset + HighBits);
} else {
if (Offset)
Val = builder.createShiftRight(Val, Offset);

if (static_cast<unsigned>(Offset) + Info.Size < StorageSize)
Val = builder.createAnd(Val,
llvm::APInt::getLowBitsSet(ValWidth, Info.Size));
}
Val = builder.createIntCast(Val, ResLTy);
assert(!UnimplementedFeature::emitScalarRangeCheck() && "NYI");
return RValue::get(Val);
}

void CIRGenFunction::buildStoreThroughLValue(RValue Src, LValue Dst) {
Expand All @@ -546,6 +656,81 @@ void CIRGenFunction::buildStoreThroughLValue(RValue Src, LValue Dst) {
buildStoreOfScalar(Src.getScalarVal(), Dst);
}

void CIRGenFunction::buildStoreThroughBitfieldLValue(RValue Src, LValue Dst,
mlir::Value &Result) {
const CIRGenBitFieldInfo &Info = Dst.getBitFieldInfo();
mlir::Type ResLTy = getTypes().convertTypeForMem(Dst.getType());
Address Ptr = Dst.getBitFieldAddress();

// Get the source value, truncated to the width of the bit-field.
mlir::Value SrcVal = Src.getScalarVal();

// Cast the source to the storage type and shift it into place.
SrcVal = builder.createIntCast(SrcVal, Ptr.getElementType());
auto SrcWidth = SrcVal.getType().cast<IntType>().getWidth();
mlir::Value MaskedVal = SrcVal;

const bool UseVolatile =
CGM.getCodeGenOpts().AAPCSBitfieldWidth && Dst.isVolatileQualified() &&
Info.VolatileStorageSize != 0 && isAAPCS(CGM.getTarget());
const unsigned StorageSize =
UseVolatile ? Info.VolatileStorageSize : Info.StorageSize;
const unsigned Offset = UseVolatile ? Info.VolatileOffset : Info.Offset;
// See if there are other bits in the bitfield's storage we'll need to load
// and mask together with source before storing.
if (StorageSize != Info.Size) {
assert(StorageSize > Info.Size && "Invalid bitfield size.");

mlir::Value Val = buildLoadOfScalar(Dst, Dst.getPointer().getLoc());

// Mask the source value as needed.
if (!hasBooleanRepresentation(Dst.getType()))
SrcVal = builder.createAnd(
SrcVal, llvm::APInt::getLowBitsSet(SrcWidth, Info.Size));

MaskedVal = SrcVal;
if (Offset)
SrcVal = builder.createShiftLeft(SrcVal, Offset);

// Mask out the original value.
Val = builder.createAnd(
Val, ~llvm::APInt::getBitsSet(SrcWidth, Offset, Offset + Info.Size));

// Or together the unchanged values and the source value.
SrcVal = builder.createOr(Val, SrcVal);

} else {
// According to the AACPS:
// When a volatile bit-field is written, and its container does not overlap
// with any non-bit-field member, its container must be read exactly once
// and written exactly once using the access width appropriate to the type
// of the container. The two accesses are not atomic.
llvm_unreachable("volatile bit-field is not implemented for the AACPS");
}

// Write the new value back out.
// TODO: constant matrix type, volatile, no init, non temporal, TBAA
buildStoreOfScalar(SrcVal, Ptr, Dst.isVolatileQualified(), Dst.getType(),
Dst.getBaseInfo(), false, false);

// Return the new value of the bit-field.
mlir::Value ResultVal = MaskedVal;
ResultVal = builder.createIntCast(ResultVal, ResLTy);

// Sign extend the value if needed.
if (Info.IsSigned) {
assert(Info.Size <= StorageSize);
unsigned HighBits = StorageSize - Info.Size;

if (HighBits) {
ResultVal = builder.createShiftLeft(ResultVal, HighBits);
ResultVal = builder.createShiftRight(ResultVal, HighBits);
}
}

Result = buildFromMemory(ResultVal, Dst.getType());
}

static LValue buildGlobalVarDeclLValue(CIRGenFunction &CGF, const Expr *E,
const VarDecl *VD) {
QualType T = E->getType();
Expand Down Expand Up @@ -769,7 +954,13 @@ LValue CIRGenFunction::buildBinaryOperatorLValue(const BinaryOperator *E) {
LValue LV = buildLValue(E->getLHS());

SourceLocRAIIObject Loc{*this, getLoc(E->getSourceRange())};
buildStoreThroughLValue(RV, LV);
if (LV.isBitField()) {
mlir::Value result;
buildStoreThroughBitfieldLValue(RV, LV, result);
} else {
buildStoreThroughLValue(RV, LV);
}

assert(!getContext().getLangOpts().OpenMP &&
"last priv cond not implemented");
return LV;
Expand Down Expand Up @@ -2203,6 +2394,13 @@ mlir::Value CIRGenFunction::buildAlloca(StringRef name, QualType ty,

mlir::Value CIRGenFunction::buildLoadOfScalar(LValue lvalue,
SourceLocation Loc) {
return buildLoadOfScalar(lvalue.getAddress(), lvalue.isVolatile(),
lvalue.getType(), getLoc(Loc), lvalue.getBaseInfo(),
lvalue.isNontemporal());
}

mlir::Value CIRGenFunction::buildLoadOfScalar(LValue lvalue,
mlir::Location Loc) {
gitoleg marked this conversation as resolved.
Show resolved Hide resolved
return buildLoadOfScalar(lvalue.getAddress(), lvalue.isVolatile(),
lvalue.getType(), Loc, lvalue.getBaseInfo(),
lvalue.isNontemporal());
Expand All @@ -2220,6 +2418,14 @@ mlir::Value CIRGenFunction::buildLoadOfScalar(Address Addr, bool Volatile,
QualType Ty, SourceLocation Loc,
LValueBaseInfo BaseInfo,
bool isNontemporal) {
return buildLoadOfScalar(Addr, Volatile, Ty, getLoc(Loc), BaseInfo,
isNontemporal);
}

mlir::Value CIRGenFunction::buildLoadOfScalar(Address Addr, bool Volatile,
QualType Ty, mlir::Location Loc,
LValueBaseInfo BaseInfo,
bool isNontemporal) {
gitoleg marked this conversation as resolved.
Show resolved Hide resolved
if (!CGM.getCodeGenOpts().PreserveVec3Type) {
if (Ty->isVectorType()) {
llvm_unreachable("NYI");
Expand All @@ -2233,15 +2439,14 @@ mlir::Value CIRGenFunction::buildLoadOfScalar(Address Addr, bool Volatile,
}

mlir::cir::LoadOp Load = builder.create<mlir::cir::LoadOp>(
getLoc(Loc), Addr.getElementType(), Addr.getPointer());
Loc, Addr.getElementType(), Addr.getPointer());

if (isNontemporal) {
llvm_unreachable("NYI");
}

// TODO: TBAA

// TODO: buildScalarRangeCheck

assert(!UnimplementedFeature::tbaa() && "NYI");
assert(!UnimplementedFeature::emitScalarRangeCheck() && "NYI");

return buildFromMemory(Load, Ty);
}
Expand Down
Loading