Skip to content

Commit

Permalink
[Matrix] Move multiply-add code generation into separate function (NFC).
Browse files Browse the repository at this point in the history
This logic can be shared with the tiled code generation.

Reviewers: anemet, Gerolf, hfinkel, andrew.w.kaylor, LuoYuanke

Reviewed By: anemet

Differential Revision: https://reviews.llvm.org/D75565
  • Loading branch information
fhahn committed Mar 19, 2020
1 parent e23d786 commit 796fb2e
Showing 1 changed file with 48 additions and 27 deletions.
75 changes: 48 additions & 27 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Expand Up @@ -181,6 +181,10 @@ class LowerMatrixIntrinsics {

void setColumn(unsigned i, Value *V) { Columns[i] = V; }

Type *getElementType() {
return cast<VectorType>(Columns[0]->getType())->getElementType();
}

unsigned getNumColumns() const { return Columns.size(); }
unsigned getNumRows() const {
assert(Columns.size() > 0 && "Cannot call getNumRows without columns");
Expand Down Expand Up @@ -848,6 +852,49 @@ class LowerMatrixIntrinsics {
}
}

/// Compute Res += A * B for tile-sized matrices with left-associating
/// addition.
void emitChainedMatrixMultiply(ColumnMatrixTy &Result,
const ColumnMatrixTy &A,
const ColumnMatrixTy &B, bool AllowContraction,
IRBuilder<> &Builder, bool isTiled) {
const unsigned VF = std::max<unsigned>(
TTI.getRegisterBitWidth(true) /
Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
1U);
unsigned R = Result.getNumRows();
unsigned C = Result.getNumColumns();
unsigned M = A.getNumColumns();

for (unsigned J = 0; J < C; ++J) {
unsigned BlockSize = VF;

// If Result is zero, we don't need to accumulate in the K==0 iteration.
bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));

unsigned NumOps = 0;
for (unsigned I = 0; I < R; I += BlockSize) {
// Gradually lower the vectorization factor to cover the remainder.
while (I + BlockSize > R)
BlockSize /= 2;

Value *Sum =
isTiled ? extractVector(Result, I, J, BlockSize, Builder) : nullptr;
for (unsigned K = 0; K < M; ++K) {
Value *L = extractVector(A, I, K, BlockSize, Builder);
Value *RH = Builder.CreateExtractElement(B.getColumn(J), K);
Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
Result.getElementType()->isFloatingPointTy(),
Builder, AllowContraction, NumOps);
}
Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder));
}

Result.addNumComputeOps(NumOps);
}
}

/// Lowers llvm.matrix.multiply.
void LowerMultiply(CallInst *MatMul) {
IRBuilder<> Builder(MatMul);
Expand All @@ -870,35 +917,9 @@ class LowerMatrixIntrinsics {
for (unsigned J = 0; J < C; ++J)
Result.addColumn(UndefValue::get(VectorType::get(EltType, R)));

const unsigned VF = std::max(TTI.getRegisterBitWidth(true) /
EltType->getPrimitiveSizeInBits(),
uint64_t(1));

bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
MatMul->hasAllowContract());
unsigned NumComputeOps = 0;
// Multiply columns from the first operand with scalars from the second
// operand. Then move along the K axes and accumulate the columns. With
// this the adds can be vectorized without reassociation.
for (unsigned J = 0; J < C; ++J) {
unsigned BlockSize = VF;
for (unsigned I = 0; I < R; I += BlockSize) {
// Gradually lower the vectorization factor to cover the remainder.
while (I + BlockSize > R)
BlockSize /= 2;

Value *Sum = nullptr;
for (unsigned K = 0; K < M; ++K) {
Value *L = extractVector(Lhs, I, K, BlockSize, Builder);
Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K);
Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(),
Builder, AllowContract, NumComputeOps);
}
Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder));
}
}
Result.addNumComputeOps(NumComputeOps);
emitChainedMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false);
finalizeLowering(MatMul, Result, Builder);
}

Expand Down

0 comments on commit 796fb2e

Please sign in to comment.