Skip to content

Commit

Permalink
[InstCombine] Fold for masked scatters to a uniform address
Browse files Browse the repository at this point in the history
When masked scatter intrinsic does a uniform store to a destination
address from a source vector, and in this case, the mask is all one value.
This patch replaces the masked scatter with an extracted element of the
last lane of the source vector and stores it in the destination vector.
This patch also folds when the value in the masked scatter is a splat.
In this case, the mask cannot be all zero, and it folds to a scalar store
of the value in the destination pointer.

Differential Revision: https://reviews.llvm.org/D115724
  • Loading branch information
CarolineConcatto committed Jan 14, 2022
1 parent 20d9c51 commit 8e5a5b6
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 1 deletion.
29 changes: 28 additions & 1 deletion llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Expand Up @@ -362,7 +362,6 @@ Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) {
// * Single constant active lane -> store
// * Adjacent vector addresses -> masked.store
// * Narrow store width by halfs excluding zero/undef lanes
// * Vector splat address w/known mask -> scalar store
// * Vector incrementing address -> vector masked store
Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
Expand All @@ -373,6 +372,34 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
if (ConstMask->isNullValue())
return eraseInstFromFunction(II);

// Vector splat address -> scalar store
if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) {
// scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
StoreInst *S =
new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment);
S->copyMetadata(II);
return S;
}
// scatter(vector, splat(ptr), splat(true)) -> store extract(vector,
// lastlane), ptr
if (ConstMask->isAllOnesValue()) {
Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
VectorType *WideLoadTy = cast<VectorType>(II.getArgOperand(1)->getType());
ElementCount VF = WideLoadTy->getElementCount();
Constant *EC =
ConstantInt::get(Builder.getInt32Ty(), VF.getKnownMinValue());
Value *RunTimeVF = VF.isScalable() ? Builder.CreateVScale(EC) : EC;
Value *LastLane = Builder.CreateSub(RunTimeVF, Builder.getInt32(1));
Value *Extract =
Builder.CreateExtractElement(II.getArgOperand(0), LastLane);
StoreInst *S =
new StoreInst(Extract, SplatPtr, /*IsVolatile=*/false, Alignment);
S->copyMetadata(II);
return S;
}
}
if (isa<ScalableVectorType>(ConstMask->getType()))
return nullptr;

Expand Down
107 changes: 107 additions & 0 deletions llvm/test/Transforms/InstCombine/masked_intrinsics.ll
Expand Up @@ -269,3 +269,110 @@ define void @scatter_demandedelts(double* %ptr, double %val) {
call void @llvm.masked.scatter.v2f64.v2p0f64(<2 x double> %valvec2, <2 x double*> %ptrs, i32 8, <2 x i1> <i1 true, i1 false>)
ret void
}


; Test scatters that can be simplified to scalar stores.

;; Value splat (mask is not used)
define void @scatter_v4i16_uniform_vals_uniform_ptrs_no_all_active_mask(i16* %dst, i16 %val) {
; CHECK-LABEL: @scatter_v4i16_uniform_vals_uniform_ptrs_no_all_active_mask(
; CHECK-NEXT: entry:
; CHECK-NEXT: store i16 [[VAL:%.*]], i16* [[DST:%.*]], align 2
; CHECK-NEXT: ret void
;
entry:
%broadcast.splatinsert = insertelement <4 x i16*> poison, i16* %dst, i32 0
%broadcast.splat = shufflevector <4 x i16*> %broadcast.splatinsert, <4 x i16*> poison, <4 x i32> zeroinitializer
%broadcast.value = insertelement <4 x i16> poison, i16 %val, i32 0
%broadcast.splatvalue = shufflevector <4 x i16> %broadcast.value, <4 x i16> poison, <4 x i32> zeroinitializer
call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %broadcast.splatvalue, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 0, i1 0, i1 1, i1 1>)
ret void
}

define void @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, i16 %val) {
; CHECK-LABEL: @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(
; CHECK-NEXT: entry:
; CHECK-NEXT: store i16 [[VAL:%.*]], i16* [[DST:%.*]], align 2
; CHECK-NEXT: ret void
;
entry:
%broadcast.splatinsert = insertelement <vscale x 4 x i16*> poison, i16* %dst, i32 0
%broadcast.splat = shufflevector <vscale x 4 x i16*> %broadcast.splatinsert, <vscale x 4 x i16*> poison, <vscale x 4 x i32> zeroinitializer
%broadcast.value = insertelement <vscale x 4 x i16> poison, i16 %val, i32 0
%broadcast.splatvalue = shufflevector <vscale x 4 x i16> %broadcast.value, <vscale x 4 x i16> poison, <vscale x 4 x i32> zeroinitializer
call void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16> %broadcast.splatvalue, <vscale x 4 x i16*> %broadcast.splat, i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> zeroinitializer , i1 true, i32 0), <vscale x 4 x i1> zeroinitializer, <vscale x 4 x i32> zeroinitializer))
ret void
}

;; The pointer is splat and mask is all active, but value is not a splat
define void @scatter_v4i16_no_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, <4 x i16>* %src) {
; CHECK-LABEL: @scatter_v4i16_no_uniform_vals_uniform_ptrs_all_active_mask(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i64 3
; CHECK-NEXT: store i16 [[TMP0]], i16* [[DST:%.*]], align 2
; CHECK-NEXT: ret void
;
entry:
%broadcast.splatinsert = insertelement <4 x i16*> poison, i16* %dst, i32 0
%broadcast.splat = shufflevector <4 x i16*> %broadcast.splatinsert, <4 x i16*> poison, <4 x i32> zeroinitializer
%wide.load = load <4 x i16>, <4 x i16>* %src, align 2
call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 1, i1 1, i1 1, i1 1>)
ret void
}

define void @scatter_nxv4i16_no_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, <vscale x 4 x i16>* %src) {
; CHECK-LABEL: @scatter_nxv4i16_no_uniform_vals_uniform_ptrs_all_active_mask(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x i16>, <vscale x 4 x i16>* [[SRC:%.*]], align 2
; CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.vscale.i32()
; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[TMP0]], 2
; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[TMP1]], -1
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <vscale x 4 x i16> [[WIDE_LOAD]], i32 [[TMP2]]
; CHECK-NEXT: store i16 [[TMP3]], i16* [[DST:%.*]], align 2
; CHECK-NEXT: ret void
;
entry:
%broadcast.splatinsert = insertelement <vscale x 4 x i16*> poison, i16* %dst, i32 0
%broadcast.splat = shufflevector <vscale x 4 x i16*> %broadcast.splatinsert, <vscale x 4 x i16*> poison, <vscale x 4 x i32> zeroinitializer
%wide.load = load <vscale x 4 x i16>, <vscale x 4 x i16>* %src, align 2
call void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16> %wide.load, <vscale x 4 x i16*> %broadcast.splat, i32 2, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i32 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
ret void
}

; Negative scatter tests

;; Pointer is splat, but mask is not all active and value is not a splat
define void @negative_scatter_v4i16_no_uniform_vals_uniform_ptrs_all_inactive_mask(i16* %dst, <4 x i16>* %src) {
; CHECK-LABEL: @negative_scatter_v4i16_no_uniform_vals_uniform_ptrs_all_inactive_mask(
; CHECK-NEXT: [[INSERT_ELT:%.*]] = insertelement <4 x i16*> poison, i16* [[DST:%.*]], i64 0
; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i16*> [[INSERT_ELT]], <4 x i16*> poison, <4 x i32> <i32 undef, i32 undef, i32 0, i32 0>
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
; CHECK-NEXT: call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> [[WIDE_LOAD]], <4 x i16*> [[BROADCAST_SPLAT]], i32 2, <4 x i1> <i1 false, i1 false, i1 true, i1 true>)
; CHECK-NEXT: ret void
;
%insert.elt = insertelement <4 x i16*> poison, i16* %dst, i32 0
%broadcast.splat = shufflevector <4 x i16*> %insert.elt, <4 x i16*> poison, <4 x i32> zeroinitializer
%wide.load = load <4 x i16>, <4 x i16>* %src, align 2
call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> <i1 0, i1 0, i1 1, i1 1>)
ret void
}

;; The pointer in NOT a splat
define void @negative_scatter_v4i16_no_uniform_vals_no_uniform_ptrs_all_active_mask(<4 x i16*> %inPtr, <4 x i16>* %src) {
; CHECK-LABEL: @negative_scatter_v4i16_no_uniform_vals_no_uniform_ptrs_all_active_mask(
; CHECK-NEXT: [[BROADCAST:%.*]] = shufflevector <4 x i16*> [[INPTR:%.*]], <4 x i16*> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2
; CHECK-NEXT: call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> [[WIDE_LOAD]], <4 x i16*> [[BROADCAST]], i32 2, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
; CHECK-NEXT: ret void
;
%broadcast= shufflevector <4 x i16*> %inPtr, <4 x i16*> poison, <4 x i32> zeroinitializer
%wide.load = load <4 x i16>, <4 x i16>* %src, align 2
call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast, i32 2, <4 x i1> <i1 1, i1 1, i1 1, i1 1> )
ret void
}


; Function Attrs:
declare void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16>, <4 x i16*>, i32 immarg, <4 x i1>)
declare void @llvm.masked.scatter.nxv4i16.nxv4p0i16(<vscale x 4 x i16>, <vscale x 4 x i16*>, i32 immarg, <vscale x 4 x i1>)

0 comments on commit 8e5a5b6

Please sign in to comment.