Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/hotspot/cpu/riscv/riscv.ad
Original file line number Diff line number Diff line change
Expand Up @@ -1933,8 +1933,8 @@ bool Matcher::match_rule_supported(int opcode) {
case Op_MaxHF:
case Op_MinHF:
case Op_MulHF:
case Op_SubHF:
case Op_SqrtHF:
case Op_SubHF:
return UseZfh;

case Op_CMoveF:
Expand Down
147 changes: 144 additions & 3 deletions src/hotspot/cpu/riscv/riscv_v.ad
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,6 @@ source %{
return false;
}
break;
case Op_VectorCastHF2F:
case Op_VectorCastF2HF:
return UseZvfh;
case Op_VectorLoadShuffle:
case Op_VectorRearrange:
// vlen >= 4 is required, because min vector size for byte is 4 on riscv,
Expand All @@ -113,6 +110,18 @@ source %{
if (vlen < 4) {
return false;
}
case Op_VectorCastHF2F:
case Op_VectorCastF2HF:
case Op_AddVHF:
case Op_DivVHF:
case Op_MaxVHF:
case Op_MinVHF:
case Op_MulVHF:
case Op_SqrtVHF:
case Op_SubVHF:
return UseZvfh;
case Op_FmaVHF:
return UseZvfh && UseFMA;
default:
break;
}
Expand Down Expand Up @@ -363,6 +372,21 @@ instruct vadd(vReg dst, vReg src1, vReg src2) %{
ins_pipe(pipe_slow);
%}

instruct vadd_hfp(vReg dst, vReg src1, vReg src2) %{
match(Set dst (AddVHF src1 src2));
ins_cost(VEC_COST);
format %{ "vadd_hfp $dst, $src1, $src2" %}
ins_encode %{
assert(UseZvfh, "must");
assert(Matcher::vector_element_basic_type(this) == T_SHORT, "must");
__ vsetvli_helper(T_SHORT, Matcher::vector_length(this));
__ vfadd_vv(as_VectorRegister($dst$$reg),
as_VectorRegister($src1$$reg),
as_VectorRegister($src2$$reg));
%}
ins_pipe(pipe_slow);
%}

instruct vadd_fp(vReg dst, vReg src1, vReg src2) %{
match(Set dst (AddVF src1 src2));
match(Set dst (AddVD src1 src2));
Expand Down Expand Up @@ -546,6 +570,20 @@ instruct vsub(vReg dst, vReg src1, vReg src2) %{
ins_pipe(pipe_slow);
%}

instruct vsub_hfp(vReg dst, vReg src1, vReg src2) %{
match(Set dst (SubVHF src1 src2));
ins_cost(VEC_COST);
format %{ "vsub_hfp $dst, $src1, $src2" %}
ins_encode %{
assert(UseZvfh, "must");
assert(Matcher::vector_element_basic_type(this) == T_SHORT, "must");
__ vsetvli_helper(T_SHORT, Matcher::vector_length(this));
__ vfsub_vv(as_VectorRegister($dst$$reg), as_VectorRegister($src1$$reg),
as_VectorRegister($src2$$reg));
%}
ins_pipe(pipe_slow);
%}

instruct vsub_fp(vReg dst, vReg src1, vReg src2) %{
match(Set dst (SubVF src1 src2));
match(Set dst (SubVD src1 src2));
Expand Down Expand Up @@ -1542,6 +1580,21 @@ instruct vnotL_masked(vReg dst_src, immI_M1 m1, vRegMask_V0 v0) %{

// vector float div

instruct vdiv_hfp(vReg dst, vReg src1, vReg src2) %{
match(Set dst (DivVHF src1 src2));
ins_cost(VEC_COST);
format %{ "vdiv_hfp $dst, $src1, $src2" %}
ins_encode %{
assert(UseZvfh, "must");
assert(Matcher::vector_element_basic_type(this) == T_SHORT, "must");
__ vsetvli_helper(T_SHORT, Matcher::vector_length(this));
__ vfdiv_vv(as_VectorRegister($dst$$reg),
as_VectorRegister($src1$$reg),
as_VectorRegister($src2$$reg));
%}
ins_pipe(pipe_slow);
%}

instruct vdiv_fp(vReg dst, vReg src1, vReg src2) %{
match(Set dst (DivVF src1 src2));
match(Set dst (DivVD src1 src2));
Expand Down Expand Up @@ -1698,6 +1751,38 @@ instruct vminu_masked(vReg dst_src1, vReg src2, vRegMask_V0 v0) %{
ins_pipe(pipe_slow);
%}

// vector float-point max/min (half precision)

instruct vmax_hfp(vReg dst, vReg src1, vReg src2, vRegMask_V0 v0) %{
match(Set dst (MaxVHF src1 src2));
effect(TEMP_DEF dst, TEMP v0);
ins_cost(VEC_COST);
format %{ "vmax_hfp $dst, $src1, $src2" %}
ins_encode %{
assert(UseZvfh, "must");
assert(Matcher::vector_element_basic_type(this) == T_SHORT, "must");
__ minmax_fp_v(as_VectorRegister($dst$$reg),
as_VectorRegister($src1$$reg), as_VectorRegister($src2$$reg),
T_SHORT, false /* is_min */, Matcher::vector_length(this));
%}
ins_pipe(pipe_slow);
%}

instruct vmin_hfp(vReg dst, vReg src1, vReg src2, vRegMask_V0 v0) %{
match(Set dst (MinVHF src1 src2));
effect(TEMP_DEF dst, TEMP v0);
ins_cost(VEC_COST);
format %{ "vmin_hfp $dst, $src1, $src2" %}
ins_encode %{
assert(UseZvfh, "must");
assert(Matcher::vector_element_basic_type(this) == T_SHORT, "must");
__ minmax_fp_v(as_VectorRegister($dst$$reg),
as_VectorRegister($src1$$reg), as_VectorRegister($src2$$reg),
T_SHORT, true /* is_min */, Matcher::vector_length(this));
%}
ins_pipe(pipe_slow);
%}

// vector float-point max/min

instruct vmax_fp(vReg dst, vReg src1, vReg src2, vRegMask_V0 v0) %{
Expand Down Expand Up @@ -1770,6 +1855,22 @@ instruct vmin_fp_masked(vReg dst_src1, vReg src2, vRegMask vmask, vReg tmp1, vRe

// vector fmla

// dst_src1 = src2 * src3 + dst_src1 (half precision)
instruct vhfmla(vReg dst_src1, vReg src2, vReg src3) %{
match(Set dst_src1 (FmaVHF dst_src1 (Binary src2 src3)));
ins_cost(VEC_COST);
format %{ "vhfmla $dst_src1, $dst_src1, $src2, $src3" %}
ins_encode %{
assert(UseFMA, "Needs FMA instructions support.");
assert(UseZvfh, "must");
assert(Matcher::vector_element_basic_type(this) == T_SHORT, "must");
__ vsetvli_helper(T_SHORT, Matcher::vector_length(this));
__ vfmacc_vv(as_VectorRegister($dst_src1$$reg),
as_VectorRegister($src2$$reg), as_VectorRegister($src3$$reg));
%}
ins_pipe(pipe_slow);
%}

// dst_src1 = src2 * src3 + dst_src1
instruct vfmla(vReg dst_src1, vReg src2, vReg src3) %{
match(Set dst_src1 (FmaVF dst_src1 (Binary src2 src3)));
Expand Down Expand Up @@ -2038,6 +2139,20 @@ instruct vmul(vReg dst, vReg src1, vReg src2) %{
ins_pipe(pipe_slow);
%}

instruct vmul_hfp(vReg dst, vReg src1, vReg src2) %{
match(Set dst (MulVHF src1 src2));
ins_cost(VEC_COST);
format %{ "vmul_hfp $dst, $src1, $src2" %}
ins_encode %{
assert(UseZvfh, "must");
assert(Matcher::vector_element_basic_type(this) == T_SHORT, "must");
__ vsetvli_helper(T_SHORT, Matcher::vector_length(this));
__ vfmul_vv(as_VectorRegister($dst$$reg), as_VectorRegister($src1$$reg),
as_VectorRegister($src2$$reg));
%}
ins_pipe(pipe_slow);
%}

instruct vmul_fp(vReg dst, vReg src1, vReg src2) %{
match(Set dst (MulVF src1 src2));
match(Set dst (MulVD src1 src2));
Expand Down Expand Up @@ -2971,6 +3086,19 @@ instruct replicateL_imm5(vReg dst, immL5 con) %{
ins_pipe(pipe_slow);
%}

instruct replicateHF(vReg dst, fRegF src) %{
predicate(Matcher::vector_element_basic_type(n) == T_SHORT);
match(Set dst (Replicate src));
ins_cost(VEC_COST);
format %{ "replicateHF $dst, $src" %}
ins_encode %{
assert(UseZvfh, "must");
__ vsetvli_helper(T_SHORT, Matcher::vector_length(this));
__ vfmv_v_f(as_VectorRegister($dst$$reg), $src$$FloatRegister);
%}
ins_pipe(pipe_slow);
%}

instruct replicateF(vReg dst, fRegF src) %{
predicate(Matcher::vector_element_basic_type(n) == T_FLOAT);
match(Set dst (Replicate src));
Expand Down Expand Up @@ -4014,6 +4142,19 @@ instruct vrotate_left_vi_masked(vReg dst_src, immI shift, vRegMask_V0 v0) %{

// vector sqrt

instruct vsqrt_hfp(vReg dst, vReg src) %{
match(Set dst (SqrtVHF src));
ins_cost(VEC_COST);
format %{ "vsqrt_hfp $dst, $src" %}
ins_encode %{
assert(UseZvfh, "must");
assert(Matcher::vector_element_basic_type(this) == T_SHORT, "must");
__ vsetvli_helper(T_SHORT, Matcher::vector_length(this));
__ vfsqrt_v(as_VectorRegister($dst$$reg), as_VectorRegister($src$$reg));
%}
ins_pipe(pipe_slow);
%}

instruct vsqrt_fp(vReg dst, vReg src) %{
match(Set dst (SqrtVF src));
match(Set dst (SqrtVD src));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public TestFloat16VectorOperations() {
@Test
@Warmup(10000)
@IR(counts = {IRNode.ADD_VHF, ">= 1"},
applyIfCPUFeature = {"avx512_fp16", "true"})
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
public void vectorAddFloat16() {
for (int i = 0; i < LEN; ++i) {
output[i] = float16ToRawShortBits(add(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i])));
Expand All @@ -99,7 +99,7 @@ public void checkResultAdd() {
@Test
@Warmup(10000)
@IR(counts = {IRNode.SUB_VHF, ">= 1"},
applyIfCPUFeature = {"avx512_fp16", "true"})
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
public void vectorSubFloat16() {
for (int i = 0; i < LEN; ++i) {
output[i] = float16ToRawShortBits(subtract(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i])));
Expand All @@ -120,7 +120,7 @@ public void checkResultSub() {
@Test
@Warmup(10000)
@IR(counts = {IRNode.MUL_VHF, ">= 1"},
applyIfCPUFeature = {"avx512_fp16", "true"})
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
public void vectorMulFloat16() {
for (int i = 0; i < LEN; ++i) {
output[i] = float16ToRawShortBits(multiply(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i])));
Expand All @@ -141,7 +141,7 @@ public void checkResultMul() {
@Test
@Warmup(10000)
@IR(counts = {IRNode.DIV_VHF, ">= 1"},
applyIfCPUFeature = {"avx512_fp16", "true"})
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
public void vectorDivFloat16() {
for (int i = 0; i < LEN; ++i) {
output[i] = float16ToRawShortBits(divide(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i])));
Expand All @@ -162,7 +162,7 @@ public void checkResultDiv() {
@Test
@Warmup(10000)
@IR(counts = {IRNode.MIN_VHF, ">= 1"},
applyIfCPUFeature = {"avx512_fp16", "true"})
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
public void vectorMinFloat16() {
for (int i = 0; i < LEN; ++i) {
output[i] = float16ToRawShortBits(min(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i])));
Expand All @@ -183,7 +183,7 @@ public void checkResultMin() {
@Test
@Warmup(10000)
@IR(counts = {IRNode.MAX_VHF, ">= 1"},
applyIfCPUFeature = {"avx512_fp16", "true"})
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
public void vectorMaxFloat16() {
for (int i = 0; i < LEN; ++i) {
output[i] = float16ToRawShortBits(max(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i])));
Expand All @@ -204,7 +204,7 @@ public void checkResultMax() {
@Test
@Warmup(10000)
@IR(counts = {IRNode.SQRT_VHF, ">= 1"},
applyIfCPUFeature = {"avx512_fp16", "true"})
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
public void vectorSqrtFloat16() {
for (int i = 0; i < LEN; ++i) {
output[i] = float16ToRawShortBits(sqrt(shortBitsToFloat16(input1[i])));
Expand All @@ -225,7 +225,7 @@ public void checkResultSqrt() {
@Test
@Warmup(10000)
@IR(counts = {IRNode.FMA_VHF, ">= 1"},
applyIfCPUFeature = {"avx512_fp16", "true"})
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
public void vectorFmaFloat16() {
for (int i = 0; i < LEN; ++i) {
output[i] = float16ToRawShortBits(fma(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i]),
Expand All @@ -248,7 +248,7 @@ public void checkResultFma() {
@Test
@Warmup(10000)
@IR(counts = {IRNode.FMA_VHF, " >= 1"},
applyIfCPUFeature = {"avx512_fp16", "true"})
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
public void vectorFmaFloat16ScalarMixedConstants() {
for (int i = 0; i < LEN; ++i) {
output[i] = float16ToRawShortBits(fma(shortBitsToFloat16(input1[i]), shortBitsToFloat16(SCALAR_FP16),
Expand All @@ -272,7 +272,7 @@ public void checkResultFmaScalarMixedConstants() {
@Test
@Warmup(10000)
@IR(counts = {IRNode.FMA_VHF, " >= 1"},
applyIfCPUFeature = {"avx512_fp16", "true"})
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
public void vectorFmaFloat16MixedConstants() {
short input3 = floatToFloat16(3.0f);
for (int i = 0; i < LEN; ++i) {
Expand All @@ -295,7 +295,7 @@ public void checkResultFmaMixedConstants() {
@Test
@Warmup(10000)
@IR(counts = {IRNode.FMA_VHF, " 0 "},
applyIfCPUFeature = {"avx512_fp16", "true"})
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
public void vectorFmaFloat16AllConstants() {
short input1 = floatToFloat16(1.0f);
short input2 = floatToFloat16(2.0f);
Expand Down