Skip to content

Commit

Permalink
[MatMul] Make MatMul detection independent of internal isl representa…
Browse files Browse the repository at this point in the history
…tions.

The pattern recognition for MatMul is restrictive.

The number of "disjuncts" in the isl_map containing constraint
information was previously required to be 1
(as per isl_*_coalesce - which should ideally produce a domain map with
a single disjunct, but does not under some circumstances).

This was changed and made more flexible.

Contributed-by: Annanay Agarwal <cs14btech11001@iith.ac.in>

Differential Revision: https://reviews.llvm.org/D36460

llvm-svn: 311302
  • Loading branch information
Meinersbur committed Aug 20, 2017
1 parent d6491f2 commit d091bf8
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 86 deletions.
133 changes: 47 additions & 86 deletions polly/lib/Transform/ScheduleOptimizer.cpp
Expand Up @@ -483,61 +483,6 @@ ScheduleTreeOptimizer::standardBandOpts(isl::schedule_node Node, void *User) {
return Node;
}

/// Get the position of a dimension with a non-zero coefficient.
///
/// Check that isl constraint @p Constraint has only one non-zero
/// coefficient for dimensions that have type @p DimType. If this is true,
/// return the position of the dimension corresponding to the non-zero
/// coefficient and negative value, otherwise.
///
/// @param Constraint The isl constraint to be checked.
/// @param DimType The type of the dimensions.
/// @return The position of the dimension in case the isl
/// constraint satisfies the requirements, a negative
/// value, otherwise.
static int getMatMulConstraintDim(isl::constraint Constraint,
isl::dim DimType) {
int DimPos = -1;
auto LocalSpace = Constraint.get_local_space();
int LocalSpaceDimNum = LocalSpace.dim(DimType);
for (int i = 0; i < LocalSpaceDimNum; i++) {
auto Val = Constraint.get_coefficient_val(DimType, i);
if (Val.is_zero())
continue;
if (DimPos >= 0 || (DimType == isl::dim::out && !Val.is_one()) ||
(DimType == isl::dim::in && !Val.is_negone()))
return -1;
DimPos = i;
}
return DimPos;
}

/// Check the form of the isl constraint.
///
/// Check that the @p DimInPos input dimension of the isl constraint
/// @p Constraint has a coefficient that is equal to negative one, the @p
/// DimOutPos has a coefficient that is equal to one and others
/// have coefficients equal to zero.
///
/// @param Constraint The isl constraint to be checked.
/// @param DimInPos The input dimension of the isl constraint.
/// @param DimOutPos The output dimension of the isl constraint.
/// @return isl_stat_ok in case the isl constraint satisfies
/// the requirements, isl_stat_error otherwise.
static isl_stat isMatMulOperandConstraint(isl::constraint Constraint,
int &DimInPos, int &DimOutPos) {
auto Val = Constraint.get_constant_val();
if (!isl_constraint_is_equality(Constraint.get()) || !Val.is_zero())
return isl_stat_error;
DimInPos = getMatMulConstraintDim(Constraint, isl::dim::in);
if (DimInPos < 0)
return isl_stat_error;
DimOutPos = getMatMulConstraintDim(Constraint, isl::dim::out);
if (DimOutPos < 0)
return isl_stat_error;
return isl_stat_ok;
}

/// Permute the two dimensions of the isl map.
///
/// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that
Expand Down Expand Up @@ -585,30 +530,49 @@ isl::map permuteDimensions(isl::map Map, isl::dim DimType, unsigned DstPos,
/// second output dimension.
/// @return True in case @p AccMap has the expected form and false,
/// otherwise.
static bool isMatMulOperandAcc(isl::map AccMap, int &FirstPos, int &SecondPos) {
int DimInPos[] = {FirstPos, SecondPos};
auto Lambda = [=, &DimInPos](isl::basic_map BasicMap) -> isl::stat {
auto Constraints = BasicMap.get_constraint_list();
if (isl_constraint_list_n_constraint(Constraints.get()) != 2)
return isl::stat::error;
for (int i = 0; i < 2; i++) {
auto Constraint =
isl::manage(isl_constraint_list_get_constraint(Constraints.get(), i));
int InPos, OutPos;
if (isMatMulOperandConstraint(Constraint, InPos, OutPos) ==
isl_stat_error ||
OutPos > 1 || (DimInPos[OutPos] >= 0 && DimInPos[OutPos] != InPos))
return isl::stat::error;
DimInPos[OutPos] = InPos;
}
return isl::stat::ok;
};
if (AccMap.foreach_basic_map(Lambda) != isl::stat::ok || DimInPos[0] < 0 ||
DimInPos[1] < 0)
static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos,
int &SecondPos) {

isl::space Space = AccMap.get_space();
isl::map Universe = isl::map::universe(Space);

if (Space.dim(isl::dim::out) != 2)
return false;
FirstPos = DimInPos[0];
SecondPos = DimInPos[1];
return true;

// MatMul has the form:
// for (i = 0; i < N; i++)
// for (j = 0; j < M; j++)
// for (k = 0; k < P; k++)
// C[i, j] += A[i, k] * B[k, j]
//
// Permutation of three outer loops: 3! = 6 possibilities.
int FirstDims[] = {0, 0, 1, 1, 2, 2};
int SecondDims[] = {1, 2, 2, 0, 0, 1};
for (int i = 0; i < 6; i += 1) {
auto PossibleMatMul =
Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0)
.equate(isl::dim::in, SecondDims[i], isl::dim::out, 1);

AccMap = AccMap.intersect_domain(Domain);
PossibleMatMul = PossibleMatMul.intersect_domain(Domain);

// If AccMap spans entire domain (Non-partial write),
// compute FirstPos and SecondPos.
// If AccMap != PossibleMatMul here (the two maps have been gisted at
// this point), it means that the writes are not complete, or in other
// words, it is a Partial write and Partial writes must be rejected.
if (AccMap.is_equal(PossibleMatMul)) {
if (FirstPos != -1 && FirstPos != FirstDims[i])
continue;
FirstPos = FirstDims[i];
if (SecondPos != -1 && SecondPos != SecondDims[i])
continue;
SecondPos = SecondDims[i];
return true;
}
}

return false;
}

/// Does the memory access represent a non-scalar operand of the matrix
Expand All @@ -627,18 +591,16 @@ static bool isMatMulNonScalarReadAccess(MemoryAccess *MemAccess,
if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead())
return false;
auto AccMap = MemAccess->getLatestAccessRelation();
if (isMatMulOperandAcc(AccMap, MMI.i, MMI.j) && !MMI.ReadFromC &&
isl_map_n_basic_map(AccMap.get()) == 1) {
isl::set StmtDomain = MemAccess->getStatement()->getDomain();
if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && !MMI.ReadFromC) {
MMI.ReadFromC = MemAccess;
return true;
}
if (isMatMulOperandAcc(AccMap, MMI.i, MMI.k) && !MMI.A &&
isl_map_n_basic_map(AccMap.get()) == 1) {
if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && !MMI.A) {
MMI.A = MemAccess;
return true;
}
if (isMatMulOperandAcc(AccMap, MMI.k, MMI.j) && !MMI.B &&
isl_map_n_basic_map(AccMap.get()) == 1) {
if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) {
MMI.B = MemAccess;
return true;
}
Expand Down Expand Up @@ -758,8 +720,7 @@ static bool containsMatrMult(isl::map PartialSchedule, const Dependences *D,
if (!MemAccessPtr->isWrite())
return false;
auto AccMap = MemAccessPtr->getLatestAccessRelation();
if (isl_map_n_basic_map(AccMap.get()) != 1 ||
!isMatMulOperandAcc(AccMap, MMI.i, MMI.j))
if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j))
return false;
MMI.WriteToC = MemAccessPtr;
break;
Expand Down
@@ -0,0 +1,59 @@
; RUN: opt %loadPolly -polly-import-jscop -polly-import-jscop-postfix=transformed -polly-opt-isl -debug-only=polly-opt-isl -disable-output < %s 2>&1 | FileCheck %s
; REQUIRES: asserts
;
; void pattern_matching_based_opts_splitmap(double C[static const restrict 2][2], double A[static const restrict 2][784], double B[static const restrict 784][2]) {
; for (int i = 0; i < 2; i+=1)
; for (int j = 0; j < 2; j+=1)
; for (int k = 0; k < 784; k+=1)
; C[i][j] += A[i][k] * B[k][j];
;}
;
; Check that the pattern matching detects the matrix multiplication pattern
; when the AccMap cannot be reduced to a single disjunct.
;
; CHECK: The matrix multiplication pattern was detected
;
; ModuleID = 'pattern_matching_based_opts_splitmap.ll'
;
; Function Attrs: noinline nounwind uwtable
define void @pattern_matching_based_opts_splitmap([2 x double]* noalias dereferenceable(32) %C, [784 x double]* noalias dereferenceable(12544) %A, [2 x double]* noalias dereferenceable(12544) %B) {
entry:
br label %for.body

for.body: ; preds = %entry, %for.inc21
%i = phi i64 [ 0, %entry ], [ %add22, %for.inc21 ]
br label %for.body3

for.body3: ; preds = %for.body, %for.inc18
%j = phi i64 [ 0, %for.body ], [ %add19, %for.inc18 ]
br label %for.body6

for.body6: ; preds = %for.body3, %for.body6
%k = phi i64 [ 0, %for.body3 ], [ %add17, %for.body6 ]
%arrayidx8 = getelementptr inbounds [784 x double], [784 x double]* %A, i64 %i, i64 %k
%tmp6 = load double, double* %arrayidx8, align 8
%arrayidx12 = getelementptr inbounds [2 x double], [2 x double]* %B, i64 %k, i64 %j
%tmp10 = load double, double* %arrayidx12, align 8
%mul = fmul double %tmp6, %tmp10
%arrayidx16 = getelementptr inbounds [2 x double], [2 x double]* %C, i64 %i, i64 %j
%tmp14 = load double, double* %arrayidx16, align 8
%add = fadd double %tmp14, %mul
store double %add, double* %arrayidx16, align 8
%add17 = add nsw i64 %k, 1
%cmp5 = icmp slt i64 %add17, 784
br i1 %cmp5, label %for.body6, label %for.inc18

for.inc18: ; preds = %for.body6
%add19 = add nsw i64 %j, 1
%cmp2 = icmp slt i64 %add19, 2
br i1 %cmp2, label %for.body3, label %for.inc21

for.inc21: ; preds = %for.inc18
%add22 = add nsw i64 %i, 1
%cmp = icmp slt i64 %add22, 2
br i1 %cmp, label %for.body, label %for.end23

for.end23: ; preds = %for.inc21
ret void
}

@@ -0,0 +1,46 @@
{
"arrays" : [
{
"name" : "MemRef_A",
"sizes" : [ "*", "784" ],
"type" : "double"
},
{
"name" : "MemRef_B",
"sizes" : [ "*", "2" ],
"type" : "double"
},
{
"name" : "MemRef_C",
"sizes" : [ "*", "2" ],
"type" : "double"
}
],
"context" : "{ : }",
"name" : "%for.body---%for.end23",
"statements" : [
{
"accesses" : [
{
"kind" : "read",
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_A[i0, i2] }"
},
{
"kind" : "read",
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_B[i2, i1] }"
},
{
"kind" : "read",
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] }"
},
{
"kind" : "write",
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] }"
}
],
"domain" : "{ Stmt_for_body6[i0, i1, i2] : 0 <= i0 <= 1 and 0 <= i1 <= 1 and 0 <= i2 <= 783 }",
"name" : "Stmt_for_body6",
"schedule" : "{ Stmt_for_body6[i0, i1, i2] -> [i0, i1, i2] }"
}
]
}
@@ -0,0 +1,46 @@
{
"arrays" : [
{
"name" : "MemRef_A",
"sizes" : [ "*", "784" ],
"type" : "double"
},
{
"name" : "MemRef_B",
"sizes" : [ "*", "2" ],
"type" : "double"
},
{
"name" : "MemRef_C",
"sizes" : [ "*", "2" ],
"type" : "double"
}
],
"context" : "{ : }",
"name" : "%for.body---%for.end23",
"statements" : [
{
"accesses" : [
{
"kind" : "read",
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_A[i0, i2] }"
},
{
"kind" : "read",
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_B[i2, i1] }"
},
{
"kind" : "read",
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] }"
},
{
"kind" : "write",
"relation" : "{ Stmt_for_body6[i0, i1, i2] -> MemRef_C[i0, i1] : i2 <= 784 - i0 - i1; Stmt_for_body6[1, 1, 783] -> MemRef_C[1, 1] }"
}
],
"domain" : "{ Stmt_for_body6[i0, i1, i2] : 0 <= i0 <= 1 and 0 <= i1 <= 1 and 0 <= i2 <= 783 }",
"name" : "Stmt_for_body6",
"schedule" : "{ Stmt_for_body6[i0, i1, i2] -> [i0, i1, i2] }"
}
]
}

0 comments on commit d091bf8

Please sign in to comment.