Skip to content

Commit

Permalink
8286941: Add mask IR for partial vector operations for ARM SVE
Browse files Browse the repository at this point in the history
Reviewed-by: kvn, jbhateja, njian, ngasson
  • Loading branch information
Xiaohong Gong committed Jul 7, 2022
1 parent 569de45 commit a79ce4e
Show file tree
Hide file tree
Showing 19 changed files with 1,213 additions and 1,251 deletions.
6 changes: 6 additions & 0 deletions src/hotspot/cpu/aarch64/aarch64.ad
Original file line number Diff line number Diff line change
Expand Up @@ -2479,6 +2479,7 @@ const bool Matcher::match_rule_supported_vector(int opcode, int vlen, BasicType
return false;
}
break;
case Op_VectorMaskGen:
case Op_LoadVectorGather:
case Op_StoreVectorScatter:
case Op_CompressV:
Expand All @@ -2502,6 +2503,11 @@ const bool Matcher::match_rule_supported_vector_masked(int opcode, int vlen, Bas
masked_op_sve_supported(opcode, vlen, bt);
}

const bool Matcher::vector_needs_partial_operations(Node* node, const TypeVect* vt) {
// Only SVE has partial vector operations
return (UseSVE > 0) && partial_op_sve_needed(node, vt);
}

const RegMask* Matcher::predicate_reg_mask(void) {
return &_PR_REG_mask;
}
Expand Down
999 changes: 403 additions & 596 deletions src/hotspot/cpu/aarch64/aarch64_sve.ad

Large diffs are not rendered by default.

746 changes: 357 additions & 389 deletions src/hotspot/cpu/aarch64/aarch64_sve_ad.m4

Large diffs are not rendered by default.

27 changes: 17 additions & 10 deletions src/hotspot/cpu/aarch64/assembler_aarch64.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3635,20 +3635,27 @@ void sve_fcm(Condition cond, PRegister Pd, SIMD_RegVariant T,
INSN(sve_uzp2, 0b1); // Concatenate odd elements from two predicates
#undef INSN

// Predicate counted loop (SVE) (32-bit variants are not included)
#define INSN(NAME, decode) \
// SVE integer compare scalar count and limit
#define INSN(NAME, sf, op) \
void NAME(PRegister Pd, SIMD_RegVariant T, Register Rn, Register Rm) { \
starti; \
assert(T != Q, "invalid register variant"); \
f(0b00100101, 31, 24), f(T, 23, 22), f(1, 21), \
zrf(Rm, 16), f(0, 15, 13), f(1, 12), f(decode >> 1, 11, 10), \
zrf(Rn, 5), f(decode & 1, 4), prf(Pd, 0); \
}

INSN(sve_whilelt, 0b010); // While incrementing signed scalar less than scalar
INSN(sve_whilele, 0b011); // While incrementing signed scalar less than or equal to scalar
INSN(sve_whilelo, 0b110); // While incrementing unsigned scalar lower than scalar
INSN(sve_whilels, 0b111); // While incrementing unsigned scalar lower than or the same as scalar
zrf(Rm, 16), f(0, 15, 13), f(sf, 12), f(op >> 1, 11, 10), \
zrf(Rn, 5), f(op & 1, 4), prf(Pd, 0); \
}
// While incrementing signed scalar less than scalar
INSN(sve_whileltw, 0b0, 0b010);
INSN(sve_whilelt, 0b1, 0b010);
// While incrementing signed scalar less than or equal to scalar
INSN(sve_whilelew, 0b0, 0b011);
INSN(sve_whilele, 0b1, 0b011);
// While incrementing unsigned scalar lower than scalar
INSN(sve_whilelow, 0b0, 0b110);
INSN(sve_whilelo, 0b1, 0b110);
// While incrementing unsigned scalar lower than or the same as scalar
INSN(sve_whilelsw, 0b0, 0b111);
INSN(sve_whilels, 0b1, 0b111);
#undef INSN

// SVE predicate reverse
Expand Down
94 changes: 63 additions & 31 deletions src/hotspot/cpu/aarch64/c2_MacroAssembler_aarch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "asm/assembler.inline.hpp"
#include "opto/c2_MacroAssembler.hpp"
#include "opto/intrinsicnode.hpp"
#include "opto/matcher.hpp"
#include "opto/subnode.hpp"
#include "runtime/stubRoutines.hpp"

Expand Down Expand Up @@ -1338,39 +1339,70 @@ void C2_MacroAssembler::sve_reduce_integral(int opc, Register dst, BasicType bt,
}
}

// Set elements of the dst predicate to true if the element number is
// in the range of [0, lane_cnt), or to false otherwise.
void C2_MacroAssembler::sve_ptrue_lanecnt(PRegister dst, SIMD_RegVariant size, int lane_cnt) {
// Set elements of the dst predicate to true for lanes in the range of [0, lane_cnt), or
// to false otherwise. The input "lane_cnt" should be smaller than or equal to the supported
// max vector length of the basic type. Clobbers: rscratch1 and the rFlagsReg.
void C2_MacroAssembler::sve_gen_mask_imm(PRegister dst, BasicType bt, uint32_t lane_cnt) {
uint32_t max_vector_length = Matcher::max_vector_size(bt);
assert(lane_cnt <= max_vector_length, "unsupported input lane_cnt");

// Set all elements to false if the input "lane_cnt" is zero.
if (lane_cnt == 0) {
sve_pfalse(dst);
return;
}

SIMD_RegVariant size = elemType_to_regVariant(bt);
assert(size != Q, "invalid size");

// Set all true if "lane_cnt" equals to the max lane count.
if (lane_cnt == max_vector_length) {
sve_ptrue(dst, size, /* ALL */ 0b11111);
return;
}

// Fixed numbers for "ptrue".
switch(lane_cnt) {
case 1: /* VL1 */
case 2: /* VL2 */
case 3: /* VL3 */
case 4: /* VL4 */
case 5: /* VL5 */
case 6: /* VL6 */
case 7: /* VL7 */
case 8: /* VL8 */
sve_ptrue(dst, size, lane_cnt);
break;
case 16:
sve_ptrue(dst, size, /* VL16 */ 0b01001);
break;
case 32:
sve_ptrue(dst, size, /* VL32 */ 0b01010);
break;
case 64:
sve_ptrue(dst, size, /* VL64 */ 0b01011);
break;
case 128:
sve_ptrue(dst, size, /* VL128 */ 0b01100);
break;
case 256:
sve_ptrue(dst, size, /* VL256 */ 0b01101);
break;
default:
assert(false, "unsupported");
ShouldNotReachHere();
case 1: /* VL1 */
case 2: /* VL2 */
case 3: /* VL3 */
case 4: /* VL4 */
case 5: /* VL5 */
case 6: /* VL6 */
case 7: /* VL7 */
case 8: /* VL8 */
sve_ptrue(dst, size, lane_cnt);
return;
case 16:
sve_ptrue(dst, size, /* VL16 */ 0b01001);
return;
case 32:
sve_ptrue(dst, size, /* VL32 */ 0b01010);
return;
case 64:
sve_ptrue(dst, size, /* VL64 */ 0b01011);
return;
case 128:
sve_ptrue(dst, size, /* VL128 */ 0b01100);
return;
case 256:
sve_ptrue(dst, size, /* VL256 */ 0b01101);
return;
default:
break;
}

// Special patterns for "ptrue".
if (lane_cnt == round_down_power_of_2(max_vector_length)) {
sve_ptrue(dst, size, /* POW2 */ 0b00000);
} else if (lane_cnt == max_vector_length - (max_vector_length % 4)) {
sve_ptrue(dst, size, /* MUL4 */ 0b11101);
} else if (lane_cnt == max_vector_length - (max_vector_length % 3)) {
sve_ptrue(dst, size, /* MUL3 */ 0b11110);
} else {
// Encode to "whilelow" for the remaining cases.
mov(rscratch1, lane_cnt);
sve_whilelow(dst, size, zr, rscratch1);
}
}

Expand Down
8 changes: 5 additions & 3 deletions src/hotspot/cpu/aarch64/c2_MacroAssembler_aarch64.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,11 @@
void sve_reduce_integral(int opc, Register dst, BasicType bt, Register src1,
FloatRegister src2, PRegister pg, FloatRegister tmp);

// Set elements of the dst predicate to true if the element number is
// in the range of [0, lane_cnt), or to false otherwise.
void sve_ptrue_lanecnt(PRegister dst, SIMD_RegVariant size, int lane_cnt);
// Set elements of the dst predicate to true for lanes in the range of
// [0, lane_cnt), or to false otherwise. The input "lane_cnt" should be
// smaller than or equal to the supported max vector length of the basic
// type. Clobbers: rscratch1 and the rFlagsReg.
void sve_gen_mask_imm(PRegister dst, BasicType bt, uint32_t lane_cnt);

// Extract a scalar element from an sve vector at position 'idx'.
// The input elements in src are expected to be of integral type.
Expand Down
4 changes: 4 additions & 0 deletions src/hotspot/cpu/arm/arm.ad
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,10 @@ const bool Matcher::match_rule_supported_vector_masked(int opcode, int vlen, Bas
return false;
}

const bool Matcher::vector_needs_partial_operations(Node* node, const TypeVect* vt) {
return false;
}

const RegMask* Matcher::predicate_reg_mask(void) {
return NULL;
}
Expand Down
4 changes: 4 additions & 0 deletions src/hotspot/cpu/ppc/ppc.ad
Original file line number Diff line number Diff line change
Expand Up @@ -2180,6 +2180,10 @@ const bool Matcher::match_rule_supported_vector_masked(int opcode, int vlen, Bas
return false;
}

const bool Matcher::vector_needs_partial_operations(Node* node, const TypeVect* vt) {
return false;
}

const RegMask* Matcher::predicate_reg_mask(void) {
return NULL;
}
Expand Down
4 changes: 4 additions & 0 deletions src/hotspot/cpu/riscv/riscv.ad
Original file line number Diff line number Diff line change
Expand Up @@ -1833,6 +1833,10 @@ const bool Matcher::match_rule_supported_vector_masked(int opcode, int vlen, Bas
return false;
}

const bool Matcher::vector_needs_partial_operations(Node* node, const TypeVect* vt) {
return false;
}

const RegMask* Matcher::predicate_reg_mask(void) {
return NULL;
}
Expand Down
4 changes: 4 additions & 0 deletions src/hotspot/cpu/s390/s390.ad
Original file line number Diff line number Diff line change
Expand Up @@ -1520,6 +1520,10 @@ const bool Matcher::match_rule_supported_vector_masked(int opcode, int vlen, Bas
return false;
}

const bool Matcher::vector_needs_partial_operations(Node* node, const TypeVect* vt) {
return false;
}

const RegMask* Matcher::predicate_reg_mask(void) {
return NULL;
}
Expand Down
4 changes: 4 additions & 0 deletions src/hotspot/cpu/x86/x86.ad
Original file line number Diff line number Diff line change
Expand Up @@ -2164,6 +2164,10 @@ const bool Matcher::match_rule_supported_vector_masked(int opcode, int vlen, Bas
}
}

const bool Matcher::vector_needs_partial_operations(Node* node, const TypeVect* vt) {
return false;
}

MachOper* Matcher::pd_specialize_generic_vector_operand(MachOper* generic_opnd, uint ideal_reg, bool is_temp) {
assert(Matcher::is_generic_vector(generic_opnd), "not generic");
bool legacy = (generic_opnd->opcode() == LEGVEC);
Expand Down
1 change: 0 additions & 1 deletion src/hotspot/share/opto/matcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2252,7 +2252,6 @@ bool Matcher::find_shared_visit(MStack& mstack, Node* n, uint opcode, bool& mem_
case Op_FmaVD:
case Op_FmaVF:
case Op_MacroLogicV:
case Op_LoadVectorMasked:
case Op_VectorCmpMasked:
case Op_CompressV:
case Op_CompressM:
Expand Down
2 changes: 2 additions & 0 deletions src/hotspot/share/opto/matcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,8 @@ class Matcher : public PhaseTransform {

static const bool match_rule_supported_vector_masked(int opcode, int vlen, BasicType bt);

static const bool vector_needs_partial_operations(Node* node, const TypeVect* vt);

static const RegMask* predicate_reg_mask(void);
static const TypeVectMask* predicate_reg_type(const Type* elemTy, int length);

Expand Down
2 changes: 2 additions & 0 deletions src/hotspot/share/opto/memnode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ class LoadNode : public MemNode {
bool has_reinterpret_variant(const Type* rt);
Node* convert_to_reinterpret_load(PhaseGVN& gvn, const Type* rt);

ControlDependency control_dependency() {return _control_dependency; }

bool has_unknown_control_dependency() const { return _control_dependency == UnknownControl; }

#ifndef PRODUCT
Expand Down
6 changes: 4 additions & 2 deletions src/hotspot/share/opto/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -724,11 +724,13 @@ class Node {
DEFINE_CLASS_ID(Load, Mem, 0)
DEFINE_CLASS_ID(LoadVector, Load, 0)
DEFINE_CLASS_ID(LoadVectorGather, LoadVector, 0)
DEFINE_CLASS_ID(LoadVectorMasked, LoadVector, 1)
DEFINE_CLASS_ID(LoadVectorGatherMasked, LoadVector, 1)
DEFINE_CLASS_ID(LoadVectorMasked, LoadVector, 2)
DEFINE_CLASS_ID(Store, Mem, 1)
DEFINE_CLASS_ID(StoreVector, Store, 0)
DEFINE_CLASS_ID(StoreVectorScatter, StoreVector, 0)
DEFINE_CLASS_ID(StoreVectorMasked, StoreVector, 1)
DEFINE_CLASS_ID(StoreVectorScatterMasked, StoreVector, 1)
DEFINE_CLASS_ID(StoreVectorMasked, StoreVector, 2)
DEFINE_CLASS_ID(LoadStore, Mem, 2)
DEFINE_CLASS_ID(LoadStoreConditional, LoadStore, 0)
DEFINE_CLASS_ID(CompareAndSwap, LoadStoreConditional, 0)
Expand Down

1 comment on commit a79ce4e

@openjdk-notifier
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.