diff --git a/clang/lib/DPCT/ExprAnalysis.cpp b/clang/lib/DPCT/ExprAnalysis.cpp index 48721647c44e..fb96835bbf36 100644 --- a/clang/lib/DPCT/ExprAnalysis.cpp +++ b/clang/lib/DPCT/ExprAnalysis.cpp @@ -505,9 +505,33 @@ void ExprAnalysis::analyzeExpr(const MemberExpr *ME) { auto ItFieldRule = MapNames::ClassFieldMap.find(BaseType + "." + FieldName); if (ItFieldRule != MapNames::ClassFieldMap.end()) { - addReplacement(ME->getMemberLoc(), ME->getMemberLoc(), - ItFieldRule->second->NewName); - return; + if (ItFieldRule->second->GetterName == "") { + addReplacement(ME->getMemberLoc(), ME->getMemberLoc(), + ItFieldRule->second->NewName); + return; + } else { + if (auto BO = DpctGlobalInfo::findAncestor(ME)) { + if (BO->getOpcode() == BinaryOperatorKind::BO_Assign && + ME == BO->getLHS()) { + ExprAnalysis EA; + EA.analyze(BO->getRHS()); + std::string RHSStr = EA.getReplacedString(); + addReplacement(ME->getMemberLoc(), ME->getMemberLoc(), + ItFieldRule->second->SetterName + "(" + RHSStr + ")"); + auto SpellingLocInfo = getSpellingOffsetAndLength( + BO->getOperatorLoc(), BO->getOperatorLoc()); + addExtReplacement(std::make_shared( + SM, SpellingLocInfo.first, SpellingLocInfo.second, "", nullptr)); + SpellingLocInfo = getSpellingOffsetAndLength( + BO->getRHS()->getBeginLoc(), BO->getRHS()->getEndLoc()); + addExtReplacement(std::make_shared( + SM, SpellingLocInfo.first, SpellingLocInfo.second, "", nullptr)); + } + } else { + addReplacement(ME->getMemberLoc(), ME->getMemberLoc(), + ItFieldRule->second->GetterName + "()"); + } + } } static MapNames::MapTy NdItemMemberMap{{"__fetch_builtin_x", "2"}, diff --git a/clang/lib/DPCT/Rules.cpp b/clang/lib/DPCT/Rules.cpp index a08b1cfd00bb..35e67bef1e87 100644 --- a/clang/lib/DPCT/Rules.cpp +++ b/clang/lib/DPCT/Rules.cpp @@ -104,7 +104,15 @@ void registerClassRule(MetaRuleObject &R) { auto ItFieldRule = MapNames::ClassFieldMap.find(BaseAndFieldName); if (ItFieldRule != MapNames::ClassFieldMap.end()) { if (ItFieldRule->second->Priority > R.Priority) { - ItFieldRule->second->NewName = (*ItField)->Out; + if((*ItField)->OutGetter != ""){ + ItFieldRule->second->SetterName = (*ItField)->OutSetter; + ItFieldRule->second->GetterName = (*ItField)->OutGetter; + ItFieldRule->second->NewName = ""; + } else { + ItFieldRule->second->SetterName = ""; + ItFieldRule->second->GetterName = ""; + ItFieldRule->second->NewName = (*ItField)->Out; + } ItFieldRule->second->Priority = R.Priority; ItFieldRule->second->RequestFeature = clang::dpct::HelperFeatureEnum::no_feature_helper; @@ -119,9 +127,16 @@ void registerClassRule(MetaRuleObject &R) { return new clang::dpct::UserDefinedClassFieldRule(R.In, (*ItField)->In); }); - auto RulePtr = std::make_shared( - (*ItField)->Out, clang::dpct::HelperFeatureEnum::no_feature_helper, - R.Priority); + std::shared_ptr RulePtr; + if ((*ItField)->OutGetter != "") { + RulePtr = std::make_shared( + (*ItField)->OutSetter, (*ItField)->OutGetter, + clang::dpct::HelperFeatureEnum::no_feature_helper, R.Priority); + } else { + RulePtr = std::make_shared( + (*ItField)->Out, clang::dpct::HelperFeatureEnum::no_feature_helper, + R.Priority); + } RulePtr->Includes.insert(RulePtr->Includes.end(), R.Includes.begin(), R.Includes.end()); MapNames::ClassFieldMap.emplace(BaseAndFieldName, RulePtr); diff --git a/clang/lib/DPCT/Rules.h b/clang/lib/DPCT/Rules.h index 75a4b0bef116..de46d1aec5a1 100644 --- a/clang/lib/DPCT/Rules.h +++ b/clang/lib/DPCT/Rules.h @@ -32,10 +32,17 @@ struct TypeNameRule { }; struct ClassFieldRule : public TypeNameRule { + std::string SetterName; + std::string GetterName; ClassFieldRule(std::string Name) : TypeNameRule(Name) {} ClassFieldRule(std::string Name, clang::dpct::HelperFeatureEnum Feature, RulePriority Priority = RulePriority::Fallback) : TypeNameRule(Name, Feature) {} + ClassFieldRule(std::string SetterName, std::string GetterName, + clang::dpct::HelperFeatureEnum Feature, + RulePriority Priority = RulePriority::Fallback) + : TypeNameRule(SetterName, Feature), SetterName(SetterName), + GetterName(GetterName) {} }; // Record all information of imported rules diff --git a/clang/test/dpct/user_defined_rule.cu b/clang/test/dpct/user_defined_rule.cu index 38e0739475cf..f78b149d1a0b 100644 --- a/clang/test/dpct/user_defined_rule.cu +++ b/clang/test/dpct/user_defined_rule.cu @@ -41,6 +41,7 @@ __forceinline__ __global__ void foo(){ class ClassA{ public: int fieldA; + int fieldC; int methodA(int i, int j){return 0;}; }; class ClassB{ @@ -64,9 +65,13 @@ void foo2(){ CUstream_st *cu_st; //CHECK: ClassB a; - //CHECK-NEXT: a.fieldB = 3; + //CHECK-NEXT: a.fieldD = 3; //CHECK-NEXT: a.methodB(2); + //CHECK-NEXT: a.set_a(3); + //CHECK-NEXT: int k = a.get_a(); ClassA a; - a.fieldA = 3; + a.fieldC = 3; a.methodA(1,2); + a.fieldA = 3; + int k = a.fieldA; } \ No newline at end of file diff --git a/clang/test/dpct/user_defined_rule.yaml b/clang/test/dpct/user_defined_rule.yaml index e5fa834add09..a4eb26f6ee68 100644 --- a/clang/test/dpct/user_defined_rule.yaml +++ b/clang/test/dpct/user_defined_rule.yaml @@ -49,7 +49,8 @@ Includes: [] Fields: - In: fieldA - Out: fieldB + OutGetter: get_a + OutSetter: set_a - In: fieldC Out: fieldD Methods: