Skip to content

Commit

Permalink
[PIR] add 3 case to build_cinn_pass_test (PaddlePaddle#58620)
Browse files Browse the repository at this point in the history
  • Loading branch information
DrRyanHuang committed Nov 3, 2023
1 parent b43645a commit 1a4fd36
Showing 1 changed file with 169 additions and 0 deletions.
169 changes: 169 additions & 0 deletions test/cpp/pir/cinn/build_cinn_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,172 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
CHECK_EQ(iter->name(), op_names[index++]);
}
}

std::shared_ptr<::pir::Program> BuildNoOpSupportCinnGraph() {
::pir::IrContext* ctx = ::pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();

auto program = std::make_shared<::pir::Program>(ctx);
::pir::Builder builder = ::pir::Builder(ctx, program->block());

// ones -> hardswish -> square -> unsqueeze
const std::vector<int64_t> shape = {64, 128};
const std::vector<int64_t> axis = {0};
auto ones_op_x = builder.Build<paddle::dialect::OnesOp>(
shape, phi::DataType::FLOAT32, phi::GPUPlace());
auto hardswish_op_y =
builder.Build<paddle::dialect::HardswishOp>(ones_op_x->result(0));
auto square_op_y =
builder.Build<paddle::dialect::SquareOp>(hardswish_op_y->result(0));
auto unsqueeze_op_x =
builder.Build<paddle::dialect::UnsqueezeOp>(square_op_y->result(0), axis);

return program;
}

TEST(BuildCinnPassTest, NoOpSupportCinn) {
auto origin_program = BuildNoOpSupportCinnGraph();
pir::IrContext* ctx = pir::IrContext::Instance();
pir::PassManager pm(ctx);
pm.AddPass(pir::CreateBuildCinnPass());
pm.EnablePassTiming();
pm.EnableIRPrinting();
CHECK_EQ(pm.Run(origin_program.get()), true);
LOG(INFO) << "after pass: " << *origin_program;

CHECK_EQ(origin_program->block()->size(), 5u); // Because of `FullIntArrayOp`

std::vector<std::string> op_names = {
paddle::dialect::OnesOp::name(),
paddle::dialect::HardswishOp::name(),
paddle::dialect::SquareOp::name(),
paddle::dialect::FullIntArrayOp::name(),
paddle::dialect::UnsqueezeOp::name(),
};
int index = 0;
for (auto iter : *origin_program->block()) {
CHECK_EQ(iter->name(), op_names[index++]);
}
}

std::shared_ptr<::pir::Program> BuildOneCinnSubgraph() {
::pir::IrContext* ctx = ::pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();

auto program = std::make_shared<::pir::Program>(ctx);
::pir::Builder builder = ::pir::Builder(ctx, program->block());

// full -> acosh -> relu -> square -> unsqueeze
const std::vector<int64_t> axis = {0};

const float value_one = 1.0;
const std::vector<int64_t> shape = {64, 128};
auto full_op_x = builder.Build<paddle::dialect::FullOp>(
shape, value_one, phi::DataType::FLOAT32, phi::GPUPlace());

auto acosh_op_x =
builder.Build<paddle::dialect::AcoshOp>(full_op_x->result(0));
auto relu_op_y =
builder.Build<paddle::dialect::ReluOp>(acosh_op_x->result(0));
auto square_op_y =
builder.Build<paddle::dialect::SquareOp>(relu_op_y->result(0));
auto unsqueeze_op_x =
builder.Build<paddle::dialect::UnsqueezeOp>(square_op_y->result(0), axis);
return program;
}

TEST(BuildCinnPassTest, OneCinnSubgraph) {
auto origin_program = BuildOneCinnSubgraph();
pir::IrContext* ctx = pir::IrContext::Instance();
pir::PassManager pm(ctx);
pm.AddPass(pir::CreateBuildCinnPass());
pm.EnablePassTiming();
pm.EnableIRPrinting();
CHECK_EQ(pm.Run(origin_program.get()), true);
LOG(INFO) << "after pass: " << *origin_program;

CHECK_EQ(origin_program->block()->size(), 4u);
pir::Operation* group_op = origin_program->block()->front();
pir::Block* group_block =
group_op->dyn_cast<cinn::dialect::GroupOp>().block();
CHECK_EQ(group_block->size(), 4u);

std::vector<std::string> op_names = {
paddle::dialect::FullOp::name(),
paddle::dialect::AcoshOp::name(),
paddle::dialect::ReluOp::name(),
pir::YieldOp::name(),
};
int index = 0;
for (auto iter : *group_block) {
CHECK_EQ(iter->name(), op_names[index++]);
}
}

std::shared_ptr<::pir::Program> BuildMultiCinnSubgraph() {
::pir::IrContext* ctx = ::pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();

auto program = std::make_shared<::pir::Program>(ctx);
::pir::Builder builder = ::pir::Builder(ctx, program->block());

// full -> acosh -> hardswish -> square -> unsqueeze -> relu
const std::vector<int64_t> axis = {0};

const float value_one = 1.0;
const std::vector<int64_t> shape = {64, 128};
auto full_op_x = builder.Build<paddle::dialect::FullOp>(
shape, value_one, phi::DataType::FLOAT32, phi::GPUPlace());

auto acosh_op_x =
builder.Build<paddle::dialect::AcoshOp>(full_op_x->result(0));
auto hardswish_op_y =
builder.Build<paddle::dialect::HardswishOp>(acosh_op_x->result(0));
auto square_op_y =
builder.Build<paddle::dialect::SquareOp>(hardswish_op_y->result(0));
auto unsqueeze_op_x =
builder.Build<paddle::dialect::UnsqueezeOp>(square_op_y->result(0), axis);
auto relu_op_y =
builder.Build<paddle::dialect::ReluOp>(unsqueeze_op_x->result(0));
return program;
}

TEST(BuildCinnPassTest, MultiCinnSubgraph) {
auto origin_program = BuildMultiCinnSubgraph();
pir::IrContext* ctx = pir::IrContext::Instance();
pir::PassManager pm(ctx);
pm.AddPass(pir::CreateBuildCinnPass());
pm.EnablePassTiming();
pm.EnableIRPrinting();
CHECK_EQ(pm.Run(origin_program.get()), true);
LOG(INFO) << "after pass: " << *origin_program;

CHECK_EQ(origin_program->block()->size(), 6u);
pir::Operation* group_op = origin_program->block()->front();
pir::Block* group_block =
group_op->dyn_cast<cinn::dialect::GroupOp>().block();
CHECK_EQ(group_block->size(), 3u);

std::vector<std::string> op_names_front = {
paddle::dialect::FullOp::name(),
paddle::dialect::AcoshOp::name(),
pir::YieldOp::name(),
};
int index = 0;
for (auto iter : *group_block) {
CHECK_EQ(iter->name(), op_names_front[index++]);
}

group_op = origin_program->block()->back();
group_block = group_op->dyn_cast<cinn::dialect::GroupOp>().block();
CHECK_EQ(group_block->size(), 2u);

std::vector<std::string> op_names_back = {
paddle::dialect::ReluOp::name(),
pir::YieldOp::name(),
};
index = 0;
for (auto iter : *group_block) {
CHECK_EQ(iter->name(), op_names_back[index++]);
}
}

0 comments on commit 1a4fd36

Please sign in to comment.