Skip to content

Commit

Permalink
[ESIMD] ESIMDOptimizeVecArgCallConv: allow more IR patterns. (#6919)
Browse files Browse the repository at this point in the history
Allow all-zero GEPs in optimized ptr param use-def chains.

Signed-off-by: Konstantin S Bobrovsky <konstantin.s.bobrovsky@intel.com>
  • Loading branch information
kbobrovs committed Sep 30, 2022
1 parent 53d9c7b commit 4926454
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 6 deletions.
18 changes: 18 additions & 0 deletions llvm/include/llvm/SYCLLowerIR/ESIMD/ESIMDUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,34 @@ inline void assert_and_diag(bool Condition, StringRef Msg,
/// Tells if this value is a bit cast or address space cast.
bool isCast(const Value *V);

/// Tells if this value is a GEP instructions with all zero indices.
bool isZeroGEP(const Value *V);

/// Climbs up the use-def chain of given value until a value which is not a
/// bit cast or address space cast is met.
const Value *stripCasts(const Value *V);
Value *stripCasts(Value *V);

/// Climbs up the use-def chain of given value until a value is met which is
/// neither of:
/// - bit cast
/// - address space cast
/// - GEP instruction with all zero indices
const Value *stripCastsAndZeroGEPs(const Value *V);
Value *stripCastsAndZeroGEPs(Value *V);

/// Collects uses of given value "looking through" casts. I.e. if a use is a
/// cast (chain), then uses of the result of the cast (chain) are collected.
void collectUsesLookThroughCasts(const Value *V,
SmallPtrSetImpl<const Use *> &Uses);

/// Collects uses of given pointer-typed value "looking through" casts and GEPs
/// with all zero indices - those pointer transformation instructions which
/// don't change pointed-to value. E.g. if a use is a cast (chain), then uses of
/// the result of the cast (chain) are collected.
void collectUsesLookThroughCastsAndZeroGEPs(const Value *V,
SmallPtrSetImpl<const Use *> &Uses);

/// Unwraps a presumably simd* type to extract the native vector type encoded
/// in it. Returns nullptr if failed to do so.
Type *getVectorTyOrNull(StructType *STy);
Expand Down
12 changes: 6 additions & 6 deletions llvm/lib/SYCLLowerIR/ESIMD/ESIMDOptimizeVecArgCallConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ getMemTypeIfSameAddressLoadsStores(SmallPtrSetImpl<const Use *> &Uses,
if (Uses.size() == 0) {
return nullptr;
}
Value *Addr = esimd::stripCasts((*Uses.begin())->get());
Value *Addr = esimd::stripCastsAndZeroGEPs((*Uses.begin())->get());

for (const auto *UU : Uses) {
const User *U = UU->getUser();
Expand All @@ -92,7 +92,7 @@ getMemTypeIfSameAddressLoadsStores(SmallPtrSetImpl<const Use *> &Uses,
}

if (const auto *SI = dyn_cast<StoreInst>(U)) {
if (esimd::stripCasts(SI->getPointerOperand()) != Addr) {
if (esimd::stripCastsAndZeroGEPs(SI->getPointerOperand()) != Addr) {
// the pointer escapes into memory
return nullptr;
}
Expand Down Expand Up @@ -167,7 +167,7 @@ Type *getPointedToTypeIfOptimizeable(const Argument &FormalParam) {
// }
{
SmallPtrSet<const Use *, 4> Uses;
esimd::collectUsesLookThroughCasts(&FormalParam, Uses);
esimd::collectUsesLookThroughCastsAndZeroGEPs(&FormalParam, Uses);
bool LoadMet = 0;
bool StoreMet = 0;
ContentT = getMemTypeIfSameAddressLoadsStores(Uses, LoadMet, StoreMet);
Expand Down Expand Up @@ -225,14 +225,14 @@ Type *getPointedToTypeIfOptimizeable(const Argument &FormalParam) {
if (!Call || (Call->getCalledFunction() != F)) {
return nullptr;
}
auto ArgNo = FormalParam.getArgNo();
Value *ActualParam = esimd::stripCasts(Call->getArgOperand(ArgNo));
Value *ActualParam = esimd::stripCastsAndZeroGEPs(
Call->getArgOperand(FormalParam.getArgNo()));

if (!IsSret && !isa<AllocaInst>(ActualParam)) {
return nullptr;
}
SmallPtrSet<const Use *, 4> Uses;
esimd::collectUsesLookThroughCasts(ActualParam, Uses);
esimd::collectUsesLookThroughCastsAndZeroGEPs(ActualParam, Uses);
bool LoadMet = 0;
bool StoreMet = 0;

Expand Down
44 changes: 44 additions & 0 deletions llvm/lib/SYCLLowerIR/ESIMD/ESIMDUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ bool isCast(const Value *V) {
return (Opc == Instruction::BitCast) || (Opc == Instruction::AddrSpaceCast);
}

bool isZeroGEP(const Value *V) {
const auto *GEPI = dyn_cast<GetElementPtrInst>(V);
return GEPI && GEPI->hasAllZeroIndices();
}

const Value *stripCasts(const Value *V) {
if (!V->getType()->isPtrOrPtrVectorTy())
return V;
Expand All @@ -110,6 +115,30 @@ Value *stripCasts(Value *V) {
return const_cast<Value *>(stripCasts(const_cast<const Value *>(V)));
}

const Value *stripCastsAndZeroGEPs(const Value *V) {
if (!V->getType()->isPtrOrPtrVectorTy())
return V;
// Even though we don't look through PHI nodes, we could be called on an
// instruction in an unreachable block, which may be on a cycle.
SmallPtrSet<const Value *, 4> Visited;
Visited.insert(V);

do {
if (isCast(V)) {
V = cast<Operator>(V)->getOperand(0);
} else if (isZeroGEP(V)) {
V = cast<GetElementPtrInst>(V)->getOperand(0);
}
assert(V->getType()->isPtrOrPtrVectorTy() && "Unexpected operand type!");
} while (Visited.insert(V).second);
return V;
}

Value *stripCastsAndZeroGEPs(Value *V) {
return const_cast<Value *>(
stripCastsAndZeroGEPs(const_cast<const Value *>(V)));
}

void collectUsesLookThroughCasts(const Value *V,
SmallPtrSetImpl<const Use *> &Uses) {
for (const Use &U : V->uses()) {
Expand All @@ -123,6 +152,21 @@ void collectUsesLookThroughCasts(const Value *V,
}
}

void collectUsesLookThroughCastsAndZeroGEPs(
const Value *V, SmallPtrSetImpl<const Use *> &Uses) {
assert(V->getType()->isPtrOrPtrVectorTy() && "pointer type expected");

for (const Use &U : V->uses()) {
Value *VV = U.getUser();

if (isCast(VV) || isZeroGEP(VV)) {
collectUsesLookThroughCastsAndZeroGEPs(VV, Uses);
} else {
Uses.insert(&U);
}
}
}

Type *getVectorTyOrNull(StructType *STy) {
Type *Res = nullptr;
while (STy && (STy->getStructNumElements() == 1)) {
Expand Down
32 changes: 32 additions & 0 deletions llvm/test/SYCLLowerIR/ESIMD/vec_arg_call_conv.ll
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,38 @@ entry:
ret void
}

;----- Test4: IR contains all-zero GEP instructions in parameter use-def chains
; Based on Test2.
define dso_local spir_func void @_Z23callee__sret__x_param_x1(ptr addrspace(4) noalias sret(%"class.sycl::_V1::ext::intel::esimd::simd.2") align 32 %agg.result, i32 noundef %i, ptr noundef %x, i32 noundef %j) local_unnamed_addr #3 !sycl_explicit_simd !8 !intel_reqd_sub_group_size !9 {
; CHECK: define dso_local spir_func <8 x i32> @_Z23callee__sret__x_param_x1(i32 noundef %{{.*}}, <8 x i32> %{{.*}}, i32 noundef %{{.*}})
entry:
%x.ascast = addrspacecast ptr %x to ptr addrspace(4)
%add = add nsw i32 %i, %j
%splat.splatinsert.i.i.i = insertelement <8 x i32> poison, i32 %add, i64 0
%splat.splat.i.i.i = shufflevector <8 x i32> %splat.splatinsert.i.i.i, <8 x i32> poison, <8 x i32> zeroinitializer
%M_data.i.i.i = getelementptr inbounds %"class.sycl::_V1::ext::intel::esimd::detail::simd_obj_impl.3", ptr addrspace(4) %x.ascast, i64 0, i32 0
%call.i.i.i1 = load <8 x i32>, ptr addrspace(4) %M_data.i.i.i, align 32
%add.i.i.i.i.i = add <8 x i32> %call.i.i.i1, %splat.splat.i.i.i
store <8 x i32> %add.i.i.i.i.i, ptr addrspace(4) %agg.result, align 32
ret void
}

;----- Test4 caller.
; Function Attrs: convergent noinline norecurse
define dso_local spir_func void @_Z21test__sret__x_param_x1(ptr addrspace(4) noalias sret(%"class.sycl::_V1::ext::intel::esimd::simd.2") align 32 %agg.result, ptr noundef %x) local_unnamed_addr #3 !sycl_explicit_simd !8 !intel_reqd_sub_group_size !9 {
; CHECK: define dso_local spir_func <8 x i32> @_Z21test__sret__x_param_x1(<8 x i32> %{{.*}})
entry:
%agg.tmp = alloca %"class.sycl::_V1::ext::intel::esimd::simd.2", align 32
%agg.tmp.ascast = addrspacecast ptr %agg.tmp to ptr addrspace(4)
%x.ascast = addrspacecast ptr %x to ptr addrspace(4)
%M_data.i.i.i = getelementptr inbounds %"class.sycl::_V1::ext::intel::esimd::detail::simd_obj_impl.3", ptr addrspace(4) %x.ascast, i64 0, i32 0
%call.i.i.i1 = load <8 x i32>, ptr addrspace(4) %M_data.i.i.i, align 32
store <8 x i32> %call.i.i.i1, ptr addrspace(4) %agg.tmp.ascast, align 32
call spir_func void @_Z23callee__sret__x_param_x1(ptr addrspace(4) sret(%"class.sycl::_V1::ext::intel::esimd::simd.2") align 32 %agg.result, i32 noundef 2, ptr noundef nonnull %agg.tmp, i32 noundef 1) #7
; CHECK: %{{.*}} = call spir_func <8 x i32> @_Z23callee__sret__x_param_x1(i32 2, <8 x i32> %{{.*}}, i32 1)
ret void
}

attributes #0 = { convergent noinline norecurse "frame-pointer"="all" "min-legal-vector-width"="512" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="../opaque_ptr.cpp" }
attributes #1 = { alwaysinline convergent "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #2 = { convergent noinline norecurse "frame-pointer"="all" "min-legal-vector-width"="12288" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="../opaque_ptr.cpp" }
Expand Down

0 comments on commit 4926454

Please sign in to comment.