Skip to content

Commit

Permalink
[AArch64][SVE] Predicate bfloat16 load patterns with HasBF16
Browse files Browse the repository at this point in the history
Reviewers: sdesmalen, c-rhodes, efriedma, fpetrogalli

Reviewed By: fpetrogalli

Subscribers: tschuett, kristof.beyls, hiraditya, rkruppe, psnobl, danielkiss, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D82464
  • Loading branch information
kmclaughlin-arm committed Jun 26, 2020
1 parent c65d4eb commit 0ccfe1b
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 12 deletions.
5 changes: 5 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Expand Up @@ -12074,6 +12074,11 @@ static SDValue performLDNT1Combine(SDNode *N, SelectionDAG &DAG) {
EVT VT = N->getValueType(0);
EVT PtrTy = N->getOperand(3).getValueType();

if (VT == MVT::nxv8bf16)
assert(
static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16() &&
"Unsupported type (BF16)");

EVT LoadVT = VT;
if (VT.isFloatingPoint())
LoadVT = VT.changeTypeToInteger();
Expand Down
20 changes: 16 additions & 4 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Expand Up @@ -1550,7 +1550,10 @@ multiclass sve_prefetch<SDPatternOperator prefetch, ValueType PredTy, Instructio
defm : pred_load<nxv8i16, nxv8i1, asext_masked_load_i8, LD1SB_H, LD1SB_H_IMM, am_sve_regreg_lsl0>;
defm : pred_load<nxv8i16, nxv8i1, nonext_masked_load, LD1H, LD1H_IMM, am_sve_regreg_lsl1>;
defm : pred_load<nxv8f16, nxv8i1, nonext_masked_load, LD1H, LD1H_IMM, am_sve_regreg_lsl1>;
defm : pred_load<nxv8bf16, nxv8i1, nonext_masked_load, LD1H, LD1H_IMM, am_sve_regreg_lsl1>;

let Predicates = [HasBF16, HasSVE] in {
defm : pred_load<nxv8bf16, nxv8i1, nonext_masked_load, LD1H, LD1H_IMM, am_sve_regreg_lsl1>;
}

// 16-element contiguous loads
defm : pred_load<nxv16i8, nxv16i1, nonext_masked_load, LD1B, LD1B_IMM, am_sve_regreg_lsl0>;
Expand Down Expand Up @@ -1737,7 +1740,10 @@ multiclass sve_prefetch<SDPatternOperator prefetch, ValueType PredTy, Instructio
defm : ld1<LD1SB_H, LD1SB_H_IMM, nxv8i16, AArch64ld1s, nxv8i1, nxv8i8, am_sve_regreg_lsl0>;
defm : ld1<LD1H, LD1H_IMM, nxv8i16, AArch64ld1, nxv8i1, nxv8i16, am_sve_regreg_lsl1>;
defm : ld1<LD1H, LD1H_IMM, nxv8f16, AArch64ld1, nxv8i1, nxv8f16, am_sve_regreg_lsl1>;
defm : ld1<LD1H, LD1H_IMM, nxv8bf16, AArch64ld1, nxv8i1, nxv8bf16, am_sve_regreg_lsl1>;

let Predicates = [HasBF16, HasSVE] in {
defm : ld1<LD1H, LD1H_IMM, nxv8bf16, AArch64ld1, nxv8i1, nxv8bf16, am_sve_regreg_lsl1>;
}

// 16-element contiguous loads
defm : ld1<LD1B, LD1B_IMM, nxv16i8, AArch64ld1, nxv16i1, nxv16i8, am_sve_regreg_lsl0>;
Expand Down Expand Up @@ -1777,7 +1783,10 @@ multiclass sve_prefetch<SDPatternOperator prefetch, ValueType PredTy, Instructio
defm : ldnf1<LDNF1SB_H_IMM, nxv8i16, AArch64ldnf1s, nxv8i1, nxv8i8>;
defm : ldnf1<LDNF1H_IMM, nxv8i16, AArch64ldnf1, nxv8i1, nxv8i16>;
defm : ldnf1<LDNF1H_IMM, nxv8f16, AArch64ldnf1, nxv8i1, nxv8f16>;
defm : ldnf1<LDNF1H_IMM, nxv8bf16, AArch64ldnf1, nxv8i1, nxv8bf16>;

let Predicates = [HasBF16, HasSVE] in {
defm : ldnf1<LDNF1H_IMM, nxv8bf16, AArch64ldnf1, nxv8i1, nxv8bf16>;
}

// 16-element contiguous non-faulting loads
defm : ldnf1<LDNF1B_IMM, nxv16i8, AArch64ldnf1, nxv16i1, nxv16i8>;
Expand Down Expand Up @@ -1818,7 +1827,10 @@ multiclass sve_prefetch<SDPatternOperator prefetch, ValueType PredTy, Instructio
defm : ldff1<LDFF1SB_H, nxv8i16, AArch64ldff1s, nxv8i1, nxv8i8, am_sve_regreg_lsl0>;
defm : ldff1<LDFF1H, nxv8i16, AArch64ldff1, nxv8i1, nxv8i16, am_sve_regreg_lsl1>;
defm : ldff1<LDFF1H, nxv8f16, AArch64ldff1, nxv8i1, nxv8f16, am_sve_regreg_lsl1>;
defm : ldff1<LDFF1H, nxv8bf16, AArch64ldff1, nxv8i1, nxv8bf16, am_sve_regreg_lsl1>;

let Predicates = [HasBF16, HasSVE] in {
defm : ldff1<LDFF1H, nxv8bf16, AArch64ldff1, nxv8i1, nxv8bf16, am_sve_regreg_lsl1>;
}

// 16-element contiguous first faulting loads
defm : ldff1<LDFF1B, nxv16i8, AArch64ldff1, nxv16i1, nxv16i8, am_sve_regreg_lsl0>;
Expand Down
Expand Up @@ -207,7 +207,7 @@ define <vscale x 8 x half> @ld1h_f16_inbound(<vscale x 8 x i1> %pg, half* %a) {
ret <vscale x 8 x half> %load
}

define <vscale x 8 x bfloat> @ld1h_bf16_inbound(<vscale x 8 x i1> %pg, bfloat* %a) {
define <vscale x 8 x bfloat> @ld1h_bf16_inbound(<vscale x 8 x i1> %pg, bfloat* %a) #0 {
; CHECK-LABEL: ld1h_bf16_inbound:
; CHECK: ld1h { z0.h }, p0/z, [x0, #1, mul vl]
; CHECK-NEXT: ret
Expand Down Expand Up @@ -311,3 +311,6 @@ declare <vscale x 2 x i16> @llvm.aarch64.sve.ld1.nxv2i16(<vscale x 2 x i1>, i16*
declare <vscale x 2 x i32> @llvm.aarch64.sve.ld1.nxv2i32(<vscale x 2 x i1>, i32*)
declare <vscale x 2 x i64> @llvm.aarch64.sve.ld1.nxv2i64(<vscale x 2 x i1>, i64*)
declare <vscale x 2 x double> @llvm.aarch64.sve.ld1.nxv2f64(<vscale x 2 x i1>, double*)

; +bf16 is required for the bfloat version.
attributes #0 = { "target-features"="+sve,+bf16" }
Expand Up @@ -95,7 +95,7 @@ define <vscale x 8 x half> @ld1h_f16(<vscale x 8 x i1> %pg, half* %a, i64 %index
ret <vscale x 8 x half> %load
}

define <vscale x 8 x bfloat> @ld1h_bf16(<vscale x 8 x i1> %pg, bfloat* %a, i64 %index) {
define <vscale x 8 x bfloat> @ld1h_bf16(<vscale x 8 x i1> %pg, bfloat* %a, i64 %index) #0 {
; CHECK-LABEL: ld1h_bf16
; CHECK: ld1h { z0.h }, p0/z, [x0, x1, lsl #1]
; CHECK-NEXT: ret
Expand Down Expand Up @@ -225,3 +225,6 @@ declare <vscale x 2 x i16> @llvm.aarch64.sve.ld1.nxv2i16(<vscale x 2 x i1>, i16*
declare <vscale x 2 x i32> @llvm.aarch64.sve.ld1.nxv2i32(<vscale x 2 x i1>, i32*)
declare <vscale x 2 x i64> @llvm.aarch64.sve.ld1.nxv2i64(<vscale x 2 x i1>, i64*)
declare <vscale x 2 x double> @llvm.aarch64.sve.ld1.nxv2f64(<vscale x 2 x i1>, double*)

; +bf16 is required for the bfloat version.
attributes #0 = { "target-features"="+sve,+bf16" }
5 changes: 4 additions & 1 deletion llvm/test/CodeGen/AArch64/sve-intrinsics-ld1.ll
Expand Up @@ -87,7 +87,7 @@ define <vscale x 8 x half> @ld1h_f16(<vscale x 8 x i1> %pred, half* %addr) {
ret <vscale x 8 x half> %res
}

define <vscale x 8 x bfloat> @ld1h_bf16(<vscale x 8 x i1> %pred, bfloat* %addr) {
define <vscale x 8 x bfloat> @ld1h_bf16(<vscale x 8 x i1> %pred, bfloat* %addr) #0 {
; CHECK-LABEL: ld1h_bf16:
; CHECK: ld1h { z0.h }, p0/z, [x0]
; CHECK-NEXT: ret
Expand Down Expand Up @@ -208,3 +208,6 @@ declare <vscale x 2 x i16> @llvm.aarch64.sve.ld1.nxv2i16(<vscale x 2 x i1>, i16*
declare <vscale x 2 x i32> @llvm.aarch64.sve.ld1.nxv2i32(<vscale x 2 x i1>, i32*)
declare <vscale x 2 x i64> @llvm.aarch64.sve.ld1.nxv2i64(<vscale x 2 x i1>, i64*)
declare <vscale x 2 x double> @llvm.aarch64.sve.ld1.nxv2f64(<vscale x 2 x i1>, double*)

; +bf16 is required for the bfloat version.
attributes #0 = { "target-features"="+sve,+bf16" }
7 changes: 5 additions & 2 deletions llvm/test/CodeGen/AArch64/sve-intrinsics-loads-ff.ll
Expand Up @@ -206,7 +206,7 @@ define <vscale x 8 x half> @ldff1h_f16(<vscale x 8 x i1> %pg, half* %a) {
ret <vscale x 8 x half> %load
}

define <vscale x 8 x bfloat> @ldff1h_bf16(<vscale x 8 x i1> %pg, bfloat* %a) {
define <vscale x 8 x bfloat> @ldff1h_bf16(<vscale x 8 x i1> %pg, bfloat* %a) #0 {
; CHECK-LABEL: ldff1h_bf16:
; CHECK: ldff1h { z0.h }, p0/z, [x0]
; CHECK-NEXT: ret
Expand All @@ -223,7 +223,7 @@ define <vscale x 8 x half> @ldff1h_f16_reg(<vscale x 8 x i1> %pg, half* %a, i64
ret <vscale x 8 x half> %load
}

define <vscale x 8 x bfloat> @ldff1h_bf16_reg(<vscale x 8 x i1> %pg, bfloat* %a, i64 %offset) {
define <vscale x 8 x bfloat> @ldff1h_bf16_reg(<vscale x 8 x i1> %pg, bfloat* %a, i64 %offset) #0 {
; CHECK-LABEL: ldff1h_bf16_reg:
; CHECK: ldff1h { z0.h }, p0/z, [x0, x1, lsl #1]
; CHECK-NEXT: ret
Expand Down Expand Up @@ -428,3 +428,6 @@ declare <vscale x 2 x i16> @llvm.aarch64.sve.ldff1.nxv2i16(<vscale x 2 x i1>, i1
declare <vscale x 2 x i32> @llvm.aarch64.sve.ldff1.nxv2i32(<vscale x 2 x i1>, i32*)
declare <vscale x 2 x i64> @llvm.aarch64.sve.ldff1.nxv2i64(<vscale x 2 x i1>, i64*)
declare <vscale x 2 x double> @llvm.aarch64.sve.ldff1.nxv2f64(<vscale x 2 x i1>, double*)

; +bf16 is required for the bfloat version.
attributes #0 = { "target-features"="+sve,+bf16" }
7 changes: 5 additions & 2 deletions llvm/test/CodeGen/AArch64/sve-intrinsics-loads-nf.ll
Expand Up @@ -140,7 +140,7 @@ define <vscale x 8 x half> @ldnf1h_f16(<vscale x 8 x i1> %pg, half* %a) {
ret <vscale x 8 x half> %load
}

define <vscale x 8 x bfloat> @ldnf1h_bf16(<vscale x 8 x i1> %pg, bfloat* %a) {
define <vscale x 8 x bfloat> @ldnf1h_bf16(<vscale x 8 x i1> %pg, bfloat* %a) #0 {
; CHECK-LABEL: ldnf1h_bf16:
; CHECK: ldnf1h { z0.h }, p0/z, [x0]
; CHECK-NEXT: ret
Expand All @@ -159,7 +159,7 @@ define <vscale x 8 x half> @ldnf1h_f16_inbound(<vscale x 8 x i1> %pg, half* %a)
ret <vscale x 8 x half> %load
}

define <vscale x 8 x bfloat> @ldnf1h_bf16_inbound(<vscale x 8 x i1> %pg, bfloat* %a) {
define <vscale x 8 x bfloat> @ldnf1h_bf16_inbound(<vscale x 8 x i1> %pg, bfloat* %a) #0 {
; CHECK-LABEL: ldnf1h_bf16_inbound:
; CHECK: ldnf1h { z0.h }, p0/z, [x0, #1, mul vl]
; CHECK-NEXT: ret
Expand Down Expand Up @@ -473,3 +473,6 @@ declare <vscale x 2 x i16> @llvm.aarch64.sve.ldnf1.nxv2i16(<vscale x 2 x i1>, i1
declare <vscale x 2 x i32> @llvm.aarch64.sve.ldnf1.nxv2i32(<vscale x 2 x i1>, i32*)
declare <vscale x 2 x i64> @llvm.aarch64.sve.ldnf1.nxv2i64(<vscale x 2 x i1>, i64*)
declare <vscale x 2 x double> @llvm.aarch64.sve.ldnf1.nxv2f64(<vscale x 2 x i1>, double*)

; +bf16 is required for the bfloat version.
attributes #0 = { "target-features"="+sve,+bf16" }
5 changes: 4 additions & 1 deletion llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll
Expand Up @@ -87,7 +87,7 @@ define <vscale x 8 x half> @masked_load_nxv8f16(<vscale x 8 x half> *%a, <vscale
ret <vscale x 8 x half> %load
}

define <vscale x 8 x bfloat> @masked_load_nxv8bf16(<vscale x 8 x bfloat> *%a, <vscale x 8 x i1> %mask) nounwind {
define <vscale x 8 x bfloat> @masked_load_nxv8bf16(<vscale x 8 x bfloat> *%a, <vscale x 8 x i1> %mask) nounwind #0 {
; CHECK-LABEL: masked_load_nxv8bf16:
; CHECK-NEXT: ld1h { z0.h }, p0/z, [x0]
; CHECK-NEXT: ret
Expand Down Expand Up @@ -203,3 +203,6 @@ declare void @llvm.masked.store.nxv2f16(<vscale x 2 x half>, <vscale x 2 x half>
declare void @llvm.masked.store.nxv4f32(<vscale x 4 x float>, <vscale x 4 x float>*, i32, <vscale x 4 x i1>)
declare void @llvm.masked.store.nxv4f16(<vscale x 4 x half>, <vscale x 4 x half>*, i32, <vscale x 4 x i1>)
declare void @llvm.masked.store.nxv8f16(<vscale x 8 x half>, <vscale x 8 x half>*, i32, <vscale x 8 x i1>)

; +bf16 is required for the bfloat version.
attributes #0 = { "target-features"="+sve,+bf16" }

0 comments on commit 0ccfe1b

Please sign in to comment.