Skip to content

Commit

Permalink
8308363: Initial compiler support for FP16 scalar operations.
Browse files Browse the repository at this point in the history
  • Loading branch information
jatin-bhateja committed May 18, 2023
1 parent 131a3ce commit b7b7231
Show file tree
Hide file tree
Showing 22 changed files with 1,410 additions and 9 deletions.
2 changes: 1 addition & 1 deletion make/common/JavaCompilation.gmk
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ define SetupJavaCompilationBody
PARANOIA_FLAGS := -implicit:none -Xprefer:source -XDignore.symbol.file=true -encoding ascii

$1_FLAGS += -g -Xlint:all $$($1_TARGET_RELEASE) $$(PARANOIA_FLAGS) $$(JAVA_WARNINGS_ARE_ERRORS)
$1_FLAGS += $$($1_JAVAC_FLAGS)
$1_FLAGS += $$($1_JAVAC_FLAGS) -XDenablePrimitiveClasses

ifneq ($$($1_DISABLED_WARNINGS), )
$1_FLAGS += -Xlint:$$(call CommaList, $$(addprefix -, $$($1_DISABLED_WARNINGS)))
Expand Down
10 changes: 9 additions & 1 deletion src/hotspot/cpu/x86/assembler_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7217,6 +7217,14 @@ void Assembler::vpaddq(XMMRegister dst, XMMRegister nds, Address src, int vector
emit_operand(dst, src, 0);
}

void Assembler::evaddsh(XMMRegister dst, XMMRegister nds, XMMRegister src) {
assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16");
InstructionAttr attributes(AVX_128bit, false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false);
attributes.set_is_evex_instruction();
int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_F3, VEX_OPCODE_MAP5, &attributes);
emit_int16(0x58, (0xC0 | encode));
}

void Assembler::psubb(XMMRegister dst, XMMRegister src) {
NOT_LP64(assert(VM_Version::supports_sse2(), ""));
InstructionAttr attributes(AVX_128bit, /* rex_w */ false, /* legacy_mode */ _legacy_mode_bw, /* no_mask_reg */ true, /* uses_vl */ true);
Expand Down Expand Up @@ -11364,7 +11372,7 @@ void Assembler::evex_prefix(bool vex_r, bool vex_b, bool vex_x, bool evex_r, boo
int byte2 = (vex_r ? VEX_R : 0) | (vex_x ? VEX_X : 0) | (vex_b ? VEX_B : 0) | (evex_r ? EVEX_Rb : 0);
byte2 = (~byte2) & 0xF0;
// confine opc opcode extensions in mm bits to lower two bits
// of form {0F, 0F_38, 0F_3A}
// of form {0F, 0F_38, 0F_3A, MAP5}
byte2 |= opc;

// P1: byte 3 as Wvvvv1pp
Expand Down
2 changes: 2 additions & 0 deletions src/hotspot/cpu/x86/assembler_x86.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ class Assembler : public AbstractAssembler {
VEX_OPCODE_0F = 0x1,
VEX_OPCODE_0F_38 = 0x2,
VEX_OPCODE_0F_3A = 0x3,
VEX_OPCODE_MAP5 = 0x5,
VEX_OPCODE_MASK = 0x1F
};

Expand Down Expand Up @@ -2380,6 +2381,7 @@ class Assembler : public AbstractAssembler {
void vpaddw(XMMRegister dst, XMMRegister nds, Address src, int vector_len);
void vpaddd(XMMRegister dst, XMMRegister nds, Address src, int vector_len);
void vpaddq(XMMRegister dst, XMMRegister nds, Address src, int vector_len);
void evaddsh(XMMRegister dst, XMMRegister nds, XMMRegister src);

// Leaf level assembler routines for masked operations.
void evpaddb(XMMRegister dst, KRegister mask, XMMRegister nds, XMMRegister src, bool merge, int vector_len);
Expand Down
3 changes: 3 additions & 0 deletions src/hotspot/cpu/x86/vm_version_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3010,6 +3010,9 @@ uint64_t VM_Version::feature_flags() {
}
if (_cpuid_info.sef_cpuid7_edx.bits.serialize != 0)
result |= CPU_SERIALIZE;

if (_cpuid_info.sef_cpuid7_edx.bits.avx512_fp16 != 0)
result |= CPU_AVX512_FP16;
}

// ZX features.
Expand Down
8 changes: 6 additions & 2 deletions src/hotspot/cpu/x86/vm_version_x86.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ class VM_Version : public Abstract_VM_Version {
serialize : 1,
: 5,
cet_ibt : 1,
: 11;
: 2,
avx512_fp16 : 1,
: 8;
} bits;
};

Expand Down Expand Up @@ -390,7 +392,8 @@ class VM_Version : public Abstract_VM_Version {
decl(OSPKE, "ospke", 55) /* OS enables protection keys */ \
decl(CET_IBT, "cet_ibt", 56) /* Control Flow Enforcement - Indirect Branch Tracking */ \
decl(CET_SS, "cet_ss", 57) /* Control Flow Enforcement - Shadow Stack */ \
decl(AVX512_IFMA, "avx512_ifma", 58) /* Integer Vector FMA instructions*/
decl(AVX512_IFMA, "avx512_ifma", 58) /* Integer Vector FMA instructions*/ \
decl(AVX512_FP16, "avx512_fp16", 59) /* AVX512 FP16 ISA support*/

#define DECLARE_CPU_FEATURE_FLAG(id, name, bit) CPU_##id = (1ULL << bit),
CPU_FEATURE_FLAGS(DECLARE_CPU_FEATURE_FLAG)
Expand Down Expand Up @@ -696,6 +699,7 @@ class VM_Version : public Abstract_VM_Version {
static bool supports_avx512_bitalg() { return (_features & CPU_AVX512_BITALG) != 0; }
static bool supports_avx512_vbmi() { return (_features & CPU_AVX512_VBMI) != 0; }
static bool supports_avx512_vbmi2() { return (_features & CPU_AVX512_VBMI2) != 0; }
static bool supports_avx512_fp16() { return (_features & CPU_AVX512_FP16) != 0; }
static bool supports_hv() { return (_features & CPU_HV) != 0; }
static bool supports_serialize() { return (_features & CPU_SERIALIZE) != 0; }
static bool supports_f16c() { return (_features & CPU_F16C) != 0; }
Expand Down
37 changes: 37 additions & 0 deletions src/hotspot/cpu/x86/x86.ad
Original file line number Diff line number Diff line change
Expand Up @@ -1455,6 +1455,11 @@ const bool Matcher::match_rule_supported(int opcode) {
return false;
}
break;
case Op_AddHF:
if (!VM_Version::supports_avx512_fp16()) {
return false;
}
break;
case Op_VectorLoadShuffle:
case Op_VectorRearrange:
case Op_MulReductionVI:
Expand Down Expand Up @@ -10224,4 +10229,36 @@ instruct DoubleClassCheck_reg_reg_vfpclass(rRegI dst, regD src, kReg ktmp, rFlag
ins_pipe(pipe_slow);
%}

instruct reinterpretS2H (regF dst, rRegI src, rRegI tmp)
%{
match(Set dst (ReinterpretS2HF src));
effect(TEMP tmp);
format %{ "movdl $dst, $src\t! using $tmp as TEMP" %}
ins_encode %{
__ movl($tmp$$Register, $src$$Register);
__ andl($tmp$$Register, 0xFFFF);
__ movdl($dst$$XMMRegister, $tmp$$Register);
%}
ins_pipe(pipe_slow);
%}

instruct reinterpretH2S (rRegI dst, regF src)
%{
match(Set dst (ReinterpretHF2S src));
format %{ "movdl $dst, $src" %}
ins_encode %{
__ movdl($dst$$Register, $src$$XMMRegister);
__ movswl($dst$$Register, $dst$$Register);
%}
ins_pipe(pipe_slow);
%}

instruct addFP16_scalar (regF dst, regF src1, regF src2)
%{
match(Set dst (AddHF src1 src2));
format %{ "vaddsh $dst, $src1, $src2" %}
ins_encode %{
__ evaddsh($dst$$XMMRegister, $src1$$XMMRegister, $src2$$XMMRegister);
%}
ins_pipe(pipe_slow);
%}
11 changes: 10 additions & 1 deletion src/hotspot/share/classfile/classFileParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4781,6 +4781,15 @@ static void check_illegal_static_method(const InstanceKlass* this_klass, TRAPS)
}
}

// utility function to skip over internal jdk primitive classes used to override the need for passing
// an explict JVM flag EnablePrimitiveClasses.
bool ClassFileParser::is_jdk_internal_class(const Symbol* class_name) const {
if (vmSymbols::java_lang_Float16() == class_name) {
return true;
}
return false;
}

// utility methods for format checking

void ClassFileParser::verify_legal_class_modifiers(jint flags, const char* name, bool is_Object, TRAPS) const {
Expand Down Expand Up @@ -4811,7 +4820,7 @@ void ClassFileParser::verify_legal_class_modifiers(jint flags, const char* name,
return;
}

if (is_primitive_class && !EnablePrimitiveClasses) {
if (is_primitive_class && !is_jdk_internal_class(_class_name) && !EnablePrimitiveClasses) {
ResourceMark rm(THREAD);
Exceptions::fthrow(
THREAD_AND_LOCATION,
Expand Down
2 changes: 2 additions & 0 deletions src/hotspot/share/classfile/classFileParser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ class ClassFileParser {
bool _has_vanilla_constructor;
int _max_bootstrap_specifier_index; // detects BSS values

bool is_jdk_internal_class(const Symbol* class_name) const;

void parse_stream(const ClassFileStream* const stream, TRAPS);

void mangle_hidden_class_name(InstanceKlass* const ik);
Expand Down
1 change: 1 addition & 0 deletions src/hotspot/share/classfile/vmClassMacros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
do_klass(Boolean_klass, java_lang_Boolean ) \
do_klass(Character_klass, java_lang_Character ) \
do_klass(Float_klass, java_lang_Float ) \
do_klass(Float16_klass, java_lang_Float16 ) \
do_klass(Double_klass, java_lang_Double ) \
do_klass(Byte_klass, java_lang_Byte ) \
do_klass(Short_klass, java_lang_Short ) \
Expand Down
6 changes: 6 additions & 0 deletions src/hotspot/share/classfile/vmIntrinsics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ class methodHandle;
do_intrinsic(_dsignum, java_lang_Math, signum_name, double_double_signature, F_S) \
do_intrinsic(_fsignum, java_lang_Math, signum_name, float_float_signature, F_S) \
\
\
/* Float16 intrinsics, similar to what we have in Math. */ \
do_intrinsic(_add_float16, java_lang_Float16, add_name, floa16_float16_signature, F_R) \
do_name(add_name, "add") \
do_signature(floa16_float16_signature, "(Qjava/lang/Float16;)Qjava/lang/Float16;") \
\
/* StrictMath intrinsics, similar to what we have in Math. */ \
do_intrinsic(_min_strict, java_lang_StrictMath, min_name, int2_int_signature, F_S) \
do_intrinsic(_max_strict, java_lang_StrictMath, max_name, int2_int_signature, F_S) \
Expand Down
1 change: 1 addition & 0 deletions src/hotspot/share/classfile/vmSymbols.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
template(java_lang_Character_CharacterCache, "java/lang/Character$CharacterCache") \
template(java_lang_CharacterDataLatin1, "java/lang/CharacterDataLatin1") \
template(java_lang_Float, "java/lang/Float") \
template(java_lang_Float16, "java/lang/Float16") \
template(java_lang_Double, "java/lang/Double") \
template(java_lang_Byte, "java/lang/Byte") \
template(java_lang_Byte_ByteCache, "java/lang/Byte$ByteCache") \
Expand Down
8 changes: 8 additions & 0 deletions src/hotspot/share/opto/addnode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ class AddFNode : public AddNode {
virtual uint ideal_reg() const { return Op_RegF; }
};

//------------------------------AddHFNode---------------------------------------
// Add 2 floats
class AddHFNode : public AddFNode {
public:
AddHFNode( Node *in1, Node *in2 ) : AddFNode(in1,in2) {}
virtual int Opcode() const;
};

//------------------------------AddDNode---------------------------------------
// Add 2 doubles
class AddDNode : public AddNode {
Expand Down
4 changes: 3 additions & 1 deletion src/hotspot/share/opto/c2compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,9 @@ bool C2Compiler::is_intrinsic_supported(const methodHandle& method, bool is_virt
case vmIntrinsics::_Preconditions_checkLongIndex:
case vmIntrinsics::_getObjectSize:
break;

case vmIntrinsics::_add_float16:
if (!Matcher::match_rule_supported(Op_AddHF)) return false;
break;
case vmIntrinsics::_VectorCompressExpand:
case vmIntrinsics::_VectorUnaryOp:
case vmIntrinsics::_VectorBinaryOp:
Expand Down
3 changes: 3 additions & 0 deletions src/hotspot/share/opto/classes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ macro(AddF)
macro(AddI)
macro(AddL)
macro(AddP)
macro(AddHF)
macro(Allocate)
macro(AllocateArray)
macro(AndI)
Expand Down Expand Up @@ -485,6 +486,8 @@ macro(ExtractF)
macro(ExtractD)
macro(Digit)
macro(LowerCase)
macro(ReinterpretS2HF)
macro(ReinterpretHF2S)
macro(UpperCase)
macro(Whitespace)
macro(VectorBox)
Expand Down
24 changes: 24 additions & 0 deletions src/hotspot/share/opto/convertnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -817,3 +817,27 @@ const Type* RoundDoubleModeNode::Value(PhaseGVN* phase) const {
return Type::DOUBLE;
}
//=============================================================================

const Type* ReinterpretS2HFNode::Value(PhaseGVN* phase) const {
const Type* type = phase->type( in(1) );
// Convert FP16 constant value to Float constant value, this will allow
// further constant folding to be done at float granularity by value routines
// of FP16 IR nodes.
if (type->isa_int() && type->is_int()->is_con()) {
jshort hfval = type->is_int()->get_con();
jfloat fval = SharedRuntime::hf2f(hfval);
return TypeF::make(fval);
}
return Type::FLOAT;
}

const Type* ReinterpretHF2SNode::Value(PhaseGVN* phase) const {
const Type* type = phase->type( in(1) );
// Convert Float constant value to FP16 constant value.
if (type->isa_float_constant()) {
jfloat fval = type->is_float_constant()->_f;
jshort hfval = SharedRuntime::f2hf(fval);
return TypeInt::make(hfval);
}
return TypeInt::SHORT;
}
18 changes: 18 additions & 0 deletions src/hotspot/share/opto/convertnode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,24 @@ class ConvI2FNode : public Node {
virtual uint ideal_reg() const { return Op_RegF; }
};

class ReinterpretS2HFNode : public Node {
public:
ReinterpretS2HFNode( Node *in1 ) : Node(0,in1) {}
virtual int Opcode() const;
virtual const Type *bottom_type() const { return Type::FLOAT; }
virtual const Type* Value(PhaseGVN* phase) const;
virtual uint ideal_reg() const { return Op_RegF; }
};

class ReinterpretHF2SNode : public Node {
public:
ReinterpretHF2SNode( Node *in1 ) : Node(0,in1) {}
virtual int Opcode() const;
virtual const Type* Value(PhaseGVN* phase) const;
virtual const Type *bottom_type() const { return TypeInt::SHORT; }
virtual uint ideal_reg() const { return Op_RegI; }
};

class RoundFNode : public Node {
public:
RoundFNode( Node *in1 ) : Node(0,in1) {}
Expand Down
6 changes: 3 additions & 3 deletions src/hotspot/share/opto/inlinetypenode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ class InlineTypeNode : public TypeNode {
// Nodes are connected in increasing order of the index of the field they correspond to.
};

// Get the klass defining the field layout of the inline type
ciInlineKlass* inline_klass() const { return type()->inline_klass(); }

void make_scalar_in_safepoint(PhaseIterGVN* igvn, Unique_Node_List& worklist, SafePointNode* sfpt);

const TypePtr* field_adr_type(Node* base, int offset, ciInstanceKlass* holder, DecoratorSet decorators, PhaseGVN& gvn) const;
Expand All @@ -77,6 +74,9 @@ class InlineTypeNode : public TypeNode {
static InlineTypeNode* make_from_flattened_impl(GraphKit* kit, ciInlineKlass* vk, Node* obj, Node* ptr, ciInstanceKlass* holder, int holder_offset, DecoratorSet decorators, GrowableArray<ciType*>& visited);

public:
// Get the klass defining the field layout of the inline type
ciInlineKlass* inline_klass() const { return type()->inline_klass(); }

// Create with default field values
static InlineTypeNode* make_default(PhaseGVN& gvn, ciInlineKlass* vk);
// Create uninitialized
Expand Down
25 changes: 25 additions & 0 deletions src/hotspot/share/opto/library_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,8 @@ bool LibraryCallKit::try_to_inline(int predicate) {
case vmIntrinsics::_floatToFloat16:
case vmIntrinsics::_float16ToFloat: return inline_fp_conversions(intrinsic_id());

case vmIntrinsics::_add_float16: return inline_fp16_operations(intrinsic_id());

case vmIntrinsics::_floatIsFinite:
case vmIntrinsics::_floatIsInfinite:
case vmIntrinsics::_doubleIsFinite:
Expand Down Expand Up @@ -4788,6 +4790,29 @@ bool LibraryCallKit::inline_native_Reflection_getCallerClass() {
return false; // bail-out; let JVM_GetCallerClass do the work
}

bool LibraryCallKit::inline_fp16_operations(vmIntrinsics::ID id) {
Node* result = NULL;
Node* val1 = argument(0); // receiver
Node* val2 = argument(1); // argument
assert(val1->is_InlineType() && val2->is_InlineType(), "");

Node* fld1 = _gvn.transform(new ReinterpretS2HFNode(val1->as_InlineType()->field_value(0)));
Node* fld2 = _gvn.transform(new ReinterpretS2HFNode(val2->as_InlineType()->field_value(0)));

switch (id) {
case vmIntrinsics::_add_float16: result = _gvn.transform(new AddHFNode(fld1, fld2)); break;

default:
fatal_unexpected_iid(id);
break;
}
InlineTypeNode* box = InlineTypeNode::make_uninitialized(_gvn, val1->as_InlineType()->inline_klass(), true);
Node* short_result = _gvn.transform(new ReinterpretHF2SNode(result));
box->set_field_value(0, short_result);
set_result(_gvn.transform(box));
return true;
}

bool LibraryCallKit::inline_fp_conversions(vmIntrinsics::ID id) {
Node* arg = argument(0);
Node* result = NULL;
Expand Down
1 change: 1 addition & 0 deletions src/hotspot/share/opto/library_call.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ class LibraryCallKit : public GraphKit {
bool inline_unsafe_load_store(BasicType type, LoadStoreKind kind, AccessKind access_kind);
bool inline_unsafe_fence(vmIntrinsics::ID id);
bool inline_onspinwait();
bool inline_fp16_operations(vmIntrinsics::ID id);
bool inline_fp_conversions(vmIntrinsics::ID id);
bool inline_fp_range_check(vmIntrinsics::ID id);
bool inline_number_methods(vmIntrinsics::ID id);
Expand Down
Loading

0 comments on commit b7b7231

Please sign in to comment.