Skip to content

Commit

Permalink
graph: backend: compiler: ops: fix managed_matmul dispatch_avx condit…
Browse files Browse the repository at this point in the history
…ions and code path
  • Loading branch information
huanghaixin008 authored and vpirogov committed Apr 4, 2024
1 parent 5587f08 commit 0ca5bc5
Showing 1 changed file with 20 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,6 @@ managed_matmul_core_op_t::managed_matmul_core_op_t(
}
// record padded_K of input A for matmul_core
attrs_["temp.padded_A_K"] = std::make_shared<VConst>();

auto num_threads = runtime_config_t::get().get_num_threads();
sc_dim M = A_dims.front(); // A is always 2D
sc_dim K = A_dims.back(); // A is always 2D
sc_dim N = B_dims.back(); // B is always 2D
bool is_int8 = info_.inputs_[0]->details_.dtype_ == datatypes::u8
|| info_.inputs_[0]->details_.dtype_ == datatypes::s8;
if (!is_dynamic() && M <= 2 && K >= 4096
&& N >= 4096 // TODO(niuxiaoguang): K, N shapes are from gpt-j-6B
// and llama on SPR. Change them when necessary.
&& num_threads <= 32 && is_int8) {
attrs_["dispatch_avx"] = true;
}
}

std::vector<int> managed_matmul_core_op_t::query_prefetch(
Expand Down Expand Up @@ -119,6 +106,26 @@ void managed_matmul_core_op_t::generate_prefetcher_body_for_tensor(
}

body_generator_ptr managed_matmul_core_op_t::create_generator() {
COMPILE_ASSERT(
info_.inputs_.size() == 2, "managed_matmul_core expects 2 inputs");
auto &A_dims = info_.inputs_[0]->details_.get_plain_dims();
auto &B_dims = info_.inputs_[1]->details_.get_plain_dims();
COMPILE_ASSERT(A_dims.size() == 2 && B_dims.size() == 2,
"managed_matmul_core only supports 2d cases yet");

auto num_threads = runtime_config_t::get().get_num_threads();
sc_dim M = A_dims.front(); // A is always 2D
sc_dim K = A_dims.back(); // A is always 2D
sc_dim N = B_dims.back(); // B is always 2D
bool is_int8 = info_.inputs_[0]->details_.dtype_ == datatypes::u8
|| info_.inputs_[0]->details_.dtype_ == datatypes::s8;
if (!is_dynamic() && M <= 5 && K >= 4096
&& N >= 4096 // TODO(niuxiaoguang): K, N shapes are from gpt-j-6B
// and llama on SPR. Change them when necessary.
&& num_threads <= 32 && is_int8) {
attrs_["dispatch_avx"] = true;
}

auto mat_gen = utils::make_unique<gen_managed_matmul_core_t>(this,
graph::extract_detail_from_tensors(get_inputs()),
graph::extract_detail_from_tensors(get_outputs()));
Expand Down

0 comments on commit 0ca5bc5

Please sign in to comment.