Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions clang/lib/DPCT/APINamesTemplateType.inc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,18 @@ TYPE_REWRITE_ENTRY("cub::ArgMin",
HEADER_INSERTION_FACTORY(
HeaderType::HT_DPCT_DPL_Utils,
TYPE_FACTORY(STR(MapNames::getDpctNamespace() + "argmin"))))
TYPE_REWRITE_ENTRY(
"cub::BlockRadixSort",
HEADER_INSERTION_FACTORY(HeaderType::HT_DPCT_GROUP_Utils,
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"group::group_radix_sort"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2))))
TYPE_REWRITE_ENTRY(
"cub::BlockExchange",
HEADER_INSERTION_FACTORY(HeaderType::HT_DPCT_GROUP_Utils,
TYPE_FACTORY(STR(MapNames::getDpctNamespace() +
"group::exchange"),
TEMPLATE_ARG(0), TEMPLATE_ARG(2))))
FEATURE_REQUEST_FACTORY(
HelperFeatureEnum::device_ext,
TYPE_REWRITE_ENTRY("thrust::system::cuda::experimental::pinned_allocator",
Expand Down
16 changes: 8 additions & 8 deletions clang/lib/DPCT/APINames_CUB.inc
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,23 @@ ENTRY_MEMBER_FUNCTION(cub::BlockAdjacentDifference, cub::BlockAdjacentDifference
ENTRY_MEMBER_FUNCTION(cub::BlockDiscontinuity, cub::BlockDiscontinuity, FlagHeads, FlagHeads, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockDiscontinuity, cub::BlockDiscontinuity, FlagTails, FlagTails, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockDiscontinuity, cub::BlockDiscontinuity, FlagHeadsAndTails, FlagHeadsAndTails, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, StripedToBlocked, StripedToBlocked, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, BlockedToStriped, BlockedToStriped, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, StripedToBlocked, StripedToBlocked, true, NO_FLAG, P4, "Successful")
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, BlockedToStriped, BlockedToStriped, true, NO_FLAG, P4, "Successful")
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, WarpStripedToBlocked, WarpStripedToBlocked, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, BlockedToWarpStriped, BlockedToWarpStriped, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, ScatterToBlocked, ScatterToBlocked, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, ScatterToStriped, ScatterToStriped, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, ScatterToBlocked, ScatterToBlocked, true, NO_FLAG, P4, "Successful")
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, ScatterToStriped, ScatterToStriped, true, NO_FLAG, P4, "Successful")
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, ScatterToStripedGuarded, ScatterToStripedGuarded, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockExchange, cub::BlockExchange, ScatterToStripedFlagged, ScatterToStripedFlagged, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockHistogram, cub::BlockHistogram, InitHistogram, InitHistogram, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockHistogram, cub::BlockHistogram, Histogram, Histogram, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockHistogram, cub::BlockHistogram, Composite, Composite, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockLoad, cub::BlockLoad, Load, Load, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockStore, cub::BlockStore, Store, Store, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockRadixSort, cub::BlockRadixSort, Sort, Sort, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockRadixSort, cub::BlockRadixSort, SortDescending, SortDescending, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockRadixSort, cub::BlockRadixSort, SortBlockedToStriped, SortBlockedToStriped, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockRadixSort, cub::BlockRadixSort, SortDescendingBlockedToStriped, SortDescendingBlockedToStriped, false, NO_FLAG, P4, "Comment")
ENTRY_MEMBER_FUNCTION(cub::BlockRadixSort, cub::BlockRadixSort, Sort, Sort, true, NO_FLAG, P4, "Partial")
ENTRY_MEMBER_FUNCTION(cub::BlockRadixSort, cub::BlockRadixSort, SortDescending, SortDescending, true, NO_FLAG, P4, "Partial")
ENTRY_MEMBER_FUNCTION(cub::BlockRadixSort, cub::BlockRadixSort, SortBlockedToStriped, SortBlockedToStriped, true, NO_FLAG, P4, "Partial")
ENTRY_MEMBER_FUNCTION(cub::BlockRadixSort, cub::BlockRadixSort, SortDescendingBlockedToStriped, SortDescendingBlockedToStriped, true, NO_FLAG, P4, "Partial")
ENTRY_MEMBER_FUNCTION(cub::BlockReduce, cub::BlockReduce, Reduce, Reduce, true, NO_FLAG, P4, "Successful")
ENTRY_MEMBER_FUNCTION(cub::BlockReduce, cub::BlockReduce, Sum, Sum, true, NO_FLAG, P4, "Successful")
ENTRY_MEMBER_FUNCTION(cub::BlockScan, cub::BlockScan, ExclusiveSum, ExclusiveSum, true, NO_FLAG, P4, "Successful")
Expand Down
38 changes: 29 additions & 9 deletions clang/lib/DPCT/AnalysisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3549,26 +3549,46 @@ void TempStorageVarInfo::addAccessorDecl(StmtList &AccessorList,
StringRef LocalSize) const {
std::string Accessor;
llvm::raw_string_ostream OS(Accessor);
OS << MapNames::getClNamespace() << "local_accessor<std::byte, 1> " << Name
<< "_acc(";
DpctGlobalInfo::printCtadClass(OS, MapNames::getClNamespace() + "range", 1);
OS << '(' << LocalSize << ".size() * sizeof(" << Type->getSourceString()
<< ")), cgh);";
switch (Kind) {
case BlockReduce:
OS << MapNames::getClNamespace() << "local_accessor<std::byte, 1> " << Name
<< "_acc(";
DpctGlobalInfo::printCtadClass(OS, MapNames::getClNamespace() + "range", 1);
OS << '(' << LocalSize << ".size() * sizeof("
<< ValueType->getSourceString() << ')' << ')';
break;
case BlockRadixSort:
OS << MapNames::getClNamespace() << "local_accessor<uint8_t, 1> " << Name
<< "_acc(";
OS << TmpMemSizeCalFn << '(' << LocalSize << ".size()" << ')';
break;
}

OS << ", cgh);";
AccessorList.emplace_back(Accessor);
}
void TempStorageVarInfo::applyTemplateArguments(
const std::vector<TemplateArgumentInfo> &TAList) {
Type = Type->applyTemplateArguments(TAList);
ValueType = ValueType->applyTemplateArguments(TAList);
}
ParameterStream &TempStorageVarInfo::getFuncDecl(ParameterStream &PS) {
return PS << MapNames::getClNamespace() << "local_accessor<std::byte, 1> "
<< Name;
switch (Kind) {
case BlockReduce:
PS << MapNames::getClNamespace() << "local_accessor<std::byte, 1> ";
break;
case BlockRadixSort:
PS << "uint8_t *";
break;
}
return PS << Name;
}
ParameterStream &TempStorageVarInfo::getFuncArg(ParameterStream &PS) {
return PS << Name;
}
ParameterStream &TempStorageVarInfo::getKernelArg(ParameterStream &PS) {
return PS << Name << "_acc";
if (Kind == BlockReduce)
return PS << Name << "_acc";
return PS << "&" << Name << "_acc[0]";
}
///// class CudaLaunchTextureObjectInfo /////
std::string
Expand Down
19 changes: 15 additions & 4 deletions clang/lib/DPCT/AnalysisInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -2144,14 +2144,25 @@ class TemplateArgumentInfo {
};

class TempStorageVarInfo {
public:
enum APIKind {
BlockReduce,
BlockRadixSort,
};

private:
unsigned Offset;
APIKind Kind;
std::string Name;
std::shared_ptr<TemplateDependentStringInfo> Type;
std::string TmpMemSizeCalFn;
std::shared_ptr<TemplateDependentStringInfo> ValueType;

public:
TempStorageVarInfo(unsigned Off, StringRef Name,
std::shared_ptr<TemplateDependentStringInfo> T)
: Offset(Off), Name(Name.str()), Type(T) {}
TempStorageVarInfo(unsigned Off, APIKind Kind, StringRef Name,
std::string TmpMemSizeCalFn,
std::shared_ptr<TemplateDependentStringInfo> ValT)
: Offset(Off), Kind(Kind), Name(Name.str()),
TmpMemSizeCalFn(TmpMemSizeCalFn), ValueType(ValT) {}
const std::string &getName() const { return Name; }
unsigned getOffset() const { return Offset; }
void addAccessorDecl(StmtList &AccessorList, StringRef LocalSize) const;
Expand Down
185 changes: 115 additions & 70 deletions clang/lib/DPCT/CUBAPIMigration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ void CubTypeRule::registerMatcher(ast_matchers::MatchFinder &MF) {
"cub::KeyValuePair", "cub::CountingInputIterator",
"cub::TransformInputIterator", "cub::ConstantInputIterator",
"cub::ArgIndexInputIterator", "cub::DiscardOutputIterator",
"cub::DoubleBuffer", "cub::NullType", "cub::ArgMax", "cub::ArgMin");
"cub::DoubleBuffer", "cub::NullType", "cub::ArgMax", "cub::ArgMin",
"cub::BlockRadixSort", "cub::BlockExchange");
};

MF.addMatcher(
Expand Down Expand Up @@ -157,9 +158,14 @@ void CubDeviceLevelRule::runRule(
void CubMemberCallRule::registerMatcher(ast_matchers::MatchFinder &MF) {
MF.addMatcher(
cxxMemberCallExpr(
allOf(on(hasType(hasCanonicalType(qualType(hasDeclaration(
namedDecl(hasName("cub::ArgIndexInputIterator"))))))),
callee(cxxMethodDecl(hasName("normalize")))))
allOf(
on(hasType(hasCanonicalType(qualType(hasDeclaration(namedDecl(
hasAnyName("cub::ArgIndexInputIterator",
"cub::BlockRadixSort", "cub::BlockExchange"))))))),
callee(cxxMethodDecl(hasAnyName(
"normalize", "Sort", "SortDescending", "BlockedToStriped",
"StripedToBlocked", "ScatterToBlocked", "ScatterToStriped",
"SortBlockedToStriped", "SortDescendingBlockedToStriped")))))
.bind("memberCall"),
this);

Expand All @@ -172,11 +178,110 @@ void CubMemberCallRule::registerMatcher(ast_matchers::MatchFinder &MF) {
this);
}

static std::pair<const VarDecl *, TypeLoc>
getTempstorageVarAndValueTypeLoc(const CXXMemberCallExpr *MC) {
Expr *Obj = MC->getImplicitObjectArgument();
const VarDecl *TempStorage = nullptr;

auto FindTempStorageVarInCtor = [](const Expr *E) -> const VarDecl * {
if (auto *Ctor = dyn_cast<CXXConstructExpr>(E)) {
if (auto *DRE = dyn_cast<DeclRefExpr>(Ctor->getArg(0))) {
if (auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
if (VD->hasAttr<CUDASharedAttr>() && isCubVar(VD)) {
return VD;
}
}
}
}
return nullptr;
};

auto FindDataTypeLoc = [](TypeLoc Loc) -> TypeLoc {
if (Loc.isNull())
return Loc;
while (true) {
switch (Loc.getTypeLocClass()) {
case TypeLoc::Elaborated:
Loc = Loc.getNextTypeLoc();
break;
case TypeLoc::Typedef: {
auto NewLoc = Loc.castAs<TypedefTypeLoc>();
Loc = NewLoc.getTypedefNameDecl()->getTypeSourceInfo()->getTypeLoc();
break;
}
case TypeLoc::TemplateSpecialization: {
auto NewLoc = Loc.getAs<TemplateSpecializationTypeLoc>();
return NewLoc.getArgLocInfo(0).getAsTypeSourceInfo()->getTypeLoc();
break;
}
default:
return Loc;
}
}
};

TypeLoc DataTypeLoc;
if (const auto *MTE = dyn_cast<MaterializeTemporaryExpr>(Obj)) {
if (auto *TOE = dyn_cast<CXXTemporaryObjectExpr>(MTE->getSubExpr())) {
DataTypeLoc = FindDataTypeLoc(TOE->getTypeSourceInfo()->getTypeLoc());
} else if (auto *FC = dyn_cast<CXXFunctionalCastExpr>(MTE->getSubExpr())) {
DataTypeLoc = FindDataTypeLoc(FC->getTypeInfoAsWritten()->getTypeLoc());
}
TempStorage = FindTempStorageVarInCtor(MTE->getSubExpr()->IgnoreCasts());
} else if (const auto *DRE = dyn_cast<DeclRefExpr>(Obj)) {
if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
DataTypeLoc = FindDataTypeLoc(VD->getTypeSourceInfo()->getTypeLoc());
if (isCubCollectiveRecordType(VD->getType()) && VD->hasInit())
TempStorage = FindTempStorageVarInCtor(VD->getInit());
}
}
return {TempStorage, DataTypeLoc};
}

void CubMemberCallRule::runRule(
const ast_matchers::MatchFinder::MatchResult &Result) {
ExprAnalysis EA;
if (const auto *E1 = getNodeAsType<CXXMemberCallExpr>(Result, "memberCall")) {
EA.analyze(E1);
if (const auto *BlockMC =
getNodeAsType<CXXMemberCallExpr>(Result, "memberCall")) {
EA.analyze(BlockMC);
StringRef Name = BlockMC->getMethodDecl()->getName();
bool isBlockRadixSort = Name == "Sort" || Name == "SortDescending" ||
Name == "SortBlockedToStriped" ||
Name == "SortDescendingBlockedToStriped";
bool isBlockExchange =
Name == "BlockedToStriped" || Name == "StripedToBlocked" ||
Name == "StripedToBlocked" || Name == "ScatterToBlocked" ||
Name == "ScatterToStriped";
if (isBlockRadixSort || isBlockExchange) {
std::string HelpFuncName =
isBlockRadixSort ? "group_radix_sort" : "exchange";
auto [TempStorage, DataTypeLoc] =
getTempstorageVarAndValueTypeLoc(BlockMC);
auto *FD = DpctGlobalInfo::findAncestor<FunctionDecl>(TempStorage);
if (!FD || !TempStorage || DataTypeLoc.isNull())
return;
QualType CanTy = BlockMC->getObjectType().getCanonicalType();
auto *ClassSpecDecl = dyn_cast<ClassTemplateSpecializationDecl>(
CanTy->getAs<RecordType>()->getDecl());
const auto &ValueTyArg = ClassSpecDecl->getTemplateArgs()[0];
const auto &ItemsPreThreadArg = ClassSpecDecl->getTemplateArgs()[2];
ValueTyArg.getAsType().getAsString();
std::string Fn;
llvm::raw_string_ostream OS(Fn);
OS << MapNames::getDpctNamespace() << "group::" << HelpFuncName << "<"
<< ValueTyArg.getAsType().getAsString() << ", "
<< ItemsPreThreadArg.getAsIntegral() << ">::get_local_memory_size";
if (auto FuncInfo = DeviceFunctionDecl::LinkRedecls(FD)) {
auto LocInfo = DpctGlobalInfo::getLocInfo(TempStorage);
ExprAnalysis EA;
EA.analyze(DataTypeLoc);
FuncInfo->getVarMap().addCUBTempStorage(
std::make_shared<TempStorageVarInfo>(
LocInfo.second, TempStorageVarInfo::BlockRadixSort,
TempStorage->getName(), Fn,
EA.getTemplateDependentStringInfo()));
}
}
} else if (const auto *E2 = getNodeAsType<MemberExpr>(Result, "memberExpr")) {
EA.analyze(E2);
}
Expand Down Expand Up @@ -1211,69 +1316,8 @@ void CubRule::processBlockLevelMemberCall(const CXXMemberCallExpr *BlockMC) {

NewFuncName = MapNames::getClNamespace() +
"ext::oneapi::experimental::reduce_over_group";
Expr *Obj = BlockMC->getImplicitObjectArgument();
const VarDecl *TempStorage = nullptr;

auto FindTempStorageVarInCtor = [&](const Expr *E) -> const VarDecl * {
if (auto *Ctor = dyn_cast<CXXConstructExpr>(E)) {
if (auto *DRE = dyn_cast<DeclRefExpr>(Ctor->getArg(0))) {
if (auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
if (VD->hasAttr<CUDASharedAttr>() && isCubVar(VD)) {
return VD;
}
}
}
}
return nullptr;
};

auto HandleTypeLoc = [&](TypeLoc Loc) -> TypeLoc {
if (Loc.isNull())
return Loc;
while (true) {
switch (Loc.getTypeLocClass()) {
case TypeLoc::Elaborated:
Loc = Loc.getNextTypeLoc();
break;
case TypeLoc::Typedef: {
auto NewLoc = Loc.castAs<TypedefTypeLoc>();
Loc = NewLoc.getTypedefNameDecl()
->getTypeSourceInfo()
->getTypeLoc();
break;
}
case TypeLoc::TemplateSpecialization: {
auto NewLoc = Loc.getAs<TemplateSpecializationTypeLoc>();
return NewLoc.getArgLocInfo(0).getAsTypeSourceInfo()->getTypeLoc();
break;
}
default:
return Loc;
}
}
};

TypeLoc DataTypeLoc;
if (const auto *MTE = dyn_cast<MaterializeTemporaryExpr>(Obj)) {
if (auto *TOE = dyn_cast<CXXTemporaryObjectExpr>(MTE->getSubExpr())) {
DataTypeLoc = HandleTypeLoc(TOE->getTypeSourceInfo()->getTypeLoc());
} else if (auto *FC =
dyn_cast<CXXFunctionalCastExpr>(MTE->getSubExpr())) {
DataTypeLoc =
HandleTypeLoc(FC->getTypeInfoAsWritten()->getTypeLoc());
}
TempStorage =
FindTempStorageVarInCtor(MTE->getSubExpr()->IgnoreCasts());
} else if (const auto *DRE = dyn_cast<DeclRefExpr>(Obj)) {
if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
DataTypeLoc = HandleTypeLoc(VD->getTypeSourceInfo()->getTypeLoc());
if (isCubCollectiveRecordType(VD->getType()) && VD->hasInit()) {
emplaceTransformation(new ReplaceVarDecl(VD, ""));
TempStorage = FindTempStorageVarInCtor(VD->getInit());
}
}
}

auto [TempStorage, DataTypeLoc] =
getTempstorageVarAndValueTypeLoc(BlockMC);
auto *FD = DpctGlobalInfo::findAncestor<FunctionDecl>(TempStorage);
if (!FD || !TempStorage || DataTypeLoc.isNull())
return;
Expand All @@ -1283,7 +1327,8 @@ void CubRule::processBlockLevelMemberCall(const CXXMemberCallExpr *BlockMC) {
EA.analyze(DataTypeLoc);
FuncInfo->getVarMap().addCUBTempStorage(
std::make_shared<TempStorageVarInfo>(
LocInfo.second, TempStorage->getName(),
LocInfo.second, TempStorageVarInfo::BlockReduce,
TempStorage->getName(), "",
EA.getTemplateDependentStringInfo()));
}
std::string Span = MapNames::getClNamespace() + "span<std::byte, 1>" +
Expand Down
7 changes: 5 additions & 2 deletions clang/lib/DPCT/CallExprRewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -1758,14 +1758,17 @@ createUserDefinedMethodRewriterFactory(
class CheckParamType {
unsigned Idx;
std::string TypeName;
bool isStrict;

public:
CheckParamType(unsigned I, std::string Name) : Idx(I), TypeName(Name) {}
CheckParamType(unsigned I, std::string Name, bool isStrict = false)
: Idx(I), TypeName(Name), isStrict(isStrict) {}
bool operator()(const CallExpr *C) {
std::string ParamType = getParamTypeStr(C, Idx);
if (ParamType.empty())
return true;
return ParamType.find(TypeName) != std::string::npos;
return isStrict ? ParamType == TypeName
: ParamType.find(TypeName) != std::string::npos;
}
};

Expand Down
Loading