Skip to content

Commit

Permalink
8292898: [vectorapi] Unify vector mask cast operation
Browse files Browse the repository at this point in the history
Co-authored-by: Quan Anh Mai <qamai@openjdk.org>
Reviewed-by: jbhateja, eliu
  • Loading branch information
Xiaohong Gong and Quan Anh Mai committed Oct 12, 2022
1 parent 2ceb80c commit ab8c136
Show file tree
Hide file tree
Showing 10 changed files with 455 additions and 307 deletions.
50 changes: 46 additions & 4 deletions src/hotspot/cpu/aarch64/aarch64_vector.ad
Original file line number Diff line number Diff line change
Expand Up @@ -5087,6 +5087,48 @@ instruct vmaskcast_same_esize_neon(vReg dst_src) %{
ins_pipe(pipe_class_empty);
%}

instruct vmaskcast_extend_neon(vReg dst, vReg src) %{
predicate(UseSVE == 0 &&
Matcher::vector_length_in_bytes(n) > Matcher::vector_length_in_bytes(n->in(1)));
match(Set dst (VectorMaskCast src));
format %{ "vmaskcast_extend_neon $dst, $src" %}
ins_encode %{
BasicType dst_bt = Matcher::vector_element_basic_type(this);
if (is_floating_point_type(dst_bt)) {
dst_bt = (dst_bt == T_FLOAT) ? T_INT : T_LONG;
}
uint length_in_bytes_dst = Matcher::vector_length_in_bytes(this);
BasicType src_bt = Matcher::vector_element_basic_type(this, $src);
if (is_floating_point_type(src_bt)) {
src_bt = (src_bt == T_FLOAT) ? T_INT : T_LONG;
}
__ neon_vector_extend($dst$$FloatRegister, dst_bt, length_in_bytes_dst,
$src$$FloatRegister, src_bt);
%}
ins_pipe(pipe_slow);
%}

instruct vmaskcast_narrow_neon(vReg dst, vReg src) %{
predicate(UseSVE == 0 &&
Matcher::vector_length_in_bytes(n) < Matcher::vector_length_in_bytes(n->in(1)));
match(Set dst (VectorMaskCast src));
format %{ "vmaskcast_narrow_neon $dst, $src" %}
ins_encode %{
BasicType dst_bt = Matcher::vector_element_basic_type(this);
if (is_floating_point_type(dst_bt)) {
dst_bt = (dst_bt == T_FLOAT) ? T_INT : T_LONG;
}
BasicType src_bt = Matcher::vector_element_basic_type(this, $src);
if (is_floating_point_type(src_bt)) {
src_bt = (src_bt == T_FLOAT) ? T_INT : T_LONG;
}
uint length_in_bytes_src = Matcher::vector_length_in_bytes(this, $src);
__ neon_vector_narrow($dst$$FloatRegister, dst_bt,
$src$$FloatRegister, src_bt, length_in_bytes_src);
%}
ins_pipe(pipe_slow);
%}

instruct vmaskcast_same_esize_sve(pReg dst_src) %{
predicate(UseSVE > 0 &&
Matcher::vector_length_in_bytes(n) == Matcher::vector_length_in_bytes(n->in(1)));
Expand All @@ -5097,11 +5139,11 @@ instruct vmaskcast_same_esize_sve(pReg dst_src) %{
ins_pipe(pipe_class_empty);
%}

instruct vmaskcast_extend(pReg dst, pReg src) %{
instruct vmaskcast_extend_sve(pReg dst, pReg src) %{
predicate(UseSVE > 0 &&
Matcher::vector_length_in_bytes(n) > Matcher::vector_length_in_bytes(n->in(1)));
match(Set dst (VectorMaskCast src));
format %{ "vmaskcast_extend $dst, $src" %}
format %{ "vmaskcast_extend_sve $dst, $src" %}
ins_encode %{
uint length_in_bytes_dst = Matcher::vector_length_in_bytes(this);
uint length_in_bytes_src = Matcher::vector_length_in_bytes(this, $src);
Expand All @@ -5114,11 +5156,11 @@ instruct vmaskcast_extend(pReg dst, pReg src) %{
ins_pipe(pipe_slow);
%}

instruct vmaskcast_narrow(pReg dst, pReg src) %{
instruct vmaskcast_narrow_sve(pReg dst, pReg src) %{
predicate(UseSVE > 0 &&
Matcher::vector_length_in_bytes(n) < Matcher::vector_length_in_bytes(n->in(1)));
match(Set dst (VectorMaskCast src));
format %{ "vmaskcast_narrow $dst, $src" %}
format %{ "vmaskcast_narrow_sve $dst, $src" %}
ins_encode %{
uint length_in_bytes_dst = Matcher::vector_length_in_bytes(this);
uint length_in_bytes_src = Matcher::vector_length_in_bytes(this, $src);
Expand Down
50 changes: 46 additions & 4 deletions src/hotspot/cpu/aarch64/aarch64_vector_ad.m4
Original file line number Diff line number Diff line change
Expand Up @@ -3503,6 +3503,48 @@ instruct vmaskcast_same_esize_neon(vReg dst_src) %{
ins_pipe(pipe_class_empty);
%}

instruct vmaskcast_extend_neon(vReg dst, vReg src) %{
predicate(UseSVE == 0 &&
Matcher::vector_length_in_bytes(n) > Matcher::vector_length_in_bytes(n->in(1)));
match(Set dst (VectorMaskCast src));
format %{ "vmaskcast_extend_neon $dst, $src" %}
ins_encode %{
BasicType dst_bt = Matcher::vector_element_basic_type(this);
if (is_floating_point_type(dst_bt)) {
dst_bt = (dst_bt == T_FLOAT) ? T_INT : T_LONG;
}
uint length_in_bytes_dst = Matcher::vector_length_in_bytes(this);
BasicType src_bt = Matcher::vector_element_basic_type(this, $src);
if (is_floating_point_type(src_bt)) {
src_bt = (src_bt == T_FLOAT) ? T_INT : T_LONG;
}
__ neon_vector_extend($dst$$FloatRegister, dst_bt, length_in_bytes_dst,
$src$$FloatRegister, src_bt);
%}
ins_pipe(pipe_slow);
%}

instruct vmaskcast_narrow_neon(vReg dst, vReg src) %{
predicate(UseSVE == 0 &&
Matcher::vector_length_in_bytes(n) < Matcher::vector_length_in_bytes(n->in(1)));
match(Set dst (VectorMaskCast src));
format %{ "vmaskcast_narrow_neon $dst, $src" %}
ins_encode %{
BasicType dst_bt = Matcher::vector_element_basic_type(this);
if (is_floating_point_type(dst_bt)) {
dst_bt = (dst_bt == T_FLOAT) ? T_INT : T_LONG;
}
BasicType src_bt = Matcher::vector_element_basic_type(this, $src);
if (is_floating_point_type(src_bt)) {
src_bt = (src_bt == T_FLOAT) ? T_INT : T_LONG;
}
uint length_in_bytes_src = Matcher::vector_length_in_bytes(this, $src);
__ neon_vector_narrow($dst$$FloatRegister, dst_bt,
$src$$FloatRegister, src_bt, length_in_bytes_src);
%}
ins_pipe(pipe_slow);
%}

instruct vmaskcast_same_esize_sve(pReg dst_src) %{
predicate(UseSVE > 0 &&
Matcher::vector_length_in_bytes(n) == Matcher::vector_length_in_bytes(n->in(1)));
Expand All @@ -3513,11 +3555,11 @@ instruct vmaskcast_same_esize_sve(pReg dst_src) %{
ins_pipe(pipe_class_empty);
%}

instruct vmaskcast_extend(pReg dst, pReg src) %{
instruct vmaskcast_extend_sve(pReg dst, pReg src) %{
predicate(UseSVE > 0 &&
Matcher::vector_length_in_bytes(n) > Matcher::vector_length_in_bytes(n->in(1)));
match(Set dst (VectorMaskCast src));
format %{ "vmaskcast_extend $dst, $src" %}
format %{ "vmaskcast_extend_sve $dst, $src" %}
ins_encode %{
uint length_in_bytes_dst = Matcher::vector_length_in_bytes(this);
uint length_in_bytes_src = Matcher::vector_length_in_bytes(this, $src);
Expand All @@ -3530,11 +3572,11 @@ instruct vmaskcast_extend(pReg dst, pReg src) %{
ins_pipe(pipe_slow);
%}

instruct vmaskcast_narrow(pReg dst, pReg src) %{
instruct vmaskcast_narrow_sve(pReg dst, pReg src) %{
predicate(UseSVE > 0 &&
Matcher::vector_length_in_bytes(n) < Matcher::vector_length_in_bytes(n->in(1)));
match(Set dst (VectorMaskCast src));
format %{ "vmaskcast_narrow $dst, $src" %}
format %{ "vmaskcast_narrow_sve $dst, $src" %}
ins_encode %{
uint length_in_bytes_dst = Matcher::vector_length_in_bytes(this);
uint length_in_bytes_src = Matcher::vector_length_in_bytes(this, $src);
Expand Down
55 changes: 55 additions & 0 deletions src/hotspot/cpu/x86/c2_MacroAssembler_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4750,6 +4750,61 @@ void C2_MacroAssembler::vector_unsigned_cast(XMMRegister dst, XMMRegister src, i
}
}

void C2_MacroAssembler::vector_mask_cast(XMMRegister dst, XMMRegister src,
BasicType dst_bt, BasicType src_bt, int vlen) {
int vlen_enc = vector_length_encoding(MAX2(type2aelembytes(src_bt), type2aelembytes(dst_bt)) * vlen);
assert(vlen_enc != AVX_512bit, "");

int dst_bt_size = type2aelembytes(dst_bt);
int src_bt_size = type2aelembytes(src_bt);
if (dst_bt_size > src_bt_size) {
switch (dst_bt_size / src_bt_size) {
case 2: vpmovsxbw(dst, src, vlen_enc); break;
case 4: vpmovsxbd(dst, src, vlen_enc); break;
case 8: vpmovsxbq(dst, src, vlen_enc); break;
default: ShouldNotReachHere();
}
} else {
assert(dst_bt_size < src_bt_size, "");
switch (src_bt_size / dst_bt_size) {
case 2: {
if (vlen_enc == AVX_128bit) {
vpacksswb(dst, src, src, vlen_enc);
} else {
vpacksswb(dst, src, src, vlen_enc);
vpermq(dst, dst, 0x08, vlen_enc);
}
break;
}
case 4: {
if (vlen_enc == AVX_128bit) {
vpackssdw(dst, src, src, vlen_enc);
vpacksswb(dst, dst, dst, vlen_enc);
} else {
vpackssdw(dst, src, src, vlen_enc);
vpermq(dst, dst, 0x08, vlen_enc);
vpacksswb(dst, dst, dst, AVX_128bit);
}
break;
}
case 8: {
if (vlen_enc == AVX_128bit) {
vpshufd(dst, src, 0x08, vlen_enc);
vpackssdw(dst, dst, dst, vlen_enc);
vpacksswb(dst, dst, dst, vlen_enc);
} else {
vpshufd(dst, src, 0x08, vlen_enc);
vpermq(dst, dst, 0x08, vlen_enc);
vpackssdw(dst, dst, dst, AVX_128bit);
vpacksswb(dst, dst, dst, AVX_128bit);
}
break;
}
default: ShouldNotReachHere();
}
}
}

void C2_MacroAssembler::evpternlog(XMMRegister dst, int func, KRegister mask, XMMRegister src2, XMMRegister src3,
bool merge, BasicType bt, int vlen_enc) {
if (bt == T_INT) {
Expand Down
2 changes: 2 additions & 0 deletions src/hotspot/cpu/x86/c2_MacroAssembler_x86.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@
void vector_crosslane_doubleword_pack_avx(XMMRegister dst, XMMRegister src, XMMRegister zero,
XMMRegister xtmp, int index, int vec_enc);

void vector_mask_cast(XMMRegister dst, XMMRegister src, BasicType dst_bt, BasicType src_bt, int vlen);

#ifdef _LP64
void vector_round_double_evex(XMMRegister dst, XMMRegister src, AddressLiteral double_sign_flip, AddressLiteral new_mxcsr, int vec_enc,
Register tmp, XMMRegister xtmp1, XMMRegister xtmp2, KRegister ktmp1, KRegister ktmp2);
Expand Down
19 changes: 16 additions & 3 deletions src/hotspot/cpu/x86/x86.ad
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,7 @@ const bool Matcher::match_rule_supported(int opcode) {
case Op_VectorUCastB2X:
case Op_VectorUCastS2X:
case Op_VectorUCastI2X:
case Op_VectorMaskCast:
if (UseAVX < 1) { // enabled for AVX only
return false;
}
Expand Down Expand Up @@ -1857,6 +1858,7 @@ const bool Matcher::match_rule_supported_vector(int opcode, int vlen, BasicType
}
break;
case Op_VectorLoadMask:
case Op_VectorMaskCast:
if (size_in_bits == 256 && UseAVX < 2) {
return false; // Implementation limitation
}
Expand Down Expand Up @@ -8413,7 +8415,6 @@ instruct vstoreMask_evex(vec dst, kReg mask, immI size) %{
%}

instruct vmaskcast_evex(kReg dst) %{
predicate(Matcher::vector_length(n) == Matcher::vector_length(n->in(1)));
match(Set dst (VectorMaskCast dst));
ins_cost(0);
format %{ "vector_mask_cast $dst" %}
Expand All @@ -8424,8 +8425,7 @@ instruct vmaskcast_evex(kReg dst) %{
%}

instruct vmaskcast(vec dst) %{
predicate((Matcher::vector_length(n) == Matcher::vector_length(n->in(1))) &&
(Matcher::vector_length_in_bytes(n) == Matcher::vector_length_in_bytes(n->in(1))));
predicate(Matcher::vector_length_in_bytes(n) == Matcher::vector_length_in_bytes(n->in(1)));
match(Set dst (VectorMaskCast dst));
ins_cost(0);
format %{ "vector_mask_cast $dst" %}
Expand All @@ -8435,6 +8435,19 @@ instruct vmaskcast(vec dst) %{
ins_pipe(empty);
%}

instruct vmaskcast_avx(vec dst, vec src) %{
predicate(Matcher::vector_length_in_bytes(n) != Matcher::vector_length_in_bytes(n->in(1)));
match(Set dst (VectorMaskCast src));
format %{ "vector_mask_cast $dst, $src" %}
ins_encode %{
int vlen = Matcher::vector_length(this);
BasicType src_bt = Matcher::vector_element_basic_type(this, $src);
BasicType dst_bt = Matcher::vector_element_basic_type(this);
__ vector_mask_cast($dst$$XMMRegister, $src$$XMMRegister, dst_bt, src_bt, vlen);
%}
ins_pipe(pipe_slow);
%}

//-------------------------------- Load Iota Indices ----------------------------------

instruct loadIotaIndices(vec dst, immI_0 src) %{
Expand Down
34 changes: 14 additions & 20 deletions src/hotspot/share/opto/vectorIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2488,24 +2488,15 @@ bool LibraryCallKit::inline_vector_convert() {

Node* op = opd1;
if (is_cast) {
BasicType new_elem_bt_to = elem_bt_to;
BasicType new_elem_bt_from = elem_bt_from;
if (is_mask && is_floating_point_type(elem_bt_from)) {
new_elem_bt_from = elem_bt_from == T_FLOAT ? T_INT : T_LONG;
}
if (is_mask && is_floating_point_type(elem_bt_to)) {
new_elem_bt_to = elem_bt_to == T_FLOAT ? T_INT : T_LONG;
}
int cast_vopc = VectorCastNode::opcode(new_elem_bt_from, !is_ucast);
assert(!is_mask || num_elem_from == num_elem_to, "vector mask cast needs the same elem num");
int cast_vopc = VectorCastNode::opcode(elem_bt_from, !is_ucast);

// Make sure that vector cast is implemented to particular type/size combination.
bool no_vec_cast_check = is_mask &&
((src_type->isa_vectmask() && dst_type->isa_vectmask()) ||
type2aelembytes(elem_bt_from) == type2aelembytes(elem_bt_to));
if (!no_vec_cast_check && !arch_supports_vector(cast_vopc, num_elem_to, new_elem_bt_to, VecMaskNotUsed)) {
// Make sure that vector cast is implemented to particular type/size combination if it is
// not a mask casting.
if (!is_mask && !arch_supports_vector(cast_vopc, num_elem_to, elem_bt_to, VecMaskNotUsed)) {
if (C->print_intrinsics()) {
tty->print_cr(" ** not supported: arity=1 op=cast#%d/3 vlen2=%d etype2=%s ismask=%d",
cast_vopc, num_elem_to, type2name(new_elem_bt_to), is_mask);
cast_vopc, num_elem_to, type2name(elem_bt_to), is_mask);
}
return false;
}
Expand Down Expand Up @@ -2552,12 +2543,15 @@ bool LibraryCallKit::inline_vector_convert() {
op = gvn().transform(VectorCastNode::make(cast_vopc, op, elem_bt_to, num_elem_to));
} else { // num_elem_from == num_elem_to
if (is_mask) {
if ((dst_type->isa_vectmask() && src_type->isa_vectmask()) ||
(type2aelembytes(elem_bt_from) == type2aelembytes(elem_bt_to))) {
op = gvn().transform(new VectorMaskCastNode(op, dst_type));
} else {
op = VectorMaskCastNode::makeCastNode(&gvn(), op, dst_type);
// Make sure that cast for vector mask is implemented to particular type/size combination.
if (!arch_supports_vector(Op_VectorMaskCast, num_elem_to, elem_bt_to, VecMaskNotUsed)) {
if (C->print_intrinsics()) {
tty->print_cr(" ** not supported: arity=1 op=maskcast vlen2=%d etype2=%s ismask=%d",
num_elem_to, type2name(elem_bt_to), is_mask);
}
return false;
}
op = gvn().transform(new VectorMaskCastNode(op, dst_type));
} else {
// Since input and output number of elements match, and since we know this vector size is
// supported, simply do a cast with no resize needed.
Expand Down

1 comment on commit ab8c136

@openjdk-notifier
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.