Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8267368: Add masking support for reduction vector intrinsics #86

Closed
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -849,14 +849,15 @@ class methodHandle;
"Ljava/lang/Object;ILjdk/internal/vm/vector/VectorSupport$StoreVectorOperation;)V") \
do_name(vector_store_op_name, "store") \
\
do_intrinsic(_VectorStoreMaskedOp, jdk_internal_vm_vector_VectorSupport, vector_store_masked_op_name, vector_store_masked_op_sig, F_S) \
do_signature(vector_store_masked_op_sig, "(Ljava/lang/Class;Ljava/lang/Class;Ljava/lang/Class;ILjava/lang/Object;JLjdk/internal/vm/vector/VectorSupport$Vector;" \
do_intrinsic(_VectorStoreMaskedOp, jdk_internal_vm_vector_VectorSupport, vector_store_masked_op_name, vector_store_masked_op_sig, F_S) \
do_signature(vector_store_masked_op_sig, "(Ljava/lang/Class;Ljava/lang/Class;Ljava/lang/Class;ILjava/lang/Object;JLjdk/internal/vm/vector/VectorSupport$Vector;" \
"Ljdk/internal/vm/vector/VectorSupport$VectorMask;Ljava/lang/Object;I" \
"Ljdk/internal/vm/vector/VectorSupport$StoreVectorMaskedOperation;)V") \
do_name(vector_store_masked_op_name, "storeMasked") \
\
do_intrinsic(_VectorReductionCoerced, jdk_internal_vm_vector_VectorSupport, vector_reduction_coerced_name, vector_reduction_coerced_sig, F_S) \
do_signature(vector_reduction_coerced_sig, "(ILjava/lang/Class;Ljava/lang/Class;ILjdk/internal/vm/vector/VectorSupport$Vector;Ljava/util/function/Function;)J") \
do_intrinsic(_VectorReductionCoerced, jdk_internal_vm_vector_VectorSupport, vector_reduction_coerced_name, vector_reduction_coerced_sig, F_S)\
do_signature(vector_reduction_coerced_sig, "(ILjava/lang/Class;Ljava/lang/Class;Ljava/lang/Class;ILjava/lang/Object;Ljava/lang/Object;" \
"Ljdk/internal/vm/vector/VectorSupport$ReductionOperation;)J") \
do_name(vector_reduction_coerced_name, "reductionCoerced") \
\
do_intrinsic(_VectorTest, jdk_internal_vm_vector_VectorSupport, vector_test_name, vector_test_sig, F_S) \
@@ -1127,25 +1127,26 @@ bool LibraryCallKit::inline_vector_gather_scatter(bool is_scatter) {
return true;
}

// <V extends Vector<?,?>>
// long reductionCoerced(int oprId, Class<?> vectorClass, Class<?> elementType, int vlen,
// V v,
// Function<V,Long> defaultImpl)

// <V, M>
// long reductionCoerced(int oprId, Class<? extends V> vectorClass, Class<? extends M> maskClass,
// Class<?> elementType, int length, V v, M m,
// ReductionOperation<V, M> defaultImpl) {
//
bool LibraryCallKit::inline_vector_reduction() {
const TypeInt* opr = gvn().type(argument(0))->isa_int();
const TypeInstPtr* vector_klass = gvn().type(argument(1))->isa_instptr();
const TypeInstPtr* elem_klass = gvn().type(argument(2))->isa_instptr();
const TypeInt* vlen = gvn().type(argument(3))->isa_int();
const TypeInstPtr* mask_klass = gvn().type(argument(2))->isa_instptr();
const TypeInstPtr* elem_klass = gvn().type(argument(3))->isa_instptr();
const TypeInt* vlen = gvn().type(argument(4))->isa_int();

if (opr == NULL || vector_klass == NULL || elem_klass == NULL || vlen == NULL ||
!opr->is_con() || vector_klass->const_oop() == NULL || elem_klass->const_oop() == NULL || !vlen->is_con()) {
if (C->print_intrinsics()) {
tty->print_cr(" ** missing constant: opr=%s vclass=%s etype=%s vlen=%s",
NodeClassNames[argument(0)->Opcode()],
NodeClassNames[argument(1)->Opcode()],
NodeClassNames[argument(2)->Opcode()],
NodeClassNames[argument(3)->Opcode()]);
NodeClassNames[argument(3)->Opcode()],
NodeClassNames[argument(4)->Opcode()]);
}
return false; // not enough info for intrinsification
}
@@ -1162,14 +1163,39 @@ bool LibraryCallKit::inline_vector_reduction() {
}
return false; // should be primitive type
}

const Type* vmask_type = gvn().type(argument(6));
bool is_masked_op = vmask_type != TypePtr::NULL_PTR;
if (is_masked_op) {
if (mask_klass == NULL || mask_klass->const_oop() == NULL) {
if (C->print_intrinsics()) {
tty->print_cr(" ** missing constant: maskclass=%s", NodeClassNames[argument(2)->Opcode()]);
}
return false; // not enough info for intrinsification
}

if (!is_klass_initialized(mask_klass)) {
if (C->print_intrinsics()) {
tty->print_cr(" ** mask klass argument not initialized");
}
return false;
}

if (vmask_type->maybe_null()) {
if (C->print_intrinsics()) {
tty->print_cr(" ** null mask values are not allowed for masked op");
}
return false;
}
}

BasicType elem_bt = elem_type->basic_type();
int num_elem = vlen->get_con();

int opc = VectorSupport::vop2ideal(opr->get_con(), elem_bt);
int sopc = ReductionNode::opcode(opc, elem_bt);

// TODO When mask usage is supported, VecMaskNotUsed needs to be VecMaskUseLoad.
if (!arch_supports_vector(sopc, num_elem, elem_bt, VecMaskNotUsed)) {
// When using mask, mask use type needs to be VecMaskUseLoad.
if (!arch_supports_vector(sopc, num_elem, elem_bt, is_masked_op ? VecMaskUseLoad : VecMaskNotUsed)) {
if (C->print_intrinsics()) {
tty->print_cr(" ** not supported: arity=1 op=%d/reduce vlen=%d etype=%s ismask=no",
sopc, num_elem, type2name(elem_bt));
@@ -1180,33 +1206,72 @@ bool LibraryCallKit::inline_vector_reduction() {
ciKlass* vbox_klass = vector_klass->const_oop()->as_instance()->java_lang_Class_klass();
const TypeInstPtr* vbox_type = TypeInstPtr::make_exact(TypePtr::NotNull, vbox_klass);

Node* opd = unbox_vector(argument(4), vbox_type, elem_bt, num_elem);
Node* opd = unbox_vector(argument(5), vbox_type, elem_bt, num_elem);
if (opd == NULL) {
return false; // operand unboxing failed
}

Node* mask = NULL;
bool use_predicate = false;
if (is_masked_op) {
ciKlass* mbox_klass = mask_klass->const_oop()->as_instance()->java_lang_Class_klass();
assert(is_vector_mask(mbox_klass), "argument(2) should be a mask class");
const TypeInstPtr* mbox_type = TypeInstPtr::make_exact(TypePtr::NotNull, mbox_klass);
mask = unbox_vector(argument(6), mbox_type, elem_bt, num_elem);
if (mask == NULL) {
if (C->print_intrinsics()) {
tty->print_cr(" ** unbox failed mask=%s",
NodeClassNames[argument(6)->Opcode()]);
}
return false;
}

// Return true if current platform has implemented the masked operation with predicate feature.
use_predicate = Matcher::has_predicated_vectors() &&
Matcher::match_rule_supported_vector_masked(sopc, num_elem, elem_bt);
if (!use_predicate && !arch_supports_vector(Op_VectorBlend, num_elem, elem_bt, VecMaskUseLoad)) {
return false;
}
}

Node* init = ReductionNode::make_reduction_input(gvn(), opc, elem_bt);
Node* rn = gvn().transform(ReductionNode::make(opc, NULL, init, opd, elem_bt));
Node* value = NULL;
if (mask == NULL) {
assert(!is_masked_op, "Masked op needs the mask value never null");
value = ReductionNode::make(opc, NULL, init, opd, elem_bt);
} else {
if (use_predicate) {
if (C->print_intrinsics()) {
tty->print_cr(" ** predicate feature is not supported on current platform!");
}
return false;
} else {
Node* reduce_identity = gvn().transform(VectorNode::scalar2vector(init, num_elem, Type::get_const_basic_type(elem_bt)));
value = gvn().transform(new VectorBlendNode(reduce_identity, opd, mask));
value = ReductionNode::make(opc, NULL, init, value, elem_bt);
}
}
value = gvn().transform(value);

Node* bits = NULL;
switch (elem_bt) {
case T_BYTE:
case T_SHORT:
case T_INT: {
bits = gvn().transform(new ConvI2LNode(rn));
bits = gvn().transform(new ConvI2LNode(value));
break;
}
case T_FLOAT: {
rn = gvn().transform(new MoveF2INode(rn));
bits = gvn().transform(new ConvI2LNode(rn));
value = gvn().transform(new MoveF2INode(value));
bits = gvn().transform(new ConvI2LNode(value));
break;
}
case T_DOUBLE: {
bits = gvn().transform(new MoveD2LNode(rn));
bits = gvn().transform(new MoveD2LNode(value));
break;
}
case T_LONG: {
bits = rn; // no conversion needed
bits = value; // no conversion needed
break;
}
default: fatal("%s", type2name(elem_bt));
@@ -1072,7 +1072,9 @@ Node* ReductionNode::make_reduction_input(PhaseGVN& gvn, int opc, BasicType bt)
case Op_MinReductionV:
switch (bt) {
case T_BYTE:
return gvn.makecon(TypeInt::make(max_jbyte));
case T_SHORT:
return gvn.makecon(TypeInt::make(max_jshort));
case T_INT:
return gvn.makecon(TypeInt::MAX);
case T_LONG:
@@ -1087,7 +1089,9 @@ Node* ReductionNode::make_reduction_input(PhaseGVN& gvn, int opc, BasicType bt)
case Op_MaxReductionV:
switch (bt) {
case T_BYTE:
return gvn.makecon(TypeInt::make(min_jbyte));
case T_SHORT:
return gvn.makecon(TypeInt::make(min_jshort));
case T_INT:
return gvn.makecon(TypeInt::MIN);
case T_LONG:
@@ -202,14 +202,19 @@ V indexVector(Class<? extends V> vClass, Class<E> E, int length,

@IntrinsicCandidate
public static
<V extends Vector<?>>
long reductionCoerced(int oprId, Class<?> vectorClass, Class<?> elementType, int length,
V v,
Function<V,Long> defaultImpl) {
<V, M>
long reductionCoerced(int oprId, Class<? extends V> vectorClass, Class<? extends M> maskClass,
Class<?> elementType, int length, V v, M m,
ReductionOperation<V, M> defaultImpl) {
assert isNonCapturingLambda(defaultImpl) : defaultImpl;
return defaultImpl.apply(v);
return defaultImpl.apply(v, m);
}

public interface ReductionOperation<V, M> {
long apply(V v, M mask);
}


/* ============================================================================ */

public interface VecExtractOp<V> {
@@ -236,8 +236,8 @@ Byte128Vector tOp(Vector<Byte> v1, Vector<Byte> v2,

@ForceInline
final @Override
byte rOp(byte v, FBinOp f) {
return super.rOpTemplate(v, f); // specialize
byte rOp(byte v, VectorMask<Byte> m, FBinOp f) {
return super.rOpTemplate(v, m, f); // specialize
}

@Override
@@ -334,7 +334,7 @@ public final byte reduceLanes(VectorOperators.Associative op) {
@ForceInline
public final byte reduceLanes(VectorOperators.Associative op,
VectorMask<Byte> m) {
return super.reduceLanesTemplate(op, m); // specialized
return super.reduceLanesTemplate(op, Byte128Mask.class, m); // specialized
}

@Override
@@ -347,7 +347,7 @@ public final long reduceLanesToLong(VectorOperators.Associative op) {
@ForceInline
public final long reduceLanesToLong(VectorOperators.Associative op,
VectorMask<Byte> m) {
return (long) super.reduceLanesTemplate(op, m); // specialized
return (long) super.reduceLanesTemplate(op, Byte128Mask.class, m); // specialized
}

@Override
@@ -236,8 +236,8 @@ Byte256Vector tOp(Vector<Byte> v1, Vector<Byte> v2,

@ForceInline
final @Override
byte rOp(byte v, FBinOp f) {
return super.rOpTemplate(v, f); // specialize
byte rOp(byte v, VectorMask<Byte> m, FBinOp f) {
return super.rOpTemplate(v, m, f); // specialize
}

@Override
@@ -334,7 +334,7 @@ public final byte reduceLanes(VectorOperators.Associative op) {
@ForceInline
public final byte reduceLanes(VectorOperators.Associative op,
VectorMask<Byte> m) {
return super.reduceLanesTemplate(op, m); // specialized
return super.reduceLanesTemplate(op, Byte256Mask.class, m); // specialized
}

@Override
@@ -347,7 +347,7 @@ public final long reduceLanesToLong(VectorOperators.Associative op) {
@ForceInline
public final long reduceLanesToLong(VectorOperators.Associative op,
VectorMask<Byte> m) {
return (long) super.reduceLanesTemplate(op, m); // specialized
return (long) super.reduceLanesTemplate(op, Byte256Mask.class, m); // specialized
}

@Override
@@ -236,8 +236,8 @@ Byte512Vector tOp(Vector<Byte> v1, Vector<Byte> v2,

@ForceInline
final @Override
byte rOp(byte v, FBinOp f) {
return super.rOpTemplate(v, f); // specialize
byte rOp(byte v, VectorMask<Byte> m, FBinOp f) {
return super.rOpTemplate(v, m, f); // specialize
}

@Override
@@ -334,7 +334,7 @@ public final byte reduceLanes(VectorOperators.Associative op) {
@ForceInline
public final byte reduceLanes(VectorOperators.Associative op,
VectorMask<Byte> m) {
return super.reduceLanesTemplate(op, m); // specialized
return super.reduceLanesTemplate(op, Byte512Mask.class, m); // specialized
}

@Override
@@ -347,7 +347,7 @@ public final long reduceLanesToLong(VectorOperators.Associative op) {
@ForceInline
public final long reduceLanesToLong(VectorOperators.Associative op,
VectorMask<Byte> m) {
return (long) super.reduceLanesTemplate(op, m); // specialized
return (long) super.reduceLanesTemplate(op, Byte512Mask.class, m); // specialized
}

@Override
@@ -236,8 +236,8 @@ Byte64Vector tOp(Vector<Byte> v1, Vector<Byte> v2,

@ForceInline
final @Override
byte rOp(byte v, FBinOp f) {
return super.rOpTemplate(v, f); // specialize
byte rOp(byte v, VectorMask<Byte> m, FBinOp f) {
return super.rOpTemplate(v, m, f); // specialize
}

@Override
@@ -334,7 +334,7 @@ public final byte reduceLanes(VectorOperators.Associative op) {
@ForceInline
public final byte reduceLanes(VectorOperators.Associative op,
VectorMask<Byte> m) {
return super.reduceLanesTemplate(op, m); // specialized
return super.reduceLanesTemplate(op, Byte64Mask.class, m); // specialized
}

@Override
@@ -347,7 +347,7 @@ public final long reduceLanesToLong(VectorOperators.Associative op) {
@ForceInline
public final long reduceLanesToLong(VectorOperators.Associative op,
VectorMask<Byte> m) {
return (long) super.reduceLanesTemplate(op, m); // specialized
return (long) super.reduceLanesTemplate(op, Byte64Mask.class, m); // specialized
}

@Override
@@ -236,8 +236,8 @@ ByteMaxVector tOp(Vector<Byte> v1, Vector<Byte> v2,

@ForceInline
final @Override
byte rOp(byte v, FBinOp f) {
return super.rOpTemplate(v, f); // specialize
byte rOp(byte v, VectorMask<Byte> m, FBinOp f) {
return super.rOpTemplate(v, m, f); // specialize
}

@Override
@@ -334,7 +334,7 @@ public final byte reduceLanes(VectorOperators.Associative op) {
@ForceInline
public final byte reduceLanes(VectorOperators.Associative op,
VectorMask<Byte> m) {
return super.reduceLanesTemplate(op, m); // specialized
return super.reduceLanesTemplate(op, ByteMaxMask.class, m); // specialized
}

@Override
@@ -347,7 +347,7 @@ public final long reduceLanesToLong(VectorOperators.Associative op) {
@ForceInline
public final long reduceLanesToLong(VectorOperators.Associative op,
VectorMask<Byte> m) {
return (long) super.reduceLanesTemplate(op, m); // specialized
return (long) super.reduceLanesTemplate(op, ByteMaxMask.class, m); // specialized
}

@Override