From 8256811fae40384bc016e4a81090aac3118f9871 Mon Sep 17 00:00:00 2001 From: Smita Kamath Date: Thu, 8 Dec 2022 09:46:59 +0000 Subject: [PATCH] 8294588: Auto vectorize half precision floating point conversion APIs Reviewed-by: sviswanathan, kvn, jbhateja, fgao, xgong --- src/hotspot/cpu/x86/assembler_x86.cpp | 27 +++++- src/hotspot/cpu/x86/assembler_x86.hpp | 2 + src/hotspot/cpu/x86/vm_version_x86.cpp | 1 + src/hotspot/cpu/x86/x86.ad | 55 +++++++++++ src/hotspot/share/adlc/formssel.cpp | 2 +- src/hotspot/share/opto/classes.hpp | 2 + src/hotspot/share/opto/superword.cpp | 2 +- src/hotspot/share/opto/vectorIntrinsics.cpp | 4 +- src/hotspot/share/opto/vectornode.cpp | 23 ++++- src/hotspot/share/opto/vectornode.hpp | 18 +++- .../compiler/lib/ir_framework/IRNode.java | 10 ++ .../TestFloatConversionsVector.java | 95 +++++++++++++++++++ 12 files changed, 231 insertions(+), 10 deletions(-) create mode 100644 test/hotspot/jtreg/compiler/vectorization/TestFloatConversionsVector.java diff --git a/src/hotspot/cpu/x86/assembler_x86.cpp b/src/hotspot/cpu/x86/assembler_x86.cpp index ad7543cfb783d..e1a68906d8dff 100644 --- a/src/hotspot/cpu/x86/assembler_x86.cpp +++ b/src/hotspot/cpu/x86/assembler_x86.cpp @@ -1931,14 +1931,14 @@ void Assembler::vcvtdq2pd(XMMRegister dst, XMMRegister src, int vector_len) { } void Assembler::vcvtps2ph(XMMRegister dst, XMMRegister src, int imm8, int vector_len) { - assert(VM_Version::supports_avx512vl() || VM_Version::supports_f16c(), ""); + assert(VM_Version::supports_evex() || VM_Version::supports_f16c(), ""); InstructionAttr attributes(vector_len, /* rex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /*uses_vl */ true); int encode = vex_prefix_and_encode(src->encoding(), 0, dst->encoding(), VEX_SIMD_66, VEX_OPCODE_0F_3A, &attributes); emit_int24(0x1D, (0xC0 | encode), imm8); } void Assembler::evcvtps2ph(Address dst, KRegister mask, XMMRegister src, int imm8, int vector_len) { - assert(VM_Version::supports_avx512vl(), ""); + assert(VM_Version::supports_evex(), ""); InstructionMark im(this); InstructionAttr attributes(vector_len, /* rex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ false, /*uses_vl */ true); attributes.set_address_attributes(/* tuple_type */ EVEX_HVM, /* input_size_in_bits */ EVEX_64bit); @@ -1951,13 +1951,34 @@ void Assembler::evcvtps2ph(Address dst, KRegister mask, XMMRegister src, int imm emit_int8(imm8); } +void Assembler::vcvtps2ph(Address dst, XMMRegister src, int imm8, int vector_len) { + assert(VM_Version::supports_evex() || VM_Version::supports_f16c(), ""); + InstructionMark im(this); + InstructionAttr attributes(vector_len, /* rex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /*uses_vl */ true); + attributes.set_address_attributes(/* tuple_type */ EVEX_HVM, /* input_size_in_bits */ EVEX_NObit); + vex_prefix(dst, 0, src->encoding(), VEX_SIMD_66, VEX_OPCODE_0F_3A, &attributes); + emit_int8(0x1D); + emit_operand(src, dst, 1); + emit_int8(imm8); +} + void Assembler::vcvtph2ps(XMMRegister dst, XMMRegister src, int vector_len) { - assert(VM_Version::supports_avx512vl() || VM_Version::supports_f16c(), ""); + assert(VM_Version::supports_evex() || VM_Version::supports_f16c(), ""); InstructionAttr attributes(vector_len, /* rex_w */ false, /* legacy_mode */false, /* no_mask_reg */ true, /* uses_vl */ true); int encode = vex_prefix_and_encode(dst->encoding(), 0, src->encoding(), VEX_SIMD_66, VEX_OPCODE_0F_38, &attributes); emit_int16(0x13, (0xC0 | encode)); } +void Assembler::vcvtph2ps(XMMRegister dst, Address src, int vector_len) { + assert(VM_Version::supports_evex() || VM_Version::supports_f16c(), ""); + InstructionMark im(this); + InstructionAttr attributes(vector_len, /* rex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ false, /*uses_vl */ true); + attributes.set_address_attributes(/* tuple_type */ EVEX_HVM, /* input_size_in_bits */ EVEX_NObit); + vex_prefix(src, 0, dst->encoding(), VEX_SIMD_66, VEX_OPCODE_0F_38, &attributes); + emit_int8(0x13); + emit_operand(dst, src, 0); +} + void Assembler::cvtdq2ps(XMMRegister dst, XMMRegister src) { NOT_LP64(assert(VM_Version::supports_sse2(), "")); InstructionAttr attributes(AVX_128bit, /* rex_w */ false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); diff --git a/src/hotspot/cpu/x86/assembler_x86.hpp b/src/hotspot/cpu/x86/assembler_x86.hpp index 7b4288e775ee5..a5d8f2dabfc1c 100644 --- a/src/hotspot/cpu/x86/assembler_x86.hpp +++ b/src/hotspot/cpu/x86/assembler_x86.hpp @@ -1160,6 +1160,8 @@ class Assembler : public AbstractAssembler { void vcvtps2ph(XMMRegister dst, XMMRegister src, int imm8, int vector_len); void vcvtph2ps(XMMRegister dst, XMMRegister src, int vector_len); void evcvtps2ph(Address dst, KRegister mask, XMMRegister src, int imm8, int vector_len); + void vcvtps2ph(Address dst, XMMRegister src, int imm8, int vector_len); + void vcvtph2ps(XMMRegister dst, Address src, int vector_len); // Convert Packed Signed Doubleword Integers to Packed Single-Precision Floating-Point Value void cvtdq2ps(XMMRegister dst, XMMRegister src); diff --git a/src/hotspot/cpu/x86/vm_version_x86.cpp b/src/hotspot/cpu/x86/vm_version_x86.cpp index 5a0a085401b8e..83979a60245d3 100644 --- a/src/hotspot/cpu/x86/vm_version_x86.cpp +++ b/src/hotspot/cpu/x86/vm_version_x86.cpp @@ -956,6 +956,7 @@ void VM_Version::get_processor_features() { if (UseAVX < 1) { _features &= ~CPU_AVX; _features &= ~CPU_VZEROUPPER; + _features &= ~CPU_F16C; } if (logical_processors_per_package() == 1) { diff --git a/src/hotspot/cpu/x86/x86.ad b/src/hotspot/cpu/x86/x86.ad index 629ae77567de1..d70389245ad04 100644 --- a/src/hotspot/cpu/x86/x86.ad +++ b/src/hotspot/cpu/x86/x86.ad @@ -1687,6 +1687,12 @@ const bool Matcher::match_rule_supported(int opcode) { return false; } break; + case Op_VectorCastF2HF: + case Op_VectorCastHF2F: + if (!VM_Version::supports_f16c() && !VM_Version::supports_evex()) { + return false; + } + break; } return true; // Match rules are supported by default. } @@ -1901,6 +1907,14 @@ const bool Matcher::match_rule_supported_vector(int opcode, int vlen, BasicType return false; } break; + case Op_VectorCastF2HF: + case Op_VectorCastHF2F: + if (!VM_Version::supports_f16c() && + ((!VM_Version::supports_evex() || + ((size_in_bits != 512) && !VM_Version::supports_avx512vl())))) { + return false; + } + break; case Op_RoundVD: if (!VM_Version::supports_avx512dq()) { return false; @@ -3673,6 +3687,26 @@ instruct convF2HF_mem_reg(memory mem, regF src, kReg ktmp, rRegI rtmp) %{ ins_pipe( pipe_slow ); %} +instruct vconvF2HF(vec dst, vec src) %{ + match(Set dst (VectorCastF2HF src)); + format %{ "vector_conv_F2HF $dst $src" %} + ins_encode %{ + int vlen_enc = vector_length_encoding(this, $src); + __ vcvtps2ph($dst$$XMMRegister, $src$$XMMRegister, 0x04, vlen_enc); + %} + ins_pipe( pipe_slow ); +%} + +instruct vconvF2HF_mem_reg(memory mem, vec src) %{ + match(Set mem (StoreVector mem (VectorCastF2HF src))); + format %{ "vcvtps2ph $mem,$src" %} + ins_encode %{ + int vlen_enc = vector_length_encoding(this, $src); + __ vcvtps2ph($mem$$Address, $src$$XMMRegister, 0x04, vlen_enc); + %} + ins_pipe( pipe_slow ); +%} + instruct convHF2F_reg_reg(regF dst, rRegI src) %{ match(Set dst (ConvHF2F src)); format %{ "vcvtph2ps $dst,$src" %} @@ -3683,6 +3717,27 @@ instruct convHF2F_reg_reg(regF dst, rRegI src) %{ ins_pipe( pipe_slow ); %} +instruct vconvHF2F_reg_mem(vec dst, memory mem) %{ + match(Set dst (VectorCastHF2F (LoadVector mem))); + format %{ "vcvtph2ps $dst,$mem" %} + ins_encode %{ + int vlen_enc = vector_length_encoding(this); + __ vcvtph2ps($dst$$XMMRegister, $mem$$Address, vlen_enc); + %} + ins_pipe( pipe_slow ); +%} + +instruct vconvHF2F(vec dst, vec src) %{ + match(Set dst (VectorCastHF2F src)); + ins_cost(125); + format %{ "vector_conv_HF2F $dst,$src" %} + ins_encode %{ + int vlen_enc = vector_length_encoding(this); + __ vcvtph2ps($dst$$XMMRegister, $src$$XMMRegister, vlen_enc); + %} + ins_pipe( pipe_slow ); +%} + // ---------------------------------------- VectorReinterpret ------------------------------------ instruct reinterpret_mask(kReg dst) %{ predicate(n->bottom_type()->isa_vectmask() && diff --git a/src/hotspot/share/adlc/formssel.cpp b/src/hotspot/share/adlc/formssel.cpp index d09c6a5ca92f8..21a01306333cf 100644 --- a/src/hotspot/share/adlc/formssel.cpp +++ b/src/hotspot/share/adlc/formssel.cpp @@ -4223,7 +4223,7 @@ bool MatchRule::is_vector() const { "VectorTest", "VectorLoadMask", "VectorStoreMask", "VectorBlend", "VectorInsert", "VectorRearrange","VectorLoadShuffle", "VectorLoadConst", "VectorCastB2X", "VectorCastS2X", "VectorCastI2X", - "VectorCastL2X", "VectorCastF2X", "VectorCastD2X", + "VectorCastL2X", "VectorCastF2X", "VectorCastD2X", "VectorCastF2HF", "VectorCastHF2F", "VectorUCastB2X", "VectorUCastS2X", "VectorUCastI2X", "VectorMaskWrapper","VectorMaskCmp","VectorReinterpret","LoadVectorMasked","StoreVectorMasked", "FmaVD","FmaVF","PopCountVI","PopCountVL","PopulateIndex","VectorLongToMask", diff --git a/src/hotspot/share/opto/classes.hpp b/src/hotspot/share/opto/classes.hpp index 45b7f8f2f858f..112bf860c57eb 100644 --- a/src/hotspot/share/opto/classes.hpp +++ b/src/hotspot/share/opto/classes.hpp @@ -506,6 +506,8 @@ macro(VectorCastI2X) macro(VectorCastL2X) macro(VectorCastF2X) macro(VectorCastD2X) +macro(VectorCastF2HF) +macro(VectorCastHF2F) macro(VectorUCastB2X) macro(VectorUCastS2X) macro(VectorUCastI2X) diff --git a/src/hotspot/share/opto/superword.cpp b/src/hotspot/share/opto/superword.cpp index 35dce686ec573..4df7b4219e8d7 100644 --- a/src/hotspot/share/opto/superword.cpp +++ b/src/hotspot/share/opto/superword.cpp @@ -2712,7 +2712,7 @@ bool SuperWord::output() { assert(n->req() == 2, "only one input expected"); BasicType bt = velt_basic_type(n); Node* in = vector_opd(p, 1); - int vopc = VectorCastNode::opcode(in->bottom_type()->is_vect()->element_basic_type()); + int vopc = VectorCastNode::opcode(opc, in->bottom_type()->is_vect()->element_basic_type()); vn = VectorCastNode::make(vopc, in, bt, vlen); vlen_in_bytes = vn->as_Vector()->length_in_bytes(); } else if (is_cmov_pack(p)) { diff --git a/src/hotspot/share/opto/vectorIntrinsics.cpp b/src/hotspot/share/opto/vectorIntrinsics.cpp index 4d131760e450f..ec4ed774ada48 100644 --- a/src/hotspot/share/opto/vectorIntrinsics.cpp +++ b/src/hotspot/share/opto/vectorIntrinsics.cpp @@ -775,7 +775,7 @@ bool LibraryCallKit::inline_vector_shuffle_to_vector() { return false; } - int cast_vopc = VectorCastNode::opcode(T_BYTE); // from shuffle of type T_BYTE + int cast_vopc = VectorCastNode::opcode(-1, T_BYTE); // from shuffle of type T_BYTE // Make sure that cast is implemented to particular type/size combination. if (!arch_supports_vector(cast_vopc, num_elem, elem_bt, VecMaskNotUsed)) { if (C->print_intrinsics()) { @@ -2489,7 +2489,7 @@ bool LibraryCallKit::inline_vector_convert() { Node* op = opd1; if (is_cast) { assert(!is_mask || num_elem_from == num_elem_to, "vector mask cast needs the same elem num"); - int cast_vopc = VectorCastNode::opcode(elem_bt_from, !is_ucast); + int cast_vopc = VectorCastNode::opcode(-1, elem_bt_from, !is_ucast); // Make sure that vector cast is implemented to particular type/size combination if it is // not a mask casting. diff --git a/src/hotspot/share/opto/vectornode.cpp b/src/hotspot/share/opto/vectornode.cpp index 92ec6d80cfef4..dfc81402eded8 100644 --- a/src/hotspot/share/opto/vectornode.cpp +++ b/src/hotspot/share/opto/vectornode.cpp @@ -467,6 +467,8 @@ bool VectorNode::is_convert_opcode(int opc) { case Op_ConvD2F: case Op_ConvF2D: case Op_ConvD2I: + case Op_ConvF2HF: + case Op_ConvHF2F: return true; default: return false; @@ -1328,14 +1330,31 @@ VectorCastNode* VectorCastNode::make(int vopc, Node* n1, BasicType bt, uint vlen case Op_VectorUCastB2X: return new VectorUCastB2XNode(n1, vt); case Op_VectorUCastS2X: return new VectorUCastS2XNode(n1, vt); case Op_VectorUCastI2X: return new VectorUCastI2XNode(n1, vt); + case Op_VectorCastHF2F: return new VectorCastHF2FNode(n1, vt); + case Op_VectorCastF2HF: return new VectorCastF2HFNode(n1, vt); default: assert(false, "unknown node: %s", NodeClassNames[vopc]); return NULL; } } -int VectorCastNode::opcode(BasicType bt, bool is_signed) { +int VectorCastNode::opcode(int sopc, BasicType bt, bool is_signed) { assert((is_integral_type(bt) && bt != T_LONG) || is_signed, ""); + + // Handle special case for to/from Half Float conversions + switch (sopc) { + case Op_ConvHF2F: + assert(bt == T_SHORT, ""); + return Op_VectorCastHF2F; + case Op_ConvF2HF: + assert(bt == T_FLOAT, ""); + return Op_VectorCastF2HF; + default: + // Handled normally below + break; + } + + // Handle normal conversions switch (bt) { case T_BYTE: return is_signed ? Op_VectorCastB2X : Op_VectorUCastB2X; case T_SHORT: return is_signed ? Op_VectorCastS2X : Op_VectorUCastS2X; @@ -1354,7 +1373,7 @@ bool VectorCastNode::implemented(int opc, uint vlen, BasicType src_type, BasicTy is_java_primitive(src_type) && (vlen > 1) && is_power_of_2(vlen) && VectorNode::vector_size_supported(dst_type, vlen)) { - int vopc = VectorCastNode::opcode(src_type); + int vopc = VectorCastNode::opcode(opc, src_type); return vopc > 0 && Matcher::match_rule_supported_superword(vopc, vlen, dst_type); } return false; diff --git a/src/hotspot/share/opto/vectornode.hpp b/src/hotspot/share/opto/vectornode.hpp index aa672ca983ed1..21a8a8737a664 100644 --- a/src/hotspot/share/opto/vectornode.hpp +++ b/src/hotspot/share/opto/vectornode.hpp @@ -1542,7 +1542,7 @@ class VectorCastNode : public VectorNode { virtual int Opcode() const; static VectorCastNode* make(int vopc, Node* n1, BasicType bt, uint vlen); - static int opcode(BasicType bt, bool is_signed = true); + static int opcode(int opc, BasicType bt, bool is_signed = true); static bool implemented(int opc, uint vlen, BasicType src_type, BasicType dst_type); virtual Node* Identity(PhaseGVN* phase); @@ -1628,6 +1628,22 @@ class VectorUCastS2XNode : public VectorCastNode { virtual int Opcode() const; }; +class VectorCastHF2FNode : public VectorCastNode { + public: + VectorCastHF2FNode(Node* in, const TypeVect* vt) : VectorCastNode(in, vt) { + assert(in->bottom_type()->is_vect()->element_basic_type() == T_SHORT, "must be short"); + } + virtual int Opcode() const; +}; + +class VectorCastF2HFNode : public VectorCastNode { + public: + VectorCastF2HFNode(Node* in, const TypeVect* vt) : VectorCastNode(in, vt) { + assert(in->bottom_type()->is_vect()->element_basic_type() == T_FLOAT, "must be float"); + } + virtual int Opcode() const; +}; + class VectorUCastI2XNode : public VectorCastNode { public: VectorUCastI2XNode(Node* in, const TypeVect* vt) : VectorCastNode(in, vt) { diff --git a/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java b/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java index 6f30d9f83befd..7e08baab74cd9 100644 --- a/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java +++ b/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java @@ -1079,6 +1079,16 @@ public class IRNode { beforeMatchingNameRegex(VECTOR_CAST_S2X, "VectorCastS2X"); } + public static final String VECTOR_CAST_F2HF = PREFIX + "VECTOR_CAST_F2HF" + POSTFIX; + static { + beforeMatchingNameRegex(VECTOR_CAST_F2HF, "VectorCastF2HF"); + } + + public static final String VECTOR_CAST_HF2F = PREFIX + "VECTOR_CAST_HF2F" + POSTFIX; + static { + beforeMatchingNameRegex(VECTOR_CAST_HF2F, "VectorCastHF2F"); + } + public static final String VECTOR_MASK_CAST = PREFIX + "VECTOR_MASK_CAST" + POSTFIX; static { beforeMatchingNameRegex(VECTOR_MASK_CAST, "VectorMaskCast"); diff --git a/test/hotspot/jtreg/compiler/vectorization/TestFloatConversionsVector.java b/test/hotspot/jtreg/compiler/vectorization/TestFloatConversionsVector.java new file mode 100644 index 0000000000000..e84b734144ee3 --- /dev/null +++ b/test/hotspot/jtreg/compiler/vectorization/TestFloatConversionsVector.java @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/** + * @test + * @bug 8294588 + * @summary Auto-vectorize Float.floatToFloat16, Float.float16ToFloat APIs + * @requires vm.compiler2.enabled + * @requires os.simpleArch == "x64" + * @library /test/lib / + * @run driver compiler.vectorization.TestFloatConversionsVector + */ + +package compiler.vectorization; + +import compiler.lib.ir_framework.*; + +public class TestFloatConversionsVector { + private static final int ARRLEN = 1024; + private static final int ITERS = 11000; + private static float [] finp; + private static short [] sout; + private static short [] sinp; + private static float [] fout; + + public static void main(String args[]) { + TestFramework.runWithFlags("-XX:-TieredCompilation", + "-XX:CompileThresholdScaling=0.3"); + System.out.println("PASSED"); + } + + @Test + @IR(counts = {IRNode.VECTOR_CAST_F2HF, "> 0"}, applyIfCPUFeatureOr = {"avx512f", "true", "f16c", "true"}) + public void test_float_float16(short[] sout, float[] finp) { + for (int i = 0; i < finp.length; i++) { + sout[i] = Float.floatToFloat16(finp[i]); + } + } + + @Run(test = {"test_float_float16"}, mode = RunMode.STANDALONE) + public void kernel_test_float_float16() { + finp = new float[ARRLEN]; + sout = new short[ARRLEN]; + + for (int i = 0; i < ARRLEN; i++) { + finp[i] = (float) i * 1.4f; + } + + for (int i = 0; i < ITERS; i++) { + test_float_float16(sout, finp); + } + } + + @Test + @IR(counts = {IRNode.VECTOR_CAST_HF2F, "> 0"}, applyIfCPUFeatureOr = {"avx512f", "true", "f16c", "true"}) + public void test_float16_float(float[] fout, short[] sinp) { + for (int i = 0; i < sinp.length; i++) { + fout[i] = Float.float16ToFloat(sinp[i]); + } + } + + @Run(test = {"test_float16_float"}, mode = RunMode.STANDALONE) + public void kernel_test_float16_float() { + sinp = new short[ARRLEN]; + fout = new float[ARRLEN]; + + for (int i = 0; i < ARRLEN; i++) { + sinp[i] = (short)i; + } + + for (int i = 0; i < ITERS; i++) { + test_float16_float(fout , sinp); + } + } +}