320 changes: 196 additions & 124 deletions clang/include/clang/Basic/riscv_vector.td

Large diffs are not rendered by default.

173 changes: 121 additions & 52 deletions clang/include/clang/Basic/riscv_vector_common.td
Original file line number Diff line number Diff line change
Expand Up @@ -458,52 +458,91 @@ let HasMaskedOffOperand = false in {
["vx", "Uv", "UvUvUeUv"]]>;
}
multiclass RVVFloatingTerBuiltinSet {
defm "" : RVVOutOp1BuiltinSet<NAME, "xfd",
defm "" : RVVOutOp1BuiltinSet<NAME, "fd",
[["vv", "v", "vvvv"],
["vf", "v", "vvev"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVOutOp1BuiltinSet<NAME, "x",
[["vv", "v", "vvvv"],
["vf", "v", "vvev"]]>;
}
multiclass RVVFloatingTerBuiltinSetRoundingMode {
defm "" : RVVOutOp1BuiltinSet<NAME, "xfd",
defm "" : RVVOutOp1BuiltinSet<NAME, "fd",
[["vv", "v", "vvvvu"],
["vf", "v", "vvevu"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVOutOp1BuiltinSet<NAME, "x",
[["vv", "v", "vvvvu"],
["vf", "v", "vvevu"]]>;
}
}

let HasMaskedOffOperand = false, Log2LMUL = [-2, -1, 0, 1, 2] in {
multiclass RVVFloatingWidenTerBuiltinSet {
defm "" : RVVOutOp1Op2BuiltinSet<NAME, "xf",
defm "" : RVVOutOp1Op2BuiltinSet<NAME, "f",
[["vv", "w", "wwvv"],
["vf", "w", "wwev"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVOutOp1Op2BuiltinSet<NAME, "x",
[["vv", "w", "wwvv"],
["vf", "w", "wwev"]]>;
}
multiclass RVVFloatingWidenTerBuiltinSetRoundingMode {
defm "" : RVVOutOp1Op2BuiltinSet<NAME, "xf",
defm "" : RVVOutOp1Op2BuiltinSet<NAME, "f",
[["vv", "w", "wwvvu"],
["vf", "w", "wwevu"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVOutOp1Op2BuiltinSet<NAME, "x",
[["vv", "w", "wwvvu"],
["vf", "w", "wwevu"]]>;
}
}

multiclass RVVFloatingBinBuiltinSet
: RVVOutOp1BuiltinSet<NAME, "xfd",
[["vv", "v", "vvv"],
["vf", "v", "vve"]]>;
multiclass RVVFloatingBinBuiltinSet {
defm "" : RVVOutOp1BuiltinSet<NAME, "fd",
[["vv", "v", "vvv"],
["vf", "v", "vve"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVOutOp1BuiltinSet<NAME, "x",
[["vv", "v", "vvv"],
["vf", "v", "vve"]]>;
}

multiclass RVVFloatingBinBuiltinSetRoundingMode
: RVVOutOp1BuiltinSet<NAME, "xfd",
[["vv", "v", "vvvu"],
["vf", "v", "vveu"]]>;
multiclass RVVFloatingBinBuiltinSetRoundingMode {
defm "" : RVVOutOp1BuiltinSet<NAME, "fd",
[["vv", "v", "vvvu"],
["vf", "v", "vveu"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVOutOp1BuiltinSet<NAME, "x",
[["vv", "v", "vvvu"],
["vf", "v", "vveu"]]>;
}

multiclass RVVFloatingBinVFBuiltinSet
: RVVOutOp1BuiltinSet<NAME, "xfd",
[["vf", "v", "vve"]]>;
multiclass RVVFloatingBinVFBuiltinSet {
defm "" : RVVOutOp1BuiltinSet<NAME, "fd",
[["vf", "v", "vve"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVOutOp1BuiltinSet<NAME, "x",
[["vf", "v", "vve"]]>;
}

multiclass RVVFloatingBinVFBuiltinSetRoundingMode
: RVVOutOp1BuiltinSet<NAME, "xfd",
[["vf", "v", "vveu"]]>;
multiclass RVVFloatingBinVFBuiltinSetRoundingMode {
defm "" : RVVOutOp1BuiltinSet<NAME, "fd",
[["vf", "v", "vveu"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVOutOp1BuiltinSet<NAME, "x",
[["vf", "v", "vveu"]]>;
}

multiclass RVVFloatingMaskOutBuiltinSet
: RVVOp0Op1BuiltinSet<NAME, "xfd",
[["vv", "vm", "mvv"],
["vf", "vm", "mve"]]>;
multiclass RVVFloatingMaskOutBuiltinSet {
defm "" : RVVOp0Op1BuiltinSet<NAME, "fd",
[["vv", "vm", "mvv"],
["vf", "vm", "mve"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVOp0Op1BuiltinSet<NAME, "x",
[["vv", "vm", "mvv"],
["vf", "vm", "mve"]]>;
}

multiclass RVVFloatingMaskOutVFBuiltinSet
: RVVOp0Op1BuiltinSet<NAME, "fd",
Expand Down Expand Up @@ -547,8 +586,11 @@ class RVVMaskOp0Builtin<string prototype> : RVVOp0Builtin<"m", prototype, "c"> {
let UnMaskedPolicyScheme = HasPolicyOperand,
HasMaskedOffOperand = false in {
multiclass RVVSlideUpBuiltinSet {
defm "" : RVVOutBuiltinSet<NAME, "csilxfd",
defm "" : RVVOutBuiltinSet<NAME, "csilfd",
[["vx","v", "vvvz"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVOutBuiltinSet<NAME, "x",
[["vx","v", "vvvz"]]>;
defm "" : RVVOutBuiltinSet<NAME, "csil",
[["vx","Uv", "UvUvUvz"]]>;
}
Expand All @@ -569,21 +611,16 @@ let UnMaskedPolicyScheme = HasPassthruOperand,
IntrinsicTypes = {ResultType, Ops.back()->getType()};
}] in {
multiclass RVVSlideDownBuiltinSet {
defm "" : RVVOutBuiltinSet<NAME, "csilxfd",
defm "" : RVVOutBuiltinSet<NAME, "csilfd",
[["vx","v", "vvz"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVOutBuiltinSet<NAME, "x",
[["vx","v", "vvz"]]>;
defm "" : RVVOutBuiltinSet<NAME, "csil",
[["vx","Uv", "UvUvz"]]>;
}
}

class RVVFloatingUnaryBuiltin<string builtin_suffix, string ir_suffix,
string prototype>
: RVVOutBuiltin<ir_suffix, prototype, "xfd"> {
let Name = NAME # "_" # builtin_suffix;
}

class RVVFloatingUnaryVVBuiltin : RVVFloatingUnaryBuiltin<"v", "v", "vv">;

class RVVConvBuiltin<string suffix, string prototype, string type_range,
string overloaded_name>
: RVVBuiltin<suffix, prototype, type_range> {
Expand Down Expand Up @@ -619,20 +656,32 @@ let HasMaskedOffOperand = true in {
[["vs", "UvUSv", "USvUvUSv"]]>;
}
multiclass RVVFloatingReductionBuiltin {
defm "" : RVVOutOp0BuiltinSet<NAME, "xfd",
defm "" : RVVOutOp0BuiltinSet<NAME, "fd",
[["vs", "vSv", "SvvSv"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVOutOp0BuiltinSet<NAME, "x",
[["vs", "vSv", "SvvSv"]]>;
}
multiclass RVVFloatingReductionBuiltinRoundingMode {
defm "" : RVVOutOp0BuiltinSet<NAME, "xfd",
defm "" : RVVOutOp0BuiltinSet<NAME, "fd",
[["vs", "vSv", "SvvSvu"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVOutOp0BuiltinSet<NAME, "x",
[["vs", "vSv", "SvvSvu"]]>;
}
multiclass RVVFloatingWidenReductionBuiltin {
defm "" : RVVOutOp0BuiltinSet<NAME, "xf",
defm "" : RVVOutOp0BuiltinSet<NAME, "f",
[["vs", "vSw", "SwvSw"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVOutOp0BuiltinSet<NAME, "x",
[["vs", "vSw", "SwvSw"]]>;
}
multiclass RVVFloatingWidenReductionBuiltinRoundingMode {
defm "" : RVVOutOp0BuiltinSet<NAME, "xf",
defm "" : RVVOutOp0BuiltinSet<NAME, "f",
[["vs", "vSw", "SwvSwu"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVOutOp0BuiltinSet<NAME, "x",
[["vs", "vSw", "SwvSwu"]]>;
}
}

Expand Down Expand Up @@ -692,22 +741,42 @@ multiclass RVVUnsignedWidenOp0BinBuiltinSet
[["wv", "Uw", "UwUwUv"],
["wx", "Uw", "UwUwUe"]]>;

multiclass RVVFloatingWidenBinBuiltinSet
: RVVWidenBuiltinSet<NAME, "xf",
[["vv", "w", "wvv"],
["vf", "w", "wve"]]>;
multiclass RVVFloatingWidenBinBuiltinSet {
defm "" : RVVWidenBuiltinSet<NAME, "f",
[["vv", "w", "wvv"],
["vf", "w", "wve"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVWidenBuiltinSet<NAME, "x",
[["vv", "w", "wvv"],
["vf", "w", "wve"]]>;
}

multiclass RVVFloatingWidenBinBuiltinSetRoundingMode
: RVVWidenBuiltinSet<NAME, "xf",
[["vv", "w", "wvvu"],
["vf", "w", "wveu"]]>;
multiclass RVVFloatingWidenBinBuiltinSetRoundingMode {
defm "" : RVVWidenBuiltinSet<NAME, "f",
[["vv", "w", "wvvu"],
["vf", "w", "wveu"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVWidenBuiltinSet<NAME, "x",
[["vv", "w", "wvvu"],
["vf", "w", "wveu"]]>;
}

multiclass RVVFloatingWidenOp0BinBuiltinSet
: RVVWidenWOp0BuiltinSet<NAME # "_w", "xf",
[["wv", "w", "wwv"],
["wf", "w", "wwe"]]>;
multiclass RVVFloatingWidenOp0BinBuiltinSet {
defm "" : RVVWidenWOp0BuiltinSet<NAME # "_w", "f",
[["wv", "w", "wwv"],
["wf", "w", "wwe"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVWidenWOp0BuiltinSet<NAME # "_w", "x",
[["wv", "w", "wwv"],
["wf", "w", "wwe"]]>;
}

multiclass RVVFloatingWidenOp0BinBuiltinSetRoundingMode
: RVVWidenWOp0BuiltinSet<NAME # "_w", "xf",
[["wv", "w", "wwvu"],
["wf", "w", "wweu"]]>;
multiclass RVVFloatingWidenOp0BinBuiltinSetRoundingMode {
defm "" : RVVWidenWOp0BuiltinSet<NAME # "_w", "f",
[["wv", "w", "wwvu"],
["wf", "w", "wweu"]]>;
let RequiredFeatures = ["Zvfh"] in
defm "" : RVVWidenWOp0BuiltinSet<NAME # "_w", "x",
[["wv", "w", "wwvu"],
["wf", "w", "wweu"]]>;
}
8 changes: 5 additions & 3 deletions clang/include/clang/CIR/CIRGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
namespace clang {
class DeclGroupRef;
class DiagnosticsEngine;
namespace CIRGen {
class CIRGenModule;
} // namespace CIRGen
} // namespace clang

namespace mlir {
class MLIRContext;
} // namespace mlir
namespace cir {
class CIRGenModule;

class CIRGenerator : public clang::ASTConsumer {
virtual void anchor();
clang::DiagnosticsEngine &diags;
Expand All @@ -44,7 +45,7 @@ class CIRGenerator : public clang::ASTConsumer {

protected:
std::unique_ptr<mlir::MLIRContext> mlirCtx;
std::unique_ptr<CIRGenModule> cgm;
std::unique_ptr<clang::CIRGen::CIRGenModule> cgm;

public:
CIRGenerator(clang::DiagnosticsEngine &diags,
Expand All @@ -53,6 +54,7 @@ class CIRGenerator : public clang::ASTConsumer {
~CIRGenerator() override;
void Initialize(clang::ASTContext &astCtx) override;
bool HandleTopLevelDecl(clang::DeclGroupRef group) override;
mlir::ModuleOp getModule() const;
};

} // namespace cir
Expand Down
21 changes: 21 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,25 @@
#ifndef LLVM_CLANG_CIR_DIALECT_IR_CIRDIALECT_H
#define LLVM_CLANG_CIR_DIALECT_IR_CIRDIALECT_H

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "clang/CIR/Dialect/IR/CIROpsDialect.h.inc"

// TableGen'erated files for MLIR dialects require that a macro be defined when
// they are included. GET_OP_CLASSES tells the file to define the classes for
// the operations of that dialect.
#define GET_OP_CLASSES
#include "clang/CIR/Dialect/IR/CIROps.h.inc"

#endif // LLVM_CLANG_CIR_DIALECT_IR_CIRDIALECT_H
14 changes: 8 additions & 6 deletions clang/include/clang/CIR/Dialect/IR/CIRDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def CIR_Dialect : Dialect {
let summary = "A high-level dialect for analyzing and optimizing Clang "
"supported languages";

let cppNamespace = "::mlir::cir";
let cppNamespace = "::cir";

let useDefaultAttributePrinterParser = 0;
let useDefaultTypePrinterParser = 0;
Expand All @@ -31,13 +31,15 @@ def CIR_Dialect : Dialect {
void registerAttributes();
void registerTypes();

Type parseType(DialectAsmParser &parser) const override;
void printType(Type type, DialectAsmPrinter &printer) const override;
mlir::Type parseType(mlir::DialectAsmParser &parser) const override;
void printType(mlir::Type type,
mlir::DialectAsmPrinter &printer) const override;

Attribute parseAttribute(DialectAsmParser &parser,
Type type) const override;
mlir::Attribute parseAttribute(mlir::DialectAsmParser &parser,
mlir::Type type) const override;

void printAttribute(Attribute attr, DialectAsmPrinter &os) const override;
void printAttribute(mlir::Attribute attr,
mlir::DialectAsmPrinter &os) const override;
}];
}

Expand Down
82 changes: 82 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,86 @@

include "clang/CIR/Dialect/IR/CIRDialect.td"

include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

//===----------------------------------------------------------------------===//
// CIR Ops
//===----------------------------------------------------------------------===//

// LLVMLoweringInfo is used by cir-tablegen to generate LLVM lowering logic
// automatically for CIR operations. The `llvmOp` field gives the name of the
// LLVM IR dialect operation that the CIR operation will be lowered to. The
// input arguments of the CIR operation will be passed in the same order to the
// lowered LLVM IR operation.
//
// Example:
//
// For the following CIR operation definition:
//
// def FooOp : CIR_Op<"foo"> {
// // ...
// let arguments = (ins CIR_AnyType:$arg1, CIR_AnyType:$arg2);
// let llvmOp = "BarOp";
// }
//
// cir-tablegen will generate LLVM lowering code for the FooOp similar to the
// following:
//
// class CIRFooOpLowering
// : public mlir::OpConversionPattern<cir::FooOp> {
// public:
// using OpConversionPattern<cir::FooOp>::OpConversionPattern;
//
// mlir::LogicalResult matchAndRewrite(
// cir::FooOp op,
// OpAdaptor adaptor,
// mlir::ConversionPatternRewriter &rewriter) const override {
// rewriter.replaceOpWithNewOp<mlir::LLVM::BarOp>(
// op, adaptor.getOperands()[0], adaptor.getOperands()[1]);
// return mlir::success();
// }
// }
//
// If you want fully customized LLVM IR lowering logic, simply exclude the
// `llvmOp` field from your CIR operation definition.
class LLVMLoweringInfo {
string llvmOp = "";
}

class CIR_Op<string mnemonic, list<Trait> traits = []> :
Op<CIR_Dialect, mnemonic, traits>, LLVMLoweringInfo;

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//

// TODO(CIR): For starters, cir.func has only name, nothing else. The other
// properties of a function will be added over time as more of ClangIR is
// upstreamed.

def FuncOp : CIR_Op<"func"> {
let summary = "Declare or define a function";
let description = [{
... lots of text to be added later ...
}];

let arguments = (ins SymbolNameAttr:$sym_name);

let skipDefaultBuilders = 1;

let builders = [OpBuilder<(ins "llvm::StringRef":$name)>];

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

#endif // LLVM_CLANG_CIR_DIALECT_IR_CIROPS
6 changes: 2 additions & 4 deletions clang/include/clang/Driver/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -4930,10 +4930,6 @@ def msave_restore : Flag<["-"], "msave-restore">, Group<m_riscv_Features_Group>,
HelpText<"Enable using library calls for save and restore">;
def mno_save_restore : Flag<["-"], "mno-save-restore">, Group<m_riscv_Features_Group>,
HelpText<"Disable using library calls for save and restore">;
def mforced_sw_shadow_stack : Flag<["-"], "mforced-sw-shadow-stack">, Group<m_riscv_Features_Group>,
HelpText<"Force using software shadow stack when shadow-stack enabled">;
def mno_forced_sw_shadow_stack : Flag<["-"], "mno-forced-sw-shadow-stack">, Group<m_riscv_Features_Group>,
HelpText<"Not force using software shadow stack when shadow-stack enabled">;
} // let Flags = [TargetSpecific]
let Flags = [TargetSpecific] in {
def menable_experimental_extensions : Flag<["-"], "menable-experimental-extensions">, Group<m_Group>,
Expand Down Expand Up @@ -6289,6 +6285,8 @@ def mno_80387 : Flag<["-"], "mno-80387">, Alias<mno_x87>;
def mno_fp_ret_in_387 : Flag<["-"], "mno-fp-ret-in-387">, Alias<mno_x87>;
def mmmx : Flag<["-"], "mmmx">, Group<m_x86_Features_Group>;
def mno_mmx : Flag<["-"], "mno-mmx">, Group<m_x86_Features_Group>;
def mamx_avx512 : Flag<["-"], "mamx-avx512">, Group<m_x86_Features_Group>;
def mno_amx_avx512 : Flag<["-"], "mno-amx-avx512">, Group<m_x86_Features_Group>;
def mamx_bf16 : Flag<["-"], "mamx-bf16">, Group<m_x86_Features_Group>;
def mno_amx_bf16 : Flag<["-"], "mno-amx-bf16">, Group<m_x86_Features_Group>;
def mamx_complex : Flag<["-"], "mamx-complex">, Group<m_x86_Features_Group>;
Expand Down
27 changes: 21 additions & 6 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -2510,6 +2510,8 @@ class Sema final : public SemaBase {

bool BuiltinNonDeterministicValue(CallExpr *TheCall);

bool BuiltinCountedByRef(CallExpr *TheCall);

// Matrix builtin handling.
ExprResult BuiltinMatrixTranspose(CallExpr *TheCall, ExprResult CallResult);
ExprResult BuiltinMatrixColumnMajorLoad(CallExpr *TheCall,
Expand Down Expand Up @@ -11328,9 +11330,9 @@ class Sema final : public SemaBase {
CXXScopeSpec &SS, IdentifierInfo *Name, SourceLocation NameLoc,
const ParsedAttributesView &Attr, TemplateParameterList *TemplateParams,
AccessSpecifier AS, SourceLocation ModulePrivateLoc,
SourceLocation FriendLoc,
ArrayRef<TemplateParameterList *> OuterTemplateParamLists,
bool IsMemberSpecialization, SkipBodyInfo *SkipBody = nullptr);
SourceLocation FriendLoc, unsigned NumOuterTemplateParamLists,
TemplateParameterList **OuterTemplateParamLists,
SkipBodyInfo *SkipBody = nullptr);

/// Translates template arguments as provided by the parser
/// into template arguments used by semantic analysis.
Expand Down Expand Up @@ -11369,8 +11371,7 @@ class Sema final : public SemaBase {
DeclResult ActOnVarTemplateSpecialization(
Scope *S, Declarator &D, TypeSourceInfo *DI, LookupResult &Previous,
SourceLocation TemplateKWLoc, TemplateParameterList *TemplateParams,
StorageClass SC, bool IsPartialSpecialization,
bool IsMemberSpecialization);
StorageClass SC, bool IsPartialSpecialization);

/// Get the specialization of the given variable template corresponding to
/// the specified argument list, or a null-but-valid result if the arguments
Expand Down Expand Up @@ -13012,14 +13013,28 @@ class Sema final : public SemaBase {
/// dealing with a specialization. This is only relevant for function
/// template specializations.
///
/// \param Pattern If non-NULL, indicates the pattern from which we will be
/// instantiating the definition of the given declaration, \p ND. This is
/// used to determine the proper set of template instantiation arguments for
/// friend function template specializations.
///
/// \param ForConstraintInstantiation when collecting arguments,
/// ForConstraintInstantiation indicates we should continue looking when
/// encountering a lambda generic call operator, and continue looking for
/// arguments on an enclosing class template.
///
/// \param SkipForSpecialization when specified, any template specializations
/// in a traversal would be ignored.
/// \param ForDefaultArgumentSubstitution indicates we should continue looking
/// when encountering a specialized member function template, rather than
/// returning immediately.
MultiLevelTemplateArgumentList getTemplateInstantiationArgs(
const NamedDecl *D, const DeclContext *DC = nullptr, bool Final = false,
std::optional<ArrayRef<TemplateArgument>> Innermost = std::nullopt,
bool RelativeToPrimary = false, bool ForConstraintInstantiation = false);
bool RelativeToPrimary = false, const FunctionDecl *Pattern = nullptr,
bool ForConstraintInstantiation = false,
bool SkipForSpecialization = false,
bool ForDefaultArgumentSubstitution = false);

/// RAII object to handle the state changes required to synthesize
/// a function body.
Expand Down
68 changes: 59 additions & 9 deletions clang/include/clang/Sema/SemaOpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ class OpenACCClause;

class SemaOpenACC : public SemaBase {
private:
/// A collection of loop constructs in the compute construct scope that
/// haven't had their 'parent' compute construct set yet. Entires will only be
/// made to this list in the case where we know the loop isn't an orphan.
llvm::SmallVector<OpenACCLoopConstruct *> ParentlessLoopConstructs;

struct ComputeConstructInfo {
/// Which type of compute construct we are inside of, which we can use to
/// determine whether we should add loops to the above collection. We can
Expand Down Expand Up @@ -118,6 +113,43 @@ class SemaOpenACC : public SemaBase {
/// 'loop' clause enforcement, where this is 'blocked' by a compute construct.
llvm::SmallVector<OpenACCReductionClause *> ActiveReductionClauses;

// Type to check the info about the 'for stmt'.
struct ForStmtBeginChecker {
SemaOpenACC &SemaRef;
SourceLocation ForLoc;
bool IsRangeFor = false;
std::optional<const CXXForRangeStmt *> RangeFor = nullptr;
const Stmt *Init = nullptr;
bool InitChanged = false;
std::optional<const Stmt *> Cond = nullptr;
std::optional<const Stmt *> Inc = nullptr;
// Prevent us from checking 2x, which can happen with collapse & tile.
bool AlreadyChecked = false;

ForStmtBeginChecker(SemaOpenACC &SemaRef, SourceLocation ForLoc,
std::optional<const CXXForRangeStmt *> S)
: SemaRef(SemaRef), ForLoc(ForLoc), IsRangeFor(true), RangeFor(S) {}

ForStmtBeginChecker(SemaOpenACC &SemaRef, SourceLocation ForLoc,
const Stmt *I, bool InitChanged,
std::optional<const Stmt *> C,
std::optional<const Stmt *> Inc)
: SemaRef(SemaRef), ForLoc(ForLoc), IsRangeFor(false), Init(I),
InitChanged(InitChanged), Cond(C), Inc(Inc) {}
// Do the checking for the For/Range-For. Currently this implements the 'not
// seq' restrictions only, and should be called either if we know we are a
// top-level 'for' (the one associated via associated-stmt), or extended via
// 'collapse'.
void check();

const ValueDecl *checkInit();
void checkCond();
void checkInc(const ValueDecl *Init);
};

/// Helper function for checking the 'for' and 'range for' stmts.
void ForStmtBeginHelper(SourceLocation ForLoc, ForStmtBeginChecker &C);

public:
ComputeConstructInfo &getActiveComputeConstructInfo() {
return ActiveComputeConstructInfo;
Expand All @@ -137,6 +169,11 @@ class SemaOpenACC : public SemaBase {
/// permits us to implement the restriction of no further 'gang', 'vector', or
/// 'worker' clauses.
SourceLocation LoopVectorClauseLoc;
/// If there is a current 'active' loop construct that does NOT have a 'seq'
/// clause on it, this has that source location. This permits us to implement
/// the 'loop' restrictions on the loop variable. This can be extended via
/// 'collapse', so we need to keep this around for a while.
SourceLocation LoopWithoutSeqLoc;

// Redeclaration of the version in OpenACCClause.h.
using DeviceTypeArgument = std::pair<IdentifierInfo *, SourceLocation>;
Expand Down Expand Up @@ -568,8 +605,19 @@ class SemaOpenACC : public SemaBase {
void ActOnWhileStmt(SourceLocation WhileLoc);
// Called when we encounter a 'do' statement, before looking at its 'body'.
void ActOnDoStmt(SourceLocation DoLoc);
// Called when we encounter a 'for' statement, before looking at its 'body',
// for the 'range-for'. 'ActOnForStmtEnd' is used after the body.
void ActOnRangeForStmtBegin(SourceLocation ForLoc, const Stmt *OldRangeFor,
const Stmt *RangeFor);
void ActOnRangeForStmtBegin(SourceLocation ForLoc, const Stmt *RangeFor);
// Called when we encounter a 'for' statement, before looking at its 'body'.
void ActOnForStmtBegin(SourceLocation ForLoc);
// 'ActOnForStmtEnd' is used after the body.
void ActOnForStmtBegin(SourceLocation ForLoc, const Stmt *First,
const Stmt *Second, const Stmt *Third);
void ActOnForStmtBegin(SourceLocation ForLoc, const Stmt *OldFirst,
const Stmt *First, const Stmt *OldSecond,
const Stmt *Second, const Stmt *OldThird,
const Stmt *Third);
// Called when we encounter a 'for' statement, after we've consumed/checked
// the body. This is necessary for a number of checks on the contents of the
// 'for' statement.
Expand Down Expand Up @@ -598,7 +646,9 @@ class SemaOpenACC : public SemaBase {
/// Called when we encounter an associated statement for our construct, this
/// should check legality of the statement as it appertains to this Construct.
StmtResult ActOnAssociatedStmt(SourceLocation DirectiveLoc,
OpenACCDirectiveKind K, StmtResult AssocStmt);
OpenACCDirectiveKind K,
ArrayRef<const OpenACCClause *> Clauses,
StmtResult AssocStmt);

/// Called after the directive has been completely parsed, including the
/// declaration group or associated statement.
Expand Down Expand Up @@ -712,12 +762,12 @@ class SemaOpenACC : public SemaBase {
SourceLocation OldLoopGangClauseOnKernelLoc;
SourceLocation OldLoopWorkerClauseLoc;
SourceLocation OldLoopVectorClauseLoc;
llvm::SmallVector<OpenACCLoopConstruct *> ParentlessLoopConstructs;
SourceLocation OldLoopWithoutSeqLoc;
llvm::SmallVector<OpenACCReductionClause *> ActiveReductionClauses;
LoopInConstructRAII LoopRAII;

public:
AssociatedStmtRAII(SemaOpenACC &, OpenACCDirectiveKind,
AssociatedStmtRAII(SemaOpenACC &, OpenACCDirectiveKind, SourceLocation,
ArrayRef<const OpenACCClause *>,
ArrayRef<OpenACCClause *>);
void SetCollapseInfoBeforeAssociatedStmt(
Expand Down
9 changes: 5 additions & 4 deletions clang/include/clang/Sema/SemaOpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,7 @@ class SemaOpenMP : public SemaBase {
SourceLocation OmpAllMemoryLoc;
SourceLocation
StepModifierLoc; /// 'step' modifier location for linear clause
OpenMPAllocateClauseModifier AllocClauseModifier = OMPC_ALLOCATE_unknown;
};

OMPClause *ActOnOpenMPVarListClause(OpenMPClauseKind Kind,
Expand All @@ -1165,10 +1166,10 @@ class SemaOpenMP : public SemaBase {
SourceLocation LParenLoc,
SourceLocation EndLoc);
/// Called on well-formed 'allocate' clause.
OMPClause *
ActOnOpenMPAllocateClause(Expr *Allocator, ArrayRef<Expr *> VarList,
SourceLocation StartLoc, SourceLocation ColonLoc,
SourceLocation LParenLoc, SourceLocation EndLoc);
OMPClause *ActOnOpenMPAllocateClause(
Expr *Allocator, OpenMPAllocateClauseModifier ACModifier,
ArrayRef<Expr *> VarList, SourceLocation StartLoc,
SourceLocation ColonLoc, SourceLocation LParenLoc, SourceLocation EndLoc);
/// Called on well-formed 'private' clause.
OMPClause *ActOnOpenMPPrivateClause(ArrayRef<Expr *> VarList,
SourceLocation StartLoc,
Expand Down
7 changes: 4 additions & 3 deletions clang/include/clang/Serialization/ASTRecordWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ class ASTRecordWriter

public:
/// Construct a ASTRecordWriter that uses the default encoding scheme.
ASTRecordWriter(ASTWriter &W, ASTWriter::RecordDataImpl &Record)
: DataStreamBasicWriter(W.getASTContext()), Writer(&W), Record(&Record) {}
ASTRecordWriter(ASTContext &Context, ASTWriter &W,
ASTWriter::RecordDataImpl &Record)
: DataStreamBasicWriter(Context), Writer(&W), Record(&Record) {}

/// Construct a ASTRecordWriter that uses the same encoding scheme as another
/// ASTRecordWriter.
Expand Down Expand Up @@ -208,7 +209,7 @@ class ASTRecordWriter

/// Emit a reference to a type.
void AddTypeRef(QualType T) {
return Writer->AddTypeRef(T, *Record);
return Writer->AddTypeRef(getASTContext(), T, *Record);
}
void writeQualType(QualType T) {
AddTypeRef(T);
Expand Down
39 changes: 17 additions & 22 deletions clang/include/clang/Serialization/ASTWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,6 @@ class ASTWriter : public ASTDeserializationListener,
/// The PCM manager which manages memory buffers for pcm files.
InMemoryModuleCache &ModuleCache;

/// The ASTContext we're writing.
ASTContext *Context = nullptr;

/// The preprocessor we're writing.
Preprocessor *PP = nullptr;

Expand Down Expand Up @@ -499,6 +496,9 @@ class ASTWriter : public ASTDeserializationListener,

/// Mapping from a source location entry to whether it is affecting or not.
llvm::BitVector IsSLocAffecting;
/// Mapping from a source location entry to whether it must be included as
/// input file.
llvm::BitVector IsSLocFileEntryAffecting;

/// Mapping from \c FileID to an index into the FileID adjustment table.
std::vector<FileID> NonAffectingFileIDs;
Expand Down Expand Up @@ -545,55 +545,55 @@ class ASTWriter : public ASTDeserializationListener,
unsigned getSubmoduleID(Module *Mod);

/// Write the given subexpression to the bitstream.
void WriteSubStmt(Stmt *S);
void WriteSubStmt(ASTContext &Context, Stmt *S);

void WriteBlockInfoBlock();
void WriteControlBlock(Preprocessor &PP, ASTContext &Context,
StringRef isysroot);
void WriteControlBlock(Preprocessor &PP, StringRef isysroot);

/// Write out the signature and diagnostic options, and return the signature.
void writeUnhashedControlBlock(Preprocessor &PP, ASTContext &Context);
void writeUnhashedControlBlock(Preprocessor &PP);
ASTFileSignature backpatchSignature();

/// Calculate hash of the pcm content.
std::pair<ASTFileSignature, ASTFileSignature> createSignature() const;
ASTFileSignature createSignatureForNamedModule() const;

void WriteInputFiles(SourceManager &SourceMgr, HeaderSearchOptions &HSOpts);
void WriteSourceManagerBlock(SourceManager &SourceMgr,
const Preprocessor &PP);
void WriteSourceManagerBlock(SourceManager &SourceMgr);
void WritePreprocessor(const Preprocessor &PP, bool IsModule);
void WriteHeaderSearch(const HeaderSearch &HS);
void WritePreprocessorDetail(PreprocessingRecord &PPRec,
uint64_t MacroOffsetsBase);
void WriteSubmodules(Module *WritingModule);
void WriteSubmodules(Module *WritingModule, ASTContext &Context);

void WritePragmaDiagnosticMappings(const DiagnosticsEngine &Diag,
bool isModule);

unsigned TypeExtQualAbbrev = 0;
void WriteTypeAbbrevs();
void WriteType(QualType T);
void WriteType(ASTContext &Context, QualType T);

bool isLookupResultExternal(StoredDeclsList &Result, DeclContext *DC);

void GenerateNameLookupTable(const DeclContext *DC,
void GenerateNameLookupTable(ASTContext &Context, const DeclContext *DC,
llvm::SmallVectorImpl<char> &LookupTable);
uint64_t WriteDeclContextLexicalBlock(ASTContext &Context,
const DeclContext *DC);
uint64_t WriteDeclContextVisibleBlock(ASTContext &Context, DeclContext *DC);
void WriteTypeDeclOffsets();
void WriteFileDeclIDsMap();
void WriteComments();
void WriteComments(ASTContext &Context);
void WriteSelectors(Sema &SemaRef);
void WriteReferencedSelectorsPool(Sema &SemaRef);
void WriteIdentifierTable(Preprocessor &PP, IdentifierResolver &IdResolver,
bool IsModule);
void WriteDeclAndTypes(ASTContext &Context);
void PrepareWritingSpecialDecls(Sema &SemaRef);
void WriteSpecialDeclRecords(Sema &SemaRef);
void WriteDeclUpdatesBlocks(RecordDataImpl &OffsetsRecord);
void WriteDeclContextVisibleUpdate(const DeclContext *DC);
void WriteDeclUpdatesBlocks(ASTContext &Context,
RecordDataImpl &OffsetsRecord);
void WriteDeclContextVisibleUpdate(ASTContext &Context,
const DeclContext *DC);
void WriteFPPragmaOptions(const FPOptionsOverride &Opts);
void WriteOpenCLExtensions(Sema &SemaRef);
void WriteCUDAPragmas(Sema &SemaRef);
Expand Down Expand Up @@ -655,11 +655,6 @@ class ASTWriter : public ASTDeserializationListener,
bool GeneratingReducedBMI = false);
~ASTWriter() override;

ASTContext &getASTContext() const {
assert(Context && "requested AST context when not writing AST");
return *Context;
}

const LangOptions &getLangOpts() const;

/// Get a timestamp for output into the AST file. The actual timestamp
Expand Down Expand Up @@ -725,10 +720,10 @@ class ASTWriter : public ASTDeserializationListener,
uint32_t getMacroDirectivesOffset(const IdentifierInfo *Name);

/// Emit a reference to a type.
void AddTypeRef(QualType T, RecordDataImpl &Record);
void AddTypeRef(ASTContext &Context, QualType T, RecordDataImpl &Record);

/// Force a type to be emitted and get its ID.
serialization::TypeID GetOrCreateTypeID(QualType T);
serialization::TypeID GetOrCreateTypeID(ASTContext &Context, QualType T);

/// Find the first local declaration of a given local redeclarable
/// decl.
Expand Down
4 changes: 4 additions & 0 deletions clang/include/clang/StaticAnalyzer/Checkers/Checkers.td
Original file line number Diff line number Diff line change
Expand Up @@ -1760,6 +1760,10 @@ def UncountedCallArgsChecker : Checker<"UncountedCallArgsChecker">,
HelpText<"Check uncounted call arguments.">,
Documentation<HasDocumentation>;

def UncheckedCallArgsChecker : Checker<"UncheckedCallArgsChecker">,
HelpText<"Check unchecked call arguments.">,
Documentation<HasDocumentation>;

def UncountedLocalVarsChecker : Checker<"UncountedLocalVarsChecker">,
HelpText<"Check uncounted local variables.">,
Documentation<HasDocumentation>;
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/APINotes/APINotesFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ const uint16_t VERSION_MAJOR = 0;
/// API notes file minor version number.
///
/// When the format changes IN ANY WAY, this number should be incremented.
const uint16_t VERSION_MINOR = 31; // lifetimebound
const uint16_t VERSION_MINOR =
32; // implicit parameter support (at position -1)

const uint8_t kSwiftCopyable = 1;
const uint8_t kSwiftNonCopyable = 2;
Expand Down
18 changes: 18 additions & 0 deletions clang/lib/APINotes/APINotesReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
//===----------------------------------------------------------------------===//
#include "clang/APINotes/APINotesReader.h"
#include "APINotesFormat.h"
#include "clang/APINotes/Types.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Bitstream/BitstreamReader.h"
Expand Down Expand Up @@ -396,12 +397,19 @@ class ObjCMethodTableInfo
const uint8_t *&Data) {
ObjCMethodInfo Info;
uint8_t Payload = *Data++;
bool HasSelf = Payload & 0x01;
Payload >>= 1;
Info.RequiredInit = Payload & 0x01;
Payload >>= 1;
Info.DesignatedInit = Payload & 0x01;
Payload >>= 1;
assert(Payload == 0 && "Unable to fully decode 'Payload'.");

ReadFunctionInfo(Data, Info);
if (HasSelf) {
Info.Self = ParamInfo{};
ReadParamInfo(Data, *Info.Self);
}
return Info;
}
};
Expand Down Expand Up @@ -516,7 +524,17 @@ class CXXMethodTableInfo
static CXXMethodInfo readUnversioned(internal_key_type Key,
const uint8_t *&Data) {
CXXMethodInfo Info;

uint8_t Payload = *Data++;
bool HasThis = Payload & 0x01;
Payload >>= 1;
assert(Payload == 0 && "Unable to fully decode 'Payload'.");

ReadFunctionInfo(Data, Info);
if (HasThis) {
Info.This = ParamInfo{};
ReadParamInfo(Data, *Info.This);
}
return Info;
}
};
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/APINotes/APINotesTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,18 @@ LLVM_DUMP_METHOD void FunctionInfo::dump(llvm::raw_ostream &OS) const {

LLVM_DUMP_METHOD void ObjCMethodInfo::dump(llvm::raw_ostream &OS) {
static_cast<FunctionInfo &>(*this).dump(OS);
if (Self)
Self->dump(OS);
OS << (DesignatedInit ? "[DesignatedInit] " : "")
<< (RequiredInit ? "[RequiredInit] " : "") << '\n';
}

LLVM_DUMP_METHOD void CXXMethodInfo::dump(llvm::raw_ostream &OS) {
static_cast<FunctionInfo &>(*this).dump(OS);
if (This)
This->dump(OS);
}

LLVM_DUMP_METHOD void TagInfo::dump(llvm::raw_ostream &OS) {
static_cast<CommonTypeInfo &>(*this).dump(OS);
if (HasFlagEnum)
Expand Down
29 changes: 24 additions & 5 deletions clang/lib/APINotes/APINotesWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,7 @@ namespace {
unsigned getVariableInfoSize(const VariableInfo &VI) {
return 2 + getCommonEntityInfoSize(VI) + 2 + VI.getType().size();
}
unsigned getParamInfoSize(const ParamInfo &PI);

/// Emit a serialized representation of the variable information.
void emitVariableInfo(raw_ostream &OS, const VariableInfo &VI) {
Expand Down Expand Up @@ -737,6 +738,7 @@ void APINotesWriter::Implementation::writeObjCPropertyBlock(
namespace {
unsigned getFunctionInfoSize(const FunctionInfo &);
void emitFunctionInfo(llvm::raw_ostream &, const FunctionInfo &);
void emitParamInfo(raw_ostream &OS, const ParamInfo &PI);

/// Used to serialize the on-disk Objective-C method table.
class ObjCMethodTableInfo
Expand All @@ -760,17 +762,24 @@ class ObjCMethodTableInfo
}

unsigned getUnversionedInfoSize(const ObjCMethodInfo &OMI) {
return getFunctionInfoSize(OMI) + 1;
auto size = getFunctionInfoSize(OMI) + 1;
if (OMI.Self)
size += getParamInfoSize(*OMI.Self);
return size;
}

void emitUnversionedInfo(raw_ostream &OS, const ObjCMethodInfo &OMI) {
uint8_t flags = 0;
llvm::support::endian::Writer writer(OS, llvm::endianness::little);
flags = (flags << 1) | OMI.DesignatedInit;
flags = (flags << 1) | OMI.RequiredInit;
flags = (flags << 1) | static_cast<bool>(OMI.Self);
writer.write<uint8_t>(flags);

emitFunctionInfo(OS, OMI);

if (OMI.Self)
emitParamInfo(OS, *OMI.Self);
}
};

Expand All @@ -793,12 +802,22 @@ class CXXMethodTableInfo
return static_cast<size_t>(key.hashValue());
}

unsigned getUnversionedInfoSize(const CXXMethodInfo &OMI) {
return getFunctionInfoSize(OMI);
unsigned getUnversionedInfoSize(const CXXMethodInfo &MI) {
auto size = getFunctionInfoSize(MI) + 1;
if (MI.This)
size += getParamInfoSize(*MI.This);
return size;
}

void emitUnversionedInfo(raw_ostream &OS, const CXXMethodInfo &OMI) {
emitFunctionInfo(OS, OMI);
void emitUnversionedInfo(raw_ostream &OS, const CXXMethodInfo &MI) {
uint8_t flags = 0;
llvm::support::endian::Writer writer(OS, llvm::endianness::little);
flags = (flags << 1) | static_cast<bool>(MI.This);
writer.write<uint8_t>(flags);

emitFunctionInfo(OS, MI);
if (MI.This)
emitParamInfo(OS, *MI.This);
}
};
} // namespace
Expand Down
28 changes: 21 additions & 7 deletions clang/lib/APINotes/APINotesYAMLCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "llvm/Support/VersionTuple.h"
#include "llvm/Support/YAMLTraits.h"
#include <optional>
#include <type_traits>
#include <vector>

using namespace clang;
Expand Down Expand Up @@ -68,7 +69,7 @@ template <> struct ScalarEnumerationTraits<MethodKind> {

namespace {
struct Param {
unsigned Position;
int Position;
std::optional<bool> NoEscape = false;
std::optional<bool> Lifetimebound = false;
std::optional<NullabilityKind> Nullability;
Expand Down Expand Up @@ -730,7 +731,8 @@ class YAMLConverter {
}
}

void convertParams(const ParamsSeq &Params, FunctionInfo &OutInfo) {
void convertParams(const ParamsSeq &Params, FunctionInfo &OutInfo,
std::optional<ParamInfo> &thisOrSelf) {
for (const auto &P : Params) {
ParamInfo PI;
if (P.Nullability)
Expand All @@ -739,9 +741,14 @@ class YAMLConverter {
PI.setLifetimebound(P.Lifetimebound);
PI.setType(std::string(P.Type));
PI.setRetainCountConvention(P.RetainCountConvention);
if (OutInfo.Params.size() <= P.Position)
if (static_cast<int>(OutInfo.Params.size()) <= P.Position)
OutInfo.Params.resize(P.Position + 1);
OutInfo.Params[P.Position] |= PI;
if (P.Position == -1)
thisOrSelf = PI;
else if (P.Position >= 0)
OutInfo.Params[P.Position] |= PI;
else
emitError("invalid parameter position " + llvm::itostr(P.Position));
}
}

Expand Down Expand Up @@ -818,7 +825,7 @@ class YAMLConverter {
MI.ResultType = std::string(M.ResultType);

// Translate parameter information.
convertParams(M.Params, MI);
convertParams(M.Params, MI, MI.Self);

// Translate nullability info.
convertNullability(M.Nullability, M.NullabilityOfRet, MI, M.Selector);
Expand Down Expand Up @@ -926,11 +933,18 @@ class YAMLConverter {
TheNamespace.Items, SwiftVersion);
}

void convertFunction(const Function &Function, FunctionInfo &FI) {
template <typename FuncOrMethodInfo>
void convertFunction(const Function &Function, FuncOrMethodInfo &FI) {
convertAvailability(Function.Availability, FI, Function.Name);
FI.setSwiftPrivate(Function.SwiftPrivate);
FI.SwiftName = std::string(Function.SwiftName);
convertParams(Function.Params, FI);
std::optional<ParamInfo> This;
convertParams(Function.Params, FI, This);
if constexpr (std::is_same_v<FuncOrMethodInfo, CXXMethodInfo>)
FI.This = This;
else if (This)
emitError("implicit instance parameter is only permitted on C++ and "
"Objective-C methods");
convertNullability(Function.Nullability, Function.NullabilityOfRet, FI,
Function.Name);
FI.ResultType = std::string(Function.ResultType);
Expand Down
16 changes: 10 additions & 6 deletions clang/lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1592,14 +1592,17 @@ ASTContext::setInstantiatedFromUsingShadowDecl(UsingShadowDecl *Inst,
InstantiatedFromUsingShadowDecl[Inst] = Pattern;
}

FieldDecl *ASTContext::getInstantiatedFromUnnamedFieldDecl(FieldDecl *Field) {
FieldDecl *
ASTContext::getInstantiatedFromUnnamedFieldDecl(FieldDecl *Field) const {
return InstantiatedFromUnnamedFieldDecl.lookup(Field);
}

void ASTContext::setInstantiatedFromUnnamedFieldDecl(FieldDecl *Inst,
FieldDecl *Tmpl) {
assert(!Inst->getDeclName() && "Instantiated field decl is not unnamed");
assert(!Tmpl->getDeclName() && "Template field decl is not unnamed");
assert((!Inst->getDeclName() || Inst->isPlaceholderVar(getLangOpts())) &&
"Instantiated field decl is not unnamed");
assert((!Inst->getDeclName() || Inst->isPlaceholderVar(getLangOpts())) &&
"Template field decl is not unnamed");
assert(!InstantiatedFromUnnamedFieldDecl[Inst] &&
"Already noted what unnamed field was instantiated from");

Expand Down Expand Up @@ -5300,10 +5303,11 @@ QualType ASTContext::getHLSLAttributedResourceType(
/// Retrieve a substitution-result type.
QualType ASTContext::getSubstTemplateTypeParmType(
QualType Replacement, Decl *AssociatedDecl, unsigned Index,
std::optional<unsigned> PackIndex) const {
std::optional<unsigned> PackIndex,
SubstTemplateTypeParmTypeFlag Flag) const {
llvm::FoldingSetNodeID ID;
SubstTemplateTypeParmType::Profile(ID, Replacement, AssociatedDecl, Index,
PackIndex);
PackIndex, Flag);
void *InsertPos = nullptr;
SubstTemplateTypeParmType *SubstParm =
SubstTemplateTypeParmTypes.FindNodeOrInsertPos(ID, InsertPos);
Expand All @@ -5313,7 +5317,7 @@ QualType ASTContext::getSubstTemplateTypeParmType(
!Replacement.isCanonical()),
alignof(SubstTemplateTypeParmType));
SubstParm = new (Mem) SubstTemplateTypeParmType(Replacement, AssociatedDecl,
Index, PackIndex);
Index, PackIndex, Flag);
Types.push_back(SubstParm);
SubstTemplateTypeParmTypes.InsertNode(SubstParm, InsertPos);
}
Expand Down
7 changes: 3 additions & 4 deletions clang/lib/AST/ASTImporter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1628,8 +1628,8 @@ ExpectedType ASTNodeImporter::VisitSubstTemplateTypeParmType(
return ToReplacementTypeOrErr.takeError();

return Importer.getToContext().getSubstTemplateTypeParmType(
*ToReplacementTypeOrErr, *ReplacedOrErr, T->getIndex(),
T->getPackIndex());
*ToReplacementTypeOrErr, *ReplacedOrErr, T->getIndex(), T->getPackIndex(),
T->getSubstitutionFlag());
}

ExpectedType ASTNodeImporter::VisitSubstTemplateTypeParmPackType(
Expand Down Expand Up @@ -6190,8 +6190,7 @@ ExpectedDecl ASTNodeImporter::VisitClassTemplateDecl(ClassTemplateDecl *D) {
ExpectedDecl ASTNodeImporter::VisitClassTemplateSpecializationDecl(
ClassTemplateSpecializationDecl *D) {
ClassTemplateDecl *ClassTemplate;
if (Error Err = importInto(ClassTemplate,
D->getSpecializedTemplate()->getCanonicalDecl()))
if (Error Err = importInto(ClassTemplate, D->getSpecializedTemplate()))
return std::move(Err);

// Import the context of this declaration.
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/AST/ByteCode/DynamicAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class DynamicAllocator final {
private:
llvm::DenseMap<const Expr *, AllocationSite> AllocationSites;

using PoolAllocTy = llvm::BumpPtrAllocatorImpl<llvm::MallocAllocator>;
using PoolAllocTy = llvm::BumpPtrAllocator;
PoolAllocTy DescAllocator;

/// Allocates a new descriptor.
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/AST/ByteCode/Program.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class Program final {
llvm::DenseMap<const void *, unsigned> NativePointerIndices;

/// Custom allocator for global storage.
using PoolAllocTy = llvm::BumpPtrAllocatorImpl<llvm::MallocAllocator>;
using PoolAllocTy = llvm::BumpPtrAllocator;

/// Descriptor + storage for a global object.
///
Expand Down
42 changes: 23 additions & 19 deletions clang/lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1738,13 +1738,12 @@ void NamedDecl::printNestedNameSpecifier(raw_ostream &OS,

// Suppress inline namespace if it doesn't make the result ambiguous.
if (Ctx->isInlineNamespace() && NameInScope) {
bool isRedundant =
cast<NamespaceDecl>(Ctx)->isRedundantInlineQualifierFor(NameInScope);
if (P.SuppressInlineNamespace ==
PrintingPolicy::SuppressInlineNamespaceMode::All ||
(P.SuppressInlineNamespace ==
PrintingPolicy::SuppressInlineNamespaceMode::Redundant &&
isRedundant)) {
cast<NamespaceDecl>(Ctx)->isRedundantInlineQualifierFor(
NameInScope))) {
continue;
}
}
Expand Down Expand Up @@ -2709,20 +2708,20 @@ VarDecl *VarDecl::getTemplateInstantiationPattern() const {
auto From = VDTemplSpec->getInstantiatedFrom();
if (auto *VTD = From.dyn_cast<VarTemplateDecl *>()) {
while (!VTD->isMemberSpecialization()) {
if (auto *NewVTD = VTD->getInstantiatedFromMemberTemplate())
VTD = NewVTD;
else
auto *NewVTD = VTD->getInstantiatedFromMemberTemplate();
if (!NewVTD)
break;
VTD = NewVTD;
}
return getDefinitionOrSelf(VTD->getTemplatedDecl());
}
if (auto *VTPSD =
From.dyn_cast<VarTemplatePartialSpecializationDecl *>()) {
while (!VTPSD->isMemberSpecialization()) {
if (auto *NewVTPSD = VTPSD->getInstantiatedFromMember())
VTPSD = NewVTPSD;
else
auto *NewVTPSD = VTPSD->getInstantiatedFromMember();
if (!NewVTPSD)
break;
VTPSD = NewVTPSD;
}
return getDefinitionOrSelf<VarDecl>(VTPSD);
}
Expand All @@ -2731,14 +2730,15 @@ VarDecl *VarDecl::getTemplateInstantiationPattern() const {

// If this is the pattern of a variable template, find where it was
// instantiated from. FIXME: Is this necessary?
if (VarTemplateDecl *VTD = VD->getDescribedVarTemplate()) {
while (!VTD->isMemberSpecialization()) {
if (auto *NewVTD = VTD->getInstantiatedFromMemberTemplate())
VTD = NewVTD;
else
if (VarTemplateDecl *VarTemplate = VD->getDescribedVarTemplate()) {
while (!VarTemplate->isMemberSpecialization()) {
auto *NewVT = VarTemplate->getInstantiatedFromMemberTemplate();
if (!NewVT)
break;
VarTemplate = NewVT;
}
return getDefinitionOrSelf(VTD->getTemplatedDecl());

return getDefinitionOrSelf(VarTemplate->getTemplatedDecl());
}

if (VD == this)
Expand Down Expand Up @@ -3656,6 +3656,10 @@ unsigned FunctionDecl::getBuiltinID(bool ConsiderWrapperFunctions) const {
(!hasAttr<ArmBuiltinAliasAttr>() && !hasAttr<BuiltinAliasAttr>()))
return 0;

if (getASTContext().getLangOpts().CPlusPlus &&
BuiltinID == Builtin::BI__builtin_counted_by_ref)
return 0;

const ASTContext &Context = getASTContext();
if (!Context.BuiltinInfo.isPredefinedLibFunction(BuiltinID))
return BuiltinID;
Expand Down Expand Up @@ -4154,10 +4158,10 @@ FunctionDecl::getTemplateInstantiationPattern(bool ForDefinition) const {
// If we hit a point where the user provided a specialization of this
// template, we're done looking.
while (!ForDefinition || !Primary->isMemberSpecialization()) {
if (auto *NewPrimary = Primary->getInstantiatedFromMemberTemplate())
Primary = NewPrimary;
else
auto *NewPrimary = Primary->getInstantiatedFromMemberTemplate();
if (!NewPrimary)
break;
Primary = NewPrimary;
}

return getDefinitionOrSelf(Primary->getTemplatedDecl());
Expand All @@ -4170,7 +4174,7 @@ FunctionTemplateDecl *FunctionDecl::getPrimaryTemplate() const {
if (FunctionTemplateSpecializationInfo *Info
= TemplateOrSpecialization
.dyn_cast<FunctionTemplateSpecializationInfo*>()) {
return Info->getTemplate()->getMostRecentDecl();
return Info->getTemplate();
}
return nullptr;
}
Expand Down
14 changes: 6 additions & 8 deletions clang/lib/AST/DeclCXX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2030,21 +2030,19 @@ const CXXRecordDecl *CXXRecordDecl::getTemplateInstantiationPattern() const {
if (auto *TD = dyn_cast<ClassTemplateSpecializationDecl>(this)) {
auto From = TD->getInstantiatedFrom();
if (auto *CTD = From.dyn_cast<ClassTemplateDecl *>()) {
while (!CTD->isMemberSpecialization()) {
if (auto *NewCTD = CTD->getInstantiatedFromMemberTemplate())
CTD = NewCTD;
else
while (auto *NewCTD = CTD->getInstantiatedFromMemberTemplate()) {
if (NewCTD->isMemberSpecialization())
break;
CTD = NewCTD;
}
return GetDefinitionOrSelf(CTD->getTemplatedDecl());
}
if (auto *CTPSD =
From.dyn_cast<ClassTemplatePartialSpecializationDecl *>()) {
while (!CTPSD->isMemberSpecialization()) {
if (auto *NewCTPSD = CTPSD->getInstantiatedFromMemberTemplate())
CTPSD = NewCTPSD;
else
while (auto *NewCTPSD = CTPSD->getInstantiatedFromMember()) {
if (NewCTPSD->isMemberSpecialization())
break;
CTPSD = NewCTPSD;
}
return GetDefinitionOrSelf(CTPSD);
}
Expand Down
86 changes: 18 additions & 68 deletions clang/lib/AST/DeclTemplate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,35 +320,35 @@ bool TemplateDecl::isTypeAlias() const {
void RedeclarableTemplateDecl::anchor() {}

RedeclarableTemplateDecl::CommonBase *RedeclarableTemplateDecl::getCommonPtr() const {
if (CommonBase *C = getCommonPtrInternal())
return C;
if (Common)
return Common;

// Walk the previous-declaration chain until we either find a declaration
// with a common pointer or we run out of previous declarations.
SmallVector<const RedeclarableTemplateDecl *, 2> PrevDecls;
for (const RedeclarableTemplateDecl *Prev = getPreviousDecl(); Prev;
Prev = Prev->getPreviousDecl()) {
if (CommonBase *C = Prev->getCommonPtrInternal()) {
setCommonPtr(C);
if (Prev->Common) {
Common = Prev->Common;
break;
}

PrevDecls.push_back(Prev);
}

// If we never found a common pointer, allocate one now.
if (!getCommonPtrInternal()) {
if (!Common) {
// FIXME: If any of the declarations is from an AST file, we probably
// need an update record to add the common data.

setCommonPtr(newCommon(getASTContext()));
Common = newCommon(getASTContext());
}

// Update any previous declarations we saw with the common pointer.
for (const RedeclarableTemplateDecl *Prev : PrevDecls)
Prev->setCommonPtr(getCommonPtrInternal());
Prev->Common = Common;

return getCommonPtrInternal();
return Common;
}

void RedeclarableTemplateDecl::loadLazySpecializationsImpl() const {
Expand Down Expand Up @@ -458,17 +458,19 @@ void FunctionTemplateDecl::addSpecialization(
}

void FunctionTemplateDecl::mergePrevDecl(FunctionTemplateDecl *Prev) {
using Base = RedeclarableTemplateDecl;

// If we haven't created a common pointer yet, then it can just be created
// with the usual method.
if (!getCommonPtrInternal())
if (!Base::Common)
return;

Common *ThisCommon = static_cast<Common *>(getCommonPtrInternal());
Common *ThisCommon = static_cast<Common *>(Base::Common);
Common *PrevCommon = nullptr;
SmallVector<FunctionTemplateDecl *, 8> PreviousDecls;
for (; Prev; Prev = Prev->getPreviousDecl()) {
if (CommonBase *C = Prev->getCommonPtrInternal()) {
PrevCommon = static_cast<Common *>(C);
if (Prev->Base::Common) {
PrevCommon = static_cast<Common *>(Prev->Base::Common);
break;
}
PreviousDecls.push_back(Prev);
Expand All @@ -478,15 +480,15 @@ void FunctionTemplateDecl::mergePrevDecl(FunctionTemplateDecl *Prev) {
// use this common pointer.
if (!PrevCommon) {
for (auto *D : PreviousDecls)
D->setCommonPtr(ThisCommon);
D->Base::Common = ThisCommon;
return;
}

// Ensure we don't leak any important state.
assert(ThisCommon->Specializations.size() == 0 &&
"Can't merge incompatible declarations!");

setCommonPtr(PrevCommon);
Base::Common = PrevCommon;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -993,17 +995,7 @@ ClassTemplateSpecializationDecl::getSpecializedTemplate() const {
if (const auto *PartialSpec =
SpecializedTemplate.dyn_cast<SpecializedPartialSpecialization*>())
return PartialSpec->PartialSpecialization->getSpecializedTemplate();
return SpecializedTemplate.get<ClassTemplateDecl *>()->getMostRecentDecl();
}

llvm::PointerUnion<ClassTemplateDecl *,
ClassTemplatePartialSpecializationDecl *>
ClassTemplateSpecializationDecl::getSpecializedTemplateOrPartial() const {
if (const auto *PartialSpec =
SpecializedTemplate.dyn_cast<SpecializedPartialSpecialization *>())
return PartialSpec->PartialSpecialization->getMostRecentDecl();

return SpecializedTemplate.get<ClassTemplateDecl *>()->getMostRecentDecl();
return SpecializedTemplate.get<ClassTemplateDecl*>();
}

SourceRange
Expand Down Expand Up @@ -1293,39 +1285,6 @@ VarTemplateDecl::newCommon(ASTContext &C) const {
return CommonPtr;
}

void VarTemplateDecl::mergePrevDecl(VarTemplateDecl *Prev) {
// If we haven't created a common pointer yet, then it can just be created
// with the usual method.
if (!getCommonPtrInternal())
return;

Common *ThisCommon = static_cast<Common *>(getCommonPtrInternal());
Common *PrevCommon = nullptr;
SmallVector<VarTemplateDecl *, 8> PreviousDecls;
for (; Prev; Prev = Prev->getPreviousDecl()) {
if (CommonBase *C = Prev->getCommonPtrInternal()) {
PrevCommon = static_cast<Common *>(C);
break;
}
PreviousDecls.push_back(Prev);
}

// If the previous redecl chain hasn't created a common pointer yet, then just
// use this common pointer.
if (!PrevCommon) {
for (auto *D : PreviousDecls)
D->setCommonPtr(ThisCommon);
return;
}

// Ensure we don't leak any important state.
assert(ThisCommon->Specializations.empty() &&
ThisCommon->PartialSpecializations.empty() &&
"Can't merge incompatible declarations!");

setCommonPtr(PrevCommon);
}

VarTemplateSpecializationDecl *
VarTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args,
void *&InsertPos) {
Expand Down Expand Up @@ -1448,16 +1407,7 @@ VarTemplateDecl *VarTemplateSpecializationDecl::getSpecializedTemplate() const {
if (const auto *PartialSpec =
SpecializedTemplate.dyn_cast<SpecializedPartialSpecialization *>())
return PartialSpec->PartialSpecialization->getSpecializedTemplate();
return SpecializedTemplate.get<VarTemplateDecl *>()->getMostRecentDecl();
}

llvm::PointerUnion<VarTemplateDecl *, VarTemplatePartialSpecializationDecl *>
VarTemplateSpecializationDecl::getSpecializedTemplateOrPartial() const {
if (const auto *PartialSpec =
SpecializedTemplate.dyn_cast<SpecializedPartialSpecialization *>())
return PartialSpec->PartialSpecialization->getMostRecentDecl();

return SpecializedTemplate.get<VarTemplateDecl *>()->getMostRecentDecl();
return SpecializedTemplate.get<VarTemplateDecl *>();
}

SourceRange VarTemplateSpecializationDecl::getSourceRange() const {
Expand Down
23 changes: 18 additions & 5 deletions clang/lib/AST/OpenMPClause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1023,12 +1023,17 @@ OMPPartialClause *OMPPartialClause::CreateEmpty(const ASTContext &C) {
OMPAllocateClause *
OMPAllocateClause::Create(const ASTContext &C, SourceLocation StartLoc,
SourceLocation LParenLoc, Expr *Allocator,
SourceLocation ColonLoc, SourceLocation EndLoc,
ArrayRef<Expr *> VL) {
SourceLocation ColonLoc,
OpenMPAllocateClauseModifier AllocatorModifier,
SourceLocation AllocatorModifierLoc,
SourceLocation EndLoc, ArrayRef<Expr *> VL) {

// Allocate space for private variables and initializer expressions.
void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size()));
auto *Clause = new (Mem) OMPAllocateClause(StartLoc, LParenLoc, Allocator,
ColonLoc, EndLoc, VL.size());
auto *Clause = new (Mem) OMPAllocateClause(
StartLoc, LParenLoc, Allocator, ColonLoc, AllocatorModifier,
AllocatorModifierLoc, EndLoc, VL.size());

Clause->setVarRefs(VL);
return Clause;
}
Expand Down Expand Up @@ -2242,9 +2247,17 @@ void OMPClausePrinter::VisitOMPAllocateClause(OMPAllocateClause *Node) {
if (Node->varlist_empty())
return;
OS << "allocate";
OpenMPAllocateClauseModifier Modifier = Node->getAllocatorModifier();
if (Expr *Allocator = Node->getAllocator()) {
OS << "(";
Allocator->printPretty(OS, nullptr, Policy, 0);
if (Modifier == OMPC_ALLOCATE_allocator) {
OS << getOpenMPSimpleClauseTypeName(Node->getClauseKind(), Modifier);
OS << "(";
Allocator->printPretty(OS, nullptr, Policy, 0);
OS << ")";
} else {
Allocator->printPretty(OS, nullptr, Policy, 0);
}
OS << ":";
VisitOMPClauseList(Node, ' ');
} else {
Expand Down
56 changes: 11 additions & 45 deletions clang/lib/AST/StmtOpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,44 +28,15 @@ OpenACCComputeConstruct::CreateEmpty(const ASTContext &C, unsigned NumClauses) {
OpenACCComputeConstruct *OpenACCComputeConstruct::Create(
const ASTContext &C, OpenACCDirectiveKind K, SourceLocation BeginLoc,
SourceLocation DirLoc, SourceLocation EndLoc,
ArrayRef<const OpenACCClause *> Clauses, Stmt *StructuredBlock,
ArrayRef<OpenACCLoopConstruct *> AssociatedLoopConstructs) {
ArrayRef<const OpenACCClause *> Clauses, Stmt *StructuredBlock) {
void *Mem = C.Allocate(
OpenACCComputeConstruct::totalSizeToAlloc<const OpenACCClause *>(
Clauses.size()));
auto *Inst = new (Mem) OpenACCComputeConstruct(K, BeginLoc, DirLoc, EndLoc,
Clauses, StructuredBlock);

llvm::for_each(AssociatedLoopConstructs, [&](OpenACCLoopConstruct *C) {
C->setParentComputeConstruct(Inst);
});

return Inst;
}

void OpenACCComputeConstruct::findAndSetChildLoops() {
struct LoopConstructFinder : RecursiveASTVisitor<LoopConstructFinder> {
OpenACCComputeConstruct *Construct = nullptr;

LoopConstructFinder(OpenACCComputeConstruct *Construct)
: Construct(Construct) {}

bool TraverseOpenACCComputeConstruct(OpenACCComputeConstruct *C) {
// Stop searching if we find a compute construct.
return true;
}
bool TraverseOpenACCLoopConstruct(OpenACCLoopConstruct *C) {
// Stop searching if we find a loop construct, after taking ownership of
// it.
C->setParentComputeConstruct(Construct);
return true;
}
};

LoopConstructFinder f(this);
f.TraverseStmt(getAssociatedStmt());
}

OpenACCLoopConstruct::OpenACCLoopConstruct(unsigned NumClauses)
: OpenACCAssociatedStmtConstruct(
OpenACCLoopConstructClass, OpenACCDirectiveKind::Loop,
Expand All @@ -79,11 +50,13 @@ OpenACCLoopConstruct::OpenACCLoopConstruct(unsigned NumClauses)
}

OpenACCLoopConstruct::OpenACCLoopConstruct(
SourceLocation Start, SourceLocation DirLoc, SourceLocation End,
OpenACCDirectiveKind ParentKind, SourceLocation Start,
SourceLocation DirLoc, SourceLocation End,
ArrayRef<const OpenACCClause *> Clauses, Stmt *Loop)
: OpenACCAssociatedStmtConstruct(OpenACCLoopConstructClass,
OpenACCDirectiveKind::Loop, Start, DirLoc,
End, Loop) {
End, Loop),
ParentComputeConstructKind(ParentKind) {
// accept 'nullptr' for the loop. This is diagnosed somewhere, but this gives
// us some level of AST fidelity in the error case.
assert((Loop == nullptr || isa<ForStmt, CXXForRangeStmt>(Loop)) &&
Expand All @@ -96,12 +69,6 @@ OpenACCLoopConstruct::OpenACCLoopConstruct(
Clauses.size()));
}

void OpenACCLoopConstruct::setLoop(Stmt *Loop) {
assert((isa<ForStmt, CXXForRangeStmt>(Loop)) &&
"Associated Loop not a for loop?");
setAssociatedStmt(Loop);
}

OpenACCLoopConstruct *OpenACCLoopConstruct::CreateEmpty(const ASTContext &C,
unsigned NumClauses) {
void *Mem =
Expand All @@ -111,15 +78,14 @@ OpenACCLoopConstruct *OpenACCLoopConstruct::CreateEmpty(const ASTContext &C,
return Inst;
}

OpenACCLoopConstruct *
OpenACCLoopConstruct::Create(const ASTContext &C, SourceLocation BeginLoc,
SourceLocation DirLoc, SourceLocation EndLoc,
ArrayRef<const OpenACCClause *> Clauses,
Stmt *Loop) {
OpenACCLoopConstruct *OpenACCLoopConstruct::Create(
const ASTContext &C, OpenACCDirectiveKind ParentKind,
SourceLocation BeginLoc, SourceLocation DirLoc, SourceLocation EndLoc,
ArrayRef<const OpenACCClause *> Clauses, Stmt *Loop) {
void *Mem =
C.Allocate(OpenACCLoopConstruct::totalSizeToAlloc<const OpenACCClause *>(
Clauses.size()));
auto *Inst =
new (Mem) OpenACCLoopConstruct(BeginLoc, DirLoc, EndLoc, Clauses, Loop);
auto *Inst = new (Mem)
OpenACCLoopConstruct(ParentKind, BeginLoc, DirLoc, EndLoc, Clauses, Loop);
return Inst;
}
2 changes: 1 addition & 1 deletion clang/lib/AST/TextNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2928,7 +2928,7 @@ void TextNodeDumper::VisitOpenACCLoopConstruct(const OpenACCLoopConstruct *S) {
if (S->isOrphanedLoopConstruct())
OS << " <orphan>";
else
OS << " parent: " << S->getParentComputeConstruct();
OS << " parent: " << S->getParentComputeConstructKind();
}

void TextNodeDumper::VisitEmbedExpr(const EmbedExpr *S) {
Expand Down
6 changes: 5 additions & 1 deletion clang/lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4219,7 +4219,7 @@ static const TemplateTypeParmDecl *getReplacedParameter(Decl *D,

SubstTemplateTypeParmType::SubstTemplateTypeParmType(
QualType Replacement, Decl *AssociatedDecl, unsigned Index,
std::optional<unsigned> PackIndex)
std::optional<unsigned> PackIndex, SubstTemplateTypeParmTypeFlag Flag)
: Type(SubstTemplateTypeParm, Replacement.getCanonicalType(),
Replacement->getDependence()),
AssociatedDecl(AssociatedDecl) {
Expand All @@ -4230,6 +4230,10 @@ SubstTemplateTypeParmType::SubstTemplateTypeParmType(

SubstTemplateTypeParmTypeBits.Index = Index;
SubstTemplateTypeParmTypeBits.PackIndex = PackIndex ? *PackIndex + 1 : 0;
SubstTemplateTypeParmTypeBits.SubstitutionFlag = llvm::to_underlying(Flag);
assert((Flag != SubstTemplateTypeParmTypeFlag::ExpandPacksInPlace ||
PackIndex) &&
"ExpandPacksInPlace needs a valid PackIndex");
assert(AssociatedDecl != nullptr);
}

Expand Down
23 changes: 21 additions & 2 deletions clang/lib/Basic/Attributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#include "clang/Basic/ParsedAttrInfo.h"
#include "clang/Basic/TargetInfo.h"

#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSwitch.h"

using namespace clang;

static int hasAttributeImpl(AttributeCommonInfo::Syntax Syntax, StringRef Name,
Expand Down Expand Up @@ -153,12 +156,28 @@ std::string AttributeCommonInfo::getNormalizedFullName() const {
normalizeName(getAttrName(), getScopeName(), getSyntax()));
}

AttributeCommonInfo::Scope
getScopeFromNormalizedScopeName(StringRef ScopeName) {
return llvm::StringSwitch<AttributeCommonInfo::Scope>(ScopeName)
.Case("", AttributeCommonInfo::Scope::NONE)
.Case("clang", AttributeCommonInfo::Scope::CLANG)
.Case("gnu", AttributeCommonInfo::Scope::GNU)
.Case("gsl", AttributeCommonInfo::Scope::GSL)
.Case("hlsl", AttributeCommonInfo::Scope::HLSL)
.Case("msvc", AttributeCommonInfo::Scope::MSVC)
.Case("omp", AttributeCommonInfo::Scope::OMP)
.Case("riscv", AttributeCommonInfo::Scope::RISCV);
}

unsigned AttributeCommonInfo::calculateAttributeSpellingListIndex() const {
// Both variables will be used in tablegen generated
// attribute spell list index matching code.
auto Syntax = static_cast<AttributeCommonInfo::Syntax>(getSyntax());
StringRef Scope = normalizeAttrScopeName(getScopeName(), Syntax);
StringRef Name = normalizeAttrName(getAttrName(), Scope, Syntax);
StringRef ScopeName = normalizeAttrScopeName(getScopeName(), Syntax);
StringRef Name = normalizeAttrName(getAttrName(), ScopeName, Syntax);

AttributeCommonInfo::Scope ComputedScope =
getScopeFromNormalizedScopeName(ScopeName);

#include "clang/Sema/AttrSpellingListIndex.inc"
}
17 changes: 15 additions & 2 deletions clang/lib/Basic/OpenMPKinds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,11 @@ unsigned clang::getOpenMPSimpleClauseType(OpenMPClauseKind Kind, StringRef Str,
return OMPC_NUMTASKS_unknown;
return Type;
}
case OMPC_allocate:
return llvm::StringSwitch<OpenMPAllocateClauseModifier>(Str)
#define OPENMP_ALLOCATE_MODIFIER(Name) .Case(#Name, OMPC_ALLOCATE_##Name)
#include "clang/Basic/OpenMPKinds.def"
.Default(OMPC_ALLOCATE_unknown);
case OMPC_unknown:
case OMPC_threadprivate:
case OMPC_if:
Expand All @@ -190,7 +195,6 @@ unsigned clang::getOpenMPSimpleClauseType(OpenMPClauseKind Kind, StringRef Str,
case OMPC_sizes:
case OMPC_permutation:
case OMPC_allocator:
case OMPC_allocate:
case OMPC_collapse:
case OMPC_private:
case OMPC_firstprivate:
Expand Down Expand Up @@ -505,6 +509,16 @@ const char *clang::getOpenMPSimpleClauseTypeName(OpenMPClauseKind Kind,
#include "clang/Basic/OpenMPKinds.def"
}
llvm_unreachable("Invalid OpenMP 'num_tasks' clause modifier");
case OMPC_allocate:
switch (Type) {
case OMPC_ALLOCATE_unknown:
return "unknown";
#define OPENMP_ALLOCATE_MODIFIER(Name) \
case OMPC_ALLOCATE_##Name: \
return #Name;
#include "clang/Basic/OpenMPKinds.def"
}
llvm_unreachable("Invalid OpenMP 'allocate' clause modifier");
case OMPC_unknown:
case OMPC_threadprivate:
case OMPC_if:
Expand All @@ -515,7 +529,6 @@ const char *clang::getOpenMPSimpleClauseTypeName(OpenMPClauseKind Kind,
case OMPC_sizes:
case OMPC_permutation:
case OMPC_allocator:
case OMPC_allocate:
case OMPC_collapse:
case OMPC_private:
case OMPC_firstprivate:
Expand Down
34 changes: 31 additions & 3 deletions clang/lib/Basic/SourceManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "llvm/Support/Endian.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/Path.h"
Expand Down Expand Up @@ -2227,6 +2228,28 @@ LLVM_DUMP_METHOD void SourceManager::dump() const {
}
}

// 123 -> "123".
// 1234 -> "1.23k".
// 123456 -> "123.46k".
// 1234567 -> "1.23M".
// 1234567890 -> "1.23G".
// 1234567890123 -> "1.23T".
static std::string humanizeNumber(uint64_t Number) {
static constexpr std::array<std::pair<uint64_t, char>, 4> Units = {
{{1'000'000'000'000UL, 'T'},
{1'000'000'000UL, 'G'},
{1'000'000UL, 'M'},
{1'000UL, 'k'}}};

for (const auto &[UnitSize, UnitSign] : Units) {
if (Number >= UnitSize) {
return llvm::formatv("{0:F}{1}", Number / static_cast<double>(UnitSize),
UnitSign);
}
}
return std::to_string(Number);
}

void SourceManager::noteSLocAddressSpaceUsage(
DiagnosticsEngine &Diag, std::optional<unsigned> MaxNotes) const {
struct Info {
Expand Down Expand Up @@ -2296,22 +2319,27 @@ void SourceManager::noteSLocAddressSpaceUsage(
int UsagePercent = static_cast<int>(100.0 * double(LocalUsage + LoadedUsage) /
MaxLoadedOffset);
Diag.Report(SourceLocation(), diag::note_total_sloc_usage)
<< LocalUsage << LoadedUsage << (LocalUsage + LoadedUsage) << UsagePercent;
<< LocalUsage << humanizeNumber(LocalUsage) << LoadedUsage
<< humanizeNumber(LoadedUsage) << (LocalUsage + LoadedUsage)
<< humanizeNumber(LocalUsage + LoadedUsage) << UsagePercent;

// Produce notes on sloc address space usage for each file with a high usage.
uint64_t ReportedSize = 0;
for (auto &[Entry, FileInfo] :
llvm::make_range(SortedUsage.begin(), SortedEnd)) {
Diag.Report(FileInfo.Loc, diag::note_file_sloc_usage)
<< FileInfo.Inclusions << FileInfo.DirectSize
<< (FileInfo.TotalSize - FileInfo.DirectSize);
<< humanizeNumber(FileInfo.DirectSize)
<< (FileInfo.TotalSize - FileInfo.DirectSize)
<< humanizeNumber(FileInfo.TotalSize - FileInfo.DirectSize);
ReportedSize += FileInfo.TotalSize;
}

// Describe any remaining usage not reported in the per-file usage.
if (ReportedSize != CountedSize) {
Diag.Report(SourceLocation(), diag::note_file_misc_sloc_usage)
<< (SortedUsage.end() - SortedEnd) << CountedSize - ReportedSize;
<< (SortedUsage.end() - SortedEnd) << CountedSize - ReportedSize
<< humanizeNumber(CountedSize - ReportedSize);
}
}

Expand Down
8 changes: 8 additions & 0 deletions clang/lib/Basic/Targets/AMDGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,14 @@ class LLVM_LIBRARY_VISIBILITY AMDGPUTargetInfo final : public TargetInfo {
}

bool hasHIPImageSupport() const override { return HasImage; }

std::pair<unsigned, unsigned> hardwareInterferenceSizes() const override {
// This is imprecise as the value can vary between 64, 128 (even 256!) bytes
// depending on the level of cache and the target architecture. We select
// the size that corresponds to the largest L1 cache line for all
// architectures.
return std::make_pair(128, 128);
}
};

} // namespace targets
Expand Down
4 changes: 2 additions & 2 deletions clang/lib/Basic/Targets/Mips.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ class LLVM_LIBRARY_VISIBILITY MipsTargetInfo : public TargetInfo {
if (ABI == "o32")
Layout = "m:m-p:32:32-i8:8:32-i16:16:32-i64:64-n32-S64";
else if (ABI == "n32")
Layout = "m:e-p:32:32-i8:8:32-i16:16:32-i64:64-n32:64-S128";
Layout = "m:e-p:32:32-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128";
else if (ABI == "n64")
Layout = "m:e-i8:8:32-i16:16:32-i64:64-n32:64-S128";
Layout = "m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128";
else
llvm_unreachable("Invalid ABI");

Expand Down
6 changes: 6 additions & 0 deletions clang/lib/Basic/Targets/X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,8 @@ bool X86TargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
HasAMXFP8 = true;
} else if (Feature == "+amx-transpose") {
HasAMXTRANSPOSE = true;
} else if (Feature == "+amx-avx512") {
HasAMXAVX512 = true;
} else if (Feature == "+cmpccxadd") {
HasCMPCCXADD = true;
} else if (Feature == "+raoint") {
Expand Down Expand Up @@ -955,6 +957,8 @@ void X86TargetInfo::getTargetDefines(const LangOptions &Opts,
Builder.defineMacro("__AMX_FP8__");
if (HasAMXTRANSPOSE)
Builder.defineMacro("__AMX_TRANSPOSE__");
if (HasAMXAVX512)
Builder.defineMacro("__AMX_AVX512__");
if (HasCMPCCXADD)
Builder.defineMacro("__CMPCCXADD__");
if (HasRAOINT)
Expand Down Expand Up @@ -1080,6 +1084,7 @@ bool X86TargetInfo::isValidFeatureName(StringRef Name) const {
return llvm::StringSwitch<bool>(Name)
.Case("adx", true)
.Case("aes", true)
.Case("amx-avx512", true)
.Case("amx-bf16", true)
.Case("amx-complex", true)
.Case("amx-fp16", true)
Expand Down Expand Up @@ -1200,6 +1205,7 @@ bool X86TargetInfo::hasFeature(StringRef Feature) const {
return llvm::StringSwitch<bool>(Feature)
.Case("adx", HasADX)
.Case("aes", HasAES)
.Case("amx-avx512", HasAMXAVX512)
.Case("amx-bf16", HasAMXBF16)
.Case("amx-complex", HasAMXCOMPLEX)
.Case("amx-fp16", HasAMXFP16)
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Basic/Targets/X86.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ class LLVM_LIBRARY_VISIBILITY X86TargetInfo : public TargetInfo {
bool HasAMXCOMPLEX = false;
bool HasAMXFP8 = false;
bool HasAMXTRANSPOSE = false;
bool HasAMXAVX512 = false;
bool HasSERIALIZE = false;
bool HasTSXLDTRK = false;
bool HasUSERMSR = false;
Expand Down
140 changes: 135 additions & 5 deletions clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,149 @@

#include "clang/AST/ASTContext.h"
#include "clang/AST/DeclBase.h"
#include "clang/AST/GlobalDecl.h"
#include "clang/Basic/SourceManager.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"

using namespace cir;
using namespace clang;
using namespace clang::CIRGen;

CIRGenModule::CIRGenModule(mlir::MLIRContext &context,
clang::ASTContext &astctx,
const clang::CodeGenOptions &cgo,
DiagnosticsEngine &diags)
: astCtx(astctx), langOpts(astctx.getLangOpts()),
theModule{mlir::ModuleOp::create(mlir::UnknownLoc())},
target(astCtx.getTargetInfo()) {}
: builder(&context), astCtx(astctx), langOpts(astctx.getLangOpts()),
theModule{mlir::ModuleOp::create(mlir::UnknownLoc::get(&context))},
diags(diags), target(astCtx.getTargetInfo()) {}

mlir::Location CIRGenModule::getLoc(SourceLocation cLoc) {
assert(cLoc.isValid() && "expected valid source location");
const SourceManager &sm = astCtx.getSourceManager();
PresumedLoc pLoc = sm.getPresumedLoc(cLoc);
StringRef filename = pLoc.getFilename();
return mlir::FileLineColLoc::get(builder.getStringAttr(filename),
pLoc.getLine(), pLoc.getColumn());
}

mlir::Location CIRGenModule::getLoc(SourceRange cRange) {
assert(cRange.isValid() && "expected a valid source range");
mlir::Location begin = getLoc(cRange.getBegin());
mlir::Location end = getLoc(cRange.getEnd());
mlir::Attribute metadata;
return mlir::FusedLoc::get({begin, end}, metadata, builder.getContext());
}

void CIRGenModule::buildGlobal(clang::GlobalDecl gd) {
const auto *global = cast<ValueDecl>(gd.getDecl());

if (const auto *fd = dyn_cast<FunctionDecl>(global)) {
// Update deferred annotations with the latest declaration if the function
// was already used or defined.
if (fd->hasAttr<AnnotateAttr>())
errorNYI(fd->getSourceRange(), "deferredAnnotations");
if (!fd->doesThisDeclarationHaveABody()) {
if (!fd->doesDeclarationForceExternallyVisibleDefinition())
return;

errorNYI(fd->getSourceRange(),
"function declaration that forces code gen");
return;
}
} else {
errorNYI(global->getSourceRange(), "global variable declaration");
}

// TODO(CIR): Defer emitting some global definitions until later
buildGlobalDefinition(gd);
}

void CIRGenModule::buildGlobalFunctionDefinition(clang::GlobalDecl gd,
mlir::Operation *op) {
auto const *funcDecl = cast<FunctionDecl>(gd.getDecl());
auto funcOp = builder.create<cir::FuncOp>(
getLoc(funcDecl->getSourceRange()), funcDecl->getIdentifier()->getName());
theModule.push_back(funcOp);
}

void CIRGenModule::buildGlobalDefinition(clang::GlobalDecl gd,
mlir::Operation *op) {
const auto *decl = cast<ValueDecl>(gd.getDecl());
if (const auto *fd = dyn_cast<FunctionDecl>(decl)) {
// TODO(CIR): Skip generation of CIR for functions with available_externally
// linkage at -O0.

if (const auto *method = dyn_cast<CXXMethodDecl>(decl)) {
// Make sure to emit the definition(s) before we emit the thunks. This is
// necessary for the generation of certain thunks.
(void)method;
errorNYI(method->getSourceRange(), "member function");
return;
}

if (fd->isMultiVersion())
errorNYI(fd->getSourceRange(), "multiversion functions");
buildGlobalFunctionDefinition(gd, op);
return;
}

llvm_unreachable("Invalid argument to CIRGenModule::buildGlobalDefinition");
}

// Emit code for a single top level declaration.
void CIRGenModule::buildTopLevelDecl(Decl *decl) {}
void CIRGenModule::buildTopLevelDecl(Decl *decl) {

// Ignore dependent declarations.
if (decl->isTemplated())
return;

switch (decl->getKind()) {
default:
errorNYI(decl->getBeginLoc(), "declaration of kind",
decl->getDeclKindName());
break;

case Decl::Function: {
auto *fd = cast<FunctionDecl>(decl);
// Consteval functions shouldn't be emitted.
if (!fd->isConsteval())
buildGlobal(fd);
break;
}
}
}

DiagnosticBuilder CIRGenModule::errorNYI(llvm::StringRef feature) {
unsigned diagID = diags.getCustomDiagID(
DiagnosticsEngine::Error, "ClangIR code gen Not Yet Implemented: %0");
return diags.Report(diagID) << feature;
}

DiagnosticBuilder CIRGenModule::errorNYI(SourceLocation loc,
llvm::StringRef feature) {
unsigned diagID = diags.getCustomDiagID(
DiagnosticsEngine::Error, "ClangIR code gen Not Yet Implemented: %0");
return diags.Report(loc, diagID) << feature;
}

DiagnosticBuilder CIRGenModule::errorNYI(SourceLocation loc,
llvm::StringRef feature,
llvm::StringRef name) {
unsigned diagID = diags.getCustomDiagID(
DiagnosticsEngine::Error, "ClangIR code gen Not Yet Implemented: %0: %1");
return diags.Report(loc, diagID) << feature << name;
}

DiagnosticBuilder CIRGenModule::errorNYI(SourceRange loc,
llvm::StringRef feature) {
return errorNYI(loc.getBegin(), feature) << loc;
}

DiagnosticBuilder CIRGenModule::errorNYI(SourceRange loc,
llvm::StringRef feature,
llvm::StringRef name) {
return errorNYI(loc.getBegin(), feature, name) << loc;
}
42 changes: 38 additions & 4 deletions clang/lib/CIR/CodeGen/CIRGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,24 @@

#include "CIRGenTypeCache.h"

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/StringRef.h"

namespace clang {
class ASTContext;
class CodeGenOptions;
class Decl;
class DiagnosticBuilder;
class DiagnosticsEngine;
class GlobalDecl;
class LangOptions;
class SourceLocation;
class SourceRange;
class TargetInfo;
} // namespace clang

using namespace clang;
namespace cir {
namespace CIRGen {

/// This class organizes the cross-function state that is used while generating
/// CIR code.
Expand All @@ -44,6 +48,10 @@ class CIRGenModule : public CIRGenTypeCache {
~CIRGenModule() = default;

private:
// TODO(CIR) 'builder' will change to CIRGenBuilderTy once that type is
// defined
mlir::OpBuilder builder;

/// Hold Clang AST information.
clang::ASTContext &astCtx;

Expand All @@ -52,11 +60,37 @@ class CIRGenModule : public CIRGenTypeCache {
/// A "module" matches a c/cpp source file: containing a list of functions.
mlir::ModuleOp theModule;

clang::DiagnosticsEngine &diags;

const clang::TargetInfo &target;

public:
mlir::ModuleOp getModule() const { return theModule; }

/// Helpers to convert the presumed location of Clang's SourceLocation to an
/// MLIR Location.
mlir::Location getLoc(clang::SourceLocation cLoc);
mlir::Location getLoc(clang::SourceRange cRange);

void buildTopLevelDecl(clang::Decl *decl);

/// Emit code for a single global function or variable declaration. Forward
/// declarations are emitted lazily.
void buildGlobal(clang::GlobalDecl gd);

void buildGlobalDefinition(clang::GlobalDecl gd,
mlir::Operation *op = nullptr);
void buildGlobalFunctionDefinition(clang::GlobalDecl gd, mlir::Operation *op);

/// Helpers to emit "not yet implemented" error diagnostics
DiagnosticBuilder errorNYI(llvm::StringRef);
DiagnosticBuilder errorNYI(SourceLocation, llvm::StringRef);
DiagnosticBuilder errorNYI(SourceLocation, llvm::StringRef, llvm::StringRef);
DiagnosticBuilder errorNYI(SourceRange, llvm::StringRef);
DiagnosticBuilder errorNYI(SourceRange, llvm::StringRef, llvm::StringRef);
};
} // namespace cir
} // namespace CIRGen

} // namespace clang

#endif // LLVM_CLANG_LIB_CIR_CODEGEN_CIRGENMODULE_H
4 changes: 2 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenTypeCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#ifndef LLVM_CLANG_LIB_CIR_CIRGENTYPECACHE_H
#define LLVM_CLANG_LIB_CIR_CIRGENTYPECACHE_H

namespace cir {
namespace clang::CIRGen {

/// This structure provides a set of types that are commonly used
/// during IR emission. It's initialized once in CodeGenModule's
Expand All @@ -22,6 +22,6 @@ struct CIRGenTypeCache {
CIRGenTypeCache() = default;
};

} // namespace cir
} // namespace clang::CIRGen

#endif // LLVM_CLANG_LIB_CIR_CODEGEN_CIRGENTYPECACHE_H
10 changes: 9 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@

#include "CIRGenModule.h"

#include "mlir/IR/MLIRContext.h"

#include "clang/AST/DeclGroup.h"
#include "clang/CIR/CIRGenerator.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"

using namespace cir;
using namespace clang;
Expand All @@ -31,9 +34,14 @@ void CIRGenerator::Initialize(ASTContext &astCtx) {

this->astCtx = &astCtx;

cgm = std::make_unique<CIRGenModule>(*mlirCtx, astCtx, codeGenOpts, diags);
mlirCtx = std::make_unique<mlir::MLIRContext>();
mlirCtx->loadDialect<cir::CIRDialect>();
cgm = std::make_unique<clang::CIRGen::CIRGenModule>(*mlirCtx.get(), astCtx,
codeGenOpts, diags);
}

mlir::ModuleOp CIRGenerator::getModule() const { return cgm->getModule(); }

bool CIRGenerator::HandleTopLevelDecl(DeclGroupRef group) {

for (Decl *decl : group)
Expand Down
38 changes: 38 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===- CIRAttrs.cpp - MLIR CIR Attributes ---------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the attributes in the CIR dialect.
//
//===----------------------------------------------------------------------===//

#include "clang/CIR/Dialect/IR/CIRDialect.h"

using namespace mlir;
using namespace cir;

//===----------------------------------------------------------------------===//
// General CIR parsing / printing
//===----------------------------------------------------------------------===//

Attribute CIRDialect::parseAttribute(DialectAsmParser &parser,
Type type) const {
// No attributes yet to parse
return Attribute{};
}

void CIRDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
// No attributes yet to print
}

//===----------------------------------------------------------------------===//
// CIR Dialect
//===----------------------------------------------------------------------===//

void CIRDialect::registerAttributes() {
// No attributes yet to register
}
55 changes: 54 additions & 1 deletion clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,57 @@
//
//===----------------------------------------------------------------------===//

#include <clang/CIR/Dialect/IR/CIRDialect.h>
#include "clang/CIR/Dialect/IR/CIRDialect.h"

#include "mlir/Support/LogicalResult.h"

#include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc"

using namespace mlir;
using namespace cir;

//===----------------------------------------------------------------------===//
// CIR Dialect
//===----------------------------------------------------------------------===//

void cir::CIRDialect::initialize() {
registerTypes();
registerAttributes();
addOperations<
#define GET_OP_LIST
#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
>();
}

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//

void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
StringRef name) {
result.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
}

ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
state.attributes))
return failure();
return success();
}

void cir::FuncOp::print(OpAsmPrinter &p) {
p << ' ';
// For now the only property a function has is its name
p.printSymbolName(getSymName());
}

mlir::LogicalResult cir::FuncOp::verify() { return success(); }

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "clang/CIR/Dialect/IR/CIROps.cpp.inc"
37 changes: 37 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===- CIRTypes.cpp - MLIR CIR Types --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the types in the CIR dialect.
//
//===----------------------------------------------------------------------===//

#include "clang/CIR/Dialect/IR/CIRDialect.h"

using namespace mlir;
using namespace cir;

//===----------------------------------------------------------------------===//
// General CIR parsing / printing
//===----------------------------------------------------------------------===//

Type CIRDialect::parseType(DialectAsmParser &parser) const {
// No types yet to parse
return Type{};
}

void CIRDialect::printType(Type type, DialectAsmPrinter &os) const {
// No types yet to print
}

//===----------------------------------------------------------------------===//
// CIR Dialect
//===----------------------------------------------------------------------===//

void CIRDialect::registerTypes() {
// No types yet to register
}
5 changes: 5 additions & 0 deletions clang/lib/CIR/Dialect/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
add_clang_library(MLIRCIR
CIRAttrs.cpp
CIRDialect.cpp
CIRTypes.cpp

LINK_LIBS PUBLIC
MLIRIR
)
41 changes: 40 additions & 1 deletion clang/lib/CIR/FrontendAction/CIRGenAction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ class CIRGenConsumer : public clang::ASTConsumer {

virtual void anchor();

CIRGenAction::OutputType Action;

std::unique_ptr<raw_pwrite_stream> OutputStream;

ASTContext *Context{nullptr};
IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS;
std::unique_ptr<CIRGenerator> Gen;

Expand All @@ -37,14 +40,37 @@ class CIRGenConsumer : public clang::ASTConsumer {
const LangOptions &LangOptions,
const FrontendOptions &FEOptions,
std::unique_ptr<raw_pwrite_stream> OS)
: OutputStream(std::move(OS)), FS(VFS),
: Action(Action), OutputStream(std::move(OS)), FS(VFS),
Gen(std::make_unique<CIRGenerator>(DiagnosticsEngine, std::move(VFS),
CodeGenOptions)) {}

void Initialize(ASTContext &Ctx) override {
assert(!Context && "initialized multiple times");
Context = &Ctx;
Gen->Initialize(Ctx);
}

bool HandleTopLevelDecl(DeclGroupRef D) override {
Gen->HandleTopLevelDecl(D);
return true;
}

void HandleTranslationUnit(ASTContext &C) override {
Gen->HandleTranslationUnit(C);
mlir::ModuleOp MlirModule = Gen->getModule();
switch (Action) {
case CIRGenAction::OutputType::EmitCIR:
if (OutputStream && MlirModule) {
mlir::OpPrintingFlags Flags;
Flags.enableDebugInfo(/*enable=*/true, /*prettyForm=*/false);
MlirModule->print(*OutputStream, Flags);
}
break;
default:
llvm_unreachable("NYI: CIRGenAction other than EmitCIR");
break;
}
}
};
} // namespace cir

Expand All @@ -55,10 +81,23 @@ CIRGenAction::CIRGenAction(OutputType Act, mlir::MLIRContext *MLIRCtx)

CIRGenAction::~CIRGenAction() { MLIRMod.release(); }

static std::unique_ptr<raw_pwrite_stream>
getOutputStream(CompilerInstance &CI, StringRef InFile,
CIRGenAction::OutputType Action) {
switch (Action) {
case CIRGenAction::OutputType::EmitCIR:
return CI.createDefaultOutputFile(false, InFile, "cir");
}
llvm_unreachable("Invalid CIRGenAction::OutputType");
}

std::unique_ptr<ASTConsumer>
CIRGenAction::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) {
std::unique_ptr<llvm::raw_pwrite_stream> Out = CI.takeOutputStream();

if (!Out)
Out = getOutputStream(CI, InFile, Action);

auto Result = std::make_unique<cir::CIRGenConsumer>(
Action, CI.getDiagnostics(), &CI.getVirtualFileSystem(),
CI.getHeaderSearchOpts(), CI.getCodeGenOpts(), CI.getTargetOpts(),
Expand Down
87 changes: 82 additions & 5 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3691,6 +3691,35 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
return RValue::get(emitBuiltinObjectSize(E->getArg(0), Type, ResType,
/*EmittedE=*/nullptr, IsDynamic));
}
case Builtin::BI__builtin_counted_by_ref: {
// Default to returning '(void *) 0'.
llvm::Value *Result = llvm::ConstantPointerNull::get(
llvm::PointerType::getUnqual(getLLVMContext()));

const Expr *Arg = E->getArg(0)->IgnoreParenImpCasts();

if (auto *UO = dyn_cast<UnaryOperator>(Arg);
UO && UO->getOpcode() == UO_AddrOf) {
Arg = UO->getSubExpr()->IgnoreParenImpCasts();

if (auto *ASE = dyn_cast<ArraySubscriptExpr>(Arg))
Arg = ASE->getBase()->IgnoreParenImpCasts();
}

if (const MemberExpr *ME = dyn_cast_if_present<MemberExpr>(Arg)) {
if (auto *CATy =
ME->getMemberDecl()->getType()->getAs<CountAttributedType>();
CATy && CATy->getKind() == CountAttributedType::CountedBy) {
const auto *FAMDecl = cast<FieldDecl>(ME->getMemberDecl());
if (const FieldDecl *CountFD = FAMDecl->findCountedByField())
Result = GetCountedByFieldExprGEP(Arg, FAMDecl, CountFD);
else
llvm::report_fatal_error("Cannot find the counted_by 'count' field");
}
}

return RValue::get(Result);
}
case Builtin::BI__builtin_prefetch: {
Value *Locality, *RW, *Address = EmitScalarExpr(E->getArg(0));
// FIXME: Technically these constants should of type 'int', yes?
Expand Down Expand Up @@ -18671,6 +18700,12 @@ Value *EmitAMDGPUGridSize(CodeGenFunction &CGF, unsigned Index) {
auto *GEP = CGF.Builder.CreateGEP(CGF.Int8Ty, DP, Offset);
auto *LD = CGF.Builder.CreateLoad(
Address(GEP, CGF.Int32Ty, CharUnits::fromQuantity(4)));

llvm::MDBuilder MDB(CGF.getLLVMContext());

// Known non-zero.
LD->setMetadata(llvm::LLVMContext::MD_range,
MDB.createRange(APInt(32, 1), APInt::getZero(32)));
LD->setMetadata(llvm::LLVMContext::MD_invariant_load,
llvm::MDNode::get(CGF.getLLVMContext(), {}));
return LD;
Expand Down Expand Up @@ -18767,6 +18802,15 @@ static Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
return RT.getUDotIntrinsic();
}

Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) {
if (QT->hasSignedIntegerRepresentation()) {
return RT.getFirstBitSHighIntrinsic();
}

assert(QT->hasUnsignedIntegerRepresentation());
return RT.getFirstBitUHighIntrinsic();
}

Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
const CallExpr *E,
ReturnValueSlot ReturnValue) {
Expand Down Expand Up @@ -18794,14 +18838,21 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
Value *OpMax = EmitScalarExpr(E->getArg(2));

QualType Ty = E->getArg(0)->getType();
bool IsUnsigned = false;
if (auto *VecTy = Ty->getAs<VectorType>())
Ty = VecTy->getElementType();
IsUnsigned = Ty->isUnsignedIntegerType();

Intrinsic::ID Intr;
if (Ty->isFloatingType()) {
Intr = CGM.getHLSLRuntime().getNClampIntrinsic();
} else if (Ty->isUnsignedIntegerType()) {
Intr = CGM.getHLSLRuntime().getUClampIntrinsic();
} else {
assert(Ty->isSignedIntegerType());
Intr = CGM.getHLSLRuntime().getSClampIntrinsic();
}
return Builder.CreateIntrinsic(
/*ReturnType=*/OpX->getType(),
IsUnsigned ? Intrinsic::dx_uclamp : Intrinsic::dx_clamp,
ArrayRef<Value *>{OpX, OpMin, OpMax}, nullptr, "dx.clamp");
/*ReturnType=*/OpX->getType(), Intr,
ArrayRef<Value *>{OpX, OpMin, OpMax}, nullptr, "hlsl.clamp");
}
case Builtin::BI__builtin_hlsl_cross: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
Expand Down Expand Up @@ -18866,6 +18917,25 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
"hlsl.dot4add.i8packed");
}
case Builtin::BI__builtin_hlsl_dot4add_u8packed: {
Value *A = EmitScalarExpr(E->getArg(0));
Value *B = EmitScalarExpr(E->getArg(1));
Value *C = EmitScalarExpr(E->getArg(2));

Intrinsic::ID ID = CGM.getHLSLRuntime().getDot4AddU8PackedIntrinsic();
return Builder.CreateIntrinsic(
/*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
"hlsl.dot4add.u8packed");
}
case Builtin::BI__builtin_hlsl_elementwise_firstbithigh: {

Value *X = EmitScalarExpr(E->getArg(0));

return Builder.CreateIntrinsic(
/*ReturnType=*/ConvertType(E->getType()),
getFirstBitHighIntrinsic(CGM.getHLSLRuntime(), E->getArg(0)->getType()),
ArrayRef<Value *>{X}, nullptr, "hlsl.firstbithigh");
}
case Builtin::BI__builtin_hlsl_lerp: {
Value *X = EmitScalarExpr(E->getArg(0));
Value *Y = EmitScalarExpr(E->getArg(1));
Expand Down Expand Up @@ -19022,6 +19092,13 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
}
case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID),
ArrayRef{OpExpr});
}
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
// We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in
// defined in SPIRVBuiltins.td. So instead we manually get the matching name
Expand Down
29 changes: 17 additions & 12 deletions clang/lib/CodeGen/CGExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1145,15 +1145,7 @@ static bool getGEPIndicesToField(CodeGenFunction &CGF, const RecordDecl *RD,
return false;
}

/// This method is typically called in contexts where we can't generate
/// side-effects, like in __builtin_dynamic_object_size. When finding
/// expressions, only choose those that have either already been emitted or can
/// be loaded without side-effects.
///
/// - \p FAMDecl: the \p Decl for the flexible array member. It may not be
/// within the top-level struct.
/// - \p CountDecl: must be within the same non-anonymous struct as \p FAMDecl.
llvm::Value *CodeGenFunction::EmitLoadOfCountedByField(
llvm::Value *CodeGenFunction::GetCountedByFieldExprGEP(
const Expr *Base, const FieldDecl *FAMDecl, const FieldDecl *CountDecl) {
const RecordDecl *RD = CountDecl->getParent()->getOuterLexicalRecordContext();

Expand Down Expand Up @@ -1182,12 +1174,25 @@ llvm::Value *CodeGenFunction::EmitLoadOfCountedByField(
return nullptr;

Indices.push_back(Builder.getInt32(0));
Res = Builder.CreateInBoundsGEP(
return Builder.CreateInBoundsGEP(
ConvertType(QualType(RD->getTypeForDecl(), 0)), Res,
RecIndicesTy(llvm::reverse(Indices)), "..counted_by.gep");
}

return Builder.CreateAlignedLoad(ConvertType(CountDecl->getType()), Res,
getIntAlign(), "..counted_by.load");
/// This method is typically called in contexts where we can't generate
/// side-effects, like in __builtin_dynamic_object_size. When finding
/// expressions, only choose those that have either already been emitted or can
/// be loaded without side-effects.
///
/// - \p FAMDecl: the \p Decl for the flexible array member. It may not be
/// within the top-level struct.
/// - \p CountDecl: must be within the same non-anonymous struct as \p FAMDecl.
llvm::Value *CodeGenFunction::EmitLoadOfCountedByField(
const Expr *Base, const FieldDecl *FAMDecl, const FieldDecl *CountDecl) {
if (llvm::Value *GEP = GetCountedByFieldExprGEP(Base, FAMDecl, CountDecl))
return Builder.CreateAlignedLoad(ConvertType(CountDecl->getType()), GEP,
getIntAlign(), "..counted_by.load");
return nullptr;
}

void CodeGenFunction::EmitBoundsCheck(const Expr *E, const Expr *Base,
Expand Down
7 changes: 7 additions & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,15 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh)
GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitSHigh, firstbitshigh)
GENERATE_HLSL_INTRINSIC_FUNCTION(NClamp, nclamp)
GENERATE_HLSL_INTRINSIC_FUNCTION(SClamp, sclamp)
GENERATE_HLSL_INTRINSIC_FUNCTION(UClamp, uclamp)

GENERATE_HLSL_INTRINSIC_FUNCTION(CreateHandleFromBinding, handle_fromBinding)

Expand Down
4 changes: 4 additions & 0 deletions clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3305,6 +3305,10 @@ class CodeGenFunction : public CodeGenTypeCache {
const FieldDecl *FAMDecl,
uint64_t &Offset);

llvm::Value *GetCountedByFieldExprGEP(const Expr *Base,
const FieldDecl *FAMDecl,
const FieldDecl *CountDecl);

/// Build an expression accessing the "counted_by" field.
llvm::Value *EmitLoadOfCountedByField(const Expr *Base,
const FieldDecl *FAMDecl,
Expand Down
Loading