Skip to content

Commit

Permalink
[WebAssembly] Skip implied bitmask operation in LowerShift
Browse files Browse the repository at this point in the history
This patch skips redundant explicit masks of the shift count since
it is implied inside wasm shift instruction.

Differential Revision: https://reviews.llvm.org/D144619
  • Loading branch information
junparser committed Mar 2, 2023
1 parent af2969f commit 403926a
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 25 deletions.
35 changes: 34 additions & 1 deletion llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
Expand Up @@ -2287,10 +2287,43 @@ SDValue WebAssemblyTargetLowering::LowerShift(SDValue Op,
// Only manually lower vector shifts
assert(Op.getSimpleValueType().isVector());

auto ShiftVal = DAG.getSplatValue(Op.getOperand(1));
uint64_t LaneBits = Op.getValueType().getScalarSizeInBits();
auto ShiftVal = Op.getOperand(1);

// Try to skip bitmask operation since it is implied inside shift instruction
auto SkipImpliedMask = [](SDValue MaskOp, uint64_t MaskBits) {
if (MaskOp.getOpcode() != ISD::AND)
return MaskOp;
SDValue LHS = MaskOp.getOperand(0);
SDValue RHS = MaskOp.getOperand(1);
if (MaskOp.getValueType().isVector()) {
APInt MaskVal;
if (!ISD::isConstantSplatVector(RHS.getNode(), MaskVal))
std::swap(LHS, RHS);

if (ISD::isConstantSplatVector(RHS.getNode(), MaskVal) &&
MaskVal == MaskBits)
MaskOp = LHS;
} else {
if (!isa<ConstantSDNode>(RHS.getNode()))
std::swap(LHS, RHS);

auto ConstantRHS = dyn_cast<ConstantSDNode>(RHS.getNode());
if (ConstantRHS && ConstantRHS->getAPIntValue() == MaskBits)
MaskOp = LHS;
}

return MaskOp;
};

// Skip vector and operation
ShiftVal = SkipImpliedMask(ShiftVal, LaneBits - 1);
ShiftVal = DAG.getSplatValue(ShiftVal);
if (!ShiftVal)
return unrollVectorShift(Op, DAG);

// Skip scalar and operation
ShiftVal = SkipImpliedMask(ShiftVal, LaneBits - 1);
// Use anyext because none of the high bits can affect the shift
ShiftVal = DAG.getAnyExtOrTrunc(ShiftVal, DL, MVT::i32);

Expand Down
71 changes: 47 additions & 24 deletions llvm/test/CodeGen/WebAssembly/masked-shifts.ll
Expand Up @@ -106,10 +106,6 @@ define <16 x i8> @shl_v16i8_late(<16 x i8> %v, i8 %x) {
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
; CHECK-NEXT: i8x16.splat
; CHECK-NEXT: v128.const 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
; CHECK-NEXT: v128.and
; CHECK-NEXT: i8x16.extract_lane_u 0
; CHECK-NEXT: i8x16.shl
; CHECK-NEXT: # fallthrough-return
%t = insertelement <16 x i8> undef, i8 %x, i32 0
Expand Down Expand Up @@ -145,10 +141,6 @@ define <16 x i8> @ashr_v16i8_late(<16 x i8> %v, i8 %x) {
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
; CHECK-NEXT: i8x16.splat
; CHECK-NEXT: v128.const 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
; CHECK-NEXT: v128.and
; CHECK-NEXT: i8x16.extract_lane_u 0
; CHECK-NEXT: i8x16.shr_s
; CHECK-NEXT: # fallthrough-return
%t = insertelement <16 x i8> undef, i8 %x, i32 0
Expand Down Expand Up @@ -184,10 +176,6 @@ define <16 x i8> @lshr_v16i8_late(<16 x i8> %v, i8 %x) {
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
; CHECK-NEXT: i8x16.splat
; CHECK-NEXT: v128.const 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
; CHECK-NEXT: v128.and
; CHECK-NEXT: i8x16.extract_lane_u 0
; CHECK-NEXT: i8x16.shr_u
; CHECK-NEXT: # fallthrough-return
%t = insertelement <16 x i8> undef, i8 %x, i32 0
Expand Down Expand Up @@ -222,10 +210,6 @@ define <8 x i16> @shl_v8i16_late(<8 x i16> %v, i16 %x) {
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
; CHECK-NEXT: i16x8.splat
; CHECK-NEXT: v128.const 15, 15, 15, 15, 15, 15, 15, 15
; CHECK-NEXT: v128.and
; CHECK-NEXT: i16x8.extract_lane_u 0
; CHECK-NEXT: i16x8.shl
; CHECK-NEXT: # fallthrough-return
%t = insertelement <8 x i16> undef, i16 %x, i32 0
Expand Down Expand Up @@ -259,10 +243,6 @@ define <8 x i16> @ashr_v8i16_late(<8 x i16> %v, i16 %x) {
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
; CHECK-NEXT: i16x8.splat
; CHECK-NEXT: v128.const 15, 15, 15, 15, 15, 15, 15, 15
; CHECK-NEXT: v128.and
; CHECK-NEXT: i16x8.extract_lane_u 0
; CHECK-NEXT: i16x8.shr_s
; CHECK-NEXT: # fallthrough-return
%t = insertelement <8 x i16> undef, i16 %x, i32 0
Expand Down Expand Up @@ -296,10 +276,6 @@ define <8 x i16> @lshr_v8i16_late(<8 x i16> %v, i16 %x) {
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
; CHECK-NEXT: i16x8.splat
; CHECK-NEXT: v128.const 15, 15, 15, 15, 15, 15, 15, 15
; CHECK-NEXT: v128.and
; CHECK-NEXT: i16x8.extract_lane_u 0
; CHECK-NEXT: i16x8.shr_u
; CHECK-NEXT: # fallthrough-return
%t = insertelement <8 x i16> undef, i16 %x, i32 0
Expand Down Expand Up @@ -519,6 +495,22 @@ define <2 x i64> @shl_v2i64_i32(<2 x i64> %v, i32 %x) {
ret <2 x i64> %a
}

define <2 x i64> @shl_v2i64_i32_late(<2 x i64> %v, i32 %x) {
; CHECK-LABEL: shl_v2i64_i32_late:
; CHECK: .functype shl_v2i64_i32_late (v128, i32) -> (v128)
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
; CHECK-NEXT: i64x2.shl
; CHECK-NEXT: # fallthrough-return
%z = zext i32 %x to i64
%t = insertelement <2 x i64> undef, i64 %z, i32 0
%s = shufflevector <2 x i64> %t, <2 x i64> undef, <2 x i32> <i32 0, i32 0>
%m = and <2 x i64> %s, <i64 63, i64 63>
%a = shl <2 x i64> %v, %m
ret <2 x i64> %a
}

define <2 x i64> @ashr_v2i64_i32(<2 x i64> %v, i32 %x) {
; CHECK-LABEL: ashr_v2i64_i32:
; CHECK: .functype ashr_v2i64_i32 (v128, i32) -> (v128)
Expand All @@ -535,6 +527,22 @@ define <2 x i64> @ashr_v2i64_i32(<2 x i64> %v, i32 %x) {
ret <2 x i64> %a
}

define <2 x i64> @ashr_v2i64_i32_late(<2 x i64> %v, i32 %x) {
; CHECK-LABEL: ashr_v2i64_i32_late:
; CHECK: .functype ashr_v2i64_i32_late (v128, i32) -> (v128)
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
; CHECK-NEXT: i64x2.shr_s
; CHECK-NEXT: # fallthrough-return
%z = zext i32 %x to i64
%t = insertelement <2 x i64> undef, i64 %z, i32 0
%s = shufflevector <2 x i64> %t, <2 x i64> undef, <2 x i32> <i32 0, i32 0>
%m = and <2 x i64> %s, <i64 63, i64 63>
%a = ashr <2 x i64> %v, %m
ret <2 x i64> %a
}

define <2 x i64> @lshr_v2i64_i32(<2 x i64> %v, i32 %x) {
; CHECK-LABEL: lshr_v2i64_i32:
; CHECK: .functype lshr_v2i64_i32 (v128, i32) -> (v128)
Expand All @@ -551,3 +559,18 @@ define <2 x i64> @lshr_v2i64_i32(<2 x i64> %v, i32 %x) {
ret <2 x i64> %a
}

define <2 x i64> @lshr_v2i64_i32_late(<2 x i64> %v, i32 %x) {
; CHECK-LABEL: lshr_v2i64_i32_late:
; CHECK: .functype lshr_v2i64_i32_late (v128, i32) -> (v128)
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
; CHECK-NEXT: i64x2.shr_u
; CHECK-NEXT: # fallthrough-return
%z = zext i32 %x to i64
%t = insertelement <2 x i64> undef, i64 %z, i32 0
%s = shufflevector <2 x i64> %t, <2 x i64> undef, <2 x i32> <i32 0, i32 0>
%m = and <2 x i64> %s, <i64 63, i64 63>
%a = lshr <2 x i64> %v, %m
ret <2 x i64> %a
}

0 comments on commit 403926a

Please sign in to comment.