Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#43 from tc20042008/xk-cinn-trivalop-fuse
Browse files Browse the repository at this point in the history
fix compiler complaints
  • Loading branch information
tc20042008 committed Mar 10, 2024
2 parents 8fc1551 + f59d49c commit 97735a1
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 151 deletions.
53 changes: 27 additions & 26 deletions paddle/cinn/api/op_topo_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,36 @@ struct PartialShardablePattern {};
// Reduce base pattern
template <typename T>
struct ReductionPattern {
explicit ReductionPattern(const ReductionPattern& other) = default;

using Nothing = std::monostate;
std::variant<Nothing, InjectiveSourcePattern<T>, PartialShardablePattern<T>> opt_inputs;
std::variant<Nothing, InjectiveSourcePattern<T>, PartialShardablePattern<T>> input;
SingleReductionOpPattern<T> reduction_op_pattern;

bool HasFusedInput() const {
return !std::holds_alternative<Nothing>(this->input);
}
};

// // Stmt := IS | R | PS
// // ops in StmtPattern will be lowered into a inlined cuda code.
// template <typename T>
// using StmtPattern = std::variant<InjectiveSourcePattern<T>, ReductionPattern<T>, PartialShardablePattern<T>>;

// // Stmts := [Stmt]
// template <typename T>
// using StmtsPattern = std::list<StmtPattern<T>>;

// // fuse rules:
// // 1. IS * IS -> IS
// // 2. PS * PS -> PS
// // 3. IS * PS -> PS
// // 4. IS * R -> R
// // 5. PS * R -> R

// // lifting rules:
// // 1. R -> Stmts
// // 2. PS -> Stmts
// // 3. Stmts * Stmts -> Stmts

// // OpTopoPattern := Error | Stmts
// template <typename T>
// using OpTopoPattern = std::variant<ErrorPattern<T>, StmtsPattern<T>>;
// Stmt := IS | R | PS
// ops in StmtPattern will be lowered into a inlined cuda code.
template <typename T>
using StmtPattern = std::variant<InjectiveSourcePattern<T>, ReductionPattern<T>, PartialShardablePattern<T>>;

// Stmts := [Stmt]
template <typename T>
using StmtsPattern = std::vector<StmtPattern<T>>;
// fuse rules:
// 1. IS * IS -> IS
// 2. PS * PS -> PS
// 3. IS * PS -> PS
// 4. IS * R -> R
// 5. PS * R -> R
// lifting rules:
// 1. R -> Stmts
// 2. PS -> Stmts
// 3. Stmts * Stmts -> Stmts
// OpTopoPattern := Error | Stmts
template <typename T>
using OpTopoPattern = std::variant<ErrorPattern<T>, StmtsPattern<T>>;

}
74 changes: 46 additions & 28 deletions paddle/cinn/frontend/group_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,37 @@
#include "paddle/cinn/api/op_topo_pattern.h"
#include "paddle/pir/include/core/operation.h"
#include "glog/logging.h"
#include "paddle/cinn/adt/adt.h"

namespace cinn::api {

struct FrontendPattern {};
namespace cinn::frontend {

template<>
struct ErrorPattern<FrontendPattern> {
explicit ErrorPattern(const ErrorPattern<FrontendPattern>& other) = default;
struct OpAndOperandIndex {
const pir::Operation* op;
const int operand_index;

std::vector<const pir::Operation*> ops;
std::string error_string;
bool operator==(const OpAndOperandIndex& other) const {
return this->op == other.op && this->operand_index == other.operand_index;
}
};

template<>
struct InjectiveSourcePattern<FrontendPattern> {
explicit InjectiveSourcePattern(const InjectiveSourcePattern<FrontendPattern>& other) = default;
std::vector<const pir::Operation*> ops;
};
}

namespace std {

template<>
struct SingleReductionOpPattern<FrontendPattern> {
explicit SingleReductionOpPattern(const SingleReductionOpPattern<FrontendPattern>& other) = default;
const pir::Operation* reduce_op;
struct hash<cinn::frontend::OpAndOperandIndex> {

size_t operator()(const cinn::frontend::OpAndOperandIndex& op_operand) const {
return cinn::adt::hash_combine(std::hash<const pir::Operation*>()(op_operand.op), op_operand.operand_index);
}
};

}

namespace cinn::frontend {

struct FrontendPattern {};

struct ShardableAxis {
int axis;
std::string axis_name;
Expand Down Expand Up @@ -100,29 +107,40 @@ struct ShardableAxesUtil {
};

struct ShardableAxesSignature {
using OpOperand = std::pair<const pir::Operation*, /*operand index*/int>;

ShardableAxes output_shardable_axes;
std::unordered_map<OpOperand, ShardableAxes> input_shardable_axes;
std::unordered_map<OpAndOperandIndex, ShardableAxes> input_shardable_axes;
};

}

namespace cinn::api {

template<>
struct PartialShardablePattern<FrontendPattern> {
explicit PartialShardablePattern(const PartialShardablePattern<FrontendPattern>& other) = default;
struct ErrorPattern<frontend::FrontendPattern> {
std::vector<const pir::Operation*> ops;
std::string error_string;
};

template<>
struct InjectiveSourcePattern<frontend::FrontendPattern> {
std::vector<const pir::Operation*> ops;
};

template<>
struct SingleReductionOpPattern<frontend::FrontendPattern> {
const pir::Operation* reduce_op;
};
template<>
struct PartialShardablePattern<frontend::FrontendPattern> {
std::vector<const pir::Operation*> ops;
ShardableAxesSignature shardable_axes_signature;
frontend::ShardableAxesSignature shardable_axes_signature;
};

}

namespace cinn::frontend {
using IS = api::InjectiveSourcePattern<api::FrontendPattern>;
using R = api::ReductionPattern<api::FrontendPattern>;
using PS = api::PartialShardablePattern<api::FrontendPattern>;

using StmtPattern = std::variant<IS, R, PS>;
using ErrorGroupPattern = api::ErrorPattern<api::FrontendPattern>;
using GroupPattern = std::variant<ErrorGroupPattern, StmtPattern>;
using ErrorGroupPattern = api::ErrorPattern<frontend::FrontendPattern>;
using GroupPattern = api::OpTopoPattern<frontend::FrontendPattern>;

}

0 comments on commit 97735a1

Please sign in to comment.