Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ bool pass::AdjustBrgemmCopyBLoopPorts::update_loop_info(
* 1) VNNI format is applied: KN4k for I8/U8, or KN2k for BF16
* 2) Zero padding is applied if N4k < 256 or N2k < 64
*/
if (brgemm_utils::with_repacking(brg->get_type()) && precision != element::f32 &&
loop_port.is_incremented()) {
if (brgemm_utils::with_repacking(brg->get_type()) && loop_port.is_incremented()) {
// K blocking loop: account for zero padding
if (loop_port.get_dim_idx() == 1) {
const auto ptr_incr = loop_desc.ptr_increment;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ std::tuple<size_t, size_t, size_t> BrgemmCPUBlocking::get_blocking_params(

size_t m_blk, n_blk, k_blk;
std::tie(m_blk, n_blk, k_blk) = BrgemmBlockingBase::get_blocking_params(brgemm_expr);
// Note: K,N blocking is functionally enabled, need to turn it on after blocking heuristic is updated to cover
// the low precision cases (ticket: 156014)
if (with_repacking(brgemm->get_type())) {
// [TODO]: K,N blocking is functionally enabled, need to turn it on after blocking heuristic is updated to cover
// the low precision cases (ticket: 156014)
// Please note that FP32 MatMul with `transposed_b=true` has type `with_repacking` despite the precision.
const auto precision = brgemm_expr->get_node()->get_input_element_type(1);
if (with_repacking(brgemm->get_type()) && precision != element::f32) {
n_blk = get_full_dim_value();
k_blk = get_full_dim_value();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,42 @@ TEST_F(BrgemmCPUBlockingTest, BlockingIsNotNeeded) {
}
}

TEST_F(BrgemmCPUBlockingTest, WithTransposeB) {
const ov::Dimension::value_type m = 384;
const ov::Dimension::value_type k = 1024;
const ov::Dimension::value_type n = 384;
const ov::PartialShape input_shape_a{1, 16, m, k};
const ov::PartialShape input_shape_b{1, 16, n, k};
const auto precision_a = ov::element::f32;
const auto precision_b = ov::element::f32;
const std::vector<size_t> layout_input{0, 1, 3, 2};

{
auto data_a = linear_ir->push_node<ov::opset10::Parameter>(precision_a, input_shape_a);
auto data_b = linear_ir->push_node<ov::opset10::Parameter>(precision_b, input_shape_b);
auto copy_b = linear_ir->push_node<BrgemmCopyB>(data_b.second, precision_a, BRGEMM_TYPE::REPACKING_ONLY, 0, 0, 0, layout_input);
init_expr_descriptors(*copy_b.first);

auto brgemm = linear_ir->push_node<BrgemmCPU>(data_a.second, copy_b.second, BRGEMM_TYPE::REPACKING_ONLY);
init_expr_descriptors(*brgemm.first);
auto result = linear_ir->push_node<ov::opset10::Result>(brgemm.second);
}
{
auto data_a = linear_ir_ref->push_node<ov::opset10::Parameter>(precision_a, input_shape_a);
auto data_b = linear_ir_ref->push_node<ov::opset10::Parameter>(precision_b, input_shape_b);
auto copy_b = linear_ir_ref->push_node<BrgemmCopyB>(data_b.second, precision_a, BRGEMM_TYPE::REPACKING_ONLY, 0, 0, 0, layout_input);
const auto copy_b_expr = *copy_b.first;
init_expr_descriptors(copy_b_expr, {{full_dim, full_dim}, {full_dim, full_dim}});

auto brgemm = linear_ir_ref->push_node<BrgemmCPU>(data_a.second, copy_b.second, BRGEMM_TYPE::REPACKING_ONLY);
const auto& brgemm_expr = *brgemm.first;
init_expr_descriptors(brgemm_expr, {{m_blk, k_blk}, {k_blk, n_blk}, {m_blk, n_blk}});
create_brgemm_loop_infos(linear_ir_ref, brgemm_expr, m, m_blk, k, k_blk, n, n_blk);
brgemm_expr->set_loop_ids({2, 1, 0});
auto result = linear_ir_ref->push_node<ov::opset10::Result>(brgemm.second);
}
}

TEST_F(BrgemmCPUBlockingTest, WithDataRepacking) {
// Skipped because K,N blocking is disabled until heuristic is updated (ticket: 156014)
GTEST_SKIP();
Expand Down
Loading