Skip to content

Commit

Permalink
[AMX] Support AMX-FP16 new intrinsic interface
Browse files Browse the repository at this point in the history
We support AMX-FP16 isa in https://reviews.llvm.org/D135941 now.
The old  intrinsic interface need to manually write tile registers.
So we support its new intrinsic interface to let it be able to do register allocation.

Reviewed By: LuoYuanke

Differential Revision: https://reviews.llvm.org/D138987
  • Loading branch information
xiangzh1 authored and phoebewang committed Dec 1, 2022
1 parent 6244016 commit 94c5df8
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 3 deletions.
1 change: 1 addition & 0 deletions clang/include/clang/Basic/BuiltinsX86_64.def
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ TARGET_BUILTIN(__builtin_ia32_tdpbuud_internal, "V256iUsUsUsV256iV256iV256i", "n
TARGET_BUILTIN(__builtin_ia32_tilestored64_internal, "vUsUsv*zV256i", "n", "amx-tile")
TARGET_BUILTIN(__builtin_ia32_tilezero_internal, "V256iUsUs", "n", "amx-tile")
TARGET_BUILTIN(__builtin_ia32_tdpbf16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-bf16")
TARGET_BUILTIN(__builtin_ia32_tdpfp16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-fp16")
// AMX
TARGET_BUILTIN(__builtin_ia32_tile_loadconfig, "vvC*", "n", "amx-tile")
TARGET_BUILTIN(__builtin_ia32_tile_storeconfig, "vvC*", "n", "amx-tile")
Expand Down
32 changes: 32 additions & 0 deletions clang/lib/Headers/amxintrin.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
__attribute__((__always_inline__, __nodebug__, __target__("amx-int8")))
#define __DEFAULT_FN_ATTRS_BF16 \
__attribute__((__always_inline__, __nodebug__, __target__("amx-bf16")))
#define __DEFAULT_FN_ATTRS_FP16 \
__attribute__((__always_inline__, __nodebug__, __target__("amx-fp16")))

/// Load tile configuration from a 64-byte memory location specified by
/// "mem_addr". The tile configuration includes the tile type palette, the
Expand Down Expand Up @@ -290,6 +292,13 @@ _tile_dpbf16ps_internal(unsigned short m, unsigned short n, unsigned short k,
return __builtin_ia32_tdpbf16ps_internal(m, n, k, dst, src1, src2);
}

/// This is internal intrinsic. C/C++ user should avoid calling it directly.
static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP16
_tile_dpfp16ps_internal(unsigned short m, unsigned short n, unsigned short k,
_tile1024i dst, _tile1024i src1, _tile1024i src2) {
return __builtin_ia32_tdpfp16ps_internal(m, n, k, dst, src1, src2);
}

/// This struct pack the shape and tile data together for user. We suggest
/// initializing the struct as early as possible, because compiler depends
/// on the shape information to do configure. The constant value is preferred
Expand Down Expand Up @@ -484,9 +493,32 @@ static __inline__ void __tile_dpbf16ps(__tile1024i *dst, __tile1024i src0,
src0.tile, src1.tile);
}

/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles src0 and
/// src1, accumulating the intermediate single-precision (32-bit) floating-point
/// elements with elements in "dst", and store the 32-bit result back to tile
/// "dst".
///
/// \headerfile <immintrin.h>
///
/// This intrinsic corresponds to the <c> TDPFP16PS </c> instruction.
///
/// \param dst
/// The destination tile. Max size is 1024 Bytes.
/// \param src0
/// The 1st source tile. Max size is 1024 Bytes.
/// \param src1
/// The 2nd source tile. Max size is 1024 Bytes.
__DEFAULT_FN_ATTRS_FP16
static __inline__ void __tile_dpfp16ps(__tile1024i *dst, __tile1024i src0,
__tile1024i src1) {
dst->tile = _tile_dpfp16ps_internal(src0.row, src1.col, src0.col, dst->tile,
src0.tile, src1.tile);
}

#undef __DEFAULT_FN_ATTRS_TILE
#undef __DEFAULT_FN_ATTRS_INT8
#undef __DEFAULT_FN_ATTRS_BF16
#undef __DEFAULT_FN_ATTRS_FP16

#endif /* __x86_64__ */
#endif /* __AMXINTRIN_H */
10 changes: 9 additions & 1 deletion clang/test/CodeGen/X86/amx_api.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: %clang_cc1 %s -flax-vector-conversions=none -ffreestanding -triple=x86_64-unknown-unknown -target-feature +avx512f -target-feature +amx-int8 \
// RUN: -target-feature +amx-bf16 -emit-llvm -o - -Werror -pedantic | FileCheck %s --check-prefixes=CHECK
// RUN: -target-feature +amx-bf16 -target-feature +amx-fp16 -emit-llvm -o - -Werror -pedantic | FileCheck %s --check-prefixes=CHECK

#include <immintrin.h>

Expand Down Expand Up @@ -102,3 +102,11 @@ void test_tile_dpbf16ps(__tile1024i a, __tile1024i b, __tile1024i c) {
//CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}})
__tile_dpbf16ps(&a, b, c);
}

void test_tile_dpfp16ps(__tile1024i a, __tile1024i b, __tile1024i c) {
//CHECK-LABEL: @test_tile_dpfp16ps
//CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}})
//CHECK-DAG: call x86_amx @llvm.x86.tdpfp16ps.internal
//CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}})
__tile_dpfp16ps(&a, b, c);
}
6 changes: 6 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsX86.td
Original file line number Diff line number Diff line change
Expand Up @@ -5396,6 +5396,12 @@ let TargetPrefix = "x86" in {
[llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
llvm_x86amx_ty, llvm_x86amx_ty,
llvm_x86amx_ty], []>;
def int_x86_tdpfp16ps_internal :
ClangBuiltin<"__builtin_ia32_tdpfp16ps_internal">,
Intrinsic<[llvm_x86amx_ty],
[llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
llvm_x86amx_ty, llvm_x86amx_ty,
llvm_x86amx_ty], []>;
def int_x86_cast_vector_to_tile:
DefaultAttrsIntrinsic<[llvm_x86amx_ty], [llvm_anyvector_ty], [IntrNoMem]>;
def int_x86_cast_tile_to_vector:
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/X86/X86ExpandPseudo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,8 @@ bool X86ExpandPseudo::ExpandMI(MachineBasicBlock &MBB,
case X86::PTDPBSUDV:
case X86::PTDPBUSDV:
case X86::PTDPBUUDV:
case X86::PTDPBF16PSV: {
case X86::PTDPBF16PSV:
case X86::PTDPFP16PSV: {
MI.untieRegOperand(4);
for (unsigned i = 3; i > 0; --i)
MI.removeOperand(i);
Expand All @@ -577,6 +578,7 @@ bool X86ExpandPseudo::ExpandMI(MachineBasicBlock &MBB,
case X86::PTDPBUSDV: Opc = X86::TDPBUSD; break;
case X86::PTDPBUUDV: Opc = X86::TDPBUUD; break;
case X86::PTDPBF16PSV: Opc = X86::TDPBF16PS; break;
case X86::PTDPFP16PSV: Opc = X86::TDPFP16PS; break;
default: llvm_unreachable("Impossible Opcode!");
}
MI.setDesc(TII->get(Opc));
Expand Down
12 changes: 12 additions & 0 deletions llvm/lib/Target/X86/X86InstrAMX.td
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,18 @@ let Predicates = [HasAMXFP16, In64BitMode] in {
"tdpfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}",
[]>, VEX_4V, T8XD;
}

// Pseduo instruction for RA.
let isPseudo = true, Constraints = "$src4 = $dst" in {
def PTDPFP16PSV : PseudoI<(outs TILE: $dst), (ins GR16:$src1,
GR16:$src2, GR16:$src3, TILE:$src4,
TILE:$src5, TILE:$src6),
[(set TILE: $dst,
(int_x86_tdpfp16ps_internal GR16:$src1,
GR16:$src2, GR16:$src3, TILE:$src4,
TILE:$src5, TILE:$src6))]>;
}

let usesCustomInserter = 1 in {
def PTDPFP16PS : PseudoI<(outs), (ins u8imm:$src1,
u8imm:$src2, u8imm:$src3),
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/X86/X86LowerAMXType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
case Intrinsic::x86_tdpbsud_internal:
case Intrinsic::x86_tdpbusd_internal:
case Intrinsic::x86_tdpbuud_internal:
case Intrinsic::x86_tdpbf16ps_internal: {
case Intrinsic::x86_tdpbf16ps_internal:
case Intrinsic::x86_tdpfp16ps_internal: {
switch (OpNo) {
case 3:
Row = II->getArgOperand(0);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/X86/X86RegisterInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,7 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM,
case X86::PTDPBUUDV:
case X86::PTILEZEROV:
case X86::PTDPBF16PSV:
case X86::PTDPFP16PSV:
MachineOperand &MO1 = MI->getOperand(1);
MachineOperand &MO2 = MI->getOperand(2);
ShapeT Shape(&MO1, &MO2, MRI);
Expand Down
41 changes: 41 additions & 0 deletions llvm/test/CodeGen/X86/AMX/amx-fp16.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-tile,+amx-int8,+amx-fp16,+avx512f -verify-machineinstrs | FileCheck %s

define void @test_amx(ptr %pointer, ptr %base, i64 %stride) {
; CHECK-LABEL: test_amx:
; CHECK: # %bb.0:
; CHECK-NEXT: vxorps %xmm0, %xmm0, %xmm0
; CHECK-NEXT: vmovups %zmm0, -{{[0-9]+}}(%rsp)
; CHECK-NEXT: movb $1, -{{[0-9]+}}(%rsp)
; CHECK-NEXT: movb $8, -{{[0-9]+}}(%rsp)
; CHECK-NEXT: movw $8, -{{[0-9]+}}(%rsp)
; CHECK-NEXT: movb $8, -{{[0-9]+}}(%rsp)
; CHECK-NEXT: movw $8, -{{[0-9]+}}(%rsp)
; CHECK-NEXT: movb $8, -{{[0-9]+}}(%rsp)
; CHECK-NEXT: movw $8, -{{[0-9]+}}(%rsp)
; CHECK-NEXT: ldtilecfg -{{[0-9]+}}(%rsp)
; CHECK-NEXT: movw $8, %ax
; CHECK-NEXT: tileloadd (%rsi,%rdx), %tmm0
; CHECK-NEXT: tileloadd (%rsi,%rdx), %tmm1
; CHECK-NEXT: tilezero %tmm2
; CHECK-NEXT: tdpfp16ps %tmm1, %tmm0, %tmm2
; CHECK-NEXT: tileloaddt1 (%rsi,%rdx), %tmm0
; CHECK-NEXT: tilestored %tmm2, (%rdi,%rdx)
; CHECK-NEXT: tilerelease
; CHECK-NEXT: vzeroupper
; CHECK-NEXT: retq
%a = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 8, ptr %base, i64 %stride)
%b = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 8, ptr %base, i64 %stride)
%c = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 8)
%d = call x86_amx @llvm.x86.tdpfp16ps.internal(i16 8, i16 8, i16 8, x86_amx %c, x86_amx %a, x86_amx %b)
%e = call x86_amx @llvm.x86.tileloaddt164.internal(i16 8, i16 8, ptr %base, i64 %stride)
call void @llvm.x86.tilestored64.internal(i16 8, i16 8, ptr %pointer, i64 %stride, x86_amx %d)

ret void
}

declare x86_amx @llvm.x86.tilezero.internal(i16, i16)
declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, ptr, i64)
declare x86_amx @llvm.x86.tileloaddt164.internal(i16, i16, ptr, i64)
declare x86_amx @llvm.x86.tdpfp16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
declare void @llvm.x86.tilestored64.internal(i16, i16, ptr, i64, x86_amx)

0 comments on commit 94c5df8

Please sign in to comment.