Skip to content

Commit

Permalink
[Matrix] Propagate and use shape info for binary operators.
Browse files Browse the repository at this point in the history
This patch extends the current shape propagation and shape aware
lowering to also support binary operators. Those operators are uniform
with respect to their shape (shape of the input operands is the same as
the shape of their result).

Reviewers: anemet, Gerolf, reames, hfinkel, andrew.w.kaylor

Reviewed By: anemet

Differential Revision: https://reviews.llvm.org/D70898
  • Loading branch information
fhahn committed Dec 27, 2019
1 parent f072233 commit dc2c9b0
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 8 deletions.
76 changes: 74 additions & 2 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Expand Up @@ -281,6 +281,24 @@ class LowerMatrixIntrinsics {
return true;
}

bool isUniformShape(Value *V) {
Instruction *I = dyn_cast<Instruction>(V);
if (!I)
return true;

switch (I->getOpcode()) {
case Instruction::FAdd:
case Instruction::FSub:
case Instruction::FMul: // Scalar multiply.
case Instruction::Add:
case Instruction::Mul:
case Instruction::Sub:
return true;
default:
return false;
}
}

/// Returns true if shape information can be used for \p V. The supported
/// instructions must match the instructions that can be lowered by this pass.
bool supportsShapeInfo(Value *V) {
Expand All @@ -299,7 +317,7 @@ class LowerMatrixIntrinsics {
default:
return false;
}
return isa<StoreInst>(Inst);
return isUniformShape(V) || isa<StoreInst>(V);
}

/// Propagate the shape information of instructions to their users.
Expand Down Expand Up @@ -366,6 +384,15 @@ class LowerMatrixIntrinsics {
if (OpShape != ShapeMap.end())
setShapeInfo(Inst, OpShape->second);
continue;
} else if (isUniformShape(Inst)) {
// Find the first operand that has a known shape and use that.
for (auto &Op : Inst->operands()) {
auto OpShape = ShapeMap.find(Op.get());
if (OpShape != ShapeMap.end()) {
Propagate |= setShapeInfo(Inst, OpShape->second);
break;
}
}
}

if (Propagate)
Expand All @@ -390,7 +417,9 @@ class LowerMatrixIntrinsics {

Value *Op1;
Value *Op2;
if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2))))
if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst))
Changed |= VisitBinaryOperator(BinOp);
else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2))))
Changed |= VisitStore(&Inst, Op1, Op2, Builder);
}
}
Expand Down Expand Up @@ -673,6 +702,49 @@ class LowerMatrixIntrinsics {
LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second);
return true;
}

/// Lower binary operators, if shape information is available.
bool VisitBinaryOperator(BinaryOperator *Inst) {
auto I = ShapeMap.find(Inst);
if (I == ShapeMap.end())
return false;

Value *Lhs = Inst->getOperand(0);
Value *Rhs = Inst->getOperand(1);

IRBuilder<> Builder(Inst);
ShapeInfo &Shape = I->second;

ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder);
ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder);

// Add each column and store the result back into the opmapping
ColumnMatrixTy Result;
auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) {
switch (Inst->getOpcode()) {
case Instruction::Add:
return Builder.CreateAdd(LHS, RHS);
case Instruction::Mul:
return Builder.CreateMul(LHS, RHS);
case Instruction::Sub:
return Builder.CreateSub(LHS, RHS);
case Instruction::FAdd:
return Builder.CreateFAdd(LHS, RHS);
case Instruction::FMul:
return Builder.CreateFMul(LHS, RHS);
case Instruction::FSub:
return Builder.CreateFSub(LHS, RHS);
default:
llvm_unreachable("Unsupported binary operator for matrix");
}
};
for (unsigned C = 0; C < Shape.NumColumns; ++C)
Result.addColumn(
BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C)));

finalizeLowering(Inst, Result, Builder);
return true;
}
};
} // namespace

Expand Down
Expand Up @@ -462,15 +462,34 @@ define void @transpose_multiply_add(<9 x double>* %A.Ptr, <9 x double>* %B.Ptr,

; CHECK-NEXT: [[TMP106:%.*]] = shufflevector <1 x double> [[TMP105]], <1 x double> undef, <3 x i32> <i32 0, i32 undef, i32 undef>
; CHECK-NEXT: [[TMP107:%.*]] = shufflevector <3 x double> [[TMP97]], <3 x double> [[TMP106]], <3 x i32> <i32 0, i32 1, i32 3>
; CHECK-NEXT: [[TMP108:%.*]] = shufflevector <3 x double> [[TMP47]], <3 x double> [[TMP77]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
; CHECK-NEXT: [[TMP109:%.*]] = shufflevector <3 x double> [[TMP107]], <3 x double> undef, <6 x i32> <i32 0, i32 1, i32 2, i32 undef, i32 undef, i32 undef>
; CHECK-NEXT: [[TMP110:%.*]] = shufflevector <6 x double> [[TMP108]], <6 x double> [[TMP109]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>

; Load %C and add result of multiply.
; Load %C.

; CHECK-NEXT: [[C:%.*]] = load <9 x double>, <9 x double>* [[C_PTR:%.*]]
; CHECK-NEXT: [[RES:%.*]] = fadd <9 x double> [[C]], [[TMP110]]
; CHECK-NEXT: store <9 x double> [[RES]], <9 x double>* [[C_PTR]]

; Extract columns from %C.

; CHECK-NEXT: [[SPLIT84:%.*]] = shufflevector <9 x double> [[C]], <9 x double> undef, <3 x i32> <i32 0, i32 1, i32 2>
; CHECK-NEXT: [[SPLIT85:%.*]] = shufflevector <9 x double> [[C]], <9 x double> undef, <3 x i32> <i32 3, i32 4, i32 5>
; CHECK-NEXT: [[SPLIT86:%.*]] = shufflevector <9 x double> [[C]], <9 x double> undef, <3 x i32> <i32 6, i32 7, i32 8>

; Add column vectors.

; CHECK-NEXT: [[TMP108:%.*]] = fadd <3 x double> [[SPLIT84]], [[TMP47]]
; CHECK-NEXT: [[TMP109:%.*]] = fadd <3 x double> [[SPLIT85]], [[TMP77]]
; CHECK-NEXT: [[TMP110:%.*]] = fadd <3 x double> [[SPLIT86]], [[TMP107]]

; Store result columns.

; CHECK-NEXT: [[TMP111:%.*]] = bitcast <9 x double>* [[C_PTR]] to double*
; CHECK-NEXT: [[TMP112:%.*]] = bitcast double* [[TMP111]] to <3 x double>*
; CHECK-NEXT: store <3 x double> [[TMP108]], <3 x double>* [[TMP112]], align 8
; CHECK-NEXT: [[TMP113:%.*]] = getelementptr double, double* [[TMP111]], i32 3
; CHECK-NEXT: [[TMP114:%.*]] = bitcast double* [[TMP113]] to <3 x double>*
; CHECK-NEXT: store <3 x double> [[TMP109]], <3 x double>* [[TMP114]], align 8
; CHECK-NEXT: [[TMP115:%.*]] = getelementptr double, double* [[TMP111]], i32 6
; CHECK-NEXT: [[TMP116:%.*]] = bitcast double* [[TMP115]] to <3 x double>*
; CHECK-NEXT: store <3 x double> [[TMP110]], <3 x double>* [[TMP116]], align 8
; CHECK-NEXT: ret void
;
entry:
Expand Down
72 changes: 72 additions & 0 deletions llvm/test/Transforms/LowerMatrixIntrinsics/propagate-forward.ll
Expand Up @@ -42,3 +42,75 @@ entry:
}

declare <8 x double> @llvm.matrix.transpose(<8 x double>, i32, i32)

define <8 x double> @transpose_fadd(<8 x double> %a) {
; CHECK-LABEL: @transpose_fadd(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> undef, <2 x i32> <i32 0, i32 1>
; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> <i32 2, i32 3>
; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> <i32 4, i32 5>
; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> <i32 6, i32 7>
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x double> [[SPLIT]], i64 0
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x double> undef, double [[TMP0]], i64 0
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 0
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x double> [[TMP1]], double [[TMP2]], i64 1
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x double> [[TMP3]], double [[TMP4]], i64 2
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 3
; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[SPLIT]], i64 1
; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x double> undef, double [[TMP8]], i64 0
; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 1
; CHECK-NEXT: [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 1
; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1
; CHECK-NEXT: [[TMP13:%.*]] = insertelement <4 x double> [[TMP11]], double [[TMP12]], i64 2
; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1
; CHECK-NEXT: [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 3
; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[TMP16:%.*]] = fadd <4 x double> [[TMP7]], [[SPLIT4]]
; CHECK-NEXT: [[TMP17:%.*]] = fadd <4 x double> [[TMP15]], [[SPLIT5]]
; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <4 x double> [[TMP16]], <4 x double> [[TMP17]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: ret <8 x double> [[TMP18]]
;
entry:
%c = call <8 x double> @llvm.matrix.transpose(<8 x double> %a, i32 2, i32 4)
%res = fadd <8 x double> %c, %a
ret <8 x double> %res
}

define <8 x double> @transpose_fmul(<8 x double> %a) {
; CHECK-LABEL: @transpose_fmul(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <8 x double> [[A:%.*]], <8 x double> undef, <2 x i32> <i32 0, i32 1>
; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> <i32 2, i32 3>
; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> <i32 4, i32 5>
; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <2 x i32> <i32 6, i32 7>
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x double> [[SPLIT]], i64 0
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x double> undef, double [[TMP0]], i64 0
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 0
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x double> [[TMP1]], double [[TMP2]], i64 1
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 0
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x double> [[TMP3]], double [[TMP4]], i64 2
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 0
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x double> [[TMP5]], double [[TMP6]], i64 3
; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x double> [[SPLIT]], i64 1
; CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x double> undef, double [[TMP8]], i64 0
; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x double> [[SPLIT1]], i64 1
; CHECK-NEXT: [[TMP11:%.*]] = insertelement <4 x double> [[TMP9]], double [[TMP10]], i64 1
; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x double> [[SPLIT2]], i64 1
; CHECK-NEXT: [[TMP13:%.*]] = insertelement <4 x double> [[TMP11]], double [[TMP12]], i64 2
; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x double> [[SPLIT3]], i64 1
; CHECK-NEXT: [[TMP15:%.*]] = insertelement <4 x double> [[TMP13]], double [[TMP14]], i64 3
; CHECK-NEXT: [[SPLIT4:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <8 x double> [[A]], <8 x double> undef, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[TMP16:%.*]] = fmul <4 x double> [[TMP7]], [[SPLIT4]]
; CHECK-NEXT: [[TMP17:%.*]] = fmul <4 x double> [[TMP15]], [[SPLIT5]]
; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <4 x double> [[TMP16]], <4 x double> [[TMP17]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: ret <8 x double> [[TMP18]]
;
entry:
%c = call <8 x double> @llvm.matrix.transpose(<8 x double> %a, i32 2, i32 4)
%res = fmul <8 x double> %c, %a
ret <8 x double> %res
}

0 comments on commit dc2c9b0

Please sign in to comment.