Skip to content

Commit

Permalink
[Matrix] Factor and distribute transposes across multiplies
Browse files Browse the repository at this point in the history
Now that we can fold some transposes into multiplies (CM: A * B^t and RM:
A^t * B), we want to move them around to create the optimal expressions:

* fold away double transposes while still using them to assert the shape
* sink transposes hoping they cancel out
* lift transposes when both operands are transposed

This also modifies the matrix remarks to include the number of exposed
transposes (i.e. transposes that we couldn't fold into a multiply).

The adjustment to the test remarks-inlining is a bit subtle: I am changing the
double transpose to a single transpose so that we don't remove it completely.
More importantly this changes some of the total instruction count, most
notable stores because we can no longer use a vector store.

Differential Revision: https://reviews.llvm.org/D102733
  • Loading branch information
anemet committed May 25, 2021
1 parent 2ea6e13 commit dfd1bbd
Show file tree
Hide file tree
Showing 4 changed files with 1,105 additions and 22 deletions.
145 changes: 140 additions & 5 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/MatrixBuilder.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
Expand Down Expand Up @@ -79,6 +80,9 @@ static cl::opt<MatrixLayoutTy> MatrixLayout(
clEnumValN(MatrixLayoutTy::RowMajor, "row-major",
"Use row-major layout")));

static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
cl::init(false));

/// Helper function to either return Scope, if it is a subprogram or the
/// attached subprogram for a local scope.
static DISubprogram *getSubprogram(DIScope *Scope) {
Expand Down Expand Up @@ -195,11 +199,16 @@ class LowerMatrixIntrinsics {
unsigned NumLoads = 0;
/// Number of compute operations emitted to generate this matrix.
unsigned NumComputeOps = 0;
/// Most of the time transposes can be fused with matrix multiplies or can
/// be folded away via algebraic simplifications. This is the number of
/// transposes that we failed to make "free" via such optimizations.
unsigned NumExposedTransposes = 0;

OpInfoTy &operator+=(const OpInfoTy &RHS) {
NumStores += RHS.NumStores;
NumLoads += RHS.NumLoads;
NumComputeOps += RHS.NumComputeOps;
NumExposedTransposes += RHS.NumExposedTransposes;
return *this;
}
};
Expand Down Expand Up @@ -304,6 +313,11 @@ class LowerMatrixIntrinsics {
return *this;
}

MatrixTy &addNumExposedTransposes(unsigned N) {
OpInfo.NumExposedTransposes += N;
return *this;
}

MatrixTy &addNumComputeOps(unsigned N) {
OpInfo.NumComputeOps += N;
return *this;
Expand Down Expand Up @@ -379,8 +393,10 @@ class LowerMatrixIntrinsics {
/// the result value of the instruction, with the only exceptions being store
/// instructions and the matrix_column_major_store intrinsics. For those, the
/// shape information indicates that those instructions should be lowered
/// using shape information as well.
DenseMap<Value *, ShapeInfo> ShapeMap;
/// using shape information as well. A ValueMap is used so that when
/// sub-passes like optimizeTransposes performs RAUW the map stays
/// up-to-date.
ValueMap<Value *, ShapeInfo> ShapeMap;

/// List of instructions to remove. While lowering, we are not replacing all
/// users of a lowered instruction, if shape information is available and
Expand All @@ -403,7 +419,11 @@ class LowerMatrixIntrinsics {
cast<FixedVectorType>(VT)->getNumElements());
}

//
/// Is this the minimal version executed in the backend pipelines.
bool isMinimal() const {
return !DT;
}

/// Return the estimated number of vector ops required for an operation on
/// \p VT * N.
unsigned getNumOps(Type *ST, unsigned N) {
Expand Down Expand Up @@ -654,6 +674,110 @@ class LowerMatrixIntrinsics {
return NewWorkList;
}

/// Try moving transposes in order to fold them away or into multiplies.
void optimizeTransposes() {
// First sink all transposes inside matmuls, hoping that we end up with NN,
// NT or TN variants.
for (BasicBlock &BB : reverse(Func)) {
for (auto II = BB.rbegin(); II != BB.rend();) {
Instruction &I = *II;
// We may remove II. By default continue on the next/prev instruction.
++II;
// If we were to erase II, move again.
auto EraseFromParent = [&II](Value *V) {
auto *Inst = cast<Instruction>(V);
if (Inst->use_empty()) {
if (Inst == &*II) {
++II;
}
Inst->eraseFromParent();
}
};

// If we're creating a new instruction, continue from there.
Instruction *NewInst = nullptr;

IRBuilder<> IB(&I);
MatrixBuilder<IRBuilder<>> Builder(IB);

Value *TA, *TAMA, *TAMB;
ConstantInt *R, *K, *C;
if (match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TA)))) {

// Transpose of a transpose is a nop
Value *TATA;
if (match(TA,
m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
I.replaceAllUsesWith(TATA);
EraseFromParent(&I);
EraseFromParent(TA);
}

// (A * B)^t -> B^t * A^t
// RxK KxC CxK KxR
else if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
m_Value(TAMA), m_Value(TAMB), m_ConstantInt(R),
m_ConstantInt(K), m_ConstantInt(C)))) {
Value *T0 = Builder.CreateMatrixTranspose(TAMB, K->getZExtValue(),
C->getZExtValue(),
TAMB->getName() + "_t");
// We are being run after shape prop, add shape for newly created
// instructions so that we lower them later.
setShapeInfo(T0, {C, K});
Value *T1 = Builder.CreateMatrixTranspose(TAMA, R->getZExtValue(),
K->getZExtValue(),
TAMA->getName() + "_t");
setShapeInfo(T1, {K, R});
NewInst = Builder.CreateMatrixMultiply(T0, T1, C->getZExtValue(),
K->getZExtValue(),
R->getZExtValue(), "mmul");
setShapeInfo(NewInst, {C, R});
I.replaceAllUsesWith(NewInst);
EraseFromParent(&I);
EraseFromParent(TA);
}
}

// If we replaced I with a new instruction, continue from there.
if (NewInst)
II = std::next(BasicBlock::reverse_iterator(NewInst));
}
}

// If we have a TT matmul, lift the transpose. We may be able to fold into
// consuming multiply.
for (BasicBlock &BB : Func) {
for (BasicBlock::iterator II = BB.begin(); II != BB.end();) {
Instruction *I = &*II;
// We may remove I.
++II;
Value *A, *B, *AT, *BT;
ConstantInt *R, *K, *C;
if (match(&*I, m_Intrinsic<Intrinsic::matrix_multiply>(
m_Value(A), m_Value(B), m_ConstantInt(R),
m_ConstantInt(K), m_ConstantInt(C))) &&
match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
IRBuilder<> IB(&*I);
MatrixBuilder<IRBuilder<>> Builder(IB);
Value *M = Builder.CreateMatrixMultiply(
BT, AT, C->getZExtValue(), K->getZExtValue(), R->getZExtValue());
setShapeInfo(M, {C, R});
Value *NewInst = Builder.CreateMatrixTranspose(M, R->getZExtValue(),
C->getZExtValue());
setShapeInfo(NewInst, {C, R});
I->replaceAllUsesWith(NewInst);
if (I->use_empty())
I->eraseFromParent();
if (A->use_empty())
cast<Instruction>(A)->eraseFromParent();
if (B->use_empty())
cast<Instruction>(B)->eraseFromParent();
}
}
}
}

bool Visit() {
SmallVector<Instruction *, 32> WorkList;

Expand Down Expand Up @@ -687,6 +811,14 @@ class LowerMatrixIntrinsics {
WorkList = propagateShapeBackward(WorkList);
}

if (!isMinimal()) {
optimizeTransposes();
if (PrintAfterTransposeOpt) {
dbgs() << "Dump after matrix transpose optimization:\n";
Func.dump();
}
}

bool Changed = false;
SmallVector<CallInst *, 16> MaybeFusableInsts;
SmallVector<Instruction *, 16> MatrixInsts;
Expand Down Expand Up @@ -1488,7 +1620,8 @@ class LowerMatrixIntrinsics {
// account for later simplifications/combines.
finalizeLowering(
Inst,
Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns),
Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns)
.addNumExposedTransposes(1),
Builder);
}

Expand Down Expand Up @@ -2003,7 +2136,9 @@ class LowerMatrixIntrinsics {
Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
<< ore::NV("NumLoads", Counts.NumLoads) << " loads, "
<< ore::NV("NumComputeOps", Counts.NumComputeOps)
<< " compute ops";
<< " compute ops, "
<< ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
<< " exposed transposes";

if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
SharedCounts.NumComputeOps > 0) {
Expand Down
29 changes: 14 additions & 15 deletions llvm/test/Transforms/LowerMatrixIntrinsics/remarks-inlining.ll
Original file line number Diff line number Diff line change
Expand Up @@ -47,50 +47,50 @@
target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
target triple = "aarch64-apple-ios"

; CHECK-LABEL: remark: load.h:41:43: Lowered with 0 stores, 10 loads, 0 compute ops
; CHECK-LABEL: remark: load.h:41:43: Lowered with 0 stores, 10 loads, 0 compute ops, 0 exposed transposes
; CHECK-NEXT: load(addr %A)

; CHECK-LABEL: remark: load.h:41:43: Lowered with 0 stores, 10 loads, 0 compute ops
; CHECK-LABEL: remark: load.h:41:43: Lowered with 0 stores, 10 loads, 0 compute ops, 0 exposed transposes
; CHECK-NEXT: column.major.load.3x5.double(addr %B, 5)

; CHECK-LABEL: remark: load.h:41:11: Lowered with 0 stores, 1 loads, 0 compute ops
; CHECK-LABEL: remark: load.h:41:11: Lowered with 0 stores, 1 loads, 0 compute ops, 0 exposed transposes
; CHECK-NEXT: load(addr %D)

; CHECK-LABEL: remark: assign.h:32:43: Lowered with 0 stores, 10 loads, 0 compute ops
; CHECK-LABEL: remark: assign.h:32:43: Lowered with 0 stores, 10 loads, 0 compute ops, 0 exposed transposes
; CHECK-NEXT: load(addr %A)

; CHECK-LABEL: remark: assign.h:32:43: Lowered with 0 stores, 10 loads, 0 compute ops
; CHECK-LABEL: remark: assign.h:32:43: Lowered with 0 stores, 10 loads, 0 compute ops, 0 exposed transposes
; CHECK-NEXT: column.major.load.3x5.double(addr %B, 5)

; CHECK-LABEL: remark: toplevel.c:410:0: Lowered with 10 stores, 20 loads, 10 compute ops
; CHECK-LABEL: remark: toplevel.c:410:0: Lowered with 10 stores, 20 loads, 10 compute ops, 0 exposed transposes
; CHECK-NEXT: store(
; CHECK-NEXT: fadd(
; CHECK-NEXT: load(addr %A),
; CHECK-NEXT: column.major.load.3x5.double(addr %B, 5)),
; CHECK-NEXT: addr %C)

; CHECK-LABEL: remark: toplevel.c:510:0: Lowered with 1 stores, 1 loads, 8 compute ops
; CHECK-LABEL: remark: toplevel.c:510:0: Lowered with 2 stores, 1 loads, 4 compute ops, 1 exposed transposes
; CHECK-NEXT: store(
; CHECK-NEXT: transpose.1x2.float(transpose.2x1.float(load(addr %D))),
; CHECK-NEXT: transpose.2x1.float(load(addr %D)),
; CHECK-NEXT: addr %D)

; CHECK-LABEL: remark: add.h:66:11: Lowered with 0 stores, 0 loads, 10 compute ops
; CHECK-LABEL: remark: add.h:66:11: Lowered with 0 stores, 0 loads, 10 compute ops, 0 exposed transposes
; CHECK-NEXT: fadd(
; CHECK-NEXT: addr %A,
; CHECK-NEXT: scalar)

; CHECK-LABEL: remark: store.h:10:11: Lowered with 10 stores, 0 loads, 0 compute ops
; CHECK-LABEL: remark: store.h:10:11: Lowered with 10 stores, 0 loads, 0 compute ops, 0 exposed transposes
; CHECK-NEXT: store(
; CHECK-NEXT: scalar,
; CHECK-NEXT: addr %C)

; CHECK-LABEL: remark: store.h:66:11: Lowered with 1 stores, 0 loads, 0 compute ops
; CHECK-LABEL: remark: store.h:66:11: Lowered with 2 stores, 0 loads, 0 compute ops, 0 exposed transposes
; CHECK-NEXT: store(
; CHECK-NEXT: scalar,
; CHECK-NEXT: addr %D)

; CHECK-LABEL: remark: transpose.h:13:11: Lowered with 0 stores, 0 loads, 8 compute ops
; CHECK-NEXT: transpose.1x2.float(transpose.2x1.float(addr %D))
; CHECK-LABEL: remark: transpose.h:13:11: Lowered with 0 stores, 0 loads, 4 compute ops, 1 exposed transposes
; CHECK-NEXT: transpose.2x1.float(addr %D)

define void @toplevel(<15 x double>* %A, double* %B, <15 x double>* %C, <2 x float>* %D) !dbg !16 {
entry:
Expand All @@ -101,8 +101,7 @@ entry:

%load = load <2 x float>, <2 x float>* %D, !dbg !104
%t1 = call <2 x float> @llvm.matrix.transpose(<2 x float> %load, i32 2, i32 1), !dbg !106
%t2 = call <2 x float> @llvm.matrix.transpose(<2 x float> %t1, i32 1, i32 2), !dbg !106
store <2 x float> %t2, <2 x float>* %D, !dbg !108
store <2 x float> %t1, <2 x float>* %D, !dbg !108
ret void
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
; YAML-NEXT: - NumLoads: '0'
; YAML-NEXT: - String: ' loads, '
; YAML-NEXT: - NumComputeOps: '0'
; YAML-NEXT: - String: ' compute ops'
; YAML-NEXT: - String: ' compute ops, '
; YAML-NEXT: - NumExposedTransposes: '0'
; YAML-NEXT: - String: ' exposed transposes'
; YAML-NEXT: - String: ",\nadditionally "
; YAML-NEXT: - NumStores: '0'
; YAML-NEXT: - String: ' stores, '
Expand Down Expand Up @@ -45,7 +47,9 @@
; YAML-NEXT: - NumLoads: '45'
; YAML-NEXT: - String: ' loads, '
; YAML-NEXT: - NumComputeOps: '120'
; YAML-NEXT: - String: ' compute ops'
; YAML-NEXT: - String: ' compute ops, '
; YAML-NEXT: - NumExposedTransposes: '0'
; YAML-NEXT: - String: ' exposed transposes'
; YAML-NEXT: - String: ",\nadditionally "
; YAML-NEXT: - NumStores: '0'
; YAML-NEXT: - String: ' stores, '
Expand Down
Loading

0 comments on commit dfd1bbd

Please sign in to comment.