Skip to content

Commit

Permalink
fix tpp linear loop of first/next kernel (#2561)
Browse files Browse the repository at this point in the history
  • Loading branch information
jianan-gu committed Feb 1, 2024
1 parent 901b377 commit ade4538
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions csrc/cpu/tpp/kernels/TPPGEMMKrnl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
namespace torch_ipex {
namespace tpp {

static int large_cache_opt = false;
static int use_at_vnni = false; // env2int("USE_AT_VNNI");
static int FT_OPT_SIZE = env2int("FT_OPT_SIZE", 256);
static int NCB_BLOCK_SIZE = env2int("NCB_BLOCK_SIZE", 64);
Expand Down Expand Up @@ -58,15 +57,16 @@ inline at::Tensor wt_tensor_for_first_token(at::Tensor& t) {
if (dim < 5)
return t;
auto sizes = t.sizes();
constexpr long RBS = 2;
constexpr long RBS = 4;
auto K1 = sizes[0];
if (K1 % RBS != 0)
return t;
auto C1 = sizes[1];
auto C2 = sizes[2];
auto K2 = sizes[3];
auto C3 = sizes[4];

if (K2 >= 32)
return t;
auto t_new = t.new_empty({K1 / RBS, C1, C2, RBS * K2, C3});
auto in = GetVLAPtr<T>(t, {RBS, C1, C2, K2 * C3});
auto out = GetVLAPtr<T>(t_new, {C1, C2, RBS, K2 * C3});
Expand Down Expand Up @@ -96,6 +96,7 @@ inline void tpp_linear_bias(
auto in_sizes = t_in.sizes();
auto wt_sizes = t_wt_.sizes();
auto BS = in_sizes[0] * in_sizes[1];
bool large_cache_opt = false;
if (BS > FT_OPT_SIZE) { // first token compute
if (wt_sizes[3] != 100) {
t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
Expand Down Expand Up @@ -183,13 +184,15 @@ inline void tpp_linear_no_bias(
auto in_sizes = t_in.sizes();
auto BS = in_sizes[0] * in_sizes[1];
auto wt_sizes = t_wt_.sizes();
bool large_cache_opt = false;
if (BS > FT_OPT_SIZE) { // first token compute
if (wt_sizes[3] != 100) {
t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
wt_sizes = t_wt_.sizes();
}
large_cache_opt = true;
}

auto C = in_sizes[2];

auto Nc = wt_sizes[1];
Expand Down Expand Up @@ -254,10 +257,12 @@ inline void tpp_linear_mul(
auto t_wt_ = t_wt;
auto in_sizes = t_in.sizes();
auto BS = in_sizes[0] * in_sizes[1];
bool large_cache_opt = false;
if (BS > FT_OPT_SIZE) { // first token compute
t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
large_cache_opt = true;
}

auto wt_sizes = t_wt_.sizes();
auto C = in_sizes[2];

Expand Down Expand Up @@ -348,6 +353,7 @@ inline void tpp_linear_add_add(
auto t_wt_ = t_wt;
auto in_sizes = t_in.sizes();
auto BS = in_sizes[0] * in_sizes[1];
bool large_cache_opt = false;
if (BS > FT_OPT_SIZE) { // first token compute
t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
large_cache_opt = true;
Expand Down Expand Up @@ -444,10 +450,12 @@ inline void tpp_linear_gelu(
auto t_wt_ = t_wt;
auto in_sizes = t_in.sizes();
auto BS = in_sizes[0] * in_sizes[1];
bool large_cache_opt = false;
if (BS > FT_OPT_SIZE) { // first token compute
t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
large_cache_opt = true;
}

auto wt_sizes = t_wt_.sizes();
auto C = in_sizes[2];

Expand Down Expand Up @@ -546,6 +554,7 @@ inline void tpp_fused_gate_up_proj(
auto t_wt_up_ = t_wt_up;
auto in_sizes = t_in.sizes();
auto BS = in_sizes[0] * in_sizes[1];
bool large_cache_opt = false;
if (BS > FT_OPT_SIZE) { // first token compute
t_wt_gate_ = wt_tensor_for_first_token<T>(t_wt_gate_);
t_wt_up_ = wt_tensor_for_first_token<T>(t_wt_up_);
Expand Down Expand Up @@ -670,10 +679,12 @@ inline void tpp_linear_add(
auto t_wt_ = t_wt;
auto in_sizes = t_in.sizes();
auto BS = in_sizes[0] * in_sizes[1];
bool large_cache_opt = false;
if (BS > FT_OPT_SIZE) { // first token compute
t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
large_cache_opt = true;
}

auto wt_sizes = t_wt_.sizes();
auto C = in_sizes[2];

Expand Down Expand Up @@ -761,10 +772,12 @@ inline void tpp_linear_silu(
auto t_wt_ = t_wt;
auto in_sizes = t_in.sizes();
auto BS = in_sizes[0] * in_sizes[1];
bool large_cache_opt = false;
if (BS > FT_OPT_SIZE) { // first token compute
t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
large_cache_opt = true;
}

auto wt_sizes = t_wt_.sizes();
auto C = in_sizes[2];

Expand Down Expand Up @@ -851,10 +864,12 @@ inline void tpp_linear_relu(
auto t_wt_ = t_wt;
auto in_sizes = t_in.sizes();
auto BS = in_sizes[0] * in_sizes[1];
bool large_cache_opt = false;
if (BS > FT_OPT_SIZE) { // first token compute
t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
large_cache_opt = true;
}

auto wt_sizes = t_wt_.sizes();
auto C = in_sizes[2];

Expand Down

0 comments on commit ade4538

Please sign in to comment.