Skip to content

Commit

Permalink
[release/2.2] Fuse gate_proj and up_proj in MLP of LLaMA (#2469)
Browse files Browse the repository at this point in the history
* Fuse gate_proj and up_proj in MLP of LLaMA (#2430)

* Fuse gate_proj and up_proj in MLP of LLaMA

* fix clang-format

* Update run_quantization.py (#2471)

---------

Co-authored-by: jianan-gu <jianan.gu@intel.com>
  • Loading branch information
chunyuan-w and jianan-gu committed Jan 16, 2024
1 parent de99dd7 commit aeaeba4
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 8 deletions.
21 changes: 21 additions & 0 deletions csrc/cpu/aten/TPPGEMM.cpp
Expand Up @@ -9,6 +9,7 @@ namespace cpu {
IPEX_DEFINE_DISPATCH(tpp_linear_nobias_kernel_stub);
IPEX_DEFINE_DISPATCH(tpp_linear_bias_kernel_stub);
IPEX_DEFINE_DISPATCH(tpp_linear_gelu_kernel_stub);
IPEX_DEFINE_DISPATCH(tpp_fused_gate_up_proj_kernel_stub);
IPEX_DEFINE_DISPATCH(tpp_linear_silu_kernel_stub);
IPEX_DEFINE_DISPATCH(tpp_linear_relu_kernel_stub);
IPEX_DEFINE_DISPATCH(tpp_linear_add_kernel_stub);
Expand Down Expand Up @@ -38,6 +39,17 @@ at::Tensor tpp_linear_gelu_forward_cpu(
return tpp_linear_gelu_kernel_stub(kCPU, t_in, t_wt, t_bias);
}

at::Tensor tpp_fused_gate_up_proj_forward_cpu(
const at::Tensor& t_in,
const at::Tensor& t_wt_gate,
const at::Tensor& t_bias_gate,
const at::Tensor& t_wt_up,
const at::Tensor& t_bias_up,
c10::optional<int64_t> out_features) {
return tpp_fused_gate_up_proj_kernel_stub(
kCPU, t_in, t_wt_gate, t_bias_gate, t_wt_up, t_bias_up);
}

at::Tensor tpp_linear_silu_forward_cpu(
const at::Tensor& t_in,
const at::Tensor& t_wt,
Expand Down Expand Up @@ -117,6 +129,15 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
torch_ipex::cpu::tpp_linear_gelu_forward_cpu);
}

TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
m.def(
"tpp_fused_gate_up_proj(Tensor t_in, Tensor t_wt_gate, Tensor t_bias_gate, Tensor t_wt_up, Tensor t_bias_up,int? out_features=None)-> Tensor out");
m.impl(
"tpp_fused_gate_up_proj",
c10::DispatchKey::CPU,
torch_ipex::cpu::tpp_fused_gate_up_proj_forward_cpu);
}

TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
m.def(
"tpp_linear_add_add(Tensor t_in, Tensor t_in1, Tensor t_in2, Tensor t_wt, Tensor t_bias, float scale, int? out_features=None)-> Tensor out");
Expand Down
18 changes: 18 additions & 0 deletions csrc/cpu/aten/TPPGEMM.h
Expand Up @@ -24,6 +24,14 @@ at::Tensor tpp_linear_gelu_forward_cpu(
const at::Tensor& t_bias,
c10::optional<int64_t> out_features);

at::Tensor tpp_fused_gate_up_proj_forward_cpu(
const at::Tensor& t_in,
const at::Tensor& t_wt_gate,
const at::Tensor& t_bias_gate,
const at::Tensor& t_wt_up,
const at::Tensor& t_bias_up,
c10::optional<int64_t> out_features);

at::Tensor tpp_linear_silu_forward_cpu(
const at::Tensor& t_in,
const at::Tensor& t_wt,
Expand Down Expand Up @@ -71,6 +79,13 @@ using tpp_linear_bias_kernel_impl_fn =
using tpp_linear_gelu_kernel_impl_fn =
at::Tensor (*)(const at::Tensor&, const at::Tensor&, const at::Tensor&);

using tpp_fused_gate_up_proj_kernel_impl_fn = at::Tensor (*)(
const at::Tensor&,
const at::Tensor&,
const at::Tensor&,
const at::Tensor&,
const at::Tensor&);

using tpp_linear_silu_kernel_impl_fn =
at::Tensor (*)(const at::Tensor&, const at::Tensor&, const at::Tensor&);

Expand Down Expand Up @@ -105,6 +120,9 @@ IPEX_DECLARE_DISPATCH(
IPEX_DECLARE_DISPATCH(
tpp_linear_gelu_kernel_impl_fn,
tpp_linear_gelu_kernel_stub);
IPEX_DECLARE_DISPATCH(
tpp_fused_gate_up_proj_kernel_impl_fn,
tpp_fused_gate_up_proj_kernel_stub);
IPEX_DECLARE_DISPATCH(
tpp_linear_silu_kernel_impl_fn,
tpp_linear_silu_kernel_stub);
Expand Down
35 changes: 35 additions & 0 deletions csrc/cpu/aten/kernels/TPPGEMMKrnl.cpp
Expand Up @@ -87,6 +87,38 @@ at::Tensor tpp_linear_gelu_kernel_impl(
return t_out;
}

at::Tensor tpp_fused_gate_up_proj_kernel_impl(
const at::Tensor& t_in,
const at::Tensor& t_wt_gate,
const at::Tensor& t_bias_gate,
const at::Tensor& t_wt_up,
const at::Tensor& t_bias_up) {
auto sizes = t_in.sizes().vec();
AT_ASSERT(
t_wt_gate.sizes() == t_wt_up.sizes(),
"Expect t_wt_gate.sizes() == t_wt_up.sizes()");
auto wt_sizes = t_wt_gate.sizes();
sizes[2] = wt_sizes[0] * wt_sizes[3];

auto t_out = t_in.new_empty(sizes);

auto dt = t_wt_gate.dtype();
if (dt == at::kFloat) {
torch_ipex::tpp::tpp_fused_gate_up_proj<float>(
t_in, t_wt_gate, t_bias_gate, t_wt_up, t_bias_up, t_out);
} else if (dt == at::kBFloat16) {
torch_ipex::tpp::tpp_fused_gate_up_proj<at::BFloat16>(
t_in, t_wt_gate, t_bias_gate, t_wt_up, t_bias_up, t_out);
} else {
AT_ASSERT(
0,
"TPP does not support current weight dtype %s:%d\n",
__FILE__,
__LINE__);
}
return t_out;
}

at::Tensor tpp_linear_silu_kernel_impl(
const at::Tensor& t_in,
const at::Tensor& t_wt,
Expand Down Expand Up @@ -219,6 +251,9 @@ IPEX_REGISTER_DISPATCH(
IPEX_REGISTER_DISPATCH(
tpp_linear_gelu_kernel_stub,
&tpp_linear_gelu_kernel_impl);
IPEX_REGISTER_DISPATCH(
tpp_fused_gate_up_proj_kernel_stub,
&tpp_fused_gate_up_proj_kernel_impl);
IPEX_REGISTER_DISPATCH(
tpp_linear_relu_kernel_stub,
&tpp_linear_relu_kernel_impl);
Expand Down
138 changes: 138 additions & 0 deletions csrc/cpu/tpp/kernels/TPPGEMMKrnl.h
Expand Up @@ -42,6 +42,9 @@ REGISTER_LOCAL_SCOPE(
REGISTER_LOCAL_SCOPE(
tpp_linear_silu_krnl,
"tpp_linear_silu_krnl"); // linear bias + silu
REGISTER_LOCAL_SCOPE(
tpp_fused_gate_up_proj_krnl,
"tpp_fused_gate_up_proj_krnl"); // fused gate_proj and up_proj
REGISTER_LOCAL_SCOPE(
tpp_linear_relu_krnl,
"tpp_linear_relu_krnl"); // linear bias + relu
Expand Down Expand Up @@ -521,6 +524,141 @@ inline void tpp_linear_gelu(
}
}

// Fused kernel for the gate_proj and the up_proj related computation in the MLP
// of LLaMA. The ref computation of the kernel is:
// act_fn(gate_proj(x)) * up_proj(x) where act_fn is silu, gate_proj and
// up_proj are two nn.Linear with the same weight shapes and bias = False.
// t_in is the input activation
// t_wt_gate is the prepacked weight of the gate_proj
// t_wt_up is the prepacked weight of the up_proj
// t_bias_gate is the bias of the gate_proj
// t_bias_up is the bias of the up_proj
// t_out is the output result of the kernel
template <typename T>
inline void tpp_fused_gate_up_proj(
const at::Tensor& t_in,
const at::Tensor& t_wt_gate,
const at::Tensor& t_bias_gate,
const at::Tensor& t_wt_up,
const at::Tensor& t_bias_up,
at::Tensor& t_out) {
auto t_wt_gate_ = t_wt_gate;
auto t_wt_up_ = t_wt_up;
auto in_sizes = t_in.sizes();
auto BS = in_sizes[0] * in_sizes[1];
if (BS > FT_OPT_SIZE) { // first token compute
t_wt_gate_ = wt_tensor_for_first_token<T>(t_wt_gate_);
t_wt_up_ = wt_tensor_for_first_token<T>(t_wt_up_);
large_cache_opt = true;
}

auto wt_sizes = t_wt_gate_.sizes();
auto C = in_sizes[2];

auto Nc = wt_sizes[1];
auto Hc = C / Nc;
auto Nk = wt_sizes[0];
auto Hk = wt_sizes[3];
auto K = Nk * Hk;

auto t_wt_gate_V =
torch_ipex::tpp::wt_tensor_for_fwd(Nk, Hk, Nc, Hc, t_wt_gate_);
auto t_wt_up_V = torch_ipex::tpp::wt_tensor_for_fwd(Nk, Hk, Nc, Hc, t_wt_up_);

// This is used to store the intermediate result of the up_proj layer
auto t_out_tmp = at::empty_like(t_out);

auto in = GetVLAPtr<T>(t_in, {Nc, Hc});
auto wt_gate_V = GetVLAPtr<T>(t_wt_gate_V, {Nc, Hc * Hk});
auto wt_up_V = GetVLAPtr<T>(t_wt_up_V, {Nc, Hc * Hk});
auto bias_gate = GetVLAPtr<T>(t_bias_gate, {Hk});
auto bias_up = GetVLAPtr<T>(t_bias_up, {Hk});
auto out = GetVLAPtr<T>(t_out, {Nk, Hk});
auto out_tmp = GetVLAPtr<T>(t_out_tmp, {Nk, Hk});

auto Ncb = Nc;
auto BSb = 64L;
auto rem = BS % 64;
if (large_cache_opt)
Ncb = NCB_BLOCK_SIZE;

bool with_bias_gate = (t_bias_gate.numel() > 0);
bool with_bias_up = (t_bias_up.numel() > 0);
auto copy_bias_tpp = SCOPEIT(CpyBiasTPP<T>(BSb, Hk, K), BIAS);
auto copy_bias_tpp_rem = SCOPEIT(CpyBiasTPP<T>(rem, Hk, K), BIAS);
auto zero_tpp = SCOPEIT(SetZeroTPP<T>(BSb, Hk, K), EW_ZERO);
auto zero_tpp_rem = SCOPEIT(SetZeroTPP<T>(rem, Hk, K), EW_ZERO);
auto brgemm_tpp = SCOPEITGEMM(
(BrgemmTPP<T, T>(BSb, Hk, Hc, Hc, Hk * Hc, C, Hk, K, 1.0, 0, Ncb)));
auto brgemm_tpp_rem = SCOPEITGEMM(
(BrgemmTPP<T, T>(rem, Hk, Hc, Hc, Hk * Hc, C, Hk, K, 1.0, 0, Ncb)));
auto silu_fwd_tpp = SCOPEIT(SiLUFwdTPP<T>(BSb, Hk, K, K), ACT);
auto silu_fwd_tpp_rem = SCOPEIT(SiLUFwdTPP<T>(rem, Hk, K, K), ACT);
auto mul_tpp = SCOPEIT((MulTPP<T, T>(BSb, Hk, K, K)), EW_MUL);
auto mul_tpp_rem = SCOPEIT((MulTPP<T, T>(rem, Hk, K, K)), EW_MUL);

{
RECORD_SCOPE(tpp_fused_gate_up_proj_krnl, {t_in, t_wt_gate_V});

auto loop_scheme = large_cache_opt ? GEMM_LOOP_SCHEME : "aCb";
auto igemm_loop = torch_ipex::tpp::ThreadedLoop<3>(
{{0, Nc, Ncb, false}, {0, BS, BSb}, {Nk}}, loop_scheme);
igemm_loop(
[&](int* ind) {
int nc = ind[0], s1 = ind[1], nk = ind[2];
auto count = nc + Ncb < Nc ? Ncb : Nc - nc;
bool is_rem = (s1 + BSb > BS);
if (!is_rem) {
if (nc == 0) {
if (with_bias_gate) {
copy_bias_tpp(bias_gate[nk], out[s1][nk]);
} else {
zero_tpp(out[s1][nk]);
}

if (with_bias_up) {
copy_bias_tpp(bias_up[nk], out_tmp[s1][nk]);
} else {
zero_tpp(out_tmp[s1][nk]);
}
}
brgemm_tpp(in[s1][nc], wt_gate_V[nk][nc], out[s1][nk], count, true);
brgemm_tpp(
in[s1][nc], wt_up_V[nk][nc], out_tmp[s1][nk], count, true);
if (!(nc + Ncb < Nc)) { // last nc iter
silu_fwd_tpp(out[s1][nk], out[s1][nk]);
mul_tpp(out[s1][nk], out_tmp[s1][nk], out[s1][nk]);
}
} else {
if (nc == 0) {
if (with_bias_gate) {
copy_bias_tpp_rem(bias_gate[nk], out[s1][nk]);
} else {
zero_tpp_rem(out[s1][nk]);
}

if (with_bias_up) {
copy_bias_tpp_rem(bias_up[nk], out_tmp[s1][nk]);
} else {
zero_tpp_rem(out_tmp[s1][nk]);
}
}
brgemm_tpp_rem(
in[s1][nc], wt_gate_V[nk][nc], out[s1][nk], count, false);
brgemm_tpp_rem(
in[s1][nc], wt_up_V[nk][nc], out_tmp[s1][nk], count, false);
brgemm_tpp.config();
if (!(nc + Ncb < Nc)) { // last nc iter
silu_fwd_tpp_rem(out[s1][nk], out[s1][nk]);
mul_tpp_rem(out[s1][nk], out_tmp[s1][nk], out[s1][nk]);
}
}
},
[&]() { brgemm_tpp.config(); },
[&]() { brgemm_tpp.release(); });
}
}

template <typename T>
inline void tpp_linear_add(
const at::Tensor t_in,
Expand Down
Expand Up @@ -536,10 +536,11 @@ def calib_func(prepared_model):
op_type_dict=op_type_dict,
smoothquant_args=smoothquant_args
)
pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True)
prepared_model.save_qconf_summary(args.output_dir + "/best_configure.json")

else:
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=args.alpha)
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=float(args.alpha))
user_model = ipex.llm.optimize(
user_model.eval(),
dtype=amp_dtype,
Expand Down
Expand Up @@ -379,22 +379,16 @@ def forward(self, x):
and not self.linear_m.tpp_fallback
):
x = x.to(self.dtype).contiguous()
x1 = torch.ops.torch_ipex.tpp_linear_silu(
return torch.ops.torch_ipex.tpp_fused_gate_up_proj(
x,
self.linear_s.weight.detach(),
self.linear_s.bias.detach()
if self.linear_s.bias is not None
else x.new_empty(0),
self.linear_s.out_features,
)
return torch.ops.torch_ipex.tpp_linear_mul(
x,
x1,
self.linear_m.weight.detach(),
self.linear_m.bias.detach()
if self.linear_m.bias is not None
else x.new_empty(0),
self.linear_m.out_features,
)
else: # fallback path
return nn.functional.silu(self.linear_s(x)) * self.linear_m(x)
50 changes: 50 additions & 0 deletions tests/cpu/test_tpp_linear.py
Expand Up @@ -46,6 +46,16 @@ def forward(self, x):
return torch.nn.functional.silu(self.mlp(x))


class Linear_Gate_Up(torch.nn.Module):
def __init__(self, in_feature, out_feature, bias_gate, bias_up):
super(Linear_Gate_Up, self).__init__()
self.gate_proj = torch.nn.Linear(in_feature, out_feature, bias=bias_gate)
self.up_proj = torch.nn.Linear(in_feature, out_feature, bias=bias_up)

def forward(self, x):
return torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x)


class Linear_relu(torch.nn.Module):
def __init__(self):
super(Linear_relu, self).__init__()
Expand Down Expand Up @@ -172,6 +182,46 @@ def test_tpp_linear_torchcompile(self):
self.assertTrue(out.dtype == dtype)
_disable_tpp()

def test_tpp_fused_gate_up_proj(self):
in_feature = 64
out_feature = 32

x = torch.randn(1, 4, in_feature)
x_tpp = copy.deepcopy(x)

with torch.no_grad():
for dtype, bias_gate, bias_up in itertools.product(
[torch.float, torch.bfloat16], [False, True], [False, True]
):
model = Linear_Gate_Up(
in_feature, out_feature, bias_gate, bias_up
).eval()
if dtype == torch.bfloat16:
x = x.to(torch.bfloat16)
x_tpp = x_tpp.to(torch.bfloat16)
model = model.to(torch.bfloat16)
ref_out = model(x)

_enable_tpp()
model = ipex.optimize(model, dtype=dtype)
out = torch.ops.torch_ipex.tpp_fused_gate_up_proj(
x_tpp,
model.gate_proj.weight,
model.gate_proj.bias,
model.up_proj.weight,
model.up_proj.bias,
)

out_linear_silu = torch.ops.torch_ipex.tpp_linear_silu(
x_tpp, model.gate_proj.weight, model.gate_proj.bias
)
out_tpp_ref = torch.ops.torch_ipex.tpp_linear_mul(
x_tpp, out_linear_silu, model.up_proj.weight, model.up_proj.bias
)
self.assertEqual(out, out_tpp_ref)
self.assertEqual(out, ref_out)
_disable_tpp()

def test_tpp_linear_gelu(self):
x1 = torch.rand(1, 4, 4096)
x2 = copy.deepcopy(x1)
Expand Down

0 comments on commit aeaeba4

Please sign in to comment.