Skip to content

Commit 92fd449

Browse files
author
Hamlin Li
committed
8350960: RISC-V: Add riscv backend for Float16 operations - vectorization
Reviewed-by: fyang, dzhang, luhenry
1 parent 2e26b43 commit 92fd449

File tree

3 files changed

+156
-15
lines changed

3 files changed

+156
-15
lines changed

src/hotspot/cpu/riscv/riscv.ad

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1933,8 +1933,8 @@ bool Matcher::match_rule_supported(int opcode) {
19331933
case Op_MaxHF:
19341934
case Op_MinHF:
19351935
case Op_MulHF:
1936-
case Op_SubHF:
19371936
case Op_SqrtHF:
1937+
case Op_SubHF:
19381938
return UseZfh;
19391939

19401940
case Op_CMoveF:

src/hotspot/cpu/riscv/riscv_v.ad

Lines changed: 144 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,6 @@ source %{
9595
return false;
9696
}
9797
break;
98-
case Op_VectorCastHF2F:
99-
case Op_VectorCastF2HF:
100-
return UseZvfh;
10198
case Op_VectorLoadShuffle:
10299
case Op_VectorRearrange:
103100
// vlen >= 4 is required, because min vector size for byte is 4 on riscv,
@@ -113,6 +110,18 @@ source %{
113110
if (vlen < 4) {
114111
return false;
115112
}
113+
case Op_VectorCastHF2F:
114+
case Op_VectorCastF2HF:
115+
case Op_AddVHF:
116+
case Op_DivVHF:
117+
case Op_MaxVHF:
118+
case Op_MinVHF:
119+
case Op_MulVHF:
120+
case Op_SqrtVHF:
121+
case Op_SubVHF:
122+
return UseZvfh;
123+
case Op_FmaVHF:
124+
return UseZvfh && UseFMA;
116125
default:
117126
break;
118127
}
@@ -363,6 +372,21 @@ instruct vadd(vReg dst, vReg src1, vReg src2) %{
363372
ins_pipe(pipe_slow);
364373
%}
365374

375+
instruct vadd_hfp(vReg dst, vReg src1, vReg src2) %{
376+
match(Set dst (AddVHF src1 src2));
377+
ins_cost(VEC_COST);
378+
format %{ "vadd_hfp $dst, $src1, $src2" %}
379+
ins_encode %{
380+
assert(UseZvfh, "must");
381+
assert(Matcher::vector_element_basic_type(this) == T_SHORT, "must");
382+
__ vsetvli_helper(T_SHORT, Matcher::vector_length(this));
383+
__ vfadd_vv(as_VectorRegister($dst$$reg),
384+
as_VectorRegister($src1$$reg),
385+
as_VectorRegister($src2$$reg));
386+
%}
387+
ins_pipe(pipe_slow);
388+
%}
389+
366390
instruct vadd_fp(vReg dst, vReg src1, vReg src2) %{
367391
match(Set dst (AddVF src1 src2));
368392
match(Set dst (AddVD src1 src2));
@@ -546,6 +570,20 @@ instruct vsub(vReg dst, vReg src1, vReg src2) %{
546570
ins_pipe(pipe_slow);
547571
%}
548572

573+
instruct vsub_hfp(vReg dst, vReg src1, vReg src2) %{
574+
match(Set dst (SubVHF src1 src2));
575+
ins_cost(VEC_COST);
576+
format %{ "vsub_hfp $dst, $src1, $src2" %}
577+
ins_encode %{
578+
assert(UseZvfh, "must");
579+
assert(Matcher::vector_element_basic_type(this) == T_SHORT, "must");
580+
__ vsetvli_helper(T_SHORT, Matcher::vector_length(this));
581+
__ vfsub_vv(as_VectorRegister($dst$$reg), as_VectorRegister($src1$$reg),
582+
as_VectorRegister($src2$$reg));
583+
%}
584+
ins_pipe(pipe_slow);
585+
%}
586+
549587
instruct vsub_fp(vReg dst, vReg src1, vReg src2) %{
550588
match(Set dst (SubVF src1 src2));
551589
match(Set dst (SubVD src1 src2));
@@ -1542,6 +1580,21 @@ instruct vnotL_masked(vReg dst_src, immI_M1 m1, vRegMask_V0 v0) %{
15421580

15431581
// vector float div
15441582

1583+
instruct vdiv_hfp(vReg dst, vReg src1, vReg src2) %{
1584+
match(Set dst (DivVHF src1 src2));
1585+
ins_cost(VEC_COST);
1586+
format %{ "vdiv_hfp $dst, $src1, $src2" %}
1587+
ins_encode %{
1588+
assert(UseZvfh, "must");
1589+
assert(Matcher::vector_element_basic_type(this) == T_SHORT, "must");
1590+
__ vsetvli_helper(T_SHORT, Matcher::vector_length(this));
1591+
__ vfdiv_vv(as_VectorRegister($dst$$reg),
1592+
as_VectorRegister($src1$$reg),
1593+
as_VectorRegister($src2$$reg));
1594+
%}
1595+
ins_pipe(pipe_slow);
1596+
%}
1597+
15451598
instruct vdiv_fp(vReg dst, vReg src1, vReg src2) %{
15461599
match(Set dst (DivVF src1 src2));
15471600
match(Set dst (DivVD src1 src2));
@@ -1698,6 +1751,38 @@ instruct vminu_masked(vReg dst_src1, vReg src2, vRegMask_V0 v0) %{
16981751
ins_pipe(pipe_slow);
16991752
%}
17001753

1754+
// vector float-point max/min (half precision)
1755+
1756+
instruct vmax_hfp(vReg dst, vReg src1, vReg src2, vRegMask_V0 v0) %{
1757+
match(Set dst (MaxVHF src1 src2));
1758+
effect(TEMP_DEF dst, TEMP v0);
1759+
ins_cost(VEC_COST);
1760+
format %{ "vmax_hfp $dst, $src1, $src2" %}
1761+
ins_encode %{
1762+
assert(UseZvfh, "must");
1763+
assert(Matcher::vector_element_basic_type(this) == T_SHORT, "must");
1764+
__ minmax_fp_v(as_VectorRegister($dst$$reg),
1765+
as_VectorRegister($src1$$reg), as_VectorRegister($src2$$reg),
1766+
T_SHORT, false /* is_min */, Matcher::vector_length(this));
1767+
%}
1768+
ins_pipe(pipe_slow);
1769+
%}
1770+
1771+
instruct vmin_hfp(vReg dst, vReg src1, vReg src2, vRegMask_V0 v0) %{
1772+
match(Set dst (MinVHF src1 src2));
1773+
effect(TEMP_DEF dst, TEMP v0);
1774+
ins_cost(VEC_COST);
1775+
format %{ "vmin_hfp $dst, $src1, $src2" %}
1776+
ins_encode %{
1777+
assert(UseZvfh, "must");
1778+
assert(Matcher::vector_element_basic_type(this) == T_SHORT, "must");
1779+
__ minmax_fp_v(as_VectorRegister($dst$$reg),
1780+
as_VectorRegister($src1$$reg), as_VectorRegister($src2$$reg),
1781+
T_SHORT, true /* is_min */, Matcher::vector_length(this));
1782+
%}
1783+
ins_pipe(pipe_slow);
1784+
%}
1785+
17011786
// vector float-point max/min
17021787

17031788
instruct vmax_fp(vReg dst, vReg src1, vReg src2, vRegMask_V0 v0) %{
@@ -1770,6 +1855,22 @@ instruct vmin_fp_masked(vReg dst_src1, vReg src2, vRegMask vmask, vReg tmp1, vRe
17701855

17711856
// vector fmla
17721857

1858+
// dst_src1 = src2 * src3 + dst_src1 (half precision)
1859+
instruct vhfmla(vReg dst_src1, vReg src2, vReg src3) %{
1860+
match(Set dst_src1 (FmaVHF dst_src1 (Binary src2 src3)));
1861+
ins_cost(VEC_COST);
1862+
format %{ "vhfmla $dst_src1, $dst_src1, $src2, $src3" %}
1863+
ins_encode %{
1864+
assert(UseFMA, "Needs FMA instructions support.");
1865+
assert(UseZvfh, "must");
1866+
assert(Matcher::vector_element_basic_type(this) == T_SHORT, "must");
1867+
__ vsetvli_helper(T_SHORT, Matcher::vector_length(this));
1868+
__ vfmacc_vv(as_VectorRegister($dst_src1$$reg),
1869+
as_VectorRegister($src2$$reg), as_VectorRegister($src3$$reg));
1870+
%}
1871+
ins_pipe(pipe_slow);
1872+
%}
1873+
17731874
// dst_src1 = src2 * src3 + dst_src1
17741875
instruct vfmla(vReg dst_src1, vReg src2, vReg src3) %{
17751876
match(Set dst_src1 (FmaVF dst_src1 (Binary src2 src3)));
@@ -2038,6 +2139,20 @@ instruct vmul(vReg dst, vReg src1, vReg src2) %{
20382139
ins_pipe(pipe_slow);
20392140
%}
20402141

2142+
instruct vmul_hfp(vReg dst, vReg src1, vReg src2) %{
2143+
match(Set dst (MulVHF src1 src2));
2144+
ins_cost(VEC_COST);
2145+
format %{ "vmul_hfp $dst, $src1, $src2" %}
2146+
ins_encode %{
2147+
assert(UseZvfh, "must");
2148+
assert(Matcher::vector_element_basic_type(this) == T_SHORT, "must");
2149+
__ vsetvli_helper(T_SHORT, Matcher::vector_length(this));
2150+
__ vfmul_vv(as_VectorRegister($dst$$reg), as_VectorRegister($src1$$reg),
2151+
as_VectorRegister($src2$$reg));
2152+
%}
2153+
ins_pipe(pipe_slow);
2154+
%}
2155+
20412156
instruct vmul_fp(vReg dst, vReg src1, vReg src2) %{
20422157
match(Set dst (MulVF src1 src2));
20432158
match(Set dst (MulVD src1 src2));
@@ -2971,6 +3086,19 @@ instruct replicateL_imm5(vReg dst, immL5 con) %{
29713086
ins_pipe(pipe_slow);
29723087
%}
29733088

3089+
instruct replicateHF(vReg dst, fRegF src) %{
3090+
predicate(Matcher::vector_element_basic_type(n) == T_SHORT);
3091+
match(Set dst (Replicate src));
3092+
ins_cost(VEC_COST);
3093+
format %{ "replicateHF $dst, $src" %}
3094+
ins_encode %{
3095+
assert(UseZvfh, "must");
3096+
__ vsetvli_helper(T_SHORT, Matcher::vector_length(this));
3097+
__ vfmv_v_f(as_VectorRegister($dst$$reg), $src$$FloatRegister);
3098+
%}
3099+
ins_pipe(pipe_slow);
3100+
%}
3101+
29743102
instruct replicateF(vReg dst, fRegF src) %{
29753103
predicate(Matcher::vector_element_basic_type(n) == T_FLOAT);
29763104
match(Set dst (Replicate src));
@@ -4014,6 +4142,19 @@ instruct vrotate_left_vi_masked(vReg dst_src, immI shift, vRegMask_V0 v0) %{
40144142

40154143
// vector sqrt
40164144

4145+
instruct vsqrt_hfp(vReg dst, vReg src) %{
4146+
match(Set dst (SqrtVHF src));
4147+
ins_cost(VEC_COST);
4148+
format %{ "vsqrt_hfp $dst, $src" %}
4149+
ins_encode %{
4150+
assert(UseZvfh, "must");
4151+
assert(Matcher::vector_element_basic_type(this) == T_SHORT, "must");
4152+
__ vsetvli_helper(T_SHORT, Matcher::vector_length(this));
4153+
__ vfsqrt_v(as_VectorRegister($dst$$reg), as_VectorRegister($src$$reg));
4154+
%}
4155+
ins_pipe(pipe_slow);
4156+
%}
4157+
40174158
instruct vsqrt_fp(vReg dst, vReg src) %{
40184159
match(Set dst (SqrtVF src));
40194160
match(Set dst (SqrtVD src));

test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorOperations.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public TestFloat16VectorOperations() {
7878
@Test
7979
@Warmup(10000)
8080
@IR(counts = {IRNode.ADD_VHF, ">= 1"},
81-
applyIfCPUFeature = {"avx512_fp16", "true"})
81+
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
8282
public void vectorAddFloat16() {
8383
for (int i = 0; i < LEN; ++i) {
8484
output[i] = float16ToRawShortBits(add(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i])));
@@ -99,7 +99,7 @@ public void checkResultAdd() {
9999
@Test
100100
@Warmup(10000)
101101
@IR(counts = {IRNode.SUB_VHF, ">= 1"},
102-
applyIfCPUFeature = {"avx512_fp16", "true"})
102+
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
103103
public void vectorSubFloat16() {
104104
for (int i = 0; i < LEN; ++i) {
105105
output[i] = float16ToRawShortBits(subtract(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i])));
@@ -120,7 +120,7 @@ public void checkResultSub() {
120120
@Test
121121
@Warmup(10000)
122122
@IR(counts = {IRNode.MUL_VHF, ">= 1"},
123-
applyIfCPUFeature = {"avx512_fp16", "true"})
123+
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
124124
public void vectorMulFloat16() {
125125
for (int i = 0; i < LEN; ++i) {
126126
output[i] = float16ToRawShortBits(multiply(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i])));
@@ -141,7 +141,7 @@ public void checkResultMul() {
141141
@Test
142142
@Warmup(10000)
143143
@IR(counts = {IRNode.DIV_VHF, ">= 1"},
144-
applyIfCPUFeature = {"avx512_fp16", "true"})
144+
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
145145
public void vectorDivFloat16() {
146146
for (int i = 0; i < LEN; ++i) {
147147
output[i] = float16ToRawShortBits(divide(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i])));
@@ -162,7 +162,7 @@ public void checkResultDiv() {
162162
@Test
163163
@Warmup(10000)
164164
@IR(counts = {IRNode.MIN_VHF, ">= 1"},
165-
applyIfCPUFeature = {"avx512_fp16", "true"})
165+
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
166166
public void vectorMinFloat16() {
167167
for (int i = 0; i < LEN; ++i) {
168168
output[i] = float16ToRawShortBits(min(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i])));
@@ -183,7 +183,7 @@ public void checkResultMin() {
183183
@Test
184184
@Warmup(10000)
185185
@IR(counts = {IRNode.MAX_VHF, ">= 1"},
186-
applyIfCPUFeature = {"avx512_fp16", "true"})
186+
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
187187
public void vectorMaxFloat16() {
188188
for (int i = 0; i < LEN; ++i) {
189189
output[i] = float16ToRawShortBits(max(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i])));
@@ -204,7 +204,7 @@ public void checkResultMax() {
204204
@Test
205205
@Warmup(10000)
206206
@IR(counts = {IRNode.SQRT_VHF, ">= 1"},
207-
applyIfCPUFeature = {"avx512_fp16", "true"})
207+
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
208208
public void vectorSqrtFloat16() {
209209
for (int i = 0; i < LEN; ++i) {
210210
output[i] = float16ToRawShortBits(sqrt(shortBitsToFloat16(input1[i])));
@@ -225,7 +225,7 @@ public void checkResultSqrt() {
225225
@Test
226226
@Warmup(10000)
227227
@IR(counts = {IRNode.FMA_VHF, ">= 1"},
228-
applyIfCPUFeature = {"avx512_fp16", "true"})
228+
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
229229
public void vectorFmaFloat16() {
230230
for (int i = 0; i < LEN; ++i) {
231231
output[i] = float16ToRawShortBits(fma(shortBitsToFloat16(input1[i]), shortBitsToFloat16(input2[i]),
@@ -248,7 +248,7 @@ public void checkResultFma() {
248248
@Test
249249
@Warmup(10000)
250250
@IR(counts = {IRNode.FMA_VHF, " >= 1"},
251-
applyIfCPUFeature = {"avx512_fp16", "true"})
251+
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
252252
public void vectorFmaFloat16ScalarMixedConstants() {
253253
for (int i = 0; i < LEN; ++i) {
254254
output[i] = float16ToRawShortBits(fma(shortBitsToFloat16(input1[i]), shortBitsToFloat16(SCALAR_FP16),
@@ -272,7 +272,7 @@ public void checkResultFmaScalarMixedConstants() {
272272
@Test
273273
@Warmup(10000)
274274
@IR(counts = {IRNode.FMA_VHF, " >= 1"},
275-
applyIfCPUFeature = {"avx512_fp16", "true"})
275+
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
276276
public void vectorFmaFloat16MixedConstants() {
277277
short input3 = floatToFloat16(3.0f);
278278
for (int i = 0; i < LEN; ++i) {
@@ -295,7 +295,7 @@ public void checkResultFmaMixedConstants() {
295295
@Test
296296
@Warmup(10000)
297297
@IR(counts = {IRNode.FMA_VHF, " 0 "},
298-
applyIfCPUFeature = {"avx512_fp16", "true"})
298+
applyIfCPUFeatureOr = {"avx512_fp16", "true", "zvfh", "true"})
299299
public void vectorFmaFloat16AllConstants() {
300300
short input1 = floatToFloat16(1.0f);
301301
short input2 = floatToFloat16(2.0f);

0 commit comments

Comments
 (0)