Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable RVV GEMM/IGEMM 7 x m4 in operator config #6411

Closed
wants to merge 3 commits into from

Conversation

bhbruce
Copy link
Contributor

@bhbruce bhbruce commented May 14, 2024

This PR aims to enable RVV GEMM/IGEMM/X32-PACKW in GEMM config.
It leads to enabling RVV implementation in operator API.

@bhbruce
Copy link
Contributor Author

bhbruce commented May 14, 2024

@alankelly @fbarchard Could you help to review it?
Also, I would like to ask about what's the appropriate way to enable RVV-only nr2 selection logic in following files.

src/operators/convolution-nhwc.c
1653:  const struct xnn_gemm_config* gemm_nr2_config = xnn_init_f32_gemm_nr2_config();
src/operators/dynamic-fully-connected-nc.c
215:  const struct xnn_gemm_config* gemm_nr2_config = xnn_init_f32_gemm_nr2_config();
src/operators/fully-connected-nc.c
754:  const struct xnn_gemm_config* gemm_nr2_config = xnn_init_f32_gemm_nr2_config();
src/operators/deconvolution-nhwc.c
898:  const struct xnn_gemm_config* gemm_nr2_config = xnn_init_f32_gemm_nr2_config();

The current logic determines to use nr2_config(half nr) if gemm_config->nr > output_channels.
For the RISC-V vector, I would like to specialize in either

gemm_nr2_config->nr <= output_channels
or
gemm_config->nr / 2 <= output_channels

However, there is no arch-specifc definition macro used in src/operators/.

@fbarchard
Copy link
Contributor

nr 2 is an MRx2 GEMM - 2 floats wide.
On SSE and NEON that normally use 4 floats per vector it allows a faster GEMM.
But it is optional... any gemm can output NC of less than a full vector, and on RVV is shouldnt make a difference.

@bhbruce bhbruce force-pushed the rv-gemm-config branch 3 times, most recently from 9804699 to 603cff1 Compare May 23, 2024 10:29
@bhbruce
Copy link
Contributor Author

bhbruce commented May 23, 2024

Hi @fbarchard @alankelly
Could you help to merge this PR?

Signed-off-by: Bruce Lai <bruce.lai@sifive.com>
Signed-off-by: Bruce Lai <bruce.lai@sifive.com>
Signed-off-by: Bruce Lai <bruce.lai@sifive.com>
@fbarchard
Copy link
Contributor

Re nr2 - if you didnt have such huge vectors you wouldnt have this problem :-)

nr2 doesnt come up much, and you dont have to specialize for it, especially on rvv.
a regular gemm can do nr=2... its just handled as a remainder case.

Add an entry to
static void init_f32_gemm_nr2_config(void) {
with a pack function that can do nr=2 e.g. xnn_x32_packw_gemm_goi_ukernel_x2__scalar_float_u4
normally it would be
f32_gemm_nr2_config.nr = 2;
meaning 2 floats. hmmm... I see your issue. You want something like
// nr is set to vlen * 4 / sizeof(float) = 4 * VLENB * 8 / 32 = VLENB
f32_gemm_config.nr = hardware_config->vlenb;

what if you break from convention and fill in nr=2, meaning 2 floats = 8 bytes.
and implement the gemm using u1v. which will work most of the time.
you could check, in the gemm-config, that vlenb >= 8.
you could also check if vlenb >= 16, and configure an nr2 gemm
but considering how rarely these come up, I'd just do the basic u1v and add a todo to revisit it.

Its also possible to implement nr=2 gemm's more efficiently than the obvious. I did some for neon, using 4 floats per vector. and for nr=1 you can do 4 floats at a time. If thats possible on rvv, it would likely be faster.
I forget the exact method, but look at the 4x2-aarch64-neonfma-ld128.S.in
which does a trick to load 2 blocks at a time (4 floats) and then a paired add outside the loop.
# Main loop - 4 floats of A (16 bytes)
1:
LDR q0, [x3], 16
LD2 {v20.4s, v21.4s}, [x5], 32
LDR q1, [x11], 16
LDR q2, [x12], 16
LDR q3, [x4], 16
SUBS x0, x0, 16
FMLA v24.4s, v20.4s, v0.4s
FMLA v25.4s, v21.4s, v0.4s
FMLA v26.4s, v20.4s, v1.4s
FMLA v27.4s, v21.4s, v1.4s
FMLA v28.4s, v20.4s, v2.4s
FMLA v29.4s, v21.4s, v2.4s
FMLA v30.4s, v20.4s, v3.4s
FMLA v31.4s, v21.4s, v3.4s
B.HS 1b

    FADDP       v24.4s, v24.4s, v25.4s
    FADDP       v26.4s, v26.4s, v27.4s
    FADDP       v28.4s, v28.4s, v29.4s
    FADDP       v30.4s, v30.4s, v31.4s

@fbarchard
Copy link
Contributor

Enable RVV GEMM/IGEMM 7 x m4 is landed in #7035
you can close this PR and if add an nr2 enable as followup

@bhbruce
Copy link
Contributor Author

bhbruce commented Sep 6, 2024

@fbarchard Thanks for your help.

@bhbruce bhbruce closed this Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants