Skip to content

Commit

Permalink
graph: backend: compiler: ir: graph: fix infer binding axis logic for…
Browse files Browse the repository at this point in the history
… transpose semantic tv
  • Loading branch information
yifeizh2 authored and vpirogov committed Oct 17, 2023
1 parent 57e14b5 commit 4207105
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ void excess_tensor_view_elimination(sc_graph_t &graph, const context_ptr &ctx) {
del_node->remove();
}
if (pre_node->isa<tensor_view_op_t>()) { pre_node->remove(); }
if (node->attrs_.has_key("order")) {
node->attrs_.remove("order");
}
}
vis->update_state_for_visited(node);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,10 @@ void convert_to_tensor_view(sc_graph_t &graph, const context_ptr &ctx) {
tensor_view_out->producer_owner_ = nullptr;
auto view = graph.make("tensor_view", node->get_inputs(),
{tensor_view_out},
{{"shape", tensor_view_out->details_.get_blocking_dims()}});
{{"shape", tensor_view_out->details_.get_blocking_dims()},
{"order",
node->attrs_.get<std::vector<int>>(
"order")}});
view->copy_dispatch_key_set_from_op(node);
node->replace_uses_with_and_remove(view);
vis->update_state_for_visited(view);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1005,11 +1005,31 @@ void tensor_view_op_t::pre_slice_ranges(
}
}

// transpose_axis_map stores the transpose relation of src_axis --> dst_axis
bound_axis infer_tensor_view_binding_axis(const bound_axis &src_axis,
const sc_dims &src_dims, const sc_dims &dst_dims,
const std::vector<int> &expand_dims = {}) {
const std::vector<int> &expand_dims = {},
const std::vector<int> &transpose_axis_map = {}) {
bound_axis dst_axis, tv_axis_map;

if (!transpose_axis_map.empty()) {
bound_axis real_src_axis;
COMPILE_ASSERT(src_dims.size() == dst_dims.size()
&& src_dims.size() == transpose_axis_map.size(),
"src dims, dst dims, and transpose_axis_map shall have the "
"same length.")
for (auto &bd_ax : src_axis) {
std::vector<int> ret;
for (auto &ax : bd_ax) {
COMPILE_ASSERT(ax < static_cast<int>(transpose_axis_map.size()),
"ax should be less then transpose_axis_map size")
ret.emplace_back(transpose_axis_map[ax]);
}
real_src_axis.emplace_back(ret);
}
return real_src_axis;
}

sc_dims acc_src_dims(src_dims.size()), acc_dst_dims(dst_dims.size());
sc_dim tmp_acc = 1;
std::transform(src_dims.begin(), src_dims.end(), acc_src_dims.begin(),
Expand Down Expand Up @@ -1071,8 +1091,13 @@ void tensor_view_op_t::infer_binding_axis(bound_axis_map &bdax_map) {
// dst
auto dst_plain_dims = info_.outputs_[0]->details_.get_plain_dims();
auto ths = this;
auto plain_bd_axis = infer_tensor_view_binding_axis(
known_axis_map[0], src_plain_dims, dst_plain_dims);
auto order = attrs_.get_or_else("order", std::vector<int> {});
std::vector<int> axis_mapping(order.size(), 0);
for (size_t i = 0; i < order.size(); ++i) {
axis_mapping[order[i]] = i;
}
auto plain_bd_axis = infer_tensor_view_binding_axis(known_axis_map[0],
src_plain_dims, dst_plain_dims, std::vector<int> {}, axis_mapping);
bdax_map.get(get_outputs()[0]) = plain_bd_axis;
set_unknown_axis_binding(this, known_axis_map, bdax_map);
}
Expand All @@ -1092,7 +1117,8 @@ void tensor_view_op_t::pre_binding_axis(bound_axis_map &bdax_map) {
auto ths = this;
auto plain_bd_axis = infer_tensor_view_binding_axis(outaxis,
dst_plain_dims, src_plain_dims,
attrs_.get_or_else("expand_dim", std::vector<int> {}));
attrs_.get_or_else("expand_dim", std::vector<int> {}),
attrs_.get_or_else("order", std::vector<int> {}));
inpaxis = plain_bd_axis;
if (auto bd_op
= input->producer_owner_
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2380,3 +2380,125 @@ TEST(GCCore_CPU_graph_mixed_partition_cpp,
)";
EXPECT_EQ(ss.str(), expected_str);
}

TEST(GCCore_CPU_graph_mixed_partition_cpp,
TransposeSemanticTensorViewBindAxis) {
SET_THREADS_OR_SKIP(64);
REQUIRE_AMX();

auto ctx = std::make_shared<context_t>(*get_test_ctx());
ctx->flags_.mixed_fusion_ = true;
ctx->flags_.use_cost_model_ = true;

sc_graph_t mlp_graph;
auto sigmoid_backprop_in0 = mlp_graph.make_input(
{graph_tensor::make({128, 1}, sc_data_format_t())});
auto sigmoid_backprop_in1 = mlp_graph.make_input(
{graph_tensor::make({128, 1}, sc_data_format_t())});
auto transpose_in0 = mlp_graph.make_input(
{graph_tensor::make({256, 1}, sc_data_format_t())});
auto relu_backprop_in0 = mlp_graph.make_input(
{graph_tensor::make({128, 256}, sc_data_format_t())});
auto transpose_in1 = mlp_graph.make_input(
{graph_tensor::make({512, 256}, sc_data_format_t())});
auto transpose_in2 = mlp_graph.make_input(
{graph_tensor::make({1024, 512}, sc_data_format_t())});
auto relu_backprop_in2 = mlp_graph.make_input(
{graph_tensor::make({128, 512}, sc_data_format_t())});
auto matmul_in1 = mlp_graph.make_input(
{graph_tensor::make({1024, 1024}, sc_data_format_t())});
auto matmul_in2 = mlp_graph.make_input(
{graph_tensor::make({128, 1024}, sc_data_format_t())});

auto sigmoid_backprop = mlp_graph.make("sigmoid_backprop",
{sigmoid_backprop_in0->get_outputs()[0],
sigmoid_backprop_in1->get_outputs()[0]},
{}, {{"use_dst", true}});
auto static_transpose0
= mlp_graph.make("transpose", {transpose_in0->get_outputs()[0]}, {},
{{"order", std::vector<int> {1, 0}}});
auto matmul0 = mlp_graph.make("matmul",
{sigmoid_backprop->get_outputs()[0],
static_transpose0->get_outputs()[0]},
{}, {});
auto static_transpose0_0
= mlp_graph.make("transpose", {matmul0->get_outputs()[0]}, {},
{{"order", std::vector<int> {1, 0}}});
auto matmul0_0 = mlp_graph.make("matmul",
{static_transpose0_0->get_outputs()[0],
relu_backprop_in0->get_outputs()[0]},
{}, {});
auto relu_backprop1 = mlp_graph.make("relu_backprop",
{relu_backprop_in0->get_outputs()[0], matmul0->get_outputs()[0]},
{}, {{"use_dst", true}});
auto static_transpose1
= mlp_graph.make("transpose", {transpose_in1->get_outputs()[0]}, {},
{{"order", std::vector<int> {1, 0}}});
auto static_transpose2
= mlp_graph.make("transpose", {relu_backprop1->get_outputs()[0]},
{}, {{"order", std::vector<int> {1, 0}}});
auto matmul1 = mlp_graph.make("matmul",
{relu_backprop1->get_outputs()[0],
static_transpose1->get_outputs()[0]},
{}, {});
auto matmul2 = mlp_graph.make("matmul",
{static_transpose2->get_outputs()[0],
relu_backprop_in2->get_outputs()[0]},
{}, {});
auto relu_backprop2 = mlp_graph.make("relu_backprop",
{relu_backprop_in2->get_outputs()[0], matmul1->get_outputs()[0]},
{}, {{"use_dst", true}});
auto static_transpose3
= mlp_graph.make("transpose", {transpose_in2->get_outputs()[0]}, {},
{{"order", std::vector<int> {1, 0}}});
auto static_transpose4
= mlp_graph.make("transpose", {relu_backprop2->get_outputs()[0]},
{}, {{"order", std::vector<int> {1, 0}}});
auto matmul3 = mlp_graph.make("matmul",
{relu_backprop2->get_outputs()[0],
static_transpose3->get_outputs()[0]},
{}, {});
auto matmul4 = mlp_graph.make("matmul",
{static_transpose4->get_outputs()[0], matmul_in2->get_outputs()[0]},
{}, {});
auto matmul5 = mlp_graph.make("matmul",
{matmul3->get_outputs()[0], matmul_in1->get_outputs()[0]}, {}, {});
auto static_transpose5
= mlp_graph.make("transpose", {matmul3->get_outputs()[0]}, {},
{{"order", std::vector<int> {1, 0}}});
auto matmul6 = mlp_graph.make("matmul",
{static_transpose5->get_outputs()[0], matmul_in2->get_outputs()[0]},
{}, {});
auto matmul7 = mlp_graph.make("matmul",
{matmul5->get_outputs()[0], matmul_in2->get_outputs()[0]}, {}, {});

mlp_graph.make_output({matmul0_0->get_outputs()[0]});
mlp_graph.make_output({matmul2->get_outputs()[0]});
mlp_graph.make_output({matmul4->get_outputs()[0]});
mlp_graph.make_output({matmul6->get_outputs()[0]});
mlp_graph.make_output({matmul7->get_outputs()[0]});

graph_driver(mlp_graph, ctx);
std::stringstream ss;
print_graph(mlp_graph, ss, true);
// The reduce op could not be split
std::string expected_str_spr
= R"(graph(v0: f32[128, 1], v1: f32[128, 1], v2: f32[256, 1], v3: f32[128, 256], v4: f32[512, 256], v5: f32[1024, 512], v6: f32[128, 512], v7: f32[1024, 1024], v8: f32[128, 1024]) -> [v9: f32[256, 256], v10: f32[256, 512], v11: f32[512, 1024], v12: f32[1024, 1024], v13: f32[128, 1024]] {
[v14: f32[32, 32, 4, 16]] = outerloop_32X32X4_partition_reorder_select_one(v6)
[v15: f32[1024, 512]] = tensor_view(v5)
[v16: f32[512, 1024]] = reorder(v15)
[v17: f32[512, 256]] = tensor_view(v4)
[v18: f32[256, 512]] = reorder(v17)
[v19: f32[32, 16, 4, 16]] = outerloop_32X16X4_partition_reorder_select_one(v3)
[v20: f32[1, 256]] = tensor_view(v2)
[v21: f32[128, 1]] = outerloop_128_partition_mul_sub_mul(v0, v1)
[v22: f32[64, 8, 4, 16], v23: f32[64, 8, 4, 16], v24: f32[32, 32, 4, 16]] = outerloop_32_partition_managed_matmul_core_tensor_view_reorder_mul_tensor_view_reorder_managed_matmul_core_mul(v21, v20, v19, v18, v14)
[v25: f32[64, 8, 16, 16], v13: f32[128, 1024]] = outerloop_8_partition_reorder_managed_matmul_core_tensor_view_reorder_managed_matmul_core_managed_matmul_core_reorder(v24, v16, v7, v8)
[v26: f32[32, 32, 4, 16]] = tensor_view(v24)
[v11: f32[512, 1024], v12: f32[1024, 1024]] = outerloop_32X2X1_partition_reorder_managed_matmul_core_reorder_managed_matmul_core_reorder(v26, v8, v25)
[v10: f32[256, 512]] = outerloop_64X1X1X1X1_partition_managed_matmul_core_reorder(v23, v6)
[v9: f32[256, 256]] = outerloop_64X1X1X1X1_partition_managed_matmul_core_reorder(v22, v3)
}
)";
EXPECT_EQ(ss.str(), expected_str_spr);
}

0 comments on commit 4207105

Please sign in to comment.