Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions src/hotspot/cpu/x86/x86.ad
Original file line number Diff line number Diff line change
Expand Up @@ -4514,10 +4514,10 @@ instruct vReplS_reg(vec dst, rRegI src) %{
%}

#ifdef _LP64
instruct ReplH_imm(vec dst, immH con, rRegI rtmp) %{
instruct ReplHF_imm(vec dst, immH con, rRegI rtmp) %{
match(Set dst (Replicate con));
effect(TEMP rtmp);
format %{ "replicateH $dst, $con \t! using $rtmp as TEMP" %}
format %{ "replicateHF $dst, $con \t! using $rtmp as TEMP" %}
ins_encode %{
int vlen_enc = vector_length_encoding(this);
BasicType bt = Matcher::vector_element_basic_type(this);
Expand All @@ -4528,22 +4528,22 @@ instruct ReplH_imm(vec dst, immH con, rRegI rtmp) %{
ins_pipe( pipe_slow );
%}

instruct ReplH_short_reg(vec dst, rRegI src) %{
instruct ReplHF_short_reg(vec dst, rRegI src) %{
predicate(VM_Version::supports_avx512_fp16() && Matcher::vector_element_basic_type(n) == T_SHORT);
match(Set dst (Replicate (ReinterpretS2HF src)));
format %{ "replicateH $dst, $src" %}
format %{ "replicateHF $dst, $src" %}
ins_encode %{
int vlen_enc = vector_length_encoding(this);
__ evpbroadcastw($dst$$XMMRegister, $src$$Register, vlen_enc);
%}
ins_pipe( pipe_slow );
%}

instruct ReplH_reg(vec dst, regF src, rRegI rtmp) %{
instruct ReplHF_reg(vec dst, regF src, rRegI rtmp) %{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jatin-bhateja , there are JTREG tests testing Replicate with immediate FP16 values (Rep1HF_imm backend rule) but I noticed that there are no JTREG tests for testing these rules - Rep1HF_reg and Rep1HF_short_reg. Should those be added as well? Something like -

float f = 10.0f;
for (int i = 0; i < SIZE: ++i) {
    res[i] = Float16.fma(a[i], b[i], Float16.valueOf(f));   // where f is a loop invariant float variable defined out of the loop 
}

predicate(VM_Version::supports_avx512_fp16() && Matcher::vector_element_basic_type(n) == T_SHORT);
match(Set dst (Replicate src));
effect(TEMP rtmp);
format %{ "replicateH $dst, $src \t! using $rtmp as TEMP" %}
format %{ "replicateHF $dst, $src \t! using $rtmp as TEMP" %}
ins_encode %{
int vlen_enc = vector_length_encoding(this);
__ vmovw($rtmp$$Register, $src$$XMMRegister);
Expand Down Expand Up @@ -10897,7 +10897,7 @@ instruct vector_selectfrom_twovectors_reg_evex(vec index, vec src1, vec src2)
ins_pipe(pipe_slow);
%}

instruct reinterpretS2H(regF dst, rRegI src)
instruct reinterpretS2HF(regF dst, rRegI src)
%{
match(Set dst (ReinterpretS2HF src));
format %{ "vmovw $dst, $src" %}
Expand All @@ -10917,7 +10917,7 @@ instruct convF2HFAndS2HF(regF dst, regF src)
ins_pipe(pipe_slow);
%}

instruct reinterpretH2S(rRegI dst, regF src)
instruct reinterpretHF2S(rRegI dst, regF src)
%{
match(Set dst (ReinterpretHF2S src));
format %{ "vmovw $dst, $src" %}
Expand All @@ -10927,74 +10927,74 @@ instruct reinterpretH2S(rRegI dst, regF src)
ins_pipe(pipe_slow);
%}

instruct scalar_sqrt_fp16_reg(regF dst, regF src)
instruct scalar_sqrt_HF_reg(regF dst, regF src)
%{
match(Set dst (SqrtHF src));
format %{ "vsqrtsh $dst, $src" %}
format %{ "scalar_sqrt_fp16 $dst, $src" %}
ins_encode %{
__ vsqrtsh($dst$$XMMRegister, $src$$XMMRegister);
%}
ins_pipe(pipe_slow);
%}

instruct scalar_binOps_fp16_reg(regF dst, regF src1, regF src2)
instruct scalar_binOps_HF_reg(regF dst, regF src1, regF src2)
%{
match(Set dst (AddHF src1 src2));
match(Set dst (DivHF src1 src2));
match(Set dst (MaxHF src1 src2));
match(Set dst (MinHF src1 src2));
match(Set dst (MulHF src1 src2));
match(Set dst (SubHF src1 src2));
format %{ "efp16sh $dst, $src1, $src2" %}
format %{ "scalar_binop_fp16 $dst, $src1, $src2" %}
ins_encode %{
int opcode = this->ideal_Opcode();
__ efp16sh(opcode, $dst$$XMMRegister, $src1$$XMMRegister, $src2$$XMMRegister);
%}
ins_pipe(pipe_slow);
%}

instruct scalar_fma_fp16_reg(regF dst, regF src1, regF src2)
instruct scalar_fma_HF_reg(regF dst, regF src1, regF src2)
%{
match(Set dst (FmaHF src2 (Binary dst src1)));
effect(DEF dst);
format %{ "evfmash $dst, $src1, $src2\t# $dst = $dst * $src1 + $src2 fma packedH" %}
format %{ "scalar_fma_fp16 $dst, $src1, $src2\t# $dst = $dst * $src1 + $src2 fma packedH" %}
ins_encode %{
__ vfmadd132sh($dst$$XMMRegister, $src2$$XMMRegister, $src1$$XMMRegister);
%}
ins_pipe( pipe_slow );
%}

instruct vector_sqrt_fp16_reg(vec dst, vec src)
instruct vector_sqrt_HF_reg(vec dst, vec src)
%{
match(Set dst (SqrtVHF src));
format %{ "evsqrtph_reg $dst, $src" %}
format %{ "vector_sqrt_fp16 $dst, $src" %}
ins_encode %{
int vlen_enc = vector_length_encoding(this);
__ evsqrtph($dst$$XMMRegister, $src$$XMMRegister, vlen_enc);
%}
ins_pipe(pipe_slow);
%}

instruct vector_sqrt_fp16_mem(vec dst, memory src)
instruct vector_sqrt_HF_mem(vec dst, memory src)
%{
match(Set dst (SqrtVHF (VectorReinterpret (LoadVector src))));
format %{ "evsqrtph_mem $dst, $src" %}
format %{ "vector_sqrt_fp16_mem $dst, $src" %}
ins_encode %{
int vlen_enc = vector_length_encoding(this);
__ evsqrtph($dst$$XMMRegister, $src$$Address, vlen_enc);
%}
ins_pipe(pipe_slow);
%}

instruct vector_binOps_fp16_reg(vec dst, vec src1, vec src2)
instruct vector_binOps_HF_reg(vec dst, vec src1, vec src2)
%{
match(Set dst (AddVHF src1 src2));
match(Set dst (DivVHF src1 src2));
match(Set dst (MaxVHF src1 src2));
match(Set dst (MinVHF src1 src2));
match(Set dst (MulVHF src1 src2));
match(Set dst (SubVHF src1 src2));
format %{ "evbinopfp16_reg $dst, $src1, $src2" %}
format %{ "vector_binop_fp16 $dst, $src1, $src2" %}
ins_encode %{
int vlen_enc = vector_length_encoding(this);
int opcode = this->ideal_Opcode();
Expand All @@ -11003,15 +11003,15 @@ instruct vector_binOps_fp16_reg(vec dst, vec src1, vec src2)
ins_pipe(pipe_slow);
%}

instruct vector_binOps_fp16_mem(vec dst, vec src1, memory src2)
instruct vector_binOps_HF_mem(vec dst, vec src1, memory src2)
%{
match(Set dst (AddVHF src1 (VectorReinterpret (LoadVector src2))));
match(Set dst (DivVHF src1 (VectorReinterpret (LoadVector src2))));
match(Set dst (MaxVHF src1 (VectorReinterpret (LoadVector src2))));
match(Set dst (MinVHF src1 (VectorReinterpret (LoadVector src2))));
match(Set dst (MulVHF src1 (VectorReinterpret (LoadVector src2))));
match(Set dst (SubVHF src1 (VectorReinterpret (LoadVector src2))));
format %{ "evbinopfp16_mem $dst, $src1, $src2" %}
format %{ "vector_binop_fp16_mem $dst, $src1, $src2" %}
ins_encode %{
int vlen_enc = vector_length_encoding(this);
int opcode = this->ideal_Opcode();
Expand All @@ -11021,23 +11021,23 @@ instruct vector_binOps_fp16_mem(vec dst, vec src1, memory src2)
%}


instruct vector_fma_fp16_reg(vec dst, vec src1, vec src2)
instruct vector_fma_HF_reg(vec dst, vec src1, vec src2)
%{
match(Set dst (FmaVHF src2 (Binary dst src1)));
effect(DEF dst);
format %{ "evfmaph_reg $dst, $src1, $src2\t# $dst = $dst * $src1 + $src2 fma packedH" %}
format %{ "vector_fma_fp16 $dst, $src1, $src2\t# $dst = $dst * $src1 + $src2 fma packedH" %}
ins_encode %{
int vlen_enc = vector_length_encoding(this);
__ evfmadd132ph($dst$$XMMRegister, $src2$$XMMRegister, $src1$$XMMRegister, vlen_enc);
%}
ins_pipe( pipe_slow );
%}

instruct vector_fma_fp16_mem(vec dst, memory src1, vec src2)
instruct vector_fma_HF_mem(vec dst, memory src1, vec src2)
%{
match(Set dst (FmaVHF src2 (Binary dst (VectorReinterpret (LoadVector src1)))));
effect(DEF dst);
format %{ "evfmaph_mem $dst, $src1, $src2\t# $dst = $dst * $src1 + $src2 fma packedH" %}
format %{ "vector_fma_fp16_mem $dst, $src1, $src2\t# $dst = $dst * $src1 + $src2 fma packedH" %}
ins_encode %{
int vlen_enc = vector_length_encoding(this);
__ evfmadd132ph($dst$$XMMRegister, $src2$$XMMRegister, $src1$$Address, vlen_enc);
Expand Down
13 changes: 8 additions & 5 deletions src/hotspot/share/opto/addnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -710,15 +710,14 @@ Node *AddFNode::Ideal(PhaseGVN *phase, bool can_reshape) {
//=============================================================================
//------------------------------add_of_identity--------------------------------
// Check for addition of the identity
const Type *AddHFNode::add_of_identity(const Type *t1, const Type *t2) const {
const Type *AddHFNode::add_of_identity(const Type* t1, const Type* t2) const {
return nullptr;
}

//------------------------------add_ring---------------------------------------
// Supplied function returns the sum of the inputs.
// This also type-checks the inputs for sanity. Guaranteed never to
// be passed a TOP or BOTTOM type, these are filtered out by pre-check.
const Type *AddHFNode::add_ring(const Type *t0, const Type *t1) const {
const Type* AddHFNode::add_ring(const Type* t0, const Type* t1) const {
if (!t0->isa_half_float_constant() || !t1->isa_half_float_constant()) {
return bottom_type();
}
Expand Down Expand Up @@ -1621,7 +1620,9 @@ const Type* MinHFNode::add_ring(const Type* t0, const Type* t1) const {
return f0 < f1 ? r0 : r1;
}

// handle min of 0.0, -0.0 case.
// As per IEEE 754 specification, floating point comparison consider +ve and -ve
// zeros as equals. Thus, performing signed integral comparison for max value
// detection.
return (jint_cast(f0) < jint_cast(f1)) ? r0 : r1;
}

Expand All @@ -1646,7 +1647,9 @@ const Type* MinFNode::add_ring(const Type* t0, const Type* t1 ) const {
return f0 < f1 ? r0 : r1;
}

// handle min of 0.0, -0.0 case.
// As per IEEE 754 specification, floating point comparison consider +ve and -ve
// zeros as equals. Thus, performing signed integral comparison for min value
// detection.
return (jint_cast(f0) < jint_cast(f1)) ? r0 : r1;
}

Expand Down
4 changes: 2 additions & 2 deletions src/hotspot/share/opto/addnode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ class AddDNode : public AddNode {
// Add 2 half-precision floats
class AddHFNode : public AddNode {
public:
AddHFNode(Node *in1, Node *in2) : AddNode(in1,in2) {}
AddHFNode(Node* in1, Node* in2) : AddNode(in1,in2) {}
virtual int Opcode() const;
virtual const Type* add_of_identity(const Type *t1, const Type *t2) const;
virtual const Type* add_of_identity(const Type* t1, const Type* t2) const;
virtual const Type* add_ring(const Type*, const Type*) const;
virtual const Type* add_id() const { return TypeH::ZERO; }
virtual const Type* bottom_type() const { return Type::HALF_FLOAT; }
Expand Down
4 changes: 3 additions & 1 deletion src/hotspot/share/opto/connode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ uint ConNode::hash() const {

//------------------------------make-------------------------------------------
ConNode *ConNode::make(const Type *t) {
if (t->isa_half_float_constant()) {
return new ConHNode( t->is_half_float_constant() );
}
switch( t->basic_type() ) {
case T_INT: return new ConINode( t->is_int() );
case T_SHORT: return new ConHNode( t->is_half_float_constant() );
case T_LONG: return new ConLNode( t->is_long() );
case T_FLOAT: return new ConFNode( t->is_float_constant() );
case T_DOUBLE: return new ConDNode( t->is_double_constant() );
Expand Down
5 changes: 2 additions & 3 deletions src/hotspot/share/opto/connode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,13 @@ class ConLNode : public ConNode {
// Simple half float constants
class ConHNode : public ConNode {
public:
ConHNode( const TypeH *t ) : ConNode(t) {}
ConHNode(const TypeH* t) : ConNode(t) {}
virtual int Opcode() const;

// Factory method:
static ConHNode* make(float con) {
return new ConHNode( TypeH::make(con) );
return new ConHNode(TypeH::make(con));
}

};

//------------------------------ConFNode---------------------------------------
Expand Down
20 changes: 17 additions & 3 deletions src/hotspot/share/opto/convertnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,22 @@ const Type* ConvF2HFNode::Value(PhaseGVN* phase) const {

//------------------------------Ideal------------------------------------------
Node* ConvF2HFNode::Ideal(PhaseGVN* phase, bool can_reshape) {
// Optimize pattern - ConvHF2F (FP32BinOp) ConvF2HF ==> ReinterpretS2HF (FP16BinOp) ReinterpretHF2S.
if (Float16NodeFactory::is_binary_oper(in(1)->Opcode()) &&
// Float16 instance encapsulates a short field holding IEEE 754
// binary16 value. On unboxing, this short field is loaded into a
// GPR register while FP operation operates over floating point
// registers. ConvHF2F converts incoming short value to a FP32 value
// to perform operation at FP32 granularity. However, if target
// support FP16 ISA we can save this redundant up casting and
// optimize the graph pallet using following transformation.
//
// ConvF2HF(FP32BinOp(ConvHF2F(x), ConvHF2F(y))) =>
// ReinterpretHF2S(FP16BinOp(ReinterpretS2HF(x), ReinterpretS2HF(y)))
//
// Please note we need to inject appropriate reinterpretation
// IR to move the values b/w GPR and floating point register
// before and after FP16 operation.

if (Float16NodeFactory::is_float32_binary_oper(in(1)->Opcode()) &&
in(1)->in(1)->Opcode() == Op_ConvHF2F &&
in(1)->in(2)->Opcode() == Op_ConvHF2F) {
if (Matcher::match_rule_supported(Float16NodeFactory::get_float16_binary_oper(in(1)->Opcode())) &&
Expand Down Expand Up @@ -945,7 +959,7 @@ const Type* ReinterpretHF2SNode::Value(PhaseGVN* phase) const {
return TypeInt::SHORT;
}

bool Float16NodeFactory::is_binary_oper(int opc) {
bool Float16NodeFactory::is_float32_binary_oper(int opc) {
switch(opc) {
case Op_AddF:
case Op_SubF:
Expand Down
4 changes: 2 additions & 2 deletions src/hotspot/share/opto/convertnode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class ReinterpretS2HFNode : public Node {
// Reinterpret Half Float to Short
class ReinterpretHF2SNode : public Node {
public:
ReinterpretHF2SNode( Node *in1 ) : Node(0,in1) {}
ReinterpretHF2SNode(Node* in1) : Node(0,in1) {}
virtual int Opcode() const;
virtual const Type* Value(PhaseGVN* phase) const;
virtual const Type* bottom_type() const { return TypeInt::SHORT; }
Expand Down Expand Up @@ -296,7 +296,7 @@ class RoundDoubleModeNode: public Node {

class Float16NodeFactory {
public:
static bool is_binary_oper(int opc);
static bool is_float32_binary_oper(int opc);
static int get_float16_binary_oper(int opc);
static Node* make(int opc, Node* c, Node* in1, Node* in2);
};
Expand Down
Loading