diff --git a/src/hotspot/share/opto/compile.cpp b/src/hotspot/share/opto/compile.cpp index c49c6eccc467e..cac9053e3a01d 100644 --- a/src/hotspot/share/opto/compile.cpp +++ b/src/hotspot/share/opto/compile.cpp @@ -2844,12 +2844,7 @@ void Compile::process_logic_cone_root(PhaseIterGVN &igvn, Node *n, VectorSet &vi if (mask == nullptr || Matcher::match_rule_supported_vector_masked(Op_MacroLogicV, vt->length(), vt->element_basic_type())) { Node* macro_logic = xform_to_MacroLogicV(igvn, vt, partition, inputs); -#ifdef ASSERT - if (TraceNewVectors) { - tty->print("new Vector node: "); - macro_logic->dump(); - } -#endif + VectorNode::trace_new_vector(macro_logic, "MacroLogic"); igvn.replace_node(n, macro_logic); } } diff --git a/src/hotspot/share/opto/loopnode.cpp b/src/hotspot/share/opto/loopnode.cpp index 73cb22c6db8c4..7885f5ea1527c 100644 --- a/src/hotspot/share/opto/loopnode.cpp +++ b/src/hotspot/share/opto/loopnode.cpp @@ -4628,6 +4628,16 @@ void PhaseIdealLoop::build_and_optimize() { } } } + + // Move UnorderedReduction out of counted loop. Can be introduced by SuperWord. + if (C->has_loops() && !C->major_progress()) { + for (LoopTreeIterator iter(_ltree_root); !iter.done(); iter.next()) { + IdealLoopTree* lpt = iter.current(); + if (lpt->is_counted() && lpt->is_innermost()) { + move_unordered_reduction_out_of_loop(lpt); + } + } + } } #ifndef PRODUCT diff --git a/src/hotspot/share/opto/loopnode.hpp b/src/hotspot/share/opto/loopnode.hpp index 459021120ed03..06d962fe638a4 100644 --- a/src/hotspot/share/opto/loopnode.hpp +++ b/src/hotspot/share/opto/loopnode.hpp @@ -1484,6 +1484,9 @@ class PhaseIdealLoop : public PhaseTransform { bool partial_peel( IdealLoopTree *loop, Node_List &old_new ); bool duplicate_loop_backedge(IdealLoopTree *loop, Node_List &old_new); + // Move UnorderedReduction out of loop if possible + void move_unordered_reduction_out_of_loop(IdealLoopTree* loop); + // Create a scheduled list of nodes control dependent on ctrl set. void scheduled_nodelist( IdealLoopTree *loop, VectorSet& ctrl, Node_List &sched ); // Has a use in the vector set diff --git a/src/hotspot/share/opto/loopopts.cpp b/src/hotspot/share/opto/loopopts.cpp index 53fce009040b2..34dc31cb3c64b 100644 --- a/src/hotspot/share/opto/loopopts.cpp +++ b/src/hotspot/share/opto/loopopts.cpp @@ -41,6 +41,7 @@ #include "opto/rootnode.hpp" #include "opto/subnode.hpp" #include "opto/subtypenode.hpp" +#include "opto/vectornode.hpp" #include "utilities/macros.hpp" //============================================================================= @@ -4120,3 +4121,188 @@ bool PhaseIdealLoop::duplicate_loop_backedge(IdealLoopTree *loop, Node_List &old return true; } + +// Having ReductionNodes in the loop is expensive. They need to recursively +// fold together the vector values, for every vectorized loop iteration. If +// we encounter the following pattern, we can vector accumulate the values +// inside the loop, and only have a single UnorderedReduction after the loop. +// +// CountedLoop init +// | | +// +------+ | +-----------------------+ +// | | | | +// PhiNode (s) | +// | | +// | Vector | +// | | | +// UnorderedReduction (first_ur) | +// | | +// ... Vector | +// | | | +// UnorderedReduction (last_ur) | +// | | +// +---------------------+ +// +// We patch the graph to look like this: +// +// CountedLoop identity_vector +// | | +// +-------+ | +---------------+ +// | | | | +// PhiNode (v) | +// | | +// | Vector | +// | | | +// VectorAccumulator | +// | | +// ... Vector | +// | | | +// init VectorAccumulator | +// | | | | +// UnorderedReduction +-----------+ +// +// We turned the scalar (s) Phi into a vectorized one (v). In the loop, we +// use vector_accumulators, which do the same reductions, but only element +// wise. This is a single operation per vector_accumulator, rather than many +// for a UnorderedReduction. We can then reduce the last vector_accumulator +// after the loop, and also reduce the init value into it. +// We can not do this with all reductions. Some reductions do not allow the +// reordering of operations (for example float addition). +void PhaseIdealLoop::move_unordered_reduction_out_of_loop(IdealLoopTree* loop) { + assert(!C->major_progress() && loop->is_counted() && loop->is_innermost(), "sanity"); + + // Find all Phi nodes with UnorderedReduction on backedge. + CountedLoopNode* cl = loop->_head->as_CountedLoop(); + for (DUIterator_Fast jmax, j = cl->fast_outs(jmax); j < jmax; j++) { + Node* phi = cl->fast_out(j); + // We have a phi with a single use, and a UnorderedReduction on the backedge. + if (!phi->is_Phi() || phi->outcnt() != 1 || !phi->in(2)->is_UnorderedReduction()) { + continue; + } + + UnorderedReductionNode* last_ur = phi->in(2)->as_UnorderedReduction(); + + // Determine types + const TypeVect* vec_t = last_ur->vect_type(); + uint vector_length = vec_t->length(); + BasicType bt = vec_t->element_basic_type(); + const Type* bt_t = Type::get_const_basic_type(bt); + + // Convert opcode from vector-reduction -> scalar -> normal-vector-op + const int sopc = VectorNode::scalar_opcode(last_ur->Opcode(), bt); + const int vopc = VectorNode::opcode(sopc, bt); + if (!Matcher::match_rule_supported_vector(vopc, vector_length, bt)) { + DEBUG_ONLY( last_ur->dump(); ) + assert(false, "do not have normal vector op for this reduction"); + continue; // not implemented -> fails + } + + // Traverse up the chain of UnorderedReductions, checking that it loops back to + // the phi. Check that all UnorderedReductions only have a single use, except for + // the last (last_ur), which only has phi as a use in the loop, and all other uses + // are outside the loop. + UnorderedReductionNode* current = last_ur; + UnorderedReductionNode* first_ur = nullptr; + while (true) { + assert(current->is_UnorderedReduction(), "sanity"); + + // Expect no ctrl and a vector_input from within the loop. + Node* ctrl = current->in(0); + Node* vector_input = current->in(2); + if (ctrl != nullptr || get_ctrl(vector_input) != cl) { + DEBUG_ONLY( current->dump(1); ) + assert(false, "reduction has ctrl or bad vector_input"); + break; // Chain traversal fails. + } + + // Expect single use of UnorderedReduction, except for last_ur. + if (current == last_ur) { + // Expect all uses to be outside the loop, except phi. + for (DUIterator_Fast kmax, k = current->fast_outs(kmax); k < kmax; k++) { + Node* use = current->fast_out(k); + if (use != phi && ctrl_or_self(use) == cl) { + DEBUG_ONLY( current->dump(-1); ) + assert(false, "reduction has use inside loop"); + break; // Chain traversal fails. + } + } + } else { + if (current->outcnt() != 1) { + break; // Chain traversal fails. + } + } + + // Expect another UnorderedReduction or phi as the scalar input. + Node* scalar_input = current->in(1); + if (scalar_input->is_UnorderedReduction() && + scalar_input->Opcode() == current->Opcode()) { + // Move up the UnorderedReduction chain. + current = scalar_input->as_UnorderedReduction(); + } else if (scalar_input == phi) { + // Chain terminates at phi. + first_ur = current; + current = nullptr; + break; // Success. + } else { + DEBUG_ONLY( current->dump(1); ) + assert(false, "scalar_input is neither phi nor a matchin reduction"); + break; // Chain traversal fails. + } + } + if (current != nullptr) { + // Chain traversal was not successful. + continue; + } + assert(first_ur != nullptr, "must have successfully terminated chain traversal"); + + Node* identity_scalar = ReductionNode::make_identity_con_scalar(_igvn, sopc, bt); + set_ctrl(identity_scalar, C->root()); + VectorNode* identity_vector = VectorNode::scalar2vector(identity_scalar, vector_length, bt_t); + register_new_node(identity_vector, C->root()); + assert(vec_t == identity_vector->vect_type(), "matching vector type"); + VectorNode::trace_new_vector(identity_vector, "UnorderedReduction"); + + // Turn the scalar phi into a vector phi. + _igvn.rehash_node_delayed(phi); + Node* init = phi->in(1); // Remember init before replacing it. + phi->set_req_X(1, identity_vector, &_igvn); + phi->as_Type()->set_type(vec_t); + _igvn.set_type(phi, vec_t); + + // Traverse down the chain of UnorderedReductions, and replace them with vector_accumulators. + current = first_ur; + while (true) { + // Create vector_accumulator to replace current. + Node* last_vector_accumulator = current->in(1); + Node* vector_input = current->in(2); + VectorNode* vector_accumulator = VectorNode::make(vopc, last_vector_accumulator, vector_input, vec_t); + register_new_node(vector_accumulator, cl); + _igvn.replace_node(current, vector_accumulator); + VectorNode::trace_new_vector(vector_accumulator, "UnorderedReduction"); + if (current == last_ur) { + break; + } + current = vector_accumulator->unique_out()->as_UnorderedReduction(); + } + + // Create post-loop reduction. + Node* last_accumulator = phi->in(2); + Node* post_loop_reduction = ReductionNode::make(sopc, nullptr, init, last_accumulator, bt); + + // Take over uses of last_accumulator that are not in the loop. + for (DUIterator i = last_accumulator->outs(); last_accumulator->has_out(i); i++) { + Node* use = last_accumulator->out(i); + if (use != phi && use != post_loop_reduction) { + assert(ctrl_or_self(use) != cl, "use must be outside loop"); + use->replace_edge(last_accumulator, post_loop_reduction, &_igvn); + --i; + } + } + register_new_node(post_loop_reduction, get_late_ctrl(post_loop_reduction, cl)); + VectorNode::trace_new_vector(post_loop_reduction, "UnorderedReduction"); + + assert(last_accumulator->outcnt() == 2, "last_accumulator has 2 uses: phi and post_loop_reduction"); + assert(post_loop_reduction->outcnt() > 0, "should have taken over all non loop uses of last_accumulator"); + assert(phi->outcnt() == 1, "accumulator is the only use of phi"); + } +} diff --git a/src/hotspot/share/opto/node.hpp b/src/hotspot/share/opto/node.hpp index a81359e845953..0fcfd2c73df7a 100644 --- a/src/hotspot/share/opto/node.hpp +++ b/src/hotspot/share/opto/node.hpp @@ -151,6 +151,7 @@ class Pipeline; class PopulateIndexNode; class ProjNode; class RangeCheckNode; +class ReductionNode; class RegMask; class RegionNode; class RootNode; @@ -164,6 +165,7 @@ class SubTypeCheckNode; class Type; class TypeNode; class UnlockNode; +class UnorderedReductionNode; class VectorNode; class LoadVectorNode; class LoadVectorMaskedNode; @@ -718,6 +720,8 @@ class Node { DEFINE_CLASS_ID(CompressV, Vector, 4) DEFINE_CLASS_ID(ExpandV, Vector, 5) DEFINE_CLASS_ID(CompressM, Vector, 6) + DEFINE_CLASS_ID(Reduction, Vector, 7) + DEFINE_CLASS_ID(UnorderedReduction, Reduction, 0) DEFINE_CLASS_ID(Con, Type, 8) DEFINE_CLASS_ID(ConI, Con, 0) @@ -941,6 +945,7 @@ class Node { DEFINE_CLASS_QUERY(PCTable) DEFINE_CLASS_QUERY(Phi) DEFINE_CLASS_QUERY(Proj) + DEFINE_CLASS_QUERY(Reduction) DEFINE_CLASS_QUERY(Region) DEFINE_CLASS_QUERY(Root) DEFINE_CLASS_QUERY(SafePoint) @@ -950,6 +955,7 @@ class Node { DEFINE_CLASS_QUERY(Sub) DEFINE_CLASS_QUERY(SubTypeCheck) DEFINE_CLASS_QUERY(Type) + DEFINE_CLASS_QUERY(UnorderedReduction) DEFINE_CLASS_QUERY(Vector) DEFINE_CLASS_QUERY(VectorMaskCmp) DEFINE_CLASS_QUERY(VectorUnbox) diff --git a/src/hotspot/share/opto/superword.cpp b/src/hotspot/share/opto/superword.cpp index af8f08f722bf4..abbb1625a7733 100644 --- a/src/hotspot/share/opto/superword.cpp +++ b/src/hotspot/share/opto/superword.cpp @@ -3329,12 +3329,7 @@ bool SuperWord::output() { if (vlen_in_bytes > max_vlen_in_bytes) { max_vlen_in_bytes = vlen_in_bytes; } -#ifdef ASSERT - if (TraceNewVectors) { - tty->print("new Vector node: "); - vn->dump(); - } -#endif + VectorNode::trace_new_vector(vn, "SuperWord"); } }//for (int i = 0; i < _block.length(); i++) @@ -3374,6 +3369,7 @@ bool SuperWord::output() { if (do_reserve_copy()) { make_reversable.use_new(); } + NOT_PRODUCT(if(is_trace_loop_reverse()) {tty->print_cr("\n Final loop after SuperWord"); print_loop(true);}) return true; } @@ -3506,12 +3502,7 @@ Node* SuperWord::vector_opd(Node_List* p, int opd_idx) { assert(VectorNode::is_populate_index_supported(iv_bt), "Should support"); const TypeVect* vt = TypeVect::make(iv_bt, vlen); Node* vn = new PopulateIndexNode(iv(), _igvn.intcon(1), vt); -#ifdef ASSERT - if (TraceNewVectors) { - tty->print("new Vector node: "); - vn->dump(); - } -#endif + VectorNode::trace_new_vector(vn, "SuperWord"); _igvn.register_new_node_with_optimizer(vn); _phase->set_ctrl(vn, _phase->get_ctrl(opd)); return vn; @@ -3584,12 +3575,7 @@ Node* SuperWord::vector_opd(Node_List* p, int opd_idx) { _igvn.register_new_node_with_optimizer(vn); _phase->set_ctrl(vn, _phase->get_ctrl(opd)); -#ifdef ASSERT - if (TraceNewVectors) { - tty->print("new Vector node: "); - vn->dump(); - } -#endif + VectorNode::trace_new_vector(vn, "SuperWord"); return vn; } @@ -3621,12 +3607,7 @@ Node* SuperWord::vector_opd(Node_List* p, int opd_idx) { } _igvn.register_new_node_with_optimizer(pk); _phase->set_ctrl(pk, _phase->get_ctrl(opd)); -#ifdef ASSERT - if (TraceNewVectors) { - tty->print("new Vector node: "); - pk->dump(); - } -#endif + VectorNode::trace_new_vector(pk, "SuperWord"); return pk; } diff --git a/src/hotspot/share/opto/vectorIntrinsics.cpp b/src/hotspot/share/opto/vectorIntrinsics.cpp index 734120104a674..51b531e42c006 100644 --- a/src/hotspot/share/opto/vectorIntrinsics.cpp +++ b/src/hotspot/share/opto/vectorIntrinsics.cpp @@ -1536,7 +1536,7 @@ bool LibraryCallKit::inline_vector_reduction() { } } - Node* init = ReductionNode::make_reduction_input(gvn(), opc, elem_bt); + Node* init = ReductionNode::make_identity_con_scalar(gvn(), opc, elem_bt); Node* value = nullptr; if (mask == nullptr) { assert(!is_masked_op, "Masked op needs the mask value never null"); diff --git a/src/hotspot/share/opto/vectornode.cpp b/src/hotspot/share/opto/vectornode.cpp index 7e4a16c6d3b74..52078b21b4665 100644 --- a/src/hotspot/share/opto/vectornode.cpp +++ b/src/hotspot/share/opto/vectornode.cpp @@ -34,7 +34,7 @@ //------------------------------VectorNode-------------------------------------- // Return the vector operator for the specified scalar operation -// and vector length. +// and basic type. int VectorNode::opcode(int sopc, BasicType bt) { switch (sopc) { case Op_AddI: @@ -274,6 +274,117 @@ int VectorNode::opcode(int sopc, BasicType bt) { } } +// Return the scalar opcode for the specified vector opcode +// and basic type. +int VectorNode::scalar_opcode(int sopc, BasicType bt) { + switch (sopc) { + case Op_AddReductionVI: + case Op_AddVI: + return Op_AddI; + case Op_AddReductionVL: + case Op_AddVL: + return Op_AddL; + case Op_MulReductionVI: + case Op_MulVI: + return Op_MulI; + case Op_MulReductionVL: + case Op_MulVL: + return Op_MulL; + case Op_AndReductionV: + case Op_AndV: + switch (bt) { + case T_BOOLEAN: + case T_CHAR: + case T_BYTE: + case T_SHORT: + case T_INT: + return Op_AndI; + case T_LONG: + return Op_AndL; + default: + assert(false, "basic type not handled"); + return 0; + } + case Op_OrReductionV: + case Op_OrV: + switch (bt) { + case T_BOOLEAN: + case T_CHAR: + case T_BYTE: + case T_SHORT: + case T_INT: + return Op_OrI; + case T_LONG: + return Op_OrL; + default: + assert(false, "basic type not handled"); + return 0; + } + case Op_XorReductionV: + case Op_XorV: + switch (bt) { + case T_BOOLEAN: + case T_CHAR: + case T_BYTE: + case T_SHORT: + case T_INT: + return Op_XorI; + case T_LONG: + return Op_XorL; + default: + assert(false, "basic type not handled"); + return 0; + } + case Op_MinReductionV: + case Op_MinV: + switch (bt) { + case T_BOOLEAN: + case T_CHAR: + assert(false, "boolean and char are signed, not implemented for Min"); + return 0; + case T_BYTE: + case T_SHORT: + case T_INT: + return Op_MinI; + case T_LONG: + return Op_MinL; + case T_FLOAT: + return Op_MinF; + case T_DOUBLE: + return Op_MinD; + default: + assert(false, "basic type not handled"); + return 0; + } + case Op_MaxReductionV: + case Op_MaxV: + switch (bt) { + case T_BOOLEAN: + case T_CHAR: + assert(false, "boolean and char are signed, not implemented for Max"); + return 0; + case T_BYTE: + case T_SHORT: + case T_INT: + return Op_MaxI; + case T_LONG: + return Op_MaxL; + case T_FLOAT: + return Op_MaxF; + case T_DOUBLE: + return Op_MaxD; + default: + assert(false, "basic type not handled"); + return 0; + } + default: + assert(false, + "Vector node %s is not handled in VectorNode::scalar_opcode", + NodeClassNames[sopc]); + return 0; // Unimplemented + } +} + int VectorNode::replicate_opcode(BasicType bt) { switch(bt) { case T_BOOLEAN: @@ -1398,9 +1509,9 @@ Node* VectorCastNode::Identity(PhaseGVN* phase) { return this; } -Node* ReductionNode::make_reduction_input(PhaseGVN& gvn, int opc, BasicType bt) { - int vopc = opcode(opc, bt); - guarantee(vopc != opc, "Vector reduction for '%s' is not implemented", NodeClassNames[opc]); +Node* ReductionNode::make_identity_con_scalar(PhaseGVN& gvn, int sopc, BasicType bt) { + int vopc = opcode(sopc, bt); + guarantee(vopc != sopc, "Vector reduction for '%s' is not implemented", NodeClassNames[sopc]); switch (vopc) { case Op_AndReductionV: diff --git a/src/hotspot/share/opto/vectornode.hpp b/src/hotspot/share/opto/vectornode.hpp index 22fd201d38fd4..8efa5b40dde81 100644 --- a/src/hotspot/share/opto/vectornode.hpp +++ b/src/hotspot/share/opto/vectornode.hpp @@ -25,6 +25,8 @@ #define SHARE_OPTO_VECTORNODE_HPP #include "opto/callnode.hpp" +#include "opto/cfgnode.hpp" +#include "opto/loopnode.hpp" #include "opto/matcher.hpp" #include "opto/memnode.hpp" #include "opto/node.hpp" @@ -90,7 +92,8 @@ class VectorNode : public TypeNode { static bool is_rotate_opcode(int opc); - static int opcode(int opc, BasicType bt); + static int opcode(int sopc, BasicType bt); // scalar_opc -> vector_opc + static int scalar_opcode(int vopc, BasicType bt); // vector_opc -> scalar_opc static int replicate_opcode(BasicType bt); // Limits on vector size (number of elements) for auto-vectorization. @@ -130,6 +133,15 @@ class VectorNode : public TypeNode { static bool is_vector_shift_count(Node* n) { return is_vector_shift_count(n->Opcode()); } + + static void trace_new_vector(Node* n, const char* context) { +#ifdef ASSERT + if (TraceNewVectors) { + tty->print("TraceNewVectors [%s]: ", context); + n->dump(); + } +#endif + } }; //===========================Vector=ALU=Operations============================= @@ -191,12 +203,15 @@ class ReductionNode : public Node { public: ReductionNode(Node *ctrl, Node* in1, Node* in2) : Node(ctrl, in1, in2), _bottom_type(Type::get_const_basic_type(in1->bottom_type()->basic_type())), - _vect_type(in2->bottom_type()->is_vect()) {} + _vect_type(in2->bottom_type()->is_vect()) { + init_class_id(Class_Reduction); + } - static ReductionNode* make(int opc, Node *ctrl, Node* in1, Node* in2, BasicType bt); + static ReductionNode* make(int opc, Node* ctrl, Node* in1, Node* in2, BasicType bt); static int opcode(int opc, BasicType bt); static bool implemented(int opc, uint vlen, BasicType bt); - static Node* make_reduction_input(PhaseGVN& gvn, int opc, BasicType bt); + // Make an identity scalar (zero for add, one for mul, etc) for scalar opc. + static Node* make_identity_con_scalar(PhaseGVN& gvn, int sopc, BasicType bt); virtual const Type* bottom_type() const { return _bottom_type; @@ -216,19 +231,28 @@ class ReductionNode : public Node { virtual uint size_of() const { return sizeof(*this); } }; +//---------------------------UnorderedReductionNode------------------------------------- +// Order of reduction does not matter. Example int add. Not true for float add. +class UnorderedReductionNode : public ReductionNode { +public: + UnorderedReductionNode(Node * ctrl, Node* in1, Node* in2) : ReductionNode(ctrl, in1, in2) { + init_class_id(Class_UnorderedReduction); + } +}; + //------------------------------AddReductionVINode-------------------------------------- // Vector add byte, short and int as a reduction -class AddReductionVINode : public ReductionNode { +class AddReductionVINode : public UnorderedReductionNode { public: - AddReductionVINode(Node * ctrl, Node* in1, Node* in2) : ReductionNode(ctrl, in1, in2) {} + AddReductionVINode(Node * ctrl, Node* in1, Node* in2) : UnorderedReductionNode(ctrl, in1, in2) {} virtual int Opcode() const; }; //------------------------------AddReductionVLNode-------------------------------------- // Vector add long as a reduction -class AddReductionVLNode : public ReductionNode { +class AddReductionVLNode : public UnorderedReductionNode { public: - AddReductionVLNode(Node *ctrl, Node* in1, Node* in2) : ReductionNode(ctrl, in1, in2) {} + AddReductionVLNode(Node *ctrl, Node* in1, Node* in2) : UnorderedReductionNode(ctrl, in1, in2) {} virtual int Opcode() const; }; @@ -386,17 +410,17 @@ class CMoveVDNode : public VectorNode { //------------------------------MulReductionVINode-------------------------------------- // Vector multiply byte, short and int as a reduction -class MulReductionVINode : public ReductionNode { +class MulReductionVINode : public UnorderedReductionNode { public: - MulReductionVINode(Node *ctrl, Node* in1, Node* in2) : ReductionNode(ctrl, in1, in2) {} + MulReductionVINode(Node *ctrl, Node* in1, Node* in2) : UnorderedReductionNode(ctrl, in1, in2) {} virtual int Opcode() const; }; //------------------------------MulReductionVLNode-------------------------------------- // Vector multiply int as a reduction -class MulReductionVLNode : public ReductionNode { +class MulReductionVLNode : public UnorderedReductionNode { public: - MulReductionVLNode(Node *ctrl, Node* in1, Node* in2) : ReductionNode(ctrl, in1, in2) {} + MulReductionVLNode(Node *ctrl, Node* in1, Node* in2) : UnorderedReductionNode(ctrl, in1, in2) {} virtual int Opcode() const; }; @@ -737,9 +761,9 @@ class AndVNode : public VectorNode { //------------------------------AndReductionVNode-------------------------------------- // Vector and byte, short, int, long as a reduction -class AndReductionVNode : public ReductionNode { +class AndReductionVNode : public UnorderedReductionNode { public: - AndReductionVNode(Node *ctrl, Node* in1, Node* in2) : ReductionNode(ctrl, in1, in2) {} + AndReductionVNode(Node *ctrl, Node* in1, Node* in2) : UnorderedReductionNode(ctrl, in1, in2) {} virtual int Opcode() const; }; @@ -754,17 +778,9 @@ class OrVNode : public VectorNode { //------------------------------OrReductionVNode-------------------------------------- // Vector xor byte, short, int, long as a reduction -class OrReductionVNode : public ReductionNode { - public: - OrReductionVNode(Node *ctrl, Node* in1, Node* in2) : ReductionNode(ctrl, in1, in2) {} - virtual int Opcode() const; -}; - -//------------------------------XorReductionVNode-------------------------------------- -// Vector and int, long as a reduction -class XorReductionVNode : public ReductionNode { +class OrReductionVNode : public UnorderedReductionNode { public: - XorReductionVNode(Node *ctrl, Node* in1, Node* in2) : ReductionNode(ctrl, in1, in2) {} + OrReductionVNode(Node *ctrl, Node* in1, Node* in2) : UnorderedReductionNode(ctrl, in1, in2) {} virtual int Opcode() const; }; @@ -777,19 +793,27 @@ class XorVNode : public VectorNode { virtual Node* Ideal(PhaseGVN* phase, bool can_reshape); }; +//------------------------------XorReductionVNode-------------------------------------- +// Vector and int, long as a reduction +class XorReductionVNode : public UnorderedReductionNode { + public: + XorReductionVNode(Node *ctrl, Node* in1, Node* in2) : UnorderedReductionNode(ctrl, in1, in2) {} + virtual int Opcode() const; +}; + //------------------------------MinReductionVNode-------------------------------------- // Vector min byte, short, int, long, float, double as a reduction -class MinReductionVNode : public ReductionNode { +class MinReductionVNode : public UnorderedReductionNode { public: - MinReductionVNode(Node *ctrl, Node* in1, Node* in2) : ReductionNode(ctrl, in1, in2) {} + MinReductionVNode(Node *ctrl, Node* in1, Node* in2) : UnorderedReductionNode(ctrl, in1, in2) {} virtual int Opcode() const; }; //------------------------------MaxReductionVNode-------------------------------------- // Vector min byte, short, int, long, float, double as a reduction -class MaxReductionVNode : public ReductionNode { +class MaxReductionVNode : public UnorderedReductionNode { public: - MaxReductionVNode(Node *ctrl, Node* in1, Node* in2) : ReductionNode(ctrl, in1, in2) {} + MaxReductionVNode(Node *ctrl, Node* in1, Node* in2) : UnorderedReductionNode(ctrl, in1, in2) {} virtual int Opcode() const; }; diff --git a/test/hotspot/jtreg/compiler/loopopts/superword/ProdRed_Int.java b/test/hotspot/jtreg/compiler/loopopts/superword/ProdRed_Int.java index ab7f83c18a3ac..17f3a97a8e84d 100644 --- a/test/hotspot/jtreg/compiler/loopopts/superword/ProdRed_Int.java +++ b/test/hotspot/jtreg/compiler/loopopts/superword/ProdRed_Int.java @@ -84,7 +84,7 @@ public static void prodReductionInit(int[] a, int[] b) { failOn = {IRNode.MUL_REDUCTION_VI}) @IR(applyIfCPUFeature = {"sse4.1", "true"}, applyIfAnd = {"SuperWordReductions", "true", "LoopMaxUnroll", ">= 8"}, - counts = {IRNode.MUL_REDUCTION_VI, ">= 1"}) + counts = {IRNode.MUL_REDUCTION_VI, ">= 1", IRNode.MUL_REDUCTION_VI, "<= 2"}) // one for main-loop, one for vector-post-loop public static int prodReductionImplement(int[] a, int[] b, int total) { for (int i = 0; i < a.length; i++) { total *= a[i] + b[i]; diff --git a/test/hotspot/jtreg/compiler/loopopts/superword/RedTest_int.java b/test/hotspot/jtreg/compiler/loopopts/superword/RedTest_int.java index 9beea472adcd5..faece0cbf9f3f 100644 --- a/test/hotspot/jtreg/compiler/loopopts/superword/RedTest_int.java +++ b/test/hotspot/jtreg/compiler/loopopts/superword/RedTest_int.java @@ -134,7 +134,7 @@ public static void reductionInit2( failOn = {IRNode.ADD_REDUCTION_VI}) @IR(applyIfCPUFeature = {"sse4.1", "true"}, applyIfAnd = {"SuperWordReductions", "true", "LoopMaxUnroll", ">= 8"}, - counts = {IRNode.ADD_REDUCTION_VI, ">= 1"}) + counts = {IRNode.ADD_REDUCTION_VI, ">= 1", IRNode.ADD_REDUCTION_VI, "<= 2"}) // one for main-loop, one for vector-post-loop public static int sumReductionImplement( int[] a, int[] b, @@ -151,7 +151,7 @@ public static int sumReductionImplement( failOn = {IRNode.OR_REDUCTION_V}) @IR(applyIfCPUFeature = {"sse4.1", "true"}, applyIfAnd = {"SuperWordReductions", "true", "LoopMaxUnroll", ">= 8"}, - counts = {IRNode.OR_REDUCTION_V, ">= 1"}) + counts = {IRNode.OR_REDUCTION_V, ">= 1", IRNode.OR_REDUCTION_V, "<= 2"}) // one for main-loop, one for vector-post-loop public static int orReductionImplement( int[] a, int[] b, @@ -168,7 +168,7 @@ public static int orReductionImplement( failOn = {IRNode.AND_REDUCTION_V}) @IR(applyIfCPUFeature = {"sse4.1", "true"}, applyIfAnd = {"SuperWordReductions", "true", "LoopMaxUnroll", ">= 8"}, - counts = {IRNode.AND_REDUCTION_V, ">= 1"}) + counts = {IRNode.AND_REDUCTION_V, ">= 1", IRNode.AND_REDUCTION_V, "<= 2"}) // one for main-loop, one for vector-post-loop public static int andReductionImplement( int[] a, int[] b, @@ -185,7 +185,7 @@ public static int andReductionImplement( failOn = {IRNode.XOR_REDUCTION_V}) @IR(applyIfCPUFeature = {"sse4.1", "true"}, applyIfAnd = {"SuperWordReductions", "true", "LoopMaxUnroll", ">= 8"}, - counts = {IRNode.XOR_REDUCTION_V, ">= 1"}) + counts = {IRNode.XOR_REDUCTION_V, ">= 1", IRNode.XOR_REDUCTION_V, "<= 2"}) // one for main-loop, one for vector-post-loop public static int xorReductionImplement( int[] a, int[] b, @@ -202,7 +202,7 @@ public static int xorReductionImplement( failOn = {IRNode.MUL_REDUCTION_VI}) @IR(applyIfCPUFeature = {"sse4.1", "true"}, applyIfAnd = {"SuperWordReductions", "true", "LoopMaxUnroll", ">= 8"}, - counts = {IRNode.MUL_REDUCTION_VI, ">= 1"}) + counts = {IRNode.MUL_REDUCTION_VI, ">= 1", IRNode.MUL_REDUCTION_VI, "<= 2"}) // one for main-loop, one for vector-post-loop public static int mulReductionImplement( int[] a, int[] b, diff --git a/test/hotspot/jtreg/compiler/loopopts/superword/RedTest_long.java b/test/hotspot/jtreg/compiler/loopopts/superword/RedTest_long.java index 59d612d015977..27bfa8cec0ebb 100644 --- a/test/hotspot/jtreg/compiler/loopopts/superword/RedTest_long.java +++ b/test/hotspot/jtreg/compiler/loopopts/superword/RedTest_long.java @@ -137,7 +137,7 @@ public static void reductionInit2( failOn = {IRNode.ADD_REDUCTION_VL}) @IR(applyIfCPUFeature = {"avx2", "true"}, applyIfAnd = {"SuperWordReductions", "true", "LoopMaxUnroll", ">= 8"}, - counts = {IRNode.ADD_REDUCTION_VL, ">= 1"}) + counts = {IRNode.ADD_REDUCTION_VL, ">= 1", IRNode.ADD_REDUCTION_VL, "<= 2"}) // one for main-loop, one for vector-post-loop public static long sumReductionImplement( long[] a, long[] b, @@ -154,7 +154,7 @@ public static long sumReductionImplement( failOn = {IRNode.OR_REDUCTION_V}) @IR(applyIfCPUFeature = {"avx2", "true"}, applyIfAnd = {"SuperWordReductions", "true", "LoopMaxUnroll", ">= 8"}, - counts = {IRNode.OR_REDUCTION_V, ">= 1"}) + counts = {IRNode.OR_REDUCTION_V, ">= 1", IRNode.OR_REDUCTION_V, "<= 2"}) // one for main-loop, one for vector-post-loop public static long orReductionImplement( long[] a, long[] b, @@ -171,7 +171,7 @@ public static long orReductionImplement( failOn = {IRNode.AND_REDUCTION_V}) @IR(applyIfCPUFeature = {"avx2", "true"}, applyIfAnd = {"SuperWordReductions", "true", "LoopMaxUnroll", ">= 8"}, - counts = {IRNode.AND_REDUCTION_V, ">= 1"}) + counts = {IRNode.AND_REDUCTION_V, ">= 1", IRNode.AND_REDUCTION_V, "<= 2"}) // one for main-loop, one for vector-post-loop public static long andReductionImplement( long[] a, long[] b, @@ -188,7 +188,7 @@ public static long andReductionImplement( failOn = {IRNode.XOR_REDUCTION_V}) @IR(applyIfCPUFeature = {"avx2", "true"}, applyIfAnd = {"SuperWordReductions", "true", "LoopMaxUnroll", ">= 8"}, - counts = {IRNode.XOR_REDUCTION_V, ">= 1"}) + counts = {IRNode.XOR_REDUCTION_V, ">= 1", IRNode.XOR_REDUCTION_V, "<= 2"}) // one for main-loop, one for vector-post-loop public static long xorReductionImplement( long[] a, long[] b, @@ -205,7 +205,7 @@ public static long xorReductionImplement( failOn = {IRNode.MUL_REDUCTION_VL}) @IR(applyIfCPUFeature = {"avx512dq", "true"}, applyIfAnd = {"SuperWordReductions", "true", "LoopMaxUnroll", ">= 8"}, - counts = {IRNode.MUL_REDUCTION_VL, ">= 1"}) + counts = {IRNode.MUL_REDUCTION_VL, ">= 1", IRNode.MUL_REDUCTION_VL, "<= 2"}) // one for main-loop, one for vector-post-loop public static long mulReductionImplement( long[] a, long[] b, diff --git a/test/hotspot/jtreg/compiler/loopopts/superword/ReductionPerf.java b/test/hotspot/jtreg/compiler/loopopts/superword/ReductionPerf.java index d96d5e29c0070..b1495d00548f8 100644 --- a/test/hotspot/jtreg/compiler/loopopts/superword/ReductionPerf.java +++ b/test/hotspot/jtreg/compiler/loopopts/superword/ReductionPerf.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -21,242 +21,570 @@ * questions. */ -/** +/* * @test - * @bug 8074981 - * @summary Add C2 x86 Superword support for scalar product reduction optimizations : int test - * @requires os.arch=="x86" | os.arch=="i386" | os.arch=="amd64" | os.arch=="x86_64" | os.arch=="aarch64" | os.arch=="riscv64" - * - * @run main/othervm -XX:+IgnoreUnrecognizedVMOptions - * -XX:LoopUnrollLimit=250 -XX:CompileThresholdScaling=0.1 - * -XX:CompileCommand=exclude,compiler.loopopts.superword.ReductionPerf::main - * -XX:+SuperWordReductions - * compiler.loopopts.superword.ReductionPerf - * @run main/othervm -XX:+IgnoreUnrecognizedVMOptions - * -XX:LoopUnrollLimit=250 -XX:CompileThresholdScaling=0.1 - * -XX:CompileCommand=exclude,compiler.loopopts.superword.ReductionPerf::main - * -XX:-SuperWordReductions - * compiler.loopopts.superword.ReductionPerf + * @bug 8074981 8302652 + * @summary Test SuperWord Reduction Perf. + * @requires vm.compiler2.enabled + * @requires vm.simpleArch == "x86" | vm.simpleArch == "x64" | vm.simpleArch == "aarch64" | vm.simpleArch == "riscv64" + * @library /test/lib / + * @run main/othervm -Xbatch -XX:LoopUnrollLimit=250 + * -XX:CompileCommand=exclude,compiler.loopopts.superword.ReductionPerf::main + * compiler.loopopts.superword.ReductionPerf */ package compiler.loopopts.superword; +import java.util.Random; +import jdk.test.lib.Utils; public class ReductionPerf { - public static void main(String[] args) throws Exception { - int[] a1 = new int[8 * 1024]; - int[] a2 = new int[8 * 1024]; - int[] a3 = new int[8 * 1024]; - long[] b1 = new long[8 * 1024]; - long[] b2 = new long[8 * 1024]; - long[] b3 = new long[8 * 1024]; - float[] c1 = new float[8 * 1024]; - float[] c2 = new float[8 * 1024]; - float[] c3 = new float[8 * 1024]; - double[] d1 = new double[8 * 1024]; - double[] d2 = new double[8 * 1024]; - double[] d3 = new double[8 * 1024]; - - ReductionInit(a1, a2, a3, b1, b2, b3, c1, c2, c3, d1, d2, d3); - - int sumIv = sumInt(a1, a2, a3); - long sumLv = sumLong(b1, b2, b3); - float sumFv = sumFloat(c1, c2, c3); - double sumDv = sumDouble(d1, d2, d3); - int mulIv = prodInt(a1, a2, a3); - long mulLv = prodLong(b1, b2, b3); - float mulFv = prodFloat(c1, c2, c3); - double mulDv = prodDouble(d1, d2, d3); - - int sumI = 0; - long sumL = 0; - float sumF = 0.f; - double sumD = 0.; - int mulI = 0; - long mulL = 0; - float mulF = 0.f; - double mulD = 0.; - - System.out.println("Warmup ..."); - long start = System.currentTimeMillis(); - - for (int j = 0; j < 2000; j++) { - sumI = sumInt(a1, a2, a3); - sumL = sumLong(b1, b2, b3); - sumF = sumFloat(c1, c2, c3); - sumD = sumDouble(d1, d2, d3); - mulI = prodInt(a1, a2, a3); - mulL = prodLong(b1, b2, b3); - mulF = prodFloat(c1, c2, c3); - mulD = prodDouble(d1, d2, d3); - } - - long stop = System.currentTimeMillis(); - System.out.println(" Warmup is done in " + (stop - start) + " msec"); - - if (sumIv != sumI) { - System.out.println("sum int: " + sumIv + " != " + sumI); - } - if (sumLv != sumL) { - System.out.println("sum long: " + sumLv + " != " + sumL); - } - if (sumFv != sumF) { - System.out.println("sum float: " + sumFv + " != " + sumF); - } - if (sumDv != sumD) { - System.out.println("sum double: " + sumDv + " != " + sumD); - } - if (mulIv != mulI) { - System.out.println("prod int: " + mulIv + " != " + mulI); - } - if (mulLv != mulL) { - System.out.println("prod long: " + mulLv + " != " + mulL); - } - if (mulFv != mulF) { - System.out.println("prod float: " + mulFv + " != " + mulF); - } - if (mulDv != mulD) { - System.out.println("prod double: " + mulDv + " != " + mulD); + static final int RANGE = 8192; + static Random rand = Utils.getRandomInstance(); + + public static void main(String args[]) { + // Please increase iterations for measurement to 2_000 and 100_000. + int iter_warmup = 100; + int iter_perf = 1_000; + + double[] aDouble = new double[RANGE]; + double[] bDouble = new double[RANGE]; + double[] cDouble = new double[RANGE]; + float[] aFloat = new float[RANGE]; + float[] bFloat = new float[RANGE]; + float[] cFloat = new float[RANGE]; + int[] aInt = new int[RANGE]; + int[] bInt = new int[RANGE]; + int[] cInt = new int[RANGE]; + long[] aLong = new long[RANGE]; + long[] bLong = new long[RANGE]; + long[] cLong = new long[RANGE]; + + long start, stop; + + int startIntAdd = init(aInt, bInt, cInt); + int goldIntAdd = testIntAdd(aInt, bInt, cInt, startIntAdd); + for (int j = 0; j < iter_warmup; j++) { + int total = testIntAdd(aInt, bInt, cInt, startIntAdd); + verify("int add", total, goldIntAdd); + } + start = System.currentTimeMillis(); + for (int j = 0; j < iter_perf; j++) { + testIntAdd(aInt, bInt, cInt, startIntAdd); + } + stop = System.currentTimeMillis(); + System.out.println("int add " + (stop - start)); + + int startIntMul = init(aInt, bInt, cInt); + int goldIntMul = testIntMul(aInt, bInt, cInt, startIntMul); + for (int j = 0; j < iter_warmup; j++) { + int total = testIntMul(aInt, bInt, cInt, startIntMul); + verify("int mul", total, goldIntMul); + } + start = System.currentTimeMillis(); + for (int j = 0; j < iter_perf; j++) { + testIntMul(aInt, bInt, cInt, startIntMul); + } + stop = System.currentTimeMillis(); + System.out.println("int mul " + (stop - start)); + + int startIntMin = init(aInt, bInt, cInt); + int goldIntMin = testIntMin(aInt, bInt, cInt, startIntMin); + for (int j = 0; j < iter_warmup; j++) { + int total = testIntMin(aInt, bInt, cInt, startIntMin); + verify("int min", total, goldIntMin); + } + start = System.currentTimeMillis(); + for (int j = 0; j < iter_perf; j++) { + testIntMin(aInt, bInt, cInt, startIntMin); + } + stop = System.currentTimeMillis(); + System.out.println("int min " + (stop - start)); + + int startIntMax = init(aInt, bInt, cInt); + int goldIntMax = testIntMax(aInt, bInt, cInt, startIntMax); + for (int j = 0; j < iter_warmup; j++) { + int total = testIntMax(aInt, bInt, cInt, startIntMax); + verify("int max", total, goldIntMax); + } + start = System.currentTimeMillis(); + for (int j = 0; j < iter_perf; j++) { + testIntMax(aInt, bInt, cInt, startIntMax); + } + stop = System.currentTimeMillis(); + System.out.println("int max " + (stop - start)); + + int startIntAnd = init(aInt, bInt, cInt); + int goldIntAnd = testIntAnd(aInt, bInt, cInt, startIntAnd); + for (int j = 0; j < iter_warmup; j++) { + int total = testIntAnd(aInt, bInt, cInt, startIntAnd); + verify("int and", total, goldIntAnd); + } + start = System.currentTimeMillis(); + for (int j = 0; j < iter_perf; j++) { + testIntAnd(aInt, bInt, cInt, startIntAnd); + } + stop = System.currentTimeMillis(); + System.out.println("int and " + (stop - start)); + + int startIntOr = init(aInt, bInt, cInt); + int goldIntOr = testIntOr(aInt, bInt, cInt, startIntOr); + for (int j = 0; j < iter_warmup; j++) { + int total = testIntOr(aInt, bInt, cInt, startIntOr); + verify("int or", total, goldIntOr); + } + start = System.currentTimeMillis(); + for (int j = 0; j < iter_perf; j++) { + testIntOr(aInt, bInt, cInt, startIntOr); + } + stop = System.currentTimeMillis(); + System.out.println("int or " + (stop - start)); + + int startIntXor = init(aInt, bInt, cInt); + int goldIntXor = testIntXor(aInt, bInt, cInt, startIntXor); + for (int j = 0; j < iter_warmup; j++) { + int total = testIntXor(aInt, bInt, cInt, startIntXor); + verify("int xor", total, goldIntXor); + } + start = System.currentTimeMillis(); + for (int j = 0; j < iter_perf; j++) { + testIntXor(aInt, bInt, cInt, startIntXor); + } + stop = System.currentTimeMillis(); + System.out.println("int xor " + (stop - start)); + + long startLongAdd = init(aLong, bLong, cLong); + long goldLongAdd = testLongAdd(aLong, bLong, cLong, startLongAdd); + for (int j = 0; j < iter_warmup; j++) { + long total = testLongAdd(aLong, bLong, cLong, startLongAdd); + verify("long add", total, goldLongAdd); + } + start = System.currentTimeMillis(); + for (int j = 0; j < iter_perf; j++) { + testLongAdd(aLong, bLong, cLong, startLongAdd); + } + stop = System.currentTimeMillis(); + System.out.println("long add " + (stop - start)); + + long startLongMul = init(aLong, bLong, cLong); + long goldLongMul = testLongMul(aLong, bLong, cLong, startLongMul); + for (int j = 0; j < iter_warmup; j++) { + long total = testLongMul(aLong, bLong, cLong, startLongMul); + verify("long mul", total, goldLongMul); + } + start = System.currentTimeMillis(); + for (int j = 0; j < iter_perf; j++) { + testLongMul(aLong, bLong, cLong, startLongMul); + } + stop = System.currentTimeMillis(); + System.out.println("long mul " + (stop - start)); + + long startLongMin = init(aLong, bLong, cLong); + long goldLongMin = testLongMin(aLong, bLong, cLong, startLongMin); + for (int j = 0; j < iter_warmup; j++) { + long total = testLongMin(aLong, bLong, cLong, startLongMin); + verify("long min", total, goldLongMin); + } + start = System.currentTimeMillis(); + for (int j = 0; j < iter_perf; j++) { + testLongMin(aLong, bLong, cLong, startLongMin); + } + stop = System.currentTimeMillis(); + System.out.println("long min " + (stop - start)); + + long startLongMax = init(aLong, bLong, cLong); + long goldLongMax = testLongMax(aLong, bLong, cLong, startLongMax); + for (int j = 0; j < iter_warmup; j++) { + long total = testLongMax(aLong, bLong, cLong, startLongMax); + verify("long max", total, goldLongMax); + } + start = System.currentTimeMillis(); + for (int j = 0; j < iter_perf; j++) { + testLongMax(aLong, bLong, cLong, startLongMax); + } + stop = System.currentTimeMillis(); + System.out.println("long max " + (stop - start)); + + long startLongAnd = init(aLong, bLong, cLong); + long goldLongAnd = testLongAnd(aLong, bLong, cLong, startLongAnd); + for (int j = 0; j < iter_warmup; j++) { + long total = testLongAnd(aLong, bLong, cLong, startLongAnd); + verify("long and", total, goldLongAnd); + } + start = System.currentTimeMillis(); + for (int j = 0; j < iter_perf; j++) { + testLongAnd(aLong, bLong, cLong, startLongAnd); + } + stop = System.currentTimeMillis(); + System.out.println("long and " + (stop - start)); + + long startLongOr = init(aLong, bLong, cLong); + long goldLongOr = testLongOr(aLong, bLong, cLong, startLongOr); + for (int j = 0; j < iter_warmup; j++) { + long total = testLongOr(aLong, bLong, cLong, startLongOr); + verify("long or", total, goldLongOr); + } + start = System.currentTimeMillis(); + for (int j = 0; j < iter_perf; j++) { + testLongOr(aLong, bLong, cLong, startLongOr); } + stop = System.currentTimeMillis(); + System.out.println("long or " + (stop - start)); + long startLongXor = init(aLong, bLong, cLong); + long goldLongXor = testLongXor(aLong, bLong, cLong, startLongXor); + for (int j = 0; j < iter_warmup; j++) { + long total = testLongXor(aLong, bLong, cLong, startLongXor); + verify("long xor", total, goldLongXor); + } start = System.currentTimeMillis(); - for (int j = 0; j < 5000; j++) { - sumI = sumInt(a1, a2, a3); + for (int j = 0; j < iter_perf; j++) { + testLongXor(aLong, bLong, cLong, startLongXor); } stop = System.currentTimeMillis(); - System.out.println("sum int: " + (stop - start)); + System.out.println("long xor " + (stop - start)); + float startFloatAdd = init(aFloat, bFloat, cFloat); + float goldFloatAdd = testFloatAdd(aFloat, bFloat, cFloat, startFloatAdd); + for (int j = 0; j < iter_warmup; j++) { + float total = testFloatAdd(aFloat, bFloat, cFloat, startFloatAdd); + verify("float add", total, goldFloatAdd); + } start = System.currentTimeMillis(); - for (int j = 0; j < 5000; j++) { - sumL = sumLong(b1, b2, b3); + for (int j = 0; j < iter_perf; j++) { + testFloatAdd(aFloat, bFloat, cFloat, startFloatAdd); } stop = System.currentTimeMillis(); - System.out.println("sum long: " + (stop - start)); + System.out.println("float add " + (stop - start)); + float startFloatMul = init(aFloat, bFloat, cFloat); + float goldFloatMul = testFloatMul(aFloat, bFloat, cFloat, startFloatMul); + for (int j = 0; j < iter_warmup; j++) { + float total = testFloatMul(aFloat, bFloat, cFloat, startFloatMul); + verify("float mul", total, goldFloatMul); + } start = System.currentTimeMillis(); - for (int j = 0; j < 5000; j++) { - sumF = sumFloat(c1, c2, c3); + for (int j = 0; j < iter_perf; j++) { + testFloatMul(aFloat, bFloat, cFloat, startFloatMul); } stop = System.currentTimeMillis(); - System.out.println("sum float: " + (stop - start)); + System.out.println("float mul " + (stop - start)); + float startFloatMin = init(aFloat, bFloat, cFloat); + float goldFloatMin = testFloatMin(aFloat, bFloat, cFloat, startFloatMin); + for (int j = 0; j < iter_warmup; j++) { + float total = testFloatMin(aFloat, bFloat, cFloat, startFloatMin); + verify("float min", total, goldFloatMin); + } start = System.currentTimeMillis(); - for (int j = 0; j < 5000; j++) { - sumD = sumDouble(d1, d2, d3); + for (int j = 0; j < iter_perf; j++) { + testFloatMin(aFloat, bFloat, cFloat, startFloatMin); } stop = System.currentTimeMillis(); - System.out.println("sum double: " + (stop - start)); + System.out.println("float min " + (stop - start)); + float startFloatMax = init(aFloat, bFloat, cFloat); + float goldFloatMax = testFloatMax(aFloat, bFloat, cFloat, startFloatMax); + for (int j = 0; j < iter_warmup; j++) { + float total = testFloatMax(aFloat, bFloat, cFloat, startFloatMax); + verify("float max", total, goldFloatMax); + } start = System.currentTimeMillis(); - for (int j = 0; j < 5000; j++) { - mulI = prodInt(a1, a2, a3); + for (int j = 0; j < iter_perf; j++) { + testFloatMax(aFloat, bFloat, cFloat, startFloatMax); } stop = System.currentTimeMillis(); - System.out.println("prod int: " + (stop - start)); + System.out.println("float max " + (stop - start)); + double startDoubleAdd = init(aDouble, bDouble, cDouble); + double goldDoubleAdd = testDoubleAdd(aDouble, bDouble, cDouble, startDoubleAdd); + for (int j = 0; j < iter_warmup; j++) { + double total = testDoubleAdd(aDouble, bDouble, cDouble, startDoubleAdd); + verify("double add", total, goldDoubleAdd); + } start = System.currentTimeMillis(); - for (int j = 0; j < 5000; j++) { - mulL = prodLong(b1, b2, b3); + for (int j = 0; j < iter_perf; j++) { + testDoubleAdd(aDouble, bDouble, cDouble, startDoubleAdd); } stop = System.currentTimeMillis(); - System.out.println("prod long: " + (stop - start)); + System.out.println("double add " + (stop - start)); + double startDoubleMul = init(aDouble, bDouble, cDouble); + double goldDoubleMul = testDoubleMul(aDouble, bDouble, cDouble, startDoubleMul); + for (int j = 0; j < iter_warmup; j++) { + double total = testDoubleMul(aDouble, bDouble, cDouble, startDoubleMul); + verify("double mul", total, goldDoubleMul); + } start = System.currentTimeMillis(); - for (int j = 0; j < 5000; j++) { - mulF = prodFloat(c1, c2, c3); + for (int j = 0; j < iter_perf; j++) { + testDoubleMul(aDouble, bDouble, cDouble, startDoubleMul); } stop = System.currentTimeMillis(); - System.out.println("prod float: " + (stop - start)); + System.out.println("double mul " + (stop - start)); + double startDoubleMin = init(aDouble, bDouble, cDouble); + double goldDoubleMin = testDoubleMin(aDouble, bDouble, cDouble, startDoubleMin); + for (int j = 0; j < iter_warmup; j++) { + double total = testDoubleMin(aDouble, bDouble, cDouble, startDoubleMin); + verify("double min", total, goldDoubleMin); + } start = System.currentTimeMillis(); - for (int j = 0; j < 5000; j++) { - mulD = prodDouble(d1, d2, d3); + for (int j = 0; j < iter_perf; j++) { + testDoubleMin(aDouble, bDouble, cDouble, startDoubleMin); } stop = System.currentTimeMillis(); - System.out.println("prod double: " + (stop - start)); + System.out.println("double min " + (stop - start)); + double startDoubleMax = init(aDouble, bDouble, cDouble); + double goldDoubleMax = testDoubleMax(aDouble, bDouble, cDouble, startDoubleMax); + for (int j = 0; j < iter_warmup; j++) { + double total = testDoubleMax(aDouble, bDouble, cDouble, startDoubleMax); + verify("double max", total, goldDoubleMax); + } + start = System.currentTimeMillis(); + for (int j = 0; j < iter_perf; j++) { + testDoubleMax(aDouble, bDouble, cDouble, startDoubleMax); + } + stop = System.currentTimeMillis(); + System.out.println("double max " + (stop - start)); + + } + + // ------------------- Tests ------------------- + + static int testIntAdd(int[] a, int[] b, int[] c, int total) { + for (int i = 0; i < RANGE; i++) { + int v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total += v; + } + return total; + } + + static int testIntMul(int[] a, int[] b, int[] c, int total) { + for (int i = 0; i < RANGE; i++) { + int v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total *= v; + } + return total; + } + + static int testIntMin(int[] a, int[] b, int[] c, int total) { + for (int i = 0; i < RANGE; i++) { + int v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total = Math.min(total, v); + } + return total; + } + + static int testIntMax(int[] a, int[] b, int[] c, int total) { + for (int i = 0; i < RANGE; i++) { + int v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total = Math.max(total, v); + } + return total; + } + + static int testIntAnd(int[] a, int[] b, int[] c, int total) { + for (int i = 0; i < RANGE; i++) { + int v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total &= v; + } + return total; + } + + static int testIntOr(int[] a, int[] b, int[] c, int total) { + for (int i = 0; i < RANGE; i++) { + int v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total |= v; + } + return total; + } + + static int testIntXor(int[] a, int[] b, int[] c, int total) { + for (int i = 0; i < RANGE; i++) { + int v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total ^= v; + } + return total; + } + + static long testLongAdd(long[] a, long[] b, long[] c, long total) { + for (int i = 0; i < RANGE; i++) { + long v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total += v; + } + return total; } - public static void ReductionInit(int[] a1, int[] a2, int[] a3, - long[] b1, long[] b2, long[] b3, - float[] c1, float[] c2, float[] c3, - double[] d1, double[] d2, double[] d3) { - for(int i = 0; i < a1.length; i++) { - a1[i] = (i + 0); - a2[i] = (i + 1); - a3[i] = (i + 2); - b1[i] = (long) (i + 0); - b2[i] = (long) (i + 1); - b3[i] = (long) (i + 2); - c1[i] = (float) (i + 0); - c2[i] = (float) (i + 1); - c3[i] = (float) (i + 2); - d1[i] = (double) (i + 0); - d2[i] = (double) (i + 1); - d3[i] = (double) (i + 2); + static long testLongMul(long[] a, long[] b, long[] c, long total) { + for (int i = 0; i < RANGE; i++) { + long v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total *= v; } + return total; + } + + static long testLongMin(long[] a, long[] b, long[] c, long total) { + for (int i = 0; i < RANGE; i++) { + long v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total = Math.min(total, v); + } + return total; } - public static int sumInt(int[] a1, int[] a2, int[] a3) { - int total = 0; - for (int i = 0; i < a1.length; i++) { - total += (a1[i] * a2[i]) + (a1[i] * a3[i]) + (a2[i] * a3[i]); + static long testLongMax(long[] a, long[] b, long[] c, long total) { + for (int i = 0; i < RANGE; i++) { + long v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total = Math.max(total, v); } return total; } - public static long sumLong(long[] b1, long[] b2, long[] b3) { - long total = 0; - for (int i = 0; i < b1.length; i++) { - total += (b1[i] * b2[i]) + (b1[i] * b3[i]) + (b2[i] * b3[i]); + static long testLongAnd(long[] a, long[] b, long[] c, long total) { + for (int i = 0; i < RANGE; i++) { + long v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total &= v; } return total; } - public static float sumFloat(float[] c1, float[] c2, float[] c3) { - float total = 0; - for (int i = 0; i < c1.length; i++) { - total += (c1[i] * c2[i]) + (c1[i] * c3[i]) + (c2[i] * c3[i]); + static long testLongOr(long[] a, long[] b, long[] c, long total) { + for (int i = 0; i < RANGE; i++) { + long v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total |= v; } return total; } - public static double sumDouble(double[] d1, double[] d2, double[] d3) { - double total = 0; - for (int i = 0; i < d1.length; i++) { - total += (d1[i] * d2[i]) + (d1[i] * d3[i]) + (d2[i] * d3[i]); + static long testLongXor(long[] a, long[] b, long[] c, long total) { + for (int i = 0; i < RANGE; i++) { + long v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total ^= v; } return total; } - public static int prodInt(int[] a1, int[] a2, int[] a3) { - int total = 1; - for (int i = 0; i < a1.length; i++) { - total *= (a1[i] * a2[i]) + (a1[i] * a3[i]) + (a2[i] * a3[i]); + static float testFloatAdd(float[] a, float[] b, float[] c, float total) { + for (int i = 0; i < RANGE; i++) { + float v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total += v; } return total; } - public static long prodLong(long[] b1, long[] b2, long[] b3) { - long total = 1; - for (int i = 0; i < b1.length; i++) { - total *= (b1[i] * b2[i]) + (b1[i] * b3[i]) + (b2[i] * b3[i]); + static float testFloatMul(float[] a, float[] b, float[] c, float total) { + for (int i = 0; i < RANGE; i++) { + float v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total *= v; } return total; } - public static float prodFloat(float[] c1, float[] c2, float[] c3) { - float total = 1; - for (int i = 0; i < c1.length; i++) { - total *= (c1[i] * c2[i]) + (c1[i] * c3[i]) + (c2[i] * c3[i]); + static float testFloatMin(float[] a, float[] b, float[] c, float total) { + for (int i = 0; i < RANGE; i++) { + float v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total = Math.min(total, v); } return total; } - public static double prodDouble(double[] d1, double[] d2, double[] d3) { - double total = 1; - for (int i = 0; i < d1.length; i++) { - total *= (d1[i] * d2[i]) + (d1[i] * d3[i]) + (d2[i] * d3[i]); + static float testFloatMax(float[] a, float[] b, float[] c, float total) { + for (int i = 0; i < RANGE; i++) { + float v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total = Math.max(total, v); } return total; } + + static double testDoubleAdd(double[] a, double[] b, double[] c, double total) { + for (int i = 0; i < RANGE; i++) { + double v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total += v; + } + return total; + } + + static double testDoubleMul(double[] a, double[] b, double[] c, double total) { + for (int i = 0; i < RANGE; i++) { + double v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total *= v; + } + return total; + } + + static double testDoubleMin(double[] a, double[] b, double[] c, double total) { + for (int i = 0; i < RANGE; i++) { + double v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total = Math.min(total, v); + } + return total; + } + + static double testDoubleMax(double[] a, double[] b, double[] c, double total) { + for (int i = 0; i < RANGE; i++) { + double v = (a[i] * b[i]) + (a[i] * c[i]) + (b[i] * c[i]); + total = Math.max(total, v); + } + return total; + } + + // ------------------- Initialization ------------------- + + static int init(int[] a, int[] b, int[] c) { + for (int j = 0; j < RANGE; j++) { + a[j] = rand.nextInt(); + b[j] = rand.nextInt(); + c[j] = rand.nextInt(); + } + return rand.nextInt(); + } + + static long init(long[] a, long[] b, long[] c) { + for (int j = 0; j < RANGE; j++) { + a[j] = rand.nextLong(); + b[j] = rand.nextLong(); + c[j] = rand.nextLong(); + } + return rand.nextLong(); + } + + static float init(float[] a, float[] b, float[] c) { + for (int j = 0; j < RANGE; j++) { + a[j] = rand.nextFloat(); + b[j] = rand.nextFloat(); + c[j] = rand.nextFloat(); + } + return rand.nextFloat(); + } + + static double init(double[] a, double[] b, double[] c) { + for (int j = 0; j < RANGE; j++) { + a[j] = rand.nextDouble(); + b[j] = rand.nextDouble(); + c[j] = rand.nextDouble(); + } + return rand.nextDouble(); + } + + // ------------------- Verification ------------------- + + static void verify(String context, double total, double gold) { + if (total != gold) { + throw new RuntimeException("Wrong result for " + context + ": " + total + " != " + gold); + } + } + static void verify(String context, float total, float gold) { + if (total != gold) { + throw new RuntimeException("Wrong result for " + context + ": " + total + " != " + gold); + } + } + static void verify(String context, int total, int gold) { + if (total != gold) { + throw new RuntimeException("Wrong result for " + context + ": " + total + " != " + gold); + } + } + static void verify(String context, long total, long gold) { + if (total != gold) { + throw new RuntimeException("Wrong result for " + context + ": " + total + " != " + gold); + } + } } diff --git a/test/hotspot/jtreg/compiler/loopopts/superword/SumRed_Int.java b/test/hotspot/jtreg/compiler/loopopts/superword/SumRed_Int.java index 77eebd0eea30c..ad6d9e45051f2 100644 --- a/test/hotspot/jtreg/compiler/loopopts/superword/SumRed_Int.java +++ b/test/hotspot/jtreg/compiler/loopopts/superword/SumRed_Int.java @@ -91,7 +91,7 @@ public static void sumReductionInit( failOn = {IRNode.ADD_REDUCTION_VI}) @IR(applyIfCPUFeature = {"sse4.1", "true"}, applyIfAnd = {"SuperWordReductions", "true", "LoopMaxUnroll", ">= 8"}, - counts = {IRNode.ADD_REDUCTION_VI, ">= 1"}) + counts = {IRNode.ADD_REDUCTION_VI, ">= 1", IRNode.ADD_REDUCTION_VI, "<= 2"}) // one for main-loop, one for vector-post-loop public static int sumReductionImplement( int[] a, int[] b, diff --git a/test/hotspot/jtreg/compiler/loopopts/superword/SumRed_Long.java b/test/hotspot/jtreg/compiler/loopopts/superword/SumRed_Long.java index 278c81f707cb2..ea41a652940de 100644 --- a/test/hotspot/jtreg/compiler/loopopts/superword/SumRed_Long.java +++ b/test/hotspot/jtreg/compiler/loopopts/superword/SumRed_Long.java @@ -95,7 +95,7 @@ public static void sumReductionInit( failOn = {IRNode.ADD_REDUCTION_VL}) @IR(applyIfCPUFeature = {"avx2", "true"}, applyIfAnd = {"SuperWordReductions", "true", "LoopMaxUnroll", ">= 8"}, - counts = {IRNode.ADD_REDUCTION_VL, ">= 1"}) + counts = {IRNode.ADD_REDUCTION_VL, ">= 1", IRNode.ADD_REDUCTION_VL, "<= 2"}) // one for main-loop, one for vector-post-loop public static long sumReductionImplement( long[] a, long[] b, diff --git a/test/hotspot/jtreg/compiler/loopopts/superword/TestUnorderedReduction.java b/test/hotspot/jtreg/compiler/loopopts/superword/TestUnorderedReduction.java new file mode 100644 index 0000000000000..d50dc96ffb301 --- /dev/null +++ b/test/hotspot/jtreg/compiler/loopopts/superword/TestUnorderedReduction.java @@ -0,0 +1,149 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/** + * @test + * @bug 8302652 + * @summary Special test cases for PhaseIdealLoop::move_unordered_reduction_out_of_loop + * @library /test/lib / + * @run driver compiler.loopopts.superword.TestUnorderedReduction + */ + +package compiler.loopopts.superword; + +import compiler.lib.ir_framework.*; + +public class TestUnorderedReduction { + static final int RANGE = 1024; + static final int ITER = 10; + + public static void main(String[] args) { + TestFramework.runWithFlags("-Xbatch", + "-XX:CompileCommand=compileonly,compiler.loopopts.superword.TestUnorderedReduction::test*", + "-XX:MaxVectorSize=16"); + } + + @Run(test = {"test1", "test2"}) + @Warmup(0) + public void runTests() throws Exception { + int[] data = new int[RANGE]; + + init(data); + for (int i = 0; i < ITER; i++) { + int r1 = test1(data, i); + int r2 = ref1(data, i); + if (r1 != r2) { + throw new RuntimeException("Wrong result test1: " + r1 + " != " + r2); + } + } + + for (int i = 0; i < ITER; i++) { + int r1 = test2(data, i); + int r2 = ref2(data, i); + if (r1 != r2) { + throw new RuntimeException("Wrong result test2: " + r1 + " != " + r2); + } + } + } + + @Test + @IR(counts = {IRNode.LOAD_VECTOR, "> 0", + IRNode.ADD_VI, "= 0", + IRNode.ADD_REDUCTION_VI, "> 0"}, // count can be high + applyIfCPUFeatureOr = {"sse4.1", "true", "asimd", "true"}) + static int test1(int[] data, int sum) { + // Vectorizes, but the UnorderedReduction cannot be moved out of the loop, + // because we have a use inside the loop. + int x = 0; + for (int i = 0; i < RANGE; i+=8) { + sum += 11 * data[i+0]; // vec 1 (16 bytes) + sum += 11 * data[i+1]; + sum += 11 * data[i+2]; + sum += 11 * data[i+3]; + x = sum + i; // vec 1 reduction has more than 1 use + sum += 11 * data[i+4]; // vec 2 (next 16 bytes) + sum += 11 * data[i+5]; + sum += 11 * data[i+6]; + sum += 11 * data[i+7]; + } + return sum + x; + } + + static int ref1(int[] data, int sum) { + int x = 0; + for (int i = 0; i < RANGE; i+=8) { + sum += 11 * data[i+0]; + sum += 11 * data[i+1]; + sum += 11 * data[i+2]; + sum += 11 * data[i+3]; + x = sum + i; + sum += 11 * data[i+4]; + sum += 11 * data[i+5]; + sum += 11 * data[i+6]; + sum += 11 * data[i+7]; + } + return sum + x; + } + + @Test + @IR(counts = {IRNode.LOAD_VECTOR, "> 0", + IRNode.ADD_VI, "> 0", + IRNode.ADD_REDUCTION_VI, "> 0", + IRNode.ADD_REDUCTION_VI, "<= 2"}, // count must be low + applyIfCPUFeatureOr = {"sse4.1", "true", "asimd", "true"}) + static int test2(int[] data, int sum) { + for (int i = 0; i < RANGE; i+=8) { + // Vectorized, and UnorderedReduction moved outside loop. + sum += 11 * data[i+0]; // vec 1 + sum += 11 * data[i+1]; + sum += 11 * data[i+2]; + sum += 11 * data[i+3]; + sum += 11 * data[i+4]; // vec 2 + sum += 11 * data[i+5]; + sum += 11 * data[i+6]; + sum += 11 * data[i+7]; + } + return sum; + } + + static int ref2(int[] data, int sum) { + for (int i = 0; i < RANGE; i+=8) { + sum += 11 * data[i+0]; + sum += 11 * data[i+1]; + sum += 11 * data[i+2]; + sum += 11 * data[i+3]; + sum += 11 * data[i+4]; + sum += 11 * data[i+5]; + sum += 11 * data[i+6]; + sum += 11 * data[i+7]; + } + return sum; + } + + + static void init(int[] data) { + for (int i = 0; i < RANGE; i++) { + data[i] = i + 1; + } + } +}