From ade45387ecc4e707754de9db6fc2be0af186e2ba Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Thu, 1 Feb 2024 08:48:18 +0800 Subject: [PATCH] fix tpp linear loop of first/next kernel (#2561) --- csrc/cpu/tpp/kernels/TPPGEMMKrnl.h | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/csrc/cpu/tpp/kernels/TPPGEMMKrnl.h b/csrc/cpu/tpp/kernels/TPPGEMMKrnl.h index 37e97c615..08ea1c64b 100644 --- a/csrc/cpu/tpp/kernels/TPPGEMMKrnl.h +++ b/csrc/cpu/tpp/kernels/TPPGEMMKrnl.h @@ -16,7 +16,6 @@ namespace torch_ipex { namespace tpp { -static int large_cache_opt = false; static int use_at_vnni = false; // env2int("USE_AT_VNNI"); static int FT_OPT_SIZE = env2int("FT_OPT_SIZE", 256); static int NCB_BLOCK_SIZE = env2int("NCB_BLOCK_SIZE", 64); @@ -58,7 +57,7 @@ inline at::Tensor wt_tensor_for_first_token(at::Tensor& t) { if (dim < 5) return t; auto sizes = t.sizes(); - constexpr long RBS = 2; + constexpr long RBS = 4; auto K1 = sizes[0]; if (K1 % RBS != 0) return t; @@ -66,7 +65,8 @@ inline at::Tensor wt_tensor_for_first_token(at::Tensor& t) { auto C2 = sizes[2]; auto K2 = sizes[3]; auto C3 = sizes[4]; - + if (K2 >= 32) + return t; auto t_new = t.new_empty({K1 / RBS, C1, C2, RBS * K2, C3}); auto in = GetVLAPtr(t, {RBS, C1, C2, K2 * C3}); auto out = GetVLAPtr(t_new, {C1, C2, RBS, K2 * C3}); @@ -96,6 +96,7 @@ inline void tpp_linear_bias( auto in_sizes = t_in.sizes(); auto wt_sizes = t_wt_.sizes(); auto BS = in_sizes[0] * in_sizes[1]; + bool large_cache_opt = false; if (BS > FT_OPT_SIZE) { // first token compute if (wt_sizes[3] != 100) { t_wt_ = wt_tensor_for_first_token(t_wt_); @@ -183,6 +184,7 @@ inline void tpp_linear_no_bias( auto in_sizes = t_in.sizes(); auto BS = in_sizes[0] * in_sizes[1]; auto wt_sizes = t_wt_.sizes(); + bool large_cache_opt = false; if (BS > FT_OPT_SIZE) { // first token compute if (wt_sizes[3] != 100) { t_wt_ = wt_tensor_for_first_token(t_wt_); @@ -190,6 +192,7 @@ inline void tpp_linear_no_bias( } large_cache_opt = true; } + auto C = in_sizes[2]; auto Nc = wt_sizes[1]; @@ -254,10 +257,12 @@ inline void tpp_linear_mul( auto t_wt_ = t_wt; auto in_sizes = t_in.sizes(); auto BS = in_sizes[0] * in_sizes[1]; + bool large_cache_opt = false; if (BS > FT_OPT_SIZE) { // first token compute t_wt_ = wt_tensor_for_first_token(t_wt_); large_cache_opt = true; } + auto wt_sizes = t_wt_.sizes(); auto C = in_sizes[2]; @@ -348,6 +353,7 @@ inline void tpp_linear_add_add( auto t_wt_ = t_wt; auto in_sizes = t_in.sizes(); auto BS = in_sizes[0] * in_sizes[1]; + bool large_cache_opt = false; if (BS > FT_OPT_SIZE) { // first token compute t_wt_ = wt_tensor_for_first_token(t_wt_); large_cache_opt = true; @@ -444,10 +450,12 @@ inline void tpp_linear_gelu( auto t_wt_ = t_wt; auto in_sizes = t_in.sizes(); auto BS = in_sizes[0] * in_sizes[1]; + bool large_cache_opt = false; if (BS > FT_OPT_SIZE) { // first token compute t_wt_ = wt_tensor_for_first_token(t_wt_); large_cache_opt = true; } + auto wt_sizes = t_wt_.sizes(); auto C = in_sizes[2]; @@ -546,6 +554,7 @@ inline void tpp_fused_gate_up_proj( auto t_wt_up_ = t_wt_up; auto in_sizes = t_in.sizes(); auto BS = in_sizes[0] * in_sizes[1]; + bool large_cache_opt = false; if (BS > FT_OPT_SIZE) { // first token compute t_wt_gate_ = wt_tensor_for_first_token(t_wt_gate_); t_wt_up_ = wt_tensor_for_first_token(t_wt_up_); @@ -670,10 +679,12 @@ inline void tpp_linear_add( auto t_wt_ = t_wt; auto in_sizes = t_in.sizes(); auto BS = in_sizes[0] * in_sizes[1]; + bool large_cache_opt = false; if (BS > FT_OPT_SIZE) { // first token compute t_wt_ = wt_tensor_for_first_token(t_wt_); large_cache_opt = true; } + auto wt_sizes = t_wt_.sizes(); auto C = in_sizes[2]; @@ -761,10 +772,12 @@ inline void tpp_linear_silu( auto t_wt_ = t_wt; auto in_sizes = t_in.sizes(); auto BS = in_sizes[0] * in_sizes[1]; + bool large_cache_opt = false; if (BS > FT_OPT_SIZE) { // first token compute t_wt_ = wt_tensor_for_first_token(t_wt_); large_cache_opt = true; } + auto wt_sizes = t_wt_.sizes(); auto C = in_sizes[2]; @@ -851,10 +864,12 @@ inline void tpp_linear_relu( auto t_wt_ = t_wt; auto in_sizes = t_in.sizes(); auto BS = in_sizes[0] * in_sizes[1]; + bool large_cache_opt = false; if (BS > FT_OPT_SIZE) { // first token compute t_wt_ = wt_tensor_for_first_token(t_wt_); large_cache_opt = true; } + auto wt_sizes = t_wt_.sizes(); auto C = in_sizes[2];