Skip to content

Commit

Permalink
[TableGen] Refactor the implementation of arguments to introduce Argu…
Browse files Browse the repository at this point in the history
…mentInit [nfc]

A new Init type ArgumentInit is added to represent arguments.  We currently only support positional arguments; an upcoming change will add named argument support.

The index of argument in error message is removed.

Differential Revision: https://reviews.llvm.org/D154066
  • Loading branch information
wangpc-pp authored and preames committed Jul 11, 2023
1 parent 2858475 commit 6251adc
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 41 deletions.
57 changes: 48 additions & 9 deletions llvm/include/llvm/TableGen/Record.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ class Init {
IK_VarBitInit,
IK_VarDefInit,
IK_LastTypedInit,
IK_UnsetInit
IK_UnsetInit,
IK_ArgumentInit,
};

private:
Expand Down Expand Up @@ -480,6 +481,39 @@ class UnsetInit : public Init {
std::string getAsString() const override { return "?"; }
};

// Represent an argument.
class ArgumentInit : public Init, public FoldingSetNode {
Init *Value;

protected:
explicit ArgumentInit(Init *Value) : Init(IK_ArgumentInit), Value(Value) {}

public:
ArgumentInit(const ArgumentInit &) = delete;
ArgumentInit &operator=(const ArgumentInit &) = delete;

static bool classof(const Init *I) { return I->getKind() == IK_ArgumentInit; }

RecordKeeper &getRecordKeeper() const { return Value->getRecordKeeper(); }

static ArgumentInit *get(Init *Value);

Init *getValue() const { return Value; }

void Profile(FoldingSetNodeID &ID) const;

Init *resolveReferences(Resolver &R) const override;
std::string getAsString() const override { return Value->getAsString(); }

bool isComplete() const override { return false; }
bool isConcrete() const override { return false; }
Init *getBit(unsigned Bit) const override { return Value->getBit(Bit); }
Init *getCastTo(RecTy *Ty) const override { return Value->getCastTo(Ty); }
Init *convertInitializerTo(RecTy *Ty) const override {
return Value->convertInitializerTo(Ty);
}
};

/// 'true'/'false' - Represent a concrete initializer for a bit.
class BitInit final : public TypedInit {
friend detail::RecordKeeperImpl;
Expand Down Expand Up @@ -1278,8 +1312,9 @@ class DefInit : public TypedInit {

/// classname<targs...> - Represent an uninstantiated anonymous class
/// instantiation.
class VarDefInit final : public TypedInit, public FoldingSetNode,
public TrailingObjects<VarDefInit, Init *> {
class VarDefInit final : public TypedInit,
public FoldingSetNode,
public TrailingObjects<VarDefInit, ArgumentInit *> {
Record *Class;
DefInit *Def = nullptr; // after instantiation
unsigned NumArgs;
Expand All @@ -1298,7 +1333,7 @@ class VarDefInit final : public TypedInit, public FoldingSetNode,
static bool classof(const Init *I) {
return I->getKind() == IK_VarDefInit;
}
static VarDefInit *get(Record *Class, ArrayRef<Init *> Args);
static VarDefInit *get(Record *Class, ArrayRef<ArgumentInit *> Args);

void Profile(FoldingSetNodeID &ID) const;

Expand All @@ -1307,20 +1342,24 @@ class VarDefInit final : public TypedInit, public FoldingSetNode,

std::string getAsString() const override;

Init *getArg(unsigned i) const {
ArgumentInit *getArg(unsigned i) const {
assert(i < NumArgs && "Argument index out of range!");
return getTrailingObjects<Init *>()[i];
return getTrailingObjects<ArgumentInit *>()[i];
}

using const_iterator = Init *const *;
using const_iterator = ArgumentInit *const *;

const_iterator args_begin() const { return getTrailingObjects<Init *>(); }
const_iterator args_begin() const {
return getTrailingObjects<ArgumentInit *>();
}
const_iterator args_end () const { return args_begin() + NumArgs; }

size_t args_size () const { return NumArgs; }
bool args_empty() const { return NumArgs == 0; }

ArrayRef<Init *> args() const { return ArrayRef(args_begin(), NumArgs); }
ArrayRef<ArgumentInit *> args() const {
return ArrayRef(args_begin(), NumArgs);
}

Init *getBit(unsigned Bit) const override {
llvm_unreachable("Illegal bit reference off anonymous def");
Expand Down
55 changes: 44 additions & 11 deletions llvm/lib/TableGen/Record.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ struct RecordKeeperImpl {
BitInit TrueBitInit;
BitInit FalseBitInit;

FoldingSet<ArgumentInit> TheArgumentInitPool;
FoldingSet<BitsInit> TheBitsInitPool;
std::map<int64_t, IntInit *> TheIntInitPool;
StringMap<StringInit *, BumpPtrAllocator &> StringInitStringPool;
Expand Down Expand Up @@ -349,6 +350,8 @@ LLVM_DUMP_METHOD void Init::dump() const { return print(errs()); }
RecordKeeper &Init::getRecordKeeper() const {
if (auto *TyInit = dyn_cast<TypedInit>(this))
return TyInit->getType()->getRecordKeeper();
if (auto *ArgInit = dyn_cast<ArgumentInit>(this))
return ArgInit->getRecordKeeper();
return cast<UnsetInit>(this)->getRecordKeeper();
}

Expand All @@ -364,6 +367,37 @@ Init *UnsetInit::convertInitializerTo(RecTy *Ty) const {
return const_cast<UnsetInit *>(this);
}

static void ProfileArgumentInit(FoldingSetNodeID &ID, Init *Value) {
ID.AddPointer(Value);
}

void ArgumentInit::Profile(FoldingSetNodeID &ID) const {
ProfileArgumentInit(ID, Value);
}

ArgumentInit *ArgumentInit::get(Init *Value) {
FoldingSetNodeID ID;
ProfileArgumentInit(ID, Value);

RecordKeeper &RK = Value->getRecordKeeper();
detail::RecordKeeperImpl &RKImpl = RK.getImpl();
void *IP = nullptr;
if (ArgumentInit *I = RKImpl.TheArgumentInitPool.FindNodeOrInsertPos(ID, IP))
return I;

ArgumentInit *I = new (RKImpl.Allocator) ArgumentInit(Value);
RKImpl.TheArgumentInitPool.InsertNode(I, IP);
return I;
}

Init *ArgumentInit::resolveReferences(Resolver &R) const {
Init *NewValue = Value->resolveReferences(R);
if (NewValue != Value)
return ArgumentInit::get(NewValue);

return const_cast<ArgumentInit *>(this);
}

BitInit *BitInit::get(RecordKeeper &RK, bool V) {
return V ? &RK.getImpl().TrueBitInit : &RK.getImpl().FalseBitInit;
}
Expand Down Expand Up @@ -2131,9 +2165,8 @@ RecTy *DefInit::getFieldType(StringInit *FieldName) const {

std::string DefInit::getAsString() const { return std::string(Def->getName()); }

static void ProfileVarDefInit(FoldingSetNodeID &ID,
Record *Class,
ArrayRef<Init *> Args) {
static void ProfileVarDefInit(FoldingSetNodeID &ID, Record *Class,
ArrayRef<ArgumentInit *> Args) {
ID.AddInteger(Args.size());
ID.AddPointer(Class);

Expand All @@ -2145,7 +2178,7 @@ VarDefInit::VarDefInit(Record *Class, unsigned N)
: TypedInit(IK_VarDefInit, RecordRecTy::get(Class)), Class(Class),
NumArgs(N) {}

VarDefInit *VarDefInit::get(Record *Class, ArrayRef<Init *> Args) {
VarDefInit *VarDefInit::get(Record *Class, ArrayRef<ArgumentInit *> Args) {
FoldingSetNodeID ID;
ProfileVarDefInit(ID, Class, Args);

Expand All @@ -2154,11 +2187,11 @@ VarDefInit *VarDefInit::get(Record *Class, ArrayRef<Init *> Args) {
if (VarDefInit *I = RK.TheVarDefInitPool.FindNodeOrInsertPos(ID, IP))
return I;

void *Mem = RK.Allocator.Allocate(totalSizeToAlloc<Init *>(Args.size()),
alignof(VarDefInit));
void *Mem = RK.Allocator.Allocate(
totalSizeToAlloc<ArgumentInit *>(Args.size()), alignof(VarDefInit));
VarDefInit *I = new (Mem) VarDefInit(Class, Args.size());
std::uninitialized_copy(Args.begin(), Args.end(),
I->getTrailingObjects<Init *>());
I->getTrailingObjects<ArgumentInit *>());
RK.TheVarDefInitPool.InsertNode(I, IP);
return I;
}
Expand Down Expand Up @@ -2188,7 +2221,7 @@ DefInit *VarDefInit::instantiate() {

for (unsigned i = 0, e = TArgs.size(); i != e; ++i) {
if (i < args_size())
R.set(TArgs[i], getArg(i));
R.set(TArgs[i], getArg(i)->getValue());
else
R.set(TArgs[i], NewRec->getValue(TArgs[i])->getValue());

Expand Down Expand Up @@ -2222,11 +2255,11 @@ DefInit *VarDefInit::instantiate() {
Init *VarDefInit::resolveReferences(Resolver &R) const {
TrackUnresolvedResolver UR(&R);
bool Changed = false;
SmallVector<Init *, 8> NewArgs;
SmallVector<ArgumentInit *, 8> NewArgs;
NewArgs.reserve(args_size());

for (Init *Arg : args()) {
Init *NewArg = Arg->resolveReferences(UR);
for (ArgumentInit *Arg : args()) {
auto *NewArg = cast<ArgumentInit>(Arg->resolveReferences(UR));
NewArgs.push_back(NewArg);
Changed |= NewArg != Arg;
}
Expand Down
30 changes: 15 additions & 15 deletions llvm/lib/TableGen/TGParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace llvm {
struct SubClassReference {
SMRange RefRange;
Record *Rec;
SmallVector<Init*, 4> TemplateArgs;
SmallVector<ArgumentInit *, 4> TemplateArgs;

SubClassReference() : Rec(nullptr) {}

Expand All @@ -46,7 +46,7 @@ struct SubClassReference {
struct SubMultiClassReference {
SMRange RefRange;
MultiClass *MC;
SmallVector<Init*, 4> TemplateArgs;
SmallVector<ArgumentInit *, 4> TemplateArgs;

SubMultiClassReference() : MC(nullptr) {}

Expand Down Expand Up @@ -569,7 +569,7 @@ bool TGParser::addDefOne(std::unique_ptr<Record> Rec) {
return false;
}

bool TGParser::resolveArguments(Record *Rec, ArrayRef<Init *> ArgValues,
bool TGParser::resolveArguments(Record *Rec, ArrayRef<ArgumentInit *> ArgValues,
SMLoc Loc, ArgValueHandler ArgValueHandler) {
ArrayRef<Init *> ArgNames = Rec->getTemplateArgs();
assert(ArgValues.size() <= ArgNames.size() &&
Expand All @@ -579,7 +579,7 @@ bool TGParser::resolveArguments(Record *Rec, ArrayRef<Init *> ArgValues,
// handle the (name, value) pair. If not and there was no default, complain.
for (unsigned I = 0, E = ArgNames.size(); I != E; ++I) {
if (I < ArgValues.size())
ArgValueHandler(ArgNames[I], ArgValues[I]);
ArgValueHandler(ArgNames[I], ArgValues[I]->getValue());
else {
Init *Default = Rec->getValue(ArgNames[I])->getValue();
if (!Default->isComplete())
Expand All @@ -597,15 +597,16 @@ bool TGParser::resolveArguments(Record *Rec, ArrayRef<Init *> ArgValues,
/// Resolve the arguments of class and set them to MapResolver.
/// Returns true if failed.
bool TGParser::resolveArgumentsOfClass(MapResolver &R, Record *Rec,
ArrayRef<Init *> ArgValues, SMLoc Loc) {
ArrayRef<ArgumentInit *> ArgValues,
SMLoc Loc) {
return resolveArguments(Rec, ArgValues, Loc,
[&](Init *Name, Init *Value) { R.set(Name, Value); });
}

/// Resolve the arguments of multiclass and store them into SubstStack.
/// Returns true if failed.
bool TGParser::resolveArgumentsOfMultiClass(SubstStack &Substs, MultiClass *MC,
ArrayRef<Init *> ArgValues,
ArrayRef<ArgumentInit *> ArgValues,
Init *DefmName, SMLoc Loc) {
// Add an implicit argument NAME.
Substs.emplace_back(QualifiedNameOfImplicitName(MC), DefmName);
Expand Down Expand Up @@ -2596,7 +2597,7 @@ Init *TGParser::ParseSimpleValue(Record *CurRec, RecTy *ItemType,
return nullptr;
}

SmallVector<Init *, 8> Args;
SmallVector<ArgumentInit *, 8> Args;
Lex.Lex(); // consume the <
if (ParseTemplateArgValueList(Args, CurRec, Class))
return nullptr; // Error parsing value list.
Expand Down Expand Up @@ -3121,8 +3122,8 @@ void TGParser::ParseValueList(SmallVectorImpl<Init *> &Result, Record *CurRec,
// error was detected.
//
// TemplateArgList ::= '<' [Value {',' Value}*] '>'
bool TGParser::ParseTemplateArgValueList(SmallVectorImpl<Init *> &Result,
Record *CurRec, Record *ArgsRec) {
bool TGParser::ParseTemplateArgValueList(
SmallVectorImpl<ArgumentInit *> &Result, Record *CurRec, Record *ArgsRec) {

assert(Result.empty() && "Result vector is not empty");
ArrayRef<Init *> TArgs = ArgsRec->getTemplateArgs();
Expand All @@ -3144,7 +3145,7 @@ bool TGParser::ParseTemplateArgValueList(SmallVectorImpl<Init *> &Result,
Init *Value = ParseValue(CurRec, ItemType);
if (!Value)
return true;
Result.push_back(Value);
Result.push_back(ArgumentInit::get(Value));

if (consume(tgtok::greater)) // end of argument list?
return false;
Expand Down Expand Up @@ -4247,23 +4248,22 @@ bool TGParser::ParseFile() {
// inheritance, multiclass invocation, or anonymous class invocation.
// If necessary, replace an argument with a cast to the required type.
// The argument count has already been checked.
bool TGParser::CheckTemplateArgValues(SmallVectorImpl<llvm::Init *> &Values,
SMLoc Loc, Record *ArgsRec) {

bool TGParser::CheckTemplateArgValues(
SmallVectorImpl<llvm::ArgumentInit *> &Values, SMLoc Loc, Record *ArgsRec) {
ArrayRef<Init *> TArgs = ArgsRec->getTemplateArgs();

for (unsigned I = 0, E = Values.size(); I < E; ++I) {
RecordVal *Arg = ArgsRec->getValue(TArgs[I]);
RecTy *ArgType = Arg->getType();
auto *Value = Values[I];

if (TypedInit *ArgValue = dyn_cast<TypedInit>(Value)) {
if (TypedInit *ArgValue = dyn_cast<TypedInit>(Value->getValue())) {
auto *CastValue = ArgValue->getCastTo(ArgType);
if (CastValue) {
assert((!isa<TypedInit>(CastValue) ||
cast<TypedInit>(CastValue)->getType()->typeIsA(ArgType)) &&
"result of template arg value cast has wrong type");
Values[I] = CastValue;
Values[I] = ArgumentInit::get(CastValue);
} else {
PrintFatalError(Loc,
"Value specified for template argument '" +
Expand Down
12 changes: 6 additions & 6 deletions llvm/lib/TableGen/TGParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,13 @@ class TGParser {

using ArgValueHandler = std::function<void(Init *, Init *)>;
bool resolveArguments(
Record *Rec, ArrayRef<Init *> ArgValues, SMLoc Loc,
Record *Rec, ArrayRef<ArgumentInit *> ArgValues, SMLoc Loc,
ArgValueHandler ArgValueHandler = [](Init *, Init *) {});
bool resolveArgumentsOfClass(MapResolver &R, Record *Rec,
ArrayRef<Init *> ArgValues, SMLoc Loc);
ArrayRef<ArgumentInit *> ArgValues, SMLoc Loc);
bool resolveArgumentsOfMultiClass(SubstStack &Substs, MultiClass *MC,
ArrayRef<Init *> ArgValues, Init *DefmName,
SMLoc Loc);
ArrayRef<ArgumentInit *> ArgValues,
Init *DefmName, SMLoc Loc);

private: // Parser methods.
bool consume(tgtok::TokKind K);
Expand Down Expand Up @@ -288,7 +288,7 @@ class TGParser {
IDParseMode Mode = ParseValueMode);
void ParseValueList(SmallVectorImpl<llvm::Init*> &Result,
Record *CurRec, RecTy *ItemType = nullptr);
bool ParseTemplateArgValueList(SmallVectorImpl<llvm::Init *> &Result,
bool ParseTemplateArgValueList(SmallVectorImpl<llvm::ArgumentInit *> &Result,
Record *CurRec, Record *ArgsRec);
void ParseDagArgList(
SmallVectorImpl<std::pair<llvm::Init*, StringInit*>> &Result,
Expand All @@ -312,7 +312,7 @@ class TGParser {
MultiClass *ParseMultiClassID();
bool ApplyLetStack(Record *CurRec);
bool ApplyLetStack(RecordsEntry &Entry);
bool CheckTemplateArgValues(SmallVectorImpl<llvm::Init *> &Values,
bool CheckTemplateArgValues(SmallVectorImpl<llvm::ArgumentInit *> &Values,
SMLoc Loc, Record *ArgsRec);
};

Expand Down

0 comments on commit 6251adc

Please sign in to comment.