diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 9e54c5fee4b1..9c20125286c1 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -683,9 +683,6 @@ constructKernelName(Sema &S, FunctionDecl *KernelCallerFunc, // anonymous namespace so these don't get linkage. namespace { -QualType getItemType(const FieldDecl *FD) { return FD->getType(); } -QualType getItemType(const CXXBaseSpecifier &BS) { return BS.getType(); } - // These enable handler execution only when previous handlers succeed. template static bool handleField(FieldDecl *FD, QualType FDTy, Tn &&... tn) { @@ -729,11 +726,6 @@ template using bind_param_t = typename bind_param::type; // })...) // Implements the 'for-each-visitor' pattern. -template -static void VisitAccessorWrapper(CXXRecordDecl *Owner, ParentTy &Parent, - CXXRecordDecl *Wrapper, - Handlers &... handlers); - template static void VisitField(CXXRecordDecl *Owner, RangeTy &&Item, QualType ItemTy, Handlers &... handlers) { @@ -742,7 +734,7 @@ static void VisitField(CXXRecordDecl *Owner, RangeTy &&Item, QualType ItemTy, if (Util::isSyclStreamType(ItemTy)) KF_FOR_EACH(handleSyclStreamType, Item, ItemTy); if (ItemTy->isStructureOrClassType()) - VisitAccessorWrapper(Owner, Item, ItemTy->getAsCXXRecordDecl(), + VisitRecord(Owner, Item, ItemTy->getAsCXXRecordDecl(), handlers...); if (ItemTy->isArrayType()) VisitArrayElements(Item, ItemTy, handlers...); @@ -762,38 +754,68 @@ static void VisitArrayElements(RangeTy Item, QualType FieldTy, (void)std::initializer_list{(handlers.leaveArray(ET, ElemCount), 0)...}; } -template -static void VisitAccessorWrapperHelper(CXXRecordDecl *Owner, RangeTy Range, - Handlers &... handlers) { - for (const auto &Item : Range) { - QualType ItemTy = getItemType(Item); - (void)std::initializer_list{(handlers.enterField(Owner, Item), 0)...}; - VisitField(Owner, Item, ItemTy, handlers...); - (void)std::initializer_list{(handlers.leaveField(Owner, Item), 0)...}; +template +static void VisitRecord(CXXRecordDecl *Owner, ParentTy &Parent, + CXXRecordDecl *Wrapper, Handlers &... handlers); + +template +static void VisitRecordHelper(CXXRecordDecl *Owner, + clang::CXXRecordDecl::base_class_range Range, + Handlers &... handlers) { + for (const auto &Base : Range) { + QualType BaseTy = Base.getType(); + if (Util::isSyclAccessorType(BaseTy)) + (void)std::initializer_list{ + (handlers.handleSyclAccessorType(Base, BaseTy), 0)...}; + else if (Util::isSyclStreamType(BaseTy)) + (void)std::initializer_list{ + (handlers.handleSyclStreamType(Base, BaseTy), 0)...}; + else + VisitRecord(Owner, Base, BaseTy->getAsCXXRecordDecl(), handlers...); } } +template +static void VisitRecordHelper(CXXRecordDecl *Owner, + clang::RecordDecl::field_range Range, + Handlers &... handlers) { + VisitRecordFields(Owner, handlers...); +} + // Parent contains the FieldDecl or CXXBaseSpecifier that was used to enter // the Wrapper structure that we're currently visiting. Owner is the parent // type (which doesn't exist in cases where it is a FieldDecl in the // 'root'), and Wrapper is the current struct being unwrapped. template -static void VisitAccessorWrapper(CXXRecordDecl *Owner, ParentTy &Parent, - CXXRecordDecl *Wrapper, - Handlers &... handlers) { +static void VisitRecord(CXXRecordDecl *Owner, ParentTy &Parent, + CXXRecordDecl *Wrapper, Handlers &... handlers) { (void)std::initializer_list{(handlers.enterStruct(Owner, Parent), 0)...}; - VisitAccessorWrapperHelper(Wrapper, Wrapper->bases(), handlers...); - VisitAccessorWrapperHelper(Wrapper, Wrapper->fields(), handlers...); + VisitRecordHelper(Wrapper, Wrapper->bases(), handlers...); + VisitRecordHelper(Wrapper, Wrapper->fields(), handlers...); (void)std::initializer_list{(handlers.leaveStruct(Owner, Parent), 0)...}; } +int getFieldNumber(const CXXRecordDecl *BaseDecl) { + int Members = 0; + for (const auto *Field : BaseDecl->fields()) + ++Members; + + return Members; +} + +template +static void VisitFunctorBases(CXXRecordDecl *KernelFunctor, + Handlers &... handlers) { + VisitRecordHelper(KernelFunctor, KernelFunctor->bases(), handlers...); +} + + // A visitor function that dispatches to functions as defined in // SyclKernelFieldHandler for the purposes of kernel generation. template -static void VisitRecordFields(RecordDecl::field_range Fields, - Handlers &... handlers) { +static void VisitRecordFields(CXXRecordDecl *Owner, Handlers &... handlers) { - for (const auto Field : Fields) { + for (const auto Field : Owner->fields()) { (void)std::initializer_list{ (handlers.enterField(nullptr, Field), 0)...}; QualType FieldTy = Field->getType(); @@ -807,12 +829,12 @@ static void VisitRecordFields(RecordDecl::field_range Fields, else if (Util::isSyclStreamType(FieldTy)) { // Stream actually wraps accessors, so do recursion CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl(); - VisitAccessorWrapper(nullptr, Field, RD, handlers...); + VisitRecord(nullptr, Field, RD, handlers...); KF_FOR_EACH(handleSyclStreamType, Field, FieldTy); } else if (FieldTy->isStructureOrClassType()) { if (KF_FOR_EACH(handleStructType, Field, FieldTy)) { CXXRecordDecl *RD = FieldTy->getAsCXXRecordDecl(); - VisitAccessorWrapper(nullptr, Field, RD, handlers...); + VisitRecord(nullptr, Field, RD, handlers...); } } else if (FieldTy->isReferenceType()) KF_FOR_EACH(handleReferenceType, Field, FieldTy); @@ -821,7 +843,7 @@ static void VisitRecordFields(RecordDecl::field_range Fields, else if (FieldTy->isArrayType()) { if (KF_FOR_EACH(handleArrayType, Field, FieldTy)) VisitArrayElements(Field, FieldTy, handlers...); - } else if (FieldTy->isScalarType()) + } else if (FieldTy->isScalarType() || FieldTy->isVectorType()) KF_FOR_EACH(handleScalarType, Field, FieldTy); else KF_FOR_EACH(handleOtherType, Field, FieldTy); @@ -1131,7 +1153,7 @@ class SyclKernelDeclCreator } bool handleStructType(FieldDecl *FD, QualType FieldTy) final { - addParam(FD, FieldTy); + // addParam(FD, FieldTy); return true; } @@ -1277,7 +1299,10 @@ class SyclKernelBodyCreator VK_LValue, SourceLocation()); } - MemberExpr *SpecialObjME = BuildMemberExpr(Base, Field); + Expr *SpecialObjME = Base; + if (Field) + SpecialObjME = BuildMemberExpr(Base, Field); + MemberExpr *MethodME = BuildMemberExpr(SpecialObjME, Method); QualType ResultTy = Method->getReturnType(); @@ -1312,22 +1337,39 @@ class SyclKernelBodyCreator bool handleSpecialType(FieldDecl *FD, QualType Ty) { const auto *RecordDecl = Ty->getAsCXXRecordDecl(); - // Perform initialization only if it is field of kernel object - if (MemberExprBases.size() == 1) { - InitializedEntity Entity = - InitializedEntity::InitializeMember(FD, &VarEntity); - // Initialize with the default constructor. - InitializationKind InitKind = - InitializationKind::CreateDefault(SourceLocation()); - InitializationSequence InitSeq(SemaRef, Entity, InitKind, None); - ExprResult MemberInit = InitSeq.Perform(SemaRef, Entity, InitKind, None); - InitExprs.push_back(MemberInit.get()); - } + // TODO: VarEntity is initialized entity for KernelObjClone, I guess we need + // to create new one when enter new struct. + InitializedEntity Entity = + InitializedEntity::InitializeMember(FD, &VarEntity); + // Initialize with the default constructor. + InitializationKind InitKind = + InitializationKind::CreateDefault(SourceLocation()); + InitializationSequence InitSeq(SemaRef, Entity, InitKind, None); + ExprResult MemberInit = InitSeq.Perform(SemaRef, Entity, InitKind, None); + InitExprs.push_back(MemberInit.get()); createSpecialMethodCall(RecordDecl, MemberExprBases.back(), InitMethodName, FD); return true; } + bool handleSpecialType(const CXXBaseSpecifier &BS, QualType Ty) { + const auto *RecordDecl = Ty->getAsCXXRecordDecl(); + // TODO: VarEntity is initialized entity for KernelObjClone, I guess we need + // to create new one when enter new struct. + InitializedEntity Entity = InitializedEntity::InitializeBase( + SemaRef.Context, &BS, /*IsInheritedVirtualBase*/ false, &VarEntity); + // Initialize with the default constructor. + InitializationKind InitKind = + InitializationKind::CreateDefault(SourceLocation()); + InitializationSequence InitSeq(SemaRef, Entity, InitKind, None); + ExprResult MemberInit = InitSeq.Perform(SemaRef, Entity, InitKind, None); + InitExprs.push_back(MemberInit.get()); + + createSpecialMethodCall(RecordDecl, MemberExprBases.back(), InitMethodName, + nullptr); + return true; + } + public: SyclKernelBodyCreator(Sema &S, SyclKernelDeclCreator &DC, CXXRecordDecl *KernelObj, @@ -1359,9 +1401,7 @@ class SyclKernelBodyCreator } bool handleSyclAccessorType(const CXXBaseSpecifier &BS, QualType Ty) final { - // FIXME SYCL accessor should be usable as a base type - // See https://github.com/intel/llvm/issues/28. - return true; + return handleSpecialType(BS, Ty); } bool handleSyclSamplerType(FieldDecl *FD, QualType Ty) final { @@ -1390,7 +1430,7 @@ class SyclKernelBodyCreator } bool handleStructType(FieldDecl *FD, QualType FieldTy) final { - createExprForStructOrScalar(FD); + // createExprForStructOrScalar(FD); return true; } @@ -1403,12 +1443,51 @@ class SyclKernelBodyCreator MemberExprBases.push_back(BuildMemberExpr(MemberExprBases.back(), FD)); } - void leaveStruct(const CXXRecordDecl *, FieldDecl *FD) final { + void enterStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS) final { + CXXCastPath BasePath; + QualType DerivedTy(RD->getTypeForDecl(), 0); + QualType BaseTy = BS.getType(); + SemaRef.CheckDerivedToBaseConversion(DerivedTy, BaseTy, SourceLocation(), + SourceRange(), &BasePath, + /*IgnoreBaseAccess*/ true); + auto Cast = ImplicitCastExpr::Create( + SemaRef.Context, BaseTy, CK_DerivedToBase, MemberExprBases.back(), + /* CXXCastPath=*/&BasePath, VK_LValue); + MemberExprBases.push_back(Cast); + } + + void addStructInit(const CXXRecordDecl *RD){ + if (!RD) + return; + + int NumberOfFields = getFieldNumber(RD); + int popOut = NumberOfFields + RD->getNumBases(); + llvm::SmallVector BaseInitExprs; + for (int I = 0; I < popOut; I++) { + BaseInitExprs.push_back(InitExprs.back()); + InitExprs.pop_back(); + } + std::reverse(BaseInitExprs.begin(), BaseInitExprs.end()); + + Expr *ILE = new (SemaRef.getASTContext()) + InitListExpr(SemaRef.getASTContext(), SourceLocation(), BaseInitExprs, + SourceLocation()); + ILE->setType(QualType(RD->getTypeForDecl(), 0)); + InitExprs.push_back(ILE); + MemberExprBases.pop_back(); } - using SyclKernelFieldHandler::enterStruct; - using SyclKernelFieldHandler::leaveStruct; + void leaveStruct(const CXXRecordDecl *, FieldDecl *FD) final { + const CXXRecordDecl *RD = FD->getType()->getAsCXXRecordDecl(); + addStructInit(RD); + } + + void leaveStruct(const CXXRecordDecl *RD, const CXXBaseSpecifier &BS) final { + const CXXRecordDecl *BaseClass = BS.getType()->getAsCXXRecordDecl(); + addStructInit(BaseClass); + } + }; class SyclKernelIntHeaderCreator @@ -1512,7 +1591,7 @@ class SyclKernelIntHeaderCreator return true; } bool handleStructType(FieldDecl *FD, QualType FieldTy) final { - addParam(FD, FieldTy, SYCLIntegrationHeader::kind_std_layout); + // addParam(FD, FieldTy, SYCLIntegrationHeader::kind_std_layout); return true; } bool handleScalarType(FieldDecl *FD, QualType FieldTy) final { @@ -1606,7 +1685,9 @@ void Sema::ConstructOpenCLKernel(FunctionDecl *KernelCallerFunc, StableName); ConstructingOpenCLKernel = true; - VisitRecordFields(KernelLambda->fields(), checker, kernel_decl, kernel_body, + VisitFunctorBases(KernelLambda, checker, kernel_decl, kernel_body, + int_header); + VisitRecordFields(KernelLambda, checker, kernel_decl, kernel_body, int_header); ConstructingOpenCLKernel = false; } diff --git a/clang/test/CodeGenSYCL/integration_header.cpp b/clang/test/CodeGenSYCL/integration_header.cpp index 58d0c3addcd8..1c766b2dccd3 100644 --- a/clang/test/CodeGenSYCL/integration_header.cpp +++ b/clang/test/CodeGenSYCL/integration_header.cpp @@ -1,4 +1,4 @@ -// RUN: %clang_cc1 -I %S/Inputs -fsycl -fsycl-is-device -triple spir64-unknown-unknown-sycldevice -fsycl-int-header=%t.h %s -fsyntax-only +// RUN: %clang_cc1 -I %S/Inputs -fsycl -fsycl-is-device -triple spir64-unknown-unknown-sycldevice -fsycl-int-header=%t.h %s -emit-llvm // RUN: FileCheck -input-file=%t.h %s // // CHECK: #include @@ -28,9 +28,11 @@ // CHECK-NEXT: const kernel_param_desc_t kernel_signatures[] = { // CHECK-NEXT: //--- _ZTSZ4mainE12first_kernel // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 4 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 6112, 16 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_sampler, 8, 32 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 1, 4 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 8 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 12 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 6112, 24 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_sampler, 8, 40 }, // CHECK-EMPTY: // CHECK-NEXT: //--- _ZTSN16second_namespace13second_kernelIcEE // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, @@ -46,12 +48,15 @@ // CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, // CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 6112, 4 }, // CHECK-EMPTY: -// CHECK-NEXT: //--- _ZTSZ4mainE16accessor_in_base -// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 64, 0 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 8 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 24 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 40 }, -// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 52 }, +// CHECK-NEXT: //--- _ZTSZ4mainE16accessor_in_base +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 0 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 4 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 8 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 20 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 24 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_std_layout, 4, 36 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 40 }, +// CHECK-NEXT: { kernel_param_kind_t::kind_accessor, 4062, 52 }, // CHECK-EMPTY: // CHECK-NEXT: }; // @@ -116,15 +121,13 @@ int main() { acc2; int i = 13; cl::sycl::sampler smplr; - // TODO: Uncomemnt when structures in kernel arguments are correctly processed - // by SYCL compiler - /* struct { + struct { char c; int i; } test_s; - test_s.c = 14;*/ + test_s.c = 14; kernel_single_task([=]() { - if (i == 13 /*&& test_s.c == 14*/) { + if (i == 13 && test_s.c == 14) { acc1.use(); acc2.use(); @@ -151,10 +154,9 @@ int main() { } }); - // FIXME: We cannot use the member-capture because all the handlers except the - // integration header handler in SemaSYCL don't handle base types right. accessor_in_base::captured c; - kernel_single_task([c]() { + kernel_single_task([=]() { + c.use(); }); return 0;