Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/hotspot/share/classfile/vmIntrinsics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -942,8 +942,16 @@ class methodHandle;
\
/* Float16Math API intrinsification support */ \
/* Float16 signatures */ \
do_signature(float16_unary_math_op_sig, "(SLjdk/internal/vm/vector/Float16Math$Float16UnaryMathOp;)S") \
do_signature(float16_ternary_math_op_sig, "(SSSLjdk/internal/vm/vector/Float16Math$Float16TernaryMathOp;)S") \
do_signature(float16_unary_math_op_sig, "(Ljava/lang/Class;" \
"Ljava/lang/Object;" \
"Ljdk/internal/vm/vector/Float16Math$Float16UnaryMathOp;)" \
"Ljava/lang/Object;") \
do_signature(float16_ternary_math_op_sig, "(Ljava/lang/Class;" \
"Ljava/lang/Object;" \
"Ljava/lang/Object;" \
"Ljava/lang/Object;" \
"Ljdk/internal/vm/vector/Float16Math$Float16TernaryMathOp;)" \
"Ljava/lang/Object;") \
do_intrinsic(_sqrt_float16, jdk_internal_vm_vector_Float16Math, sqrt_name, float16_unary_math_op_sig, F_S) \
do_intrinsic(_fma_float16, jdk_internal_vm_vector_Float16Math, fma_name, float16_ternary_math_op_sig, F_S) \
\
Expand Down
1 change: 1 addition & 0 deletions src/hotspot/share/opto/escape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4605,6 +4605,7 @@ void ConnectionGraph::split_unique_types(GrowableArray<Node *> &alloc_worklist,
op == Op_StrEquals || op == Op_VectorizedHashCode ||
op == Op_StrIndexOf || op == Op_StrIndexOfChar ||
op == Op_SubTypeCheck ||
op == Op_ReinterpretS2HF ||
BarrierSet::barrier_set()->barrier_set_c2()->is_gc_barrier_node(use))) {
n->dump();
use->dump();
Expand Down
88 changes: 72 additions & 16 deletions src/hotspot/share/opto/library_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "precompiled.hpp"
#include "asm/macroAssembler.hpp"
#include "ci/ciUtilities.inline.hpp"
#include "ci/ciSymbols.hpp"
#include "classfile/vmIntrinsics.hpp"
#include "compiler/compileBroker.hpp"
#include "compiler/compileLog.hpp"
Expand Down Expand Up @@ -8617,37 +8618,91 @@ bool LibraryCallKit::inline_blackhole() {
return true;
}

Node* LibraryCallKit::unbox_fp16_value(const TypeInstPtr* float16_box_type, ciField* field, Node* box) {
const TypeInstPtr* box_type = _gvn.type(box)->isa_instptr();
if (box_type == nullptr || box_type->instance_klass() != float16_box_type->instance_klass()) {
return nullptr; // box klass is not Float16
}

// Null check; get notnull casted pointer
Node* null_ctl = top();
Node* not_null_box = null_check_oop(box, &null_ctl, true);
// If not_null_box is dead, only null-path is taken
if (stopped()) {
set_control(null_ctl);
return nullptr;
}
assert(not_null_box->bottom_type()->is_instptr()->maybe_null() == false, "");
const TypePtr* adr_type = C->alias_type(field)->adr_type();
Node* adr = basic_plus_adr(not_null_box, field->offset_in_bytes());
return access_load_at(not_null_box, adr, adr_type, TypeInt::SHORT, T_SHORT, IN_HEAP);
}

Node* LibraryCallKit::box_fp16_value(const TypeInstPtr* float16_box_type, ciField* field, Node* value) {
PreserveReexecuteState preexecs(this);
jvms()->set_should_reexecute(true);

const TypeKlassPtr* klass_type = float16_box_type->as_klass_type();
Node* klass_node = makecon(klass_type);
Node* box = new_instance(klass_node);

Node* value_field = basic_plus_adr(box, field->offset_in_bytes());
const TypePtr* value_adr_type = value_field->bottom_type()->is_ptr();

Node* field_store = _gvn.transform(access_store_at(box,
value_field,
value_adr_type,
value,
TypeInt::SHORT,
T_SHORT,
IN_HEAP));
set_memory(field_store, value_adr_type);
return box;
}

bool LibraryCallKit::inline_fp16_operations(vmIntrinsics::ID id, int num_args) {
if (!Matcher::match_rule_supported(Op_ReinterpretS2HF) ||
!Matcher::match_rule_supported(Op_ReinterpretHF2S)) {
return false;
}

const TypeInstPtr* box_type = _gvn.type(argument(0))->isa_instptr();
if (box_type == nullptr || box_type->const_oop() == nullptr) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, this is not a review comment.
Just curious, to continue the following code path why does box_type must have a valid const_oop?

Copy link
Member Author

@jatin-bhateja jatin-bhateja Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Hamlin-Li , Class types are passed as constant oop, this check is added for argument validation.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!
Seems it could be an assert instead? Or maybe I could have misunderstood your above explanation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Hamlin-Li, We intend to disable intrinsification if constraints are not met.

return false;
}

ciInstanceKlass* float16_klass = box_type->const_oop()->as_instance()->java_lang_Class_klass()->as_instance_klass();
const TypeInstPtr* float16_box_type = TypeInstPtr::make_exact(TypePtr::NotNull, float16_klass);
ciField* field = float16_klass->get_field_by_name(ciSymbols::value_name(),
ciSymbols::short_signature(),
false);
assert(field != nullptr, "");

// Transformed nodes
Node* fld1 = nullptr;
Node* fld2 = nullptr;
Node* fld3 = nullptr;
switch(num_args) {
case 3:
assert((argument(2)->is_ConI() &&
argument(2)->get_int() >= min_jshort &&
argument(2)->get_int() <= max_jshort) ||
(argument(2)->bottom_type()->array_element_basic_type() == T_SHORT), "");
fld3 = _gvn.transform(new ReinterpretS2HFNode(argument(2)));
fld3 = unbox_fp16_value(float16_box_type, field, argument(3));
if (fld3 == nullptr) {
return false;
}
fld3 = _gvn.transform(new ReinterpretS2HFNode(fld3));
// fall-through
case 2:
assert((argument(1)->is_ConI() &&
argument(1)->get_int() >= min_jshort &&
argument(1)->get_int() <= max_jshort) ||
(argument(1)->bottom_type()->array_element_basic_type() == T_SHORT), "");
fld2 = _gvn.transform(new ReinterpretS2HFNode(argument(1)));
fld2 = unbox_fp16_value(float16_box_type, field, argument(2));
if (fld2 == nullptr) {
return false;
}
fld2 = _gvn.transform(new ReinterpretS2HFNode(fld2));
// fall-through
case 1:
assert((argument(0)->is_ConI() &&
argument(0)->get_int() >= min_jshort &&
argument(0)->get_int() <= max_jshort) ||
(argument(0)->bottom_type()->array_element_basic_type() == T_SHORT), "");
fld1 = _gvn.transform(new ReinterpretS2HFNode(argument(0)));
fld1 = unbox_fp16_value(float16_box_type, field, argument(1));
if (fld1 == nullptr) {
return false;
}
fld1 = _gvn.transform(new ReinterpretS2HFNode(fld1));
break;
default: fatal("Unsupported number of arguments %d", num_args);
}
Expand All @@ -8666,7 +8721,8 @@ bool LibraryCallKit::inline_fp16_operations(vmIntrinsics::ID id, int num_args) {
fatal_unexpected_iid(id);
break;
}
set_result(_gvn.transform(new ReinterpretHF2SNode(result)));
result = _gvn.transform(new ReinterpretHF2SNode(result));
set_result(box_fp16_value(float16_box_type, field, result));
return true;
}

2 changes: 2 additions & 0 deletions src/hotspot/share/opto/library_call.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ class LibraryCallKit : public GraphKit {
bool inline_fp_conversions(vmIntrinsics::ID id);
bool inline_fp_range_check(vmIntrinsics::ID id);
bool inline_fp16_operations(vmIntrinsics::ID id, int num_args);
Node* unbox_fp16_value(const TypeInstPtr* box_class, ciField* field, Node* box);
Node* box_fp16_value(const TypeInstPtr* box_class, ciField* field, Node* value);
bool inline_number_methods(vmIntrinsics::ID id);
bool inline_bitshuffle_methods(vmIntrinsics::ID id);
bool inline_compare_unsigned(vmIntrinsics::ID id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,34 @@
package jdk.internal.vm.vector;

import jdk.internal.vm.annotation.IntrinsicCandidate;
import jdk.internal.vm.annotation.ForceInline;

import static java.lang.Float.*;

/**
* The class {@code Float16Math} constains intrinsic entry points corresponding
* to scalar numeric operations defined in Float16 class.
* @since 24
* @since 25
*/
public final class Float16Math {
private Float16Math() {
}

public interface Float16UnaryMathOp {
short apply(short a);
Object apply(Object a);
}

public interface Float16TernaryMathOp {
short apply(short a, short b, short c);
Object apply(Object a, Object b, Object c);
}

@IntrinsicCandidate
public static short sqrt(short a, Float16UnaryMathOp defaultImpl) {
public static Object sqrt(Class<?> box_class, Object oa, Float16UnaryMathOp defaultImpl) {
assert isNonCapturingLambda(defaultImpl) : defaultImpl;
return defaultImpl.apply(a);
return defaultImpl.apply(oa);
}

@IntrinsicCandidate
public static short fma(short a, short b, short c, Float16TernaryMathOp defaultImpl) {
public static Object fma(Class<?> box_class, Object oa, Object ob, Object oc, Float16TernaryMathOp defaultImpl) {
assert isNonCapturingLambda(defaultImpl) : defaultImpl;
return defaultImpl.apply(a, b, c);
return defaultImpl.apply(oa, ob, oc);
}

public static boolean isNonCapturingLambda(Object o) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1199,22 +1199,16 @@ public static Float16 divide(Float16 dividend, Float16 divisor) {
* @see Math#sqrt(double)
*/
public static Float16 sqrt(Float16 radicand) {
// Explicitly unbox float16 radicand as intrinsic expects
// to receive short type arguments holding IEEE 754 binary16
// value.
short unboxed_radicand = float16ToRawShortBits(radicand);
short retval = Float16Math.sqrt(unboxed_radicand,
(f16) -> {
return (Float16) Float16Math.sqrt(Float16.class, radicand,
(_radicand) -> {
// Rounding path of sqrt(Float16 -> double) -> Float16 is fine
// for preserving the correct final value. The conversion
// Float16 -> double preserves the exact numerical value. The
// conversion of double -> Float16 also benefits from the
// 2p+2 property of IEEE 754 arithmetic.
double res = Math.sqrt(float16ToFloat(f16));
return float16ToRawShortBits(valueOf(res));
return valueOf(Math.sqrt(((Float16)(_radicand)).doubleValue()));
}
);
return shortBitsToFloat16(retval);
}

/**
Expand Down Expand Up @@ -1416,22 +1410,14 @@ public static Float16 fma(Float16 a, Float16 b, Float16 c) {
* harmless.
*/

// Explicitly unbox float16 values as intrinsic expects
// to receive short type arguments holding IEEE 754 binary16
// values.
short unboxed_a = float16ToRawShortBits(a);
short unboxed_b = float16ToRawShortBits(b);
short unboxed_c = float16ToRawShortBits(c);

short res = Float16Math.fma(unboxed_a, unboxed_b, unboxed_c,
(f16a, f16b, f16c) -> {
return (Float16) Float16Math.fma(Float16.class, a, b, c,
(_a, _b, _c) -> {
// product is numerically exact in float before the cast to
// double; not necessary to widen to double before the
// multiply.
double product = (double)(float16ToFloat(f16a) * float16ToFloat(f16b));
return float16ToRawShortBits(valueOf(product + float16ToFloat(f16c)));
double product = (double)(((Float16)_a).floatValue() * ((Float16)_b).floatValue());
return valueOf(product + ((Float16)_c).doubleValue());
});
return shortBitsToFloat16(res);
}

/**
Expand Down