Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

riscv64: Improve pattern matching rules for FMA #8596

Merged
merged 2 commits into from
May 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 49 additions & 50 deletions cranelift/codegen/src/isa/riscv64/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -1509,56 +1509,55 @@

;;;; Rules for `fma` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

;; fmadd: rs1 * rs2 + rs3
(rule 0 (lower (has_type (ty_scalar_float ty) (fma x y z)))
(rv_fmadd ty (FRM.RNE) x y z))

;; fmsub: rs1 * rs2 - rs3
(rule 1 (lower (has_type (ty_scalar_float ty) (fma x y (fneg z))))
(rv_fmsub ty (FRM.RNE) x y z))

;; fnmsub: -rs1 * rs2 + rs3
(rule 2 (lower (has_type (ty_scalar_float ty) (fma (fneg x) y z)))
(rv_fnmsub ty (FRM.RNE) x y z))

;; fnmadd: -rs1 * rs2 - rs3
(rule 3 (lower (has_type (ty_scalar_float ty) (fma (fneg x) y (fneg z))))
(rv_fnmadd ty (FRM.RNE) x y z))

;; (fma x y z) computes x * y + z
;; vfmacc computes vd[i] = +(vs1[i] * vs2[i]) + vd[i]
;; We need to reverse the order of the arguments

(rule 4 (lower (has_type (ty_vec_fits_in_register ty) (fma x y z)))
(rv_vfmacc_vv z y x (unmasked) ty))

(rule 5 (lower (has_type (ty_vec_fits_in_register ty) (fma (splat x) y z)))
(rv_vfmacc_vf z y x (unmasked) ty))

;; vfmsac computes vd[i] = +(vs1[i] * vs2[i]) - vd[i]

(rule 6 (lower (has_type (ty_vec_fits_in_register ty) (fma x y (fneg z))))
(rv_vfmsac_vv z y x (unmasked) ty))

(rule 9 (lower (has_type (ty_vec_fits_in_register ty) (fma (splat x) y (fneg z))))
(rv_vfmsac_vf z y x (unmasked) ty))

;; vfnmacc computes vd[i] = -(vs1[i] * vs2[i]) - vd[i]

(rule 7 (lower (has_type (ty_vec_fits_in_register ty) (fma (fneg x) y (fneg z))))
(rv_vfnmacc_vv z y x (unmasked) ty))

(rule 9 (lower (has_type (ty_vec_fits_in_register ty) (fma (fneg (splat x)) y (fneg z))))
(rv_vfnmacc_vf z y x (unmasked) ty))

;; vfnmsac computes vd[i] = -(vs1[i] * vs2[i]) + vd[i]

(rule 5 (lower (has_type (ty_vec_fits_in_register ty) (fma (fneg x) y z)))
(rv_vfnmsac_vv z y x (unmasked) ty))

(rule 8 (lower (has_type (ty_vec_fits_in_register ty) (fma (fneg (splat x)) y z)))
(rv_vfnmsac_vf z y x (unmasked) ty))

;; RISC-V has 4 FMA instructions that do a slightly different computation.
;;
;; fmadd: (rs1 * rs2) + rs3
;; fmsub: (rs1 * rs2) - rs3
;; fnmadd: -(rs1 * rs2) - rs3
;; fnmsub: -(rs1 * rs2) + rs3
;;
;; Additionally there are vector versions of these instructions with slightly different names.
;; The vector instructions also have two variants each. `.vv` and `.vf`, where `.vv` variants
;; take two vector operands and the `.vf` variants take a vector operand and a scalar operand.
;;
;; Due to this, variation they recieve the arguments in a different order. So we need to swap
;; the arguments below.
;;
;; vfmacc: vd[i] = +(vs1[i] * vs2[i]) + vd[i]
;; vfmsac: vd[i] = +(vs1[i] * vs2[i]) - vd[i]
;; vfnmacc: vd[i] = -(vs1[i] * vs2[i]) - vd[i]
;; vfnmsac: vd[i] = -(vs1[i] * vs2[i]) + vd[i]

(type IsFneg (enum (Result (negate u64) (value Value))))

(decl pure is_fneg (Value) IsFneg)
(rule 1 (is_fneg (fneg x)) (IsFneg.Result 1 x))
(rule 0 (is_fneg x) (IsFneg.Result 0 x))

(rule (lower (has_type ty (fma x_src y_src z_src)))
(if-let (IsFneg.Result neg_x x) (is_fneg x_src))
(if-let (IsFneg.Result neg_y y) (is_fneg y_src))
(if-let (IsFneg.Result neg_z z) (is_fneg z_src))
(rv_fma ty (u64_xor neg_x neg_y) neg_z x y z))

; parity arguments indicate whether to negate the x*y term or the z term, respectively
(decl rv_fma (Type u64 u64 Value Value Value) InstOutput)
(rule 0 (rv_fma (ty_scalar_float ty) 0 0 x y z) (rv_fmadd ty (FRM.RNE) x y z))
(rule 0 (rv_fma (ty_scalar_float ty) 0 1 x y z) (rv_fmsub ty (FRM.RNE) x y z))
(rule 0 (rv_fma (ty_scalar_float ty) 1 0 x y z) (rv_fnmsub ty (FRM.RNE) x y z))
(rule 0 (rv_fma (ty_scalar_float ty) 1 1 x y z) (rv_fnmadd ty (FRM.RNE) x y z))
(rule 1 (rv_fma (ty_vec_fits_in_register ty) 0 0 x y z) (rv_vfmacc_vv z y x (unmasked) ty))
(rule 1 (rv_fma (ty_vec_fits_in_register ty) 0 1 x y z) (rv_vfmsac_vv z y x (unmasked) ty))
(rule 1 (rv_fma (ty_vec_fits_in_register ty) 1 0 x y z) (rv_vfnmsac_vv z y x (unmasked) ty))
(rule 1 (rv_fma (ty_vec_fits_in_register ty) 1 1 x y z) (rv_vfnmacc_vv z y x (unmasked) ty))
(rule 2 (rv_fma (ty_vec_fits_in_register ty) 0 0 (splat x) y z) (rv_vfmacc_vf z y x (unmasked) ty))
(rule 2 (rv_fma (ty_vec_fits_in_register ty) 0 1 (splat x) y z) (rv_vfmsac_vf z y x (unmasked) ty))
(rule 2 (rv_fma (ty_vec_fits_in_register ty) 1 0 (splat x) y z) (rv_vfnmsac_vf z y x (unmasked) ty))
(rule 2 (rv_fma (ty_vec_fits_in_register ty) 1 1 (splat x) y z) (rv_vfnmacc_vf z y x (unmasked) ty))
(rule 3 (rv_fma (ty_vec_fits_in_register ty) 0 0 x (splat y) z) (rv_vfmacc_vf z x y (unmasked) ty))
(rule 3 (rv_fma (ty_vec_fits_in_register ty) 0 1 x (splat y) z) (rv_vfmsac_vf z x y (unmasked) ty))
(rule 3 (rv_fma (ty_vec_fits_in_register ty) 1 0 x (splat y) z) (rv_vfnmsac_vf z x y (unmasked) ty))
(rule 3 (rv_fma (ty_vec_fits_in_register ty) 1 1 x (splat y) z) (rv_vfnmacc_vf z x y (unmasked) ty))

;;;; Rules for `sqrt` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule 0 (lower (has_type (ty_scalar_float ty) (sqrt x)))
Expand Down
87 changes: 87 additions & 0 deletions cranelift/filetests/filetests/isa/riscv64/simd-fma.clif
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,90 @@ block0(v0: f64, v1: f64x2, v2: f64x2):
; addi sp, sp, 0x10
; ret


function %fma_splat_y_f32x4(f32x4, f32, f32x4) -> f32x4 {
block0(v0: f32x4, v1: f32, v2: f32x4):
v3 = splat.f32x4 v1
v4 = fma v0, v3, v2
return v4
}

; VCode:
; addi sp,sp,-16
; sd ra,8(sp)
; sd fp,0(sp)
; mv fp,sp
; block0:
; vle8.v v9,-32(incoming_arg) #avl=16, #vtype=(e8, m1, ta, ma)
; vle8.v v14,-16(incoming_arg) #avl=16, #vtype=(e8, m1, ta, ma)
; vfmacc.vf v14,v9,fa0 #avl=4, #vtype=(e32, m1, ta, ma)
; vse8.v v14,0(a0) #avl=16, #vtype=(e8, m1, ta, ma)
; ld ra,8(sp)
; ld fp,0(sp)
; addi sp,sp,16
; ret
;
; Disassembled:
; block0: ; offset 0x0
; addi sp, sp, -0x10
; sd ra, 8(sp)
; sd s0, 0(sp)
; mv s0, sp
; block1: ; offset 0x10
; .byte 0x57, 0x70, 0x08, 0xcc
; addi t6, sp, 0x10
; .byte 0x87, 0x84, 0x0f, 0x02
; addi t6, sp, 0x20
; .byte 0x07, 0x87, 0x0f, 0x02
; .byte 0x57, 0x70, 0x02, 0xcd
; .byte 0x57, 0x57, 0x95, 0xb2
; .byte 0x57, 0x70, 0x08, 0xcc
; .byte 0x27, 0x07, 0x05, 0x02
; ld ra, 8(sp)
; ld s0, 0(sp)
; addi sp, sp, 0x10
; ret

function %fma_splat_y_f64x2(f64x2, f64, f64x2) -> f64x2 {
block0(v0: f64x2, v1: f64, v2: f64x2):
v3 = splat.f64x2 v1
v4 = fma v0, v3, v2
return v4
}

; VCode:
; addi sp,sp,-16
; sd ra,8(sp)
; sd fp,0(sp)
; mv fp,sp
; block0:
; vle8.v v9,-32(incoming_arg) #avl=16, #vtype=(e8, m1, ta, ma)
; vle8.v v14,-16(incoming_arg) #avl=16, #vtype=(e8, m1, ta, ma)
; vfmacc.vf v14,v9,fa0 #avl=2, #vtype=(e64, m1, ta, ma)
; vse8.v v14,0(a0) #avl=16, #vtype=(e8, m1, ta, ma)
; ld ra,8(sp)
; ld fp,0(sp)
; addi sp,sp,16
; ret
;
; Disassembled:
; block0: ; offset 0x0
; addi sp, sp, -0x10
; sd ra, 8(sp)
; sd s0, 0(sp)
; mv s0, sp
; block1: ; offset 0x10
; .byte 0x57, 0x70, 0x08, 0xcc
; addi t6, sp, 0x10
; .byte 0x87, 0x84, 0x0f, 0x02
; addi t6, sp, 0x20
; .byte 0x07, 0x87, 0x0f, 0x02
; .byte 0x57, 0x70, 0x81, 0xcd
; .byte 0x57, 0x57, 0x95, 0xb2
; .byte 0x57, 0x70, 0x08, 0xcc
; .byte 0x27, 0x07, 0x05, 0x02
; ld ra, 8(sp)
; ld s0, 0(sp)
; addi sp, sp, 0x10
; ret