Skip to content

Commit

Permalink
[CIR][CIRGen] CIR generation for bitfields. Fixes #13 (#233)
Browse files Browse the repository at this point in the history
This PR introduces bitfelds support.  This now works:

```
#include <stdio.h>

typedef struct {
    int a1 : 4;
    int a2 : 28;
    int a3 : 16;
    int a4 : 3;
    int a5 : 17;
    int a6 : 25;
} A;

void init(A* a) {
    a->a1 = 1;
    a->a2 = 321;
    a->a3 = 15;
    a->a4 = -2;
    a->a5 = -123;
    a->a6 = 1234;
}

void print(A* a) {
    printf("%d %d %d %d %d %d\n",
        a->a1,
        a->a2,
        a->a3,
        a->a4,
        a->a5,
        a->a6
    );
}

int main() {
    A a;
    init(&a);
    print(&a);
    return 0;
}

```
the output is:
`1 321 15 -2 -123 1234`
  • Loading branch information
gitoleg authored and lanza committed Apr 29, 2024
1 parent 865d986 commit b0b7dd8
Show file tree
Hide file tree
Showing 10 changed files with 563 additions and 85 deletions.
65 changes: 64 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,11 @@ class CIRGenBuilderTy : public mlir::OpBuilder {
return getConstInt(
loc, t, isSigned ? intVal.getSExtValue() : intVal.getZExtValue());
}
mlir::Value getConstAPInt(mlir::Location loc, mlir::Type typ,
const llvm::APInt &val) {
return create<mlir::cir::ConstantOp>(loc, typ,
getAttr<mlir::cir::IntAttr>(typ, val));
}
mlir::cir::ConstantOp getBool(bool state, mlir::Location loc) {
return create<mlir::cir::ConstantOp>(loc, getBoolTy(),
getCIRBoolAttr(state));
Expand Down Expand Up @@ -677,6 +682,65 @@ class CIRGenBuilderTy : public mlir::OpBuilder {
mlir::cir::UnaryOpKind::Not, value);
}

mlir::Value createBinop(mlir::Value lhs, mlir::cir::BinOpKind kind,
const llvm::APInt &rhs) {
return create<mlir::cir::BinOp>(
lhs.getLoc(), lhs.getType(), kind, lhs,
getConstAPInt(lhs.getLoc(), lhs.getType(), rhs));
}

mlir::Value createBinop(mlir::Value lhs, mlir::cir::BinOpKind kind,
mlir::Value rhs) {
return create<mlir::cir::BinOp>(lhs.getLoc(), lhs.getType(), kind, lhs,
rhs);
}

mlir::Value createShift(mlir::Value lhs, const llvm::APInt &rhs,
bool isShiftLeft) {
return create<mlir::cir::ShiftOp>(
lhs.getLoc(), lhs.getType(), lhs,
getConstAPInt(lhs.getLoc(), lhs.getType(), rhs), isShiftLeft);
}

mlir::Value createShift(mlir::Value lhs, unsigned bits, bool isShiftLeft) {
auto width = lhs.getType().dyn_cast<mlir::cir::IntType>().getWidth();
auto shift = llvm::APInt(width, bits);
return createShift(lhs, shift, isShiftLeft);
}

mlir::Value createShiftLeft(mlir::Value lhs, unsigned bits) {
return createShift(lhs, bits, true);
}

mlir::Value createShiftRight(mlir::Value lhs, unsigned bits) {
return createShift(lhs, bits, false);
}

mlir::Value createLowBitsSet(mlir::Location loc, unsigned size,
unsigned bits) {
auto val = llvm::APInt::getLowBitsSet(size, bits);
auto typ = mlir::cir::IntType::get(getContext(), size, false);
return getConstAPInt(loc, typ, val);
}

mlir::Value createAnd(mlir::Value lhs, llvm::APInt rhs) {
auto val = getConstAPInt(lhs.getLoc(), lhs.getType(), rhs);
return createBinop(lhs, mlir::cir::BinOpKind::And, val);
}

mlir::Value createAnd(mlir::Value lhs, mlir::Value rhs) {
return createBinop(lhs, mlir::cir::BinOpKind::And, rhs);
}

mlir::Value createOr(mlir::Value lhs, llvm::APInt rhs) {
auto val = getConstAPInt(lhs.getLoc(), lhs.getType(), rhs);
return createBinop(lhs, mlir::cir::BinOpKind::Or, val);
}

mlir::Value createOr(mlir::Value lhs, mlir::Value rhs) {
return createBinop(lhs, mlir::cir::BinOpKind::Or, rhs);
}

//===--------------------------------------------------------------------===//
// Cast/Conversion Operators
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -727,6 +791,5 @@ class CIRGenBuilderTy : public mlir::OpBuilder {
return createCast(mlir::cir::CastKind::bitcast, src, newTy);
}
};

} // namespace cir
#endif
229 changes: 217 additions & 12 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "clang/AST/GlobalDecl.h"
#include "clang/Basic/Builtins.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
Expand Down Expand Up @@ -128,6 +129,7 @@ static Address buildPointerWithAlignment(const Expr *E,
if (PtrTy->getPointeeType()->isVoidType())
break;
assert(!UnimplementedFeature::tbaa());

LValueBaseInfo InnerBaseInfo;
Address Addr = CGF.buildPointerWithAlignment(
CE->getSubExpr(), &InnerBaseInfo, IsKnownNonNull);
Expand Down Expand Up @@ -211,13 +213,78 @@ static Address buildPointerWithAlignment(const Expr *E,
return Address(CGF.buildScalarExpr(E), Align);
}

/// Helper method to check if the underlying ABI is AAPCS
static bool isAAPCS(const TargetInfo &TargetInfo) {
return TargetInfo.getABI().starts_with("aapcs");
}

Address CIRGenFunction::getAddrOfField(LValue base, const FieldDecl *field,
unsigned index) {
if (index == 0)
return base.getAddress();

auto loc = getLoc(field->getLocation());
auto fieldType = convertType(field->getType());
auto fieldPtr =
mlir::cir::PointerType::get(getBuilder().getContext(), fieldType);
auto sea = getBuilder().createGetMember(
loc, fieldPtr, base.getPointer(), field->getName(), index);

return Address(sea, CharUnits::One());
}

static bool useVolatileForBitField(const CIRGenModule &cgm, LValue base,
const CIRGenBitFieldInfo &info,
const FieldDecl *field) {
return isAAPCS(cgm.getTarget()) && cgm.getCodeGenOpts().AAPCSBitfieldWidth &&
info.VolatileStorageSize != 0 &&
field->getType()
.withCVRQualifiers(base.getVRQualifiers())
.isVolatileQualified();
}

LValue CIRGenFunction::buildLValueForBitField(LValue base,
const FieldDecl *field) {

LValueBaseInfo BaseInfo = base.getBaseInfo();
const RecordDecl *rec = field->getParent();
auto &layout = CGM.getTypes().getCIRGenRecordLayout(field->getParent());
auto &info = layout.getBitFieldInfo(field);
auto useVolatile = useVolatileForBitField(CGM, base, info, field);
unsigned Idx = layout.getCIRFieldNo(field);

if (useVolatile ||
(IsInPreservedAIRegion ||
(getDebugInfo() && rec->hasAttr<BPFPreserveAccessIndexAttr>()))) {
llvm_unreachable("NYI");
}

Address Addr = getAddrOfField(base, field, Idx);

const unsigned SS = useVolatile ? info.VolatileStorageSize : info.StorageSize;

// Get the access type.
mlir::Type FieldIntTy = builder.getUIntNTy(SS);

auto loc = getLoc(field->getLocation());
if (Addr.getElementType() != FieldIntTy)
Addr = builder.createElementBitCast(loc, Addr, FieldIntTy);

QualType fieldType =
field->getType().withCVRQualifiers(base.getVRQualifiers());

assert(!UnimplementedFeature::tbaa() && "NYI TBAA for bit fields");
LValueBaseInfo FieldBaseInfo(BaseInfo.getAlignmentSource());
return LValue::MakeBitfield(Addr, info, fieldType, FieldBaseInfo);
}

LValue CIRGenFunction::buildLValueForField(LValue base,
const FieldDecl *field) {

LValueBaseInfo BaseInfo = base.getBaseInfo();

if (field->isBitField()) {
llvm_unreachable("NYI");
}
if (field->isBitField())
return buildLValueForBitField(base, field);

// Fields of may-alias structures are may-alais themselves.
// FIXME: this hould get propagated down through anonymous structs and unions.
Expand Down Expand Up @@ -520,12 +587,55 @@ void CIRGenFunction::buildStoreOfScalar(mlir::Value value, LValue lvalue,
/// method emits the address of the lvalue, then loads the result as an rvalue,
/// returning the rvalue.
RValue CIRGenFunction::buildLoadOfLValue(LValue LV, SourceLocation Loc) {
assert(LV.isSimple() && "not implemented");
assert(!LV.getType()->isFunctionType());
assert(!(LV.getType()->isConstantMatrixType()) && "not implemented");

// Everything needs a load.
return RValue::get(buildLoadOfScalar(LV, Loc));
if (LV.isBitField())
return buildLoadOfBitfieldLValue(LV, Loc);

if (LV.isSimple())
return RValue::get(buildLoadOfScalar(LV, Loc));
llvm_unreachable("NYI");
}

RValue CIRGenFunction::buildLoadOfBitfieldLValue(LValue LV,
SourceLocation Loc) {
const CIRGenBitFieldInfo &Info = LV.getBitFieldInfo();

// Get the output type.
mlir::Type ResLTy = convertType(LV.getType());
Address Ptr = LV.getBitFieldAddress();
mlir::Value Val = builder.createLoad(getLoc(Loc), Ptr);
auto ValWidth = Val.getType().cast<IntType>().getWidth();

bool UseVolatile = LV.isVolatileQualified() &&
Info.VolatileStorageSize != 0 && isAAPCS(CGM.getTarget());
const unsigned Offset = UseVolatile ? Info.VolatileOffset : Info.Offset;
const unsigned StorageSize =
UseVolatile ? Info.VolatileStorageSize : Info.StorageSize;

if (Info.IsSigned) {
assert(static_cast<unsigned>(Offset + Info.Size) <= StorageSize);

mlir::Type typ = builder.getSIntNTy(ValWidth);
Val = builder.createIntCast(Val, typ);

unsigned HighBits = StorageSize - Offset - Info.Size;
if (HighBits)
Val = builder.createShiftLeft(Val, HighBits);
if (Offset + HighBits)
Val = builder.createShiftRight(Val, Offset + HighBits);
} else {
if (Offset)
Val = builder.createShiftRight(Val, Offset);

if (static_cast<unsigned>(Offset) + Info.Size < StorageSize)
Val = builder.createAnd(Val,
llvm::APInt::getLowBitsSet(ValWidth, Info.Size));
}
Val = builder.createIntCast(Val, ResLTy);
assert(!UnimplementedFeature::emitScalarRangeCheck() && "NYI");
return RValue::get(Val);
}

void CIRGenFunction::buildStoreThroughLValue(RValue Src, LValue Dst) {
Expand All @@ -548,6 +658,81 @@ void CIRGenFunction::buildStoreThroughLValue(RValue Src, LValue Dst) {
buildStoreOfScalar(Src.getScalarVal(), Dst);
}

void CIRGenFunction::buildStoreThroughBitfieldLValue(RValue Src, LValue Dst,
mlir::Value &Result) {
const CIRGenBitFieldInfo &Info = Dst.getBitFieldInfo();
mlir::Type ResLTy = getTypes().convertTypeForMem(Dst.getType());
Address Ptr = Dst.getBitFieldAddress();

// Get the source value, truncated to the width of the bit-field.
mlir::Value SrcVal = Src.getScalarVal();

// Cast the source to the storage type and shift it into place.
SrcVal = builder.createIntCast(SrcVal, Ptr.getElementType());
auto SrcWidth = SrcVal.getType().cast<IntType>().getWidth();
mlir::Value MaskedVal = SrcVal;

const bool UseVolatile =
CGM.getCodeGenOpts().AAPCSBitfieldWidth && Dst.isVolatileQualified() &&
Info.VolatileStorageSize != 0 && isAAPCS(CGM.getTarget());
const unsigned StorageSize =
UseVolatile ? Info.VolatileStorageSize : Info.StorageSize;
const unsigned Offset = UseVolatile ? Info.VolatileOffset : Info.Offset;
// See if there are other bits in the bitfield's storage we'll need to load
// and mask together with source before storing.
if (StorageSize != Info.Size) {
assert(StorageSize > Info.Size && "Invalid bitfield size.");

mlir::Value Val = buildLoadOfScalar(Dst, Dst.getPointer().getLoc());

// Mask the source value as needed.
if (!hasBooleanRepresentation(Dst.getType()))
SrcVal = builder.createAnd(
SrcVal, llvm::APInt::getLowBitsSet(SrcWidth, Info.Size));

MaskedVal = SrcVal;
if (Offset)
SrcVal = builder.createShiftLeft(SrcVal, Offset);

// Mask out the original value.
Val = builder.createAnd(
Val, ~llvm::APInt::getBitsSet(SrcWidth, Offset, Offset + Info.Size));

// Or together the unchanged values and the source value.
SrcVal = builder.createOr(Val, SrcVal);

} else {
// According to the AACPS:
// When a volatile bit-field is written, and its container does not overlap
// with any non-bit-field member, its container must be read exactly once
// and written exactly once using the access width appropriate to the type
// of the container. The two accesses are not atomic.
llvm_unreachable("volatile bit-field is not implemented for the AACPS");
}

// Write the new value back out.
// TODO: constant matrix type, volatile, no init, non temporal, TBAA
buildStoreOfScalar(SrcVal, Ptr, Dst.isVolatileQualified(), Dst.getType(),
Dst.getBaseInfo(), false, false);

// Return the new value of the bit-field.
mlir::Value ResultVal = MaskedVal;
ResultVal = builder.createIntCast(ResultVal, ResLTy);

// Sign extend the value if needed.
if (Info.IsSigned) {
assert(Info.Size <= StorageSize);
unsigned HighBits = StorageSize - Info.Size;

if (HighBits) {
ResultVal = builder.createShiftLeft(ResultVal, HighBits);
ResultVal = builder.createShiftRight(ResultVal, HighBits);
}
}

Result = buildFromMemory(ResultVal, Dst.getType());
}

static LValue buildGlobalVarDeclLValue(CIRGenFunction &CGF, const Expr *E,
const VarDecl *VD) {
QualType T = E->getType();
Expand Down Expand Up @@ -771,7 +956,13 @@ LValue CIRGenFunction::buildBinaryOperatorLValue(const BinaryOperator *E) {
LValue LV = buildLValue(E->getLHS());

SourceLocRAIIObject Loc{*this, getLoc(E->getSourceRange())};
buildStoreThroughLValue(RV, LV);
if (LV.isBitField()) {
mlir::Value result;
buildStoreThroughBitfieldLValue(RV, LV, result);
} else {
buildStoreThroughLValue(RV, LV);
}

assert(!getContext().getLangOpts().OpenMP &&
"last priv cond not implemented");
return LV;
Expand Down Expand Up @@ -2207,6 +2398,13 @@ mlir::Value CIRGenFunction::buildAlloca(StringRef name, QualType ty,

mlir::Value CIRGenFunction::buildLoadOfScalar(LValue lvalue,
SourceLocation Loc) {
return buildLoadOfScalar(lvalue.getAddress(), lvalue.isVolatile(),
lvalue.getType(), getLoc(Loc), lvalue.getBaseInfo(),
lvalue.isNontemporal());
}

mlir::Value CIRGenFunction::buildLoadOfScalar(LValue lvalue,
mlir::Location Loc) {
return buildLoadOfScalar(lvalue.getAddress(), lvalue.isVolatile(),
lvalue.getType(), Loc, lvalue.getBaseInfo(),
lvalue.isNontemporal());
Expand All @@ -2224,6 +2422,14 @@ mlir::Value CIRGenFunction::buildLoadOfScalar(Address Addr, bool Volatile,
QualType Ty, SourceLocation Loc,
LValueBaseInfo BaseInfo,
bool isNontemporal) {
return buildLoadOfScalar(Addr, Volatile, Ty, getLoc(Loc), BaseInfo,
isNontemporal);
}

mlir::Value CIRGenFunction::buildLoadOfScalar(Address Addr, bool Volatile,
QualType Ty, mlir::Location Loc,
LValueBaseInfo BaseInfo,
bool isNontemporal) {
if (!CGM.getCodeGenOpts().PreserveVec3Type) {
if (Ty->isVectorType()) {
llvm_unreachable("NYI");
Expand All @@ -2237,15 +2443,14 @@ mlir::Value CIRGenFunction::buildLoadOfScalar(Address Addr, bool Volatile,
}

mlir::cir::LoadOp Load = builder.create<mlir::cir::LoadOp>(
getLoc(Loc), Addr.getElementType(), Addr.getPointer());
Loc, Addr.getElementType(), Addr.getPointer());

if (isNontemporal) {
llvm_unreachable("NYI");
}

// TODO: TBAA

// TODO: buildScalarRangeCheck

assert(!UnimplementedFeature::tbaa() && "NYI");
assert(!UnimplementedFeature::emitScalarRangeCheck() && "NYI");

return buildFromMemory(Load, Ty);
}
Expand Down
Loading

0 comments on commit b0b7dd8

Please sign in to comment.