Skip to content

Commit

Permalink
[Matrix] Try to emit fmuladd for both vector and matrix types
Browse files Browse the repository at this point in the history
For vector * scalar + vector, we emit `fmuladd` directly from clang.

This enables it also for matrix * scalar + matrix.

rdar://113967122

Differential Revision: https://reviews.llvm.org/D158883
  • Loading branch information
francisvm committed Sep 1, 2023
1 parent e7bd436 commit c987f9d
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 12 deletions.
23 changes: 16 additions & 7 deletions clang/lib/CodeGen/CGExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3874,6 +3874,14 @@ Value *ScalarExprEmitter::EmitAdd(const BinOpInfo &op) {
}
}

// For vector and matrix adds, try to fold into a fmuladd.
if (op.LHS->getType()->isFPOrFPVectorTy()) {
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(CGF, op.FPFeatures);
// Try to form an fmuladd.
if (Value *FMulAdd = tryEmitFMulAdd(op, CGF, Builder))
return FMulAdd;
}

if (op.Ty->isConstantMatrixType()) {
llvm::MatrixBuilder MB(Builder);
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(CGF, op.FPFeatures);
Expand All @@ -3887,10 +3895,6 @@ Value *ScalarExprEmitter::EmitAdd(const BinOpInfo &op) {

if (op.LHS->getType()->isFPOrFPVectorTy()) {
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(CGF, op.FPFeatures);
// Try to form an fmuladd.
if (Value *FMulAdd = tryEmitFMulAdd(op, CGF, Builder))
return FMulAdd;

return Builder.CreateFAdd(op.LHS, op.RHS, "add");
}

Expand Down Expand Up @@ -4024,6 +4028,14 @@ Value *ScalarExprEmitter::EmitSub(const BinOpInfo &op) {
}
}

// For vector and matrix subs, try to fold into a fmuladd.
if (op.LHS->getType()->isFPOrFPVectorTy()) {
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(CGF, op.FPFeatures);
// Try to form an fmuladd.
if (Value *FMulAdd = tryEmitFMulAdd(op, CGF, Builder, true))
return FMulAdd;
}

if (op.Ty->isConstantMatrixType()) {
llvm::MatrixBuilder MB(Builder);
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(CGF, op.FPFeatures);
Expand All @@ -4037,9 +4049,6 @@ Value *ScalarExprEmitter::EmitSub(const BinOpInfo &op) {

if (op.LHS->getType()->isFPOrFPVectorTy()) {
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(CGF, op.FPFeatures);
// Try to form an fmuladd.
if (Value *FMulAdd = tryEmitFMulAdd(op, CGF, Builder, true))
return FMulAdd;
return Builder.CreateFSub(op.LHS, op.RHS, "sub");
}

Expand Down
112 changes: 107 additions & 5 deletions clang/test/CodeGen/ffp-model.c
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
// REQUIRES: x86-registered-target
// RUN: %clang -S -emit-llvm -ffp-model=fast -emit-llvm %s -o - \
// RUN: %clang -S -emit-llvm -fenable-matrix -ffp-model=fast %s -o - \
// RUN: | FileCheck %s --check-prefixes=CHECK,CHECK-FAST

// RUN: %clang -S -emit-llvm -ffp-model=precise %s -o - \
// RUN: %clang -S -emit-llvm -fenable-matrix -ffp-model=precise %s -o - \
// RUN: | FileCheck %s --check-prefixes=CHECK,CHECK-PRECISE

// RUN: %clang -S -emit-llvm -ffp-model=strict %s -o - \
// RUN: %clang -S -emit-llvm -fenable-matrix -ffp-model=strict %s -o - \
// RUN: -target x86_64 | FileCheck %s --check-prefixes=CHECK,CHECK-STRICT

// RUN: %clang -S -emit-llvm -ffp-model=strict -ffast-math \
// RUN: %clang -S -emit-llvm -fenable-matrix -ffp-model=strict -ffast-math \
// RUN: -target x86_64 %s -o - | FileCheck %s \
// RUN: --check-prefixes CHECK,CHECK-STRICT-FAST

// RUN: %clang -S -emit-llvm -ffp-model=precise -ffast-math \
// RUN: %clang -S -emit-llvm -fenable-matrix -ffp-model=precise -ffast-math \
// RUN: %s -o - | FileCheck %s --check-prefixes CHECK,CHECK-FAST1

float mymuladd(float x, float y, float z) {
Expand Down Expand Up @@ -46,3 +46,105 @@ float mymuladd(float x, float y, float z) {
// CHECK-FAST1: load float, ptr {{.*}}
// CHECK-FAST1: fadd fast float {{.*}}, {{.*}}
}

typedef float __attribute__((ext_vector_type(2))) v2f;

v2f my_vec_muladd(v2f x, float y, v2f z) {
// CHECK: define{{.*}} @my_vec_muladd
return x * y + z;

// CHECK-FAST: fmul fast <2 x float>
// CHECK-FAST: load <2 x float>, ptr
// CHECK-FAST: fadd fast <2 x float>

// CHECK-PRECISE: load <2 x float>, ptr
// CHECK-PRECISE: load float, ptr
// CHECK-PRECISE: load <2 x float>, ptr
// CHECK-PRECISE: call <2 x float> @llvm.fmuladd.v2f32(<2 x float> {{.*}}, <2 x float> {{.*}}, <2 x float> {{.*}})

// CHECK-STRICT: load <2 x float>, ptr
// CHECK-STRICT: load float, ptr
// CHECK-STRICT: call <2 x float> @llvm.experimental.constrained.fmul.v2f32(<2 x float> {{.*}}, <2 x float> {{.*}}, {{.*}})
// CHECK-STRICT: load <2 x float>, ptr
// CHECK-STRICT: call <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float> {{.*}}, <2 x float> {{.*}}, {{.*}})

// CHECK-STRICT-FAST: load <2 x float>, ptr
// CHECK-STRICT-FAST: load float, ptr
// CHECK-STRICT-FAST: fmul fast <2 x float> {{.*}}, {{.*}}
// CHECK-STRICT-FAST: load <2 x float>, ptr
// CHECK-STRICT-FAST: fadd fast <2 x float> {{.*}}, {{.*}}

// CHECK-FAST1: load <2 x float>, ptr
// CHECK-FAST1: load float, ptr
// CHECK-FAST1: fmul fast <2 x float> {{.*}}, {{.*}}
// CHECK-FAST1: load <2 x float>, ptr {{.*}}
// CHECK-FAST1: fadd fast <2 x float> {{.*}}, {{.*}}
}

typedef float __attribute__((matrix_type(2, 1))) m21f;

m21f my_m21_muladd(m21f x, float y, m21f z) {
// CHECK: define{{.*}} <2 x float> @my_m21_muladd
return x * y + z;

// CHECK-FAST: fmul fast <2 x float>
// CHECK-FAST: load <2 x float>, ptr
// CHECK-FAST: fadd fast <2 x float>

// CHECK-PRECISE: load <2 x float>, ptr
// CHECK-PRECISE: load float, ptr
// CHECK-PRECISE: load <2 x float>, ptr
// CHECK-PRECISE: call <2 x float> @llvm.fmuladd.v2f32(<2 x float> {{.*}}, <2 x float> {{.*}}, <2 x float> {{.*}})

// CHECK-STRICT: load <2 x float>, ptr
// CHECK-STRICT: load float, ptr
// CHECK-STRICT: call <2 x float> @llvm.experimental.constrained.fmul.v2f32(<2 x float> {{.*}}, <2 x float> {{.*}}, {{.*}})
// CHECK-STRICT: load <2 x float>, ptr
// CHECK-STRICT: call <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float> {{.*}}, <2 x float> {{.*}}, {{.*}})

// CHECK-STRICT-FAST: load <2 x float>, ptr
// CHECK-STRICT-FAST: load float, ptr
// CHECK-STRICT-FAST: fmul fast <2 x float> {{.*}}, {{.*}}
// CHECK-STRICT-FAST: load <2 x float>, ptr
// CHECK-STRICT-FAST: fadd fast <2 x float> {{.*}}, {{.*}}

// CHECK-FAST1: load <2 x float>, ptr
// CHECK-FAST1: load float, ptr
// CHECK-FAST1: fmul fast <2 x float> {{.*}}, {{.*}}
// CHECK-FAST1: load <2 x float>, ptr {{.*}}
// CHECK-FAST1: fadd fast <2 x float> {{.*}}, {{.*}}
}

typedef float __attribute__((matrix_type(2, 2))) m22f;

m22f my_m22_muladd(m22f x, float y, m22f z) {
// CHECK: define{{.*}} <4 x float> @my_m22_muladd
return x * y + z;

// CHECK-FAST: fmul fast <4 x float>
// CHECK-FAST: load <4 x float>, ptr
// CHECK-FAST: fadd fast <4 x float>

// CHECK-PRECISE: load <4 x float>, ptr
// CHECK-PRECISE: load float, ptr
// CHECK-PRECISE: load <4 x float>, ptr
// CHECK-PRECISE: call <4 x float> @llvm.fmuladd.v4f32(<4 x float> {{.*}}, <4 x float> {{.*}}, <4 x float> {{.*}})

// CHECK-STRICT: load <4 x float>, ptr
// CHECK-STRICT: load float, ptr
// CHECK-STRICT: call <4 x float> @llvm.experimental.constrained.fmul.v4f32(<4 x float> {{.*}}, <4 x float> {{.*}}, {{.*}})
// CHECK-STRICT: load <4 x float>, ptr
// CHECK-STRICT: call <4 x float> @llvm.experimental.constrained.fadd.v4f32(<4 x float> {{.*}}, <4 x float> {{.*}}, {{.*}})

// CHECK-STRICT-FAST: load <4 x float>, ptr
// CHECK-STRICT-FAST: load float, ptr
// CHECK-STRICT-FAST: fmul fast <4 x float> {{.*}}, {{.*}}
// CHECK-STRICT-FAST: load <4 x float>, ptr
// CHECK-STRICT-FAST: fadd fast <4 x float> {{.*}}, {{.*}}

// CHECK-FAST1: load <4 x float>, ptr
// CHECK-FAST1: load float, ptr
// CHECK-FAST1: fmul fast <4 x float> {{.*}}, {{.*}}
// CHECK-FAST1: load <4 x float>, ptr {{.*}}
// CHECK-FAST1: fadd fast <4 x float> {{.*}}, {{.*}}
}

0 comments on commit c987f9d

Please sign in to comment.