Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8266287: Basic mask IR implementation for the Vector API masking feature support #78

Closed
@@ -271,6 +271,7 @@ Form::DataType Form::is_load_from_memory(const char *opType) const {
if( strcmp(opType,"LoadVectorGather")==0 ) return Form::idealV;
if( strcmp(opType,"LoadVectorGatherMasked")==0 ) return Form::idealV;
if( strcmp(opType,"LoadVectorMasked")==0 ) return Form::idealV;
if( strcmp(opType,"LoadVectorMask")==0 ) return Form::idealV;
assert( strcmp(opType,"Load") != 0, "Must type Loads" );
return Form::none;
}
@@ -290,6 +291,7 @@ Form::DataType Form::is_store_to_memory(const char *opType) const {
if( strcmp(opType,"StoreVectorScatter")==0 ) return Form::idealV;
if( strcmp(opType,"StoreVectorScatterMasked")==0 ) return Form::idealV;
if( strcmp(opType,"StoreVectorMasked")==0 ) return Form::idealV;
if( strcmp(opType,"StoreVectorMask")==0 ) return Form::idealV;
assert( strcmp(opType,"Store") != 0, "Must type Stores" );
return Form::none;
}
@@ -3517,6 +3517,7 @@ int MatchNode::needs_ideal_memory_edge(FormDict &globals) const {
"LoadB" , "LoadUB", "LoadUS" ,"LoadS" ,"Load" ,
"StoreVector", "LoadVector", "LoadVectorMasked", "StoreVectorMasked",
"LoadVectorGather", "StoreVectorScatter", "LoadVectorGatherMasked", "StoreVectorScatterMasked",
"LoadVectorMask", "StoreVectorMask",
"LoadRange", "LoadKlass", "LoadNKlass", "LoadL_unaligned", "LoadD_unaligned",
"LoadPLocked",
"StorePConditional", "StoreIConditional", "StoreLConditional",
@@ -3820,51 +3821,77 @@ bool MatchNode::equivalent(FormDict &globals, MatchNode *mNode2) {
return true;
}

//-------------------------- has_commutative_op -------------------------------
//-------------------------- count_commutative_op -------------------------------
// Recursively check for commutative operations with subtree operands
// which could be swapped.
void MatchNode::count_commutative_op(int& count) {
static const char *commut_op_list[] = {
"AddI","AddL","AddF","AddD",
"AddVB","AddVS","AddVI","AddVL","AddVF","AddVD",
"AndI","AndL",
"AndV",
"MaxI","MinI","MaxF","MinF","MaxD","MinD",
"MaxV", "MinV",
"MulI","MulL","MulF","MulD",
"MulVB","MulVS","MulVI","MulVL","MulVF","MulVD",
"OrI","OrL",
"OrV",
"XorI","XorL",
"XorV"
"XorI","XorL"
};
int cnt = sizeof(commut_op_list)/sizeof(char*);

if( _lChild && _rChild && (_lChild->_lChild || _rChild->_lChild) ) {
static const char *commut_vector_op_list[] = {
"AddVB", "AddVS", "AddVI", "AddVL", "AddVF", "AddVD",
"MulVB", "MulVS", "MulVI", "MulVL", "MulVF", "MulVD",
"AndV", "OrV", "XorV",
"MaxV", "MinV"
};

if (_lChild && _rChild && (_lChild->_lChild || _rChild->_lChild)) {
// Don't swap if right operand is an immediate constant.
bool is_const = false;
if( _rChild->_lChild == NULL && _rChild->_rChild == NULL ) {
if (_rChild->_lChild == NULL && _rChild->_rChild == NULL) {
FormDict &globals = _AD.globalNames();
const Form *form = globals[_rChild->_opType];
if ( form ) {
OperandForm *oper = form->is_operand();
if( oper && oper->interface_type(globals) == Form::constant_interface )
if (form) {
OperandForm *oper = form->is_operand();
if (oper && oper->interface_type(globals) == Form::constant_interface)
is_const = true;
}
}
if( !is_const ) {
for( int i=0; i<cnt; i++ ) {
if( strcmp(_opType, commut_op_list[i]) == 0 ) {
count++;
_commutative_id = count; // id should be > 0

if (!is_const) {
int scalar_cnt = sizeof(commut_op_list)/sizeof(char*);
int vector_cnt = sizeof(commut_vector_op_list)/sizeof(char*);
bool matched = false;

// Check the commutative vector op first. It's noncommutative if
// the current node is a masked vector op, since a mask value
// is added to the original vector node's input list and the original
// first two inputs are packed into one BinaryNode. So don't swap
// if one of the operands is a BinaryNode.
for (int i = 0; i < vector_cnt; i++) {
if (strcmp(_opType, commut_vector_op_list[i]) == 0) {
if (strcmp(_lChild->_opType, "Binary") != 0 &&
strcmp(_rChild->_opType, "Binary") != 0) {
count++;
_commutative_id = count; // id should be > 0
}
matched = true;
break;
}
}

// Then check the scalar op if the current op is not in
// the commut_vector_op_list.
if (!matched) {
for (int i = 0; i < scalar_cnt; i++) {
if (strcmp(_opType, commut_op_list[i]) == 0) {
count++;
_commutative_id = count; // id should be > 0
break;
}
}
}
}
}
if( _lChild )
if (_lChild)
_lChild->count_commutative_op(count);
if( _rChild )
if (_rChild)
_rChild->count_commutative_op(count);
}
Copy link
Member

@jatin-bhateja jatin-bhateja May 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matcher rules based on proposed masked IR shall look like following
match(Set dst AddVI (Binary vsrc1 vsrc2) mask)

AddVI is still commutative and generated DFA will check for appropriate child states i.e.
kid[0]->state == _Binary_vec_vec && kid[1]->state == _mask

So even if matcher generates additional check by swapping the states of the two child nodes it should still be ok. Can you kindly elaborate the need for this change.

Copy link
Collaborator Author

@XiaohongGong XiaohongGong May 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, actually that's ok if matcher swaps the two operands. Avoiding the swapping could avoid generating unneeded rules which could reduce the whole code size.


@@ -4090,6 +4117,7 @@ int MatchRule::is_expensive() const {
strcmp(opType,"AndReductionV")==0 ||
strcmp(opType,"OrReductionV")==0 ||
strcmp(opType,"XorReductionV")==0 ||
strcmp(opType,"MaskAll")==0 ||
0 /* 0 to line up columns nicely */ )
return 1;
}
@@ -4208,10 +4236,12 @@ bool MatchRule::is_vector() const {
"VectorCastL2X", "VectorCastF2X", "VectorCastD2X",
"VectorMaskWrapper", "VectorMaskCmp", "VectorReinterpret","LoadVectorMasked","StoreVectorMasked",
"FmaVD", "FmaVF","PopCountVI",
"LoadVectorMask", "StoreVectorMask",
// Next are vector mask ops.
"MaskAll", "AndVMask", "OrVMask", "XorVMask", "VectorMaskCast",
// Next are not supported currently.
"PackB","PackS","PackI","PackL","PackF","PackD","Pack2L","Pack2D",
"ExtractB","ExtractUB","ExtractC","ExtractS","ExtractI","ExtractL","ExtractF","ExtractD",
"VectorMaskCast"
"ExtractB","ExtractUB","ExtractC","ExtractS","ExtractI","ExtractL","ExtractF","ExtractD"
};
int cnt = sizeof(vector_list)/sizeof(char*);
if (_rChild) {
@@ -417,7 +417,9 @@ macro(LoadVectorGatherMasked)
macro(StoreVector)
macro(StoreVectorScatter)
macro(StoreVectorScatterMasked)
macro(LoadVectorMask)
macro(LoadVectorMasked)
macro(StoreVectorMask)
macro(StoreVectorMasked)
macro(VectorCmpMasked)
macro(VectorMaskGen)
@@ -475,3 +477,7 @@ macro(VectorCastL2X)
macro(VectorCastF2X)
macro(VectorCastD2X)
macro(VectorInsert)
macro(MaskAll)
macro(AndVMask)
macro(OrVMask)
macro(XorVMask)
Copy link
Member

@jatin-bhateja jatin-bhateja May 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you have used existing IR nodes for other commutative vector operations like Add/Sub/Mul, why specialized masked IR for OrV/AndV/XorV

Copy link
Collaborator Author

@XiaohongGong XiaohongGong May 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These three nodes are used for the vector mask logical operations while not for the masked vector logics. E.g mask.and()/mask.or()/mask.xor().

@@ -3415,6 +3415,8 @@ void Compile::final_graph_reshaping_main_switch(Node* n, Final_Reshape_Counts& f
case Op_StoreVectorScatterMasked:
case Op_VectorCmpMasked:
case Op_VectorMaskGen:
case Op_LoadVectorMask:
case Op_StoreVectorMask:
case Op_LoadVectorMasked:
case Op_StoreVectorMasked:
break;
@@ -702,6 +702,7 @@ void PhaseCFG::adjust_register_pressure(Node* n, Block* block, intptr_t* recalc_
case Op_StoreP:
case Op_StoreN:
case Op_StoreVector:
case Op_StoreVectorMask:
case Op_StoreVectorMasked:
case Op_StoreVectorScatter:
case Op_StoreVectorScatterMasked:
@@ -2305,6 +2305,21 @@ bool Matcher::find_shared_visit(MStack& mstack, Node* n, uint opcode, bool& mem_
}

void Matcher::find_shared_post_visit(Node* n, uint opcode) {
if (n->is_predicated_vector()) {
// Restructure into binary trees for Matching.
if (n->req() == 4) {
n->set_req(1, new BinaryNode(n->in(1), n->in(2)));
n->set_req(2, n->in(3));
n->del_req(3);
} else if (n->req() == 5) {
n->set_req(1, new BinaryNode(n->in(1), n->in(2)));
n->set_req(2, new BinaryNode(n->in(3), n->in(4)));
n->del_req(4);
n->del_req(3);
}
return;
}

XiaohongGong marked this conversation as resolved.
Show resolved Hide resolved
switch(opcode) { // Handle some opcodes special
case Op_StorePConditional:
case Op_StoreIConditional:
@@ -1134,7 +1134,7 @@ Node* MemNode::can_see_stored_value(Node* st, PhaseTransform* phase) const {
return NULL;
}
// LoadVector/StoreVector needs additional check to ensure the types match.
if (store_Opcode() == Op_StoreVector) {
if (st->is_StoreVector()) {
const TypeVect* in_vt = st->as_StoreVector()->vect_type();
const TypeVect* out_vt = as_LoadVector()->vect_type();
if (in_vt != out_vt) {
@@ -777,7 +777,8 @@ class Node {
Flag_is_scheduled = 1 << 12,
Flag_has_vector_mask_set = 1 << 13,
Flag_is_expensive = 1 << 14,
Flag_for_post_loop_opts_igvn = 1 << 15,
Flag_is_predicated_vector = 1 << 15,
Flag_for_post_loop_opts_igvn = 1 << 16,
_last_flag = Flag_for_post_loop_opts_igvn
};

@@ -986,6 +987,8 @@ class Node {
// It must have the loop's phi as input and provide a def to the phi.
bool is_reduction() const { return (_flags & Flag_is_reduction) != 0; }

bool is_predicated_vector() const { return (_flags & Flag_is_predicated_vector) != 0; }

// The node is a CountedLoopEnd with a mask annotation so as to emit a restore context
bool has_vector_mask_set() const { return (_flags & Flag_has_vector_mask_set) != 0; }

@@ -2344,7 +2344,10 @@ const TypeVect *TypeVect::VECTZ = NULL; // 512-bit vectors
const TypeVect *TypeVect::VECTMASK = NULL; // predicate/mask vector

//------------------------------make-------------------------------------------
const TypeVect* TypeVect::make(const Type *elem, uint length) {
const TypeVect* TypeVect::make(const Type *elem, uint length, bool is_mask) {
if (is_mask) {
return makemask(elem, length);
}
BasicType elem_bt = elem->array_element_basic_type();
assert(is_java_primitive(elem_bt), "only primitive types in vector");
assert(Matcher::vector_size_supported(elem_bt, length), "length in range");
@@ -2370,7 +2373,11 @@ const TypeVect* TypeVect::make(const Type *elem, uint length) {
}

const TypeVect *TypeVect::makemask(const Type* elem, uint length) {
if (Matcher::has_predicated_vectors()) {
if (Matcher::has_predicated_vectors() &&
// TODO: remove this condition once the backend is supported.
// Workround to make tests pass on AVX-512/SVE when predicate is not supported.
// Could be removed once the backend is supported.
Matcher::match_rule_supported_vector_masked(Op_StoreVectorMasked, MaxVectorSize, T_BOOLEAN)) {
const TypeVect* mtype = Matcher::predicate_reg_type(elem, length);
return (TypeVect*)(const_cast<TypeVect*>(mtype))->hashcons();
} else {
@@ -796,12 +796,12 @@ class TypeVect : public Type {
virtual bool singleton(void) const; // TRUE if type is a singleton
virtual bool empty(void) const; // TRUE if type is vacuous

static const TypeVect *make(const BasicType elem_bt, uint length) {
static const TypeVect *make(const BasicType elem_bt, uint length, bool is_mask = false) {
// Use bottom primitive type.
return make(get_const_basic_type(elem_bt), length);
return make(get_const_basic_type(elem_bt), length, is_mask);
}
// Used directly by Replicate nodes to construct singleton vector.
static const TypeVect *make(const Type* elem, uint length);
static const TypeVect *make(const Type* elem, uint length, bool is_mask = false);

static const TypeVect *makemask(const BasicType elem_bt, uint length) {
// Use bottom primitive type.
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2020, 2021, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@@ -454,7 +454,7 @@ void PhaseVector::expand_vunbox_node(VectorUnboxNode* vec_unbox) {
C->set_max_vector_size(MAX2(C->max_vector_size(), vt->length_in_bytes()));

if (is_vector_mask(from_kls)) {
vec_val_load = gvn.transform(new VectorLoadMaskNode(vec_val_load, TypeVect::make(masktype, num_elem)));
vec_val_load = gvn.transform(new VectorLoadMaskNode(vec_val_load, TypeVect::makemask(masktype, num_elem)));
} else if (is_vector_shuffle(from_kls) && !vec_unbox->is_shuffle_to_vector()) {
assert(vec_unbox->bottom_type()->is_vect()->element_basic_type() == masktype, "expect shuffle type consistency");
vec_val_load = gvn.transform(new VectorLoadShuffleNode(vec_val_load, TypeVect::make(masktype, num_elem)));