diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp index 9f1616f6960fe..5f18c37ef1125 100644 --- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp +++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp @@ -29,20 +29,6 @@ static const int MaxVecSize = 4; using namespace llvm; -// Recursively creates an array-like version of a given vector type. -static Type *equivalentArrayTypeFromVector(Type *T) { - if (auto *VecTy = dyn_cast(T)) - return ArrayType::get(VecTy->getElementType(), - dyn_cast(VecTy)->getNumElements()); - if (auto *ArrayTy = dyn_cast(T)) { - Type *NewElementType = - equivalentArrayTypeFromVector(ArrayTy->getElementType()); - return ArrayType::get(NewElementType, ArrayTy->getNumElements()); - } - // If it's not a vector or array, return the original type. - return T; -} - class DXILDataScalarizationLegacy : public ModulePass { public: @@ -121,12 +107,25 @@ DataScalarizerVisitor::lookupReplacementGlobal(Value *CurrOperand) { static bool isVectorOrArrayOfVectors(Type *T) { if (isa(T)) return true; - if (ArrayType *ArrType = dyn_cast(T)) - return isa(ArrType->getElementType()) || - isVectorOrArrayOfVectors(ArrType->getElementType()); + if (ArrayType *ArrayTy = dyn_cast(T)) + return isVectorOrArrayOfVectors(ArrayTy->getElementType()); return false; } +// Recursively creates an array-like version of a given vector type. +static Type *equivalentArrayTypeFromVector(Type *T) { + if (auto *VecTy = dyn_cast(T)) + return ArrayType::get(VecTy->getElementType(), + dyn_cast(VecTy)->getNumElements()); + if (auto *ArrayTy = dyn_cast(T)) { + Type *NewElementType = + equivalentArrayTypeFromVector(ArrayTy->getElementType()); + return ArrayType::get(NewElementType, ArrayTy->getNumElements()); + } + // If it's not a vector or array, return the original type. + return T; +} + bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) { Type *AllocatedType = AI.getAllocatedType(); if (!isVectorOrArrayOfVectors(AllocatedType)) @@ -135,7 +134,7 @@ bool DataScalarizerVisitor::visitAllocaInst(AllocaInst &AI) { IRBuilder<> Builder(&AI); Type *NewType = equivalentArrayTypeFromVector(AllocatedType); AllocaInst *ArrAlloca = - Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarize"); + Builder.CreateAlloca(NewType, nullptr, AI.getName() + ".scalarized"); ArrAlloca->setAlignment(AI.getAlign()); AI.replaceAllUsesWith(ArrAlloca); AI.eraseFromParent(); @@ -303,78 +302,44 @@ bool DataScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) { bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { GEPOperator *GOp = cast(&GEPI); Value *PtrOperand = GOp->getPointerOperand(); - Type *NewGEPType = GOp->getSourceElementType(); - - // Unwrap GEP ConstantExprs to find the base operand and element type - while (auto *GEPCE = dyn_cast_or_null( - dyn_cast(PtrOperand))) { - GOp = GEPCE; - PtrOperand = GEPCE->getPointerOperand(); - NewGEPType = GEPCE->getSourceElementType(); - } - - Type *const OrigGEPType = NewGEPType; - Value *const OrigOperand = PtrOperand; - - if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) { - NewGEPType = NewGlobal->getValueType(); - PtrOperand = NewGlobal; - } else if (AllocaInst *Alloca = dyn_cast(PtrOperand)) { - Type *AllocatedType = Alloca->getAllocatedType(); - if (isa(AllocatedType) && - AllocatedType != GOp->getResultElementType()) - NewGEPType = AllocatedType; - } else - return false; // Only GEPs into an alloca or global variable are considered - - // Defer changing i8 GEP types until dxil-flatten-arrays - if (OrigGEPType->isIntegerTy(8)) - NewGEPType = OrigGEPType; - - // If the original type is a "sub-type" of the new type, then ensure the gep - // correctly zero-indexes the extra dimensions to keep the offset calculation - // correct. - // Eg: - // i32, [4 x i32] and [8 x [4 x i32]] are sub-types of [8 x [4 x i32]], etc. - // - // So then: - // gep [4 x i32] %idx - // -> gep [8 x [4 x i32]], i32 0, i32 %idx - // gep i32 %idx - // -> gep [8 x [4 x i32]], i32 0, i32 0, i32 %idx - uint32_t MissingDims = 0; - Type *SubType = NewGEPType; - - // The new type will be in its array version; so match accordingly. - Type *const GEPArrType = equivalentArrayTypeFromVector(OrigGEPType); - - while (SubType != GEPArrType) { - MissingDims++; - - ArrayType *ArrType = dyn_cast(SubType); - if (!ArrType) { - assert(SubType == GEPArrType && - "GEP uses an DXIL invalid sub-type of alloca/global variable"); - break; - } - - SubType = ArrType->getElementType(); + Type *GEPType = GOp->getSourceElementType(); + + // Replace a GEP ConstantExpr pointer operand with a GEP instruction so that + // it can be visited + if (auto *PtrOpGEPCE = dyn_cast(PtrOperand); + PtrOpGEPCE && PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) { + GetElementPtrInst *OldGEPI = + cast(PtrOpGEPCE->getAsInstruction()); + OldGEPI->insertBefore(GEPI.getIterator()); + + IRBuilder<> Builder(&GEPI); + SmallVector Indices(GEPI.indices()); + Value *NewGEP = + Builder.CreateGEP(GEPI.getSourceElementType(), OldGEPI, Indices, + GEPI.getName(), GEPI.getNoWrapFlags()); + assert(isa(NewGEP) && + "Expected newly-created GEP to be an instruction"); + GetElementPtrInst *NewGEPI = cast(NewGEP); + + GEPI.replaceAllUsesWith(NewGEPI); + GEPI.eraseFromParent(); + visitGetElementPtrInst(*OldGEPI); + visitGetElementPtrInst(*NewGEPI); + return true; } - bool NeedsTransform = OrigOperand != PtrOperand || - OrigGEPType != NewGEPType || MissingDims != 0; + Type *NewGEPType = equivalentArrayTypeFromVector(GEPType); + Value *NewPtrOperand = PtrOperand; + if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) + NewPtrOperand = NewGlobal; + bool NeedsTransform = NewPtrOperand != PtrOperand || NewGEPType != GEPType; if (!NeedsTransform) return false; IRBuilder<> Builder(&GEPI); - SmallVector Indices; - - for (uint32_t I = 0; I < MissingDims; I++) - Indices.push_back(Builder.getInt32(0)); - llvm::append_range(Indices, GOp->indices()); - - Value *NewGEP = Builder.CreateGEP(NewGEPType, PtrOperand, Indices, + SmallVector Indices(GOp->idx_begin(), GOp->idx_end()); + Value *NewGEP = Builder.CreateGEP(NewGEPType, NewPtrOperand, Indices, GOp->getName(), GOp->getNoWrapFlags()); GOp->replaceAllUsesWith(NewGEP); diff --git a/llvm/test/CodeGen/DirectX/bugfix_150050_data_scalarize_const_gep.ll b/llvm/test/CodeGen/DirectX/bugfix_150050_data_scalarize_const_gep.ll index 156a8e7c5c386..def886f933d08 100644 --- a/llvm/test/CodeGen/DirectX/bugfix_150050_data_scalarize_const_gep.ll +++ b/llvm/test/CodeGen/DirectX/bugfix_150050_data_scalarize_const_gep.ll @@ -11,9 +11,10 @@ define void @CSMain() { ; CHECK-NEXT: [[ENTRY:.*:]] ; CHECK-NEXT: [[AFRAGPACKED_I_SCALARIZE:%.*]] = alloca [4 x i32], align 16 ; -; SCHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds [10 x <4 x i32>], ptr addrspace(3) getelementptr inbounds ([10 x [10 x [4 x i32]]], ptr addrspace(3) @aTile.scalarized, i32 0, i32 1), i32 0, i32 2 -; SCHECK-NEXT: [[TMP1:%.*]] = load <4 x i32>, ptr addrspace(3) [[TMP0]], align 16 -; SCHECK-NEXT: store <4 x i32> [[TMP1]], ptr [[AFRAGPACKED_I_SCALARIZE]], align 16 +; SCHECK-NEXT: [[GEP0:%.*]] = getelementptr inbounds [10 x [10 x [4 x i32]]], ptr addrspace(3) @aTile.scalarized, i32 0, i32 1 +; SCHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds [10 x [4 x i32]], ptr addrspace(3) [[GEP0]], i32 0, i32 2 +; SCHECK-NEXT: [[LOAD:%.*]] = load <4 x i32>, ptr addrspace(3) [[GEP1]], align 16 +; SCHECK-NEXT: store <4 x i32> [[LOAD]], ptr [[AFRAGPACKED_I_SCALARIZE]], align 16 ; ; FCHECK-NEXT: [[AFRAGPACKED_I_SCALARIZE_I14:%.*]] = getelementptr [4 x i32], ptr [[AFRAGPACKED_I_SCALARIZE]], i32 0, i32 1 ; FCHECK-NEXT: [[AFRAGPACKED_I_SCALARIZE_I25:%.*]] = getelementptr [4 x i32], ptr [[AFRAGPACKED_I_SCALARIZE]], i32 0, i32 2 @@ -40,12 +41,13 @@ define void @Main() { ; CHECK-NEXT: [[ENTRY:.*:]] ; CHECK-NEXT: [[BFRAGPACKED_I:%.*]] = alloca i32, align 16 ; -; SCHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds [10 x i32], ptr addrspace(3) getelementptr inbounds ([10 x [10 x i32]], ptr addrspace(3) @bTile, i32 0, i32 1), i32 0, i32 1 -; SCHECK-NEXT: [[TMP1:%.*]] = load i32, ptr addrspace(3) [[TMP0]], align 16 -; SCHECK-NEXT: store i32 [[TMP1]], ptr [[BFRAGPACKED_I]], align 16 +; SCHECK-NEXT: [[GEP0:%.*]] = getelementptr inbounds [10 x [10 x i32]], ptr addrspace(3) @bTile, i32 0, i32 1 +; SCHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds [10 x i32], ptr addrspace(3) [[GEP0]], i32 0, i32 1 +; SCHECK-NEXT: [[LOAD:%.*]] = load i32, ptr addrspace(3) [[GEP1]], align 16 +; SCHECK-NEXT: store i32 [[LOAD]], ptr [[BFRAGPACKED_I]], align 16 ; -; FCHECK-NEXT: [[TMP0:%.*]] = load i32, ptr addrspace(3) getelementptr inbounds ([100 x i32], ptr addrspace(3) @bTile.1dim, i32 0, i32 11), align 16 -; FCHECK-NEXT: store i32 [[TMP0]], ptr [[BFRAGPACKED_I]], align 16 +; FCHECK-NEXT: [[LOAD:%.*]] = load i32, ptr addrspace(3) getelementptr inbounds ([100 x i32], ptr addrspace(3) @bTile.1dim, i32 0, i32 11), align 16 +; FCHECK-NEXT: store i32 [[LOAD]], ptr [[BFRAGPACKED_I]], align 16 ; ; CHECK-NEXT: ret void entry: @@ -57,10 +59,12 @@ entry: define void @global_nested_geps_3d() { ; CHECK-LABEL: define void @global_nested_geps_3d() { -; SCHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <2 x i32>, ptr getelementptr inbounds ([2 x <2 x i32>], ptr getelementptr inbounds ([2 x [2 x [2 x i32]]], ptr @cTile.scalarized, i32 0, i32 1), i32 0, i32 1), i32 0, i32 1 -; SCHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4 +; SCHECK-NEXT: [[GEP0:%.*]] = getelementptr inbounds [2 x [2 x [2 x i32]]], ptr @cTile.scalarized, i32 0, i32 1 +; SCHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds [2 x [2 x i32]], ptr [[GEP0]], i32 0, i32 1 +; SCHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds [2 x i32], ptr [[GEP1]], i32 0, i32 1 +; SCHECK-NEXT: load i32, ptr [[GEP2]], align 4 ; -; FCHECK-NEXT: [[TMP1:%.*]] = load i32, ptr getelementptr inbounds ([8 x i32], ptr @cTile.scalarized.1dim, i32 0, i32 7), align 4 +; FCHECK-NEXT: load i32, ptr getelementptr inbounds ([8 x i32], ptr @cTile.scalarized.1dim, i32 0, i32 7), align 4 ; ; CHECK-NEXT: ret void %1 = load i32, i32* getelementptr inbounds (<2 x i32>, <2 x i32>* getelementptr inbounds ([2 x <2 x i32>], [2 x <2 x i32>]* getelementptr inbounds ([2 x [2 x <2 x i32>]], [2 x [2 x <2 x i32>]]* @cTile, i32 0, i32 1), i32 0, i32 1), i32 0, i32 1), align 4 @@ -69,10 +73,13 @@ define void @global_nested_geps_3d() { define void @global_nested_geps_4d() { ; CHECK-LABEL: define void @global_nested_geps_4d() { -; SCHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds <2 x i32>, ptr getelementptr inbounds ([2 x <2 x i32>], ptr getelementptr inbounds ([2 x [2 x <2 x i32>]], ptr getelementptr inbounds ([2 x [2 x [2 x [2 x i32]]]], ptr @dTile.scalarized, i32 0, i32 1), i32 0, i32 1), i32 0, i32 1), i32 0, i32 1 -; SCHECK-NEXT: [[TMP2:%.*]] = load i32, ptr [[TMP1]], align 4 +; SCHECK-NEXT: [[GEP0:%.*]] = getelementptr inbounds [2 x [2 x [2 x [2 x i32]]]], ptr @dTile.scalarized, i32 0, i32 1 +; SCHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds [2 x [2 x [2 x i32]]], ptr [[GEP0]], i32 0, i32 1 +; SCHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds [2 x [2 x i32]], ptr [[GEP1]], i32 0, i32 1 +; SCHECK-NEXT: [[GEP3:%.*]] = getelementptr inbounds [2 x i32], ptr [[GEP2]], i32 0, i32 1 +; SCHECK-NEXT: load i32, ptr [[GEP3]], align 4 ; -; FCHECK-NEXT: [[TMP1:%.*]] = load i32, ptr getelementptr inbounds ([16 x i32], ptr @dTile.scalarized.1dim, i32 0, i32 15), align 4 +; FCHECK-NEXT: load i32, ptr getelementptr inbounds ([16 x i32], ptr @dTile.scalarized.1dim, i32 0, i32 15), align 4 ; ; CHECK-NEXT: ret void %1 = load i32, i32* getelementptr inbounds (<2 x i32>, <2 x i32>* getelementptr inbounds ([2 x <2 x i32>], [2 x <2 x i32>]* getelementptr inbounds ([2 x [2 x <2 x i32>]], [2 x [2 x <2 x i32>]]* getelementptr inbounds ([2 x [2 x [2 x <2 x i32>]]], [2 x [2 x [2 x <2 x i32>]]]* @dTile, i32 0, i32 1), i32 0, i32 1), i32 0, i32 1), i32 0, i32 1), align 4 diff --git a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll index 475935d2eb135..85e3bb0185e44 100644 --- a/llvm/test/CodeGen/DirectX/scalarize-alloca.ll +++ b/llvm/test/CodeGen/DirectX/scalarize-alloca.ll @@ -48,7 +48,7 @@ define void @subtype_array_test() { ; SCHECK: [[alloca_val:%.*]] = alloca [8 x [4 x i32]], align 4 ; FCHECK: [[alloca_val:%.*]] = alloca [32 x i32], align 4 ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0) - ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr [[alloca_val]], i32 0, i32 [[tid]] + ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [4 x i32], ptr [[alloca_val]], i32 [[tid]] ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 4 ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]] ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr [[alloca_val]], i32 0, i32 [[flatidx]] @@ -64,7 +64,7 @@ define void @subtype_vector_test() { ; SCHECK: [[alloca_val:%.*]] = alloca [8 x [4 x i32]], align 4 ; FCHECK: [[alloca_val:%.*]] = alloca [32 x i32], align 4 ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0) - ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr [[alloca_val]], i32 0, i32 [[tid]] + ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [4 x i32], ptr [[alloca_val]], i32 [[tid]] ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 4 ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]] ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr [[alloca_val]], i32 0, i32 [[flatidx]] @@ -80,7 +80,7 @@ define void @subtype_scalar_test() { ; SCHECK: [[alloca_val:%.*]] = alloca [8 x [4 x i32]], align 4 ; FCHECK: [[alloca_val:%.*]] = alloca [32 x i32], align 4 ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0) - ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr [[alloca_val]], i32 0, i32 0, i32 [[tid]] + ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw i32, ptr [[alloca_val]], i32 [[tid]] ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 1 ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]] ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr [[alloca_val]], i32 0, i32 [[flatidx]] diff --git a/llvm/test/CodeGen/DirectX/scalarize-global.ll b/llvm/test/CodeGen/DirectX/scalarize-global.ll index ca10f6ece5a85..c27dc4083bfd3 100644 --- a/llvm/test/CodeGen/DirectX/scalarize-global.ll +++ b/llvm/test/CodeGen/DirectX/scalarize-global.ll @@ -11,7 +11,7 @@ ; CHECK-LABEL: subtype_array_test define <4 x i32> @subtype_array_test() { ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0) - ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[tid]] + ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [4 x i32], ptr addrspace(3) [[arrayofVecData]], i32 [[tid]] ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 4 ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]] ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[flatidx]] @@ -26,7 +26,7 @@ define <4 x i32> @subtype_array_test() { ; CHECK-LABEL: subtype_vector_test define <4 x i32> @subtype_vector_test() { ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0) - ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[tid]] + ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [4 x i32], ptr addrspace(3) [[arrayofVecData]], i32 [[tid]] ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 4 ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]] ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[flatidx]] @@ -41,7 +41,7 @@ define <4 x i32> @subtype_vector_test() { ; CHECK-LABEL: subtype_scalar_test define <4 x i32> @subtype_scalar_test() { ; CHECK: [[tid:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0) - ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw [8 x [4 x i32]], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 0, i32 [[tid]] + ; SCHECK: [[gep:%.*]] = getelementptr inbounds nuw i32, ptr addrspace(3) [[arrayofVecData]], i32 [[tid]] ; FCHECK: [[flatidx_mul:%.*]] = mul i32 [[tid]], 1 ; FCHECK: [[flatidx:%.*]] = add i32 0, [[flatidx_mul]] ; FCHECK: [[gep:%.*]] = getelementptr inbounds nuw [32 x i32], ptr addrspace(3) [[arrayofVecData]], i32 0, i32 [[flatidx]]