diff --git a/examples/cpu/inference/python/llm/README.md b/examples/cpu/inference/python/llm/README.md index 0d8055f0c..938d44d0f 100644 --- a/examples/cpu/inference/python/llm/README.md +++ b/examples/cpu/inference/python/llm/README.md @@ -209,9 +209,9 @@ OMP_NUM_THREADS= numactl -m -C --output-dir "saved_results" --int8-bf16-mixed ``` -- We provide the downloading links of tuned static quantization qconfig summary files with good quality: ["meta-llama/Llama-2-7b-hf"](https://intel-extension-for-pytorch.s3.amazonaws.com/miscellaneous/llm/llama-2-7b_qconfig.json), ["meta-llama/Llama-2-7b-chat-hf"](https://intel-extension-for-pytorch.s3.amazonaws.com/miscellaneous/llm/llama-2-7b-chat_qconfig.json), ["meta-llama/Llama-2-13b-hf"](https://intel-extension-for-pytorch.s3.amazonaws.com/miscellaneous/llm/llama-2-13b_qconfig.json) and ["EleutherAI/gpt-j-6b"](https://intel-extension-for-pytorch.s3.amazonaws.com/miscellaneous/llm/gpt-j-6b_qconfig.json). +- We provide the downloading links of tuned static quantization qconfig summary files with good quality: ["meta-llama/Llama-2-7b-hf"](https://intel-extension-for-pytorch.s3.amazonaws.com/miscellaneous/llm/llama-2-7b_qconfig.json), ["meta-llama/Llama-2-7b-chat-hf"](https://intel-extension-for-pytorch.s3.amazonaws.com/miscellaneous/llm/llama-2-7b-chat_qconfig.json), ["meta-llama/Llama-2-13b-hf"](https://intel-extension-for-pytorch.s3.amazonaws.com/miscellaneous/llm/llama-2-13b_qconfig.json) and ["EleutherAI/gpt-j-6b"](https://intel-extension-for-pytorch.s3.amazonaws.com/miscellaneous/llm/gpt-j-6b_qconfig.json). We verify these qconfig files with the certain codebase of Transformers/PyTorch/IPEX to make sure their ease of use, while their compatibility may change due to the changes on codebases. If you meet any failure when you adopt these existing qconfig files, please refer to [Intel® Neural Compressor scripts](https://github.com/intel/neural-compressor/tree/master/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/ipex) to generate them based on your environment of codebases. -- For other models' qconfig recipes, you can just try to run your model_id and use IPEX default recipes by removing "--qconfig-summary-file ". If IPEX default recipes are not good enough for accuracy requirements, please refer to the [Intel® Neural Compressor tutorial](https://github.com/intel/neural-compressor/blob/master/docs/source/smooth_quant.md#validated-models) for more tuned recipes. +- For other models' qconfig recipes, you can just try to run your model_id and use IPEX default recipes by removing "--qconfig-summary-file ". If IPEX default recipes are not good enough for accuracy requirements, please refer to the [Intel® Neural Compressor tutorial](https://github.com/intel/neural-compressor/blob/master/docs/source/smooth_quant.md#validated-models) and [scripts](https://github.com/intel/neural-compressor/tree/master/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/ipex) for more tuned recipes. #### Weight-only quantization: diff --git a/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py b/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py index 9e8c24202..c015ed178 100644 --- a/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py +++ b/intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py @@ -2,7 +2,6 @@ from torch import nn import math import warnings -import copy from intel_extension_for_pytorch.nn.modules import IpexWoqLinear from intel_extension_for_pytorch.quantization import ( get_weight_only_quant_qconfig_mapping, @@ -221,12 +220,15 @@ def __init__(self, module, tpp=False, woq=False): super().__init__(module.linear_0, tpp=tpp, woq=woq) assert hasattr(module, "num_concat") self.num_concat = module.num_concat - self.linear_list = [] - for i in range(self.num_concat): - attr_name = f"linear_{i}" - assert hasattr(module, attr_name) - self.linear_list.append(getattr(module, attr_name)) self.concat_linear = None + self.linear_list = [] + self.woq = woq + self.tpp = tpp + if woq: + for i in range(self.num_concat): + attr_name = f"linear_{i}" + assert hasattr(module, attr_name) + self.linear_list.append(getattr(module, attr_name)) if woq and all( isinstance(linear, IpexWoqLinear) for linear in self.linear_list ): @@ -303,6 +305,7 @@ def __init__(self, module, tpp=False, woq=False): mod.bias = nn.Parameter(concat_bias) if use_bias else None mod.qconfig = qconfig mod._num_concats = len(weights_list) + self._num_concats = mod._num_concats if w_dtype == torch.quint4x2: self.concat_linear = IpexWoqLinear.from_float_and_int4_weight( mod, @@ -317,20 +320,32 @@ def __init__(self, module, tpp=False, woq=False): mod, concat_scales, concat_zeros ) else: - for i in range(self.num_concat): - attr_name = f"linear_{i}" - setattr(self, attr_name, copy.deepcopy(getattr(module, attr_name))) + self._num_concats = module._num_concats + if ( + self.tpp + and hasattr(module, "concat_linear") + and module.concat_linear is not None + ): + self.concat_linear = module.concat_linear + else: + for i in range(self.num_concat): + attr_name = f"linear_{i}" + setattr(self, attr_name, getattr(module, attr_name)) def forward(self, x): if self.concat_linear is not None: - num_concats = self.concat_linear._num_concats concat_output = self.concat_linear(x) - hidden_size = concat_output.shape[-1] // num_concats - concat_output = concat_output.view(num_concats, -1, hidden_size) - expected_shape = list(x.shape)[:-1] + [hidden_size] - return tuple( - [concat_output[i].view(expected_shape) for i in range(num_concats)] - ) + if self.woq: + num_concats = self._num_concats + hidden_size = concat_output.shape[-1] // num_concats + concat_output = concat_output.view(num_concats, -1, hidden_size) + expected_shape = list(x.shape)[:-1] + [hidden_size] + return tuple( + [concat_output[i].view(expected_shape) for i in range(num_concats)] + ) + else: + return concat_output + output_list = [] for i in range(self.num_concat): assert hasattr(self, f"linear_{i}") diff --git a/intel_extension_for_pytorch/transformers/models/cpu/fusions/mha_fusion.py b/intel_extension_for_pytorch/transformers/models/cpu/fusions/mha_fusion.py index d1fc77490..7272c1b6e 100644 --- a/intel_extension_for_pytorch/transformers/models/cpu/fusions/mha_fusion.py +++ b/intel_extension_for_pytorch/transformers/models/cpu/fusions/mha_fusion.py @@ -26,21 +26,32 @@ def forward( offset: int, rotary_ndims: int, seq_len: Optional[int] = None, + num_concats: Optional[int] = None, ): position_ids = position_ids.contiguous() sin_cos, _, _ = self.embed_positions(seq_len) - # ToDo: when the input is concat_qkv, the output will be (query, key, value) - x = x.contiguous() - query, _, _ = torch.ops.torch_ipex.rotary_position_embedding( - x, - sin_cos, - position_ids, - num_head, - head_dim, - offset, - rotary_ndims, - ) - return query + if num_concats is None: + x, _, _ = torch.ops.torch_ipex.rotary_position_embedding( + x, + sin_cos, + position_ids, + num_head, + head_dim, + offset, + rotary_ndims, + ) + return x + else: + query, key, value = torch.ops.torch_ipex.rotary_position_embedding( + x, + sin_cos, + position_ids, + num_head, + head_dim, + offset, + rotary_ndims, + ) + return query, key, value class _IPEXScaleDotProductCPU(nn.Module): diff --git a/intel_extension_for_pytorch/transformers/models/reference/fusions/linear_fusion.py b/intel_extension_for_pytorch/transformers/models/reference/fusions/linear_fusion.py index cc198628a..f39e3e2f5 100644 --- a/intel_extension_for_pytorch/transformers/models/reference/fusions/linear_fusion.py +++ b/intel_extension_for_pytorch/transformers/models/reference/fusions/linear_fusion.py @@ -2,6 +2,7 @@ from torch import nn import math import copy +from intel_extension_for_pytorch.nn.modules import IpexWoqLinear class _IPEXlinearSiluRef(nn.Module): @@ -85,6 +86,24 @@ def __init__(self, linear_list: list): for i in range(self.num_concat): attr_name = f"linear_{i}" setattr(self, attr_name, copy.deepcopy(linear_list[i])) + self.concat_linear = None + self._num_concats = None + if all(not isinstance(linear, IpexWoqLinear) for linear in linear_list): + weights_list = [] + bias_list = [] + for i in range(self.num_concat): + weights_list.append(linear_list[i].weight) + if linear_list[i].bias is not None: + bias_list.append(linear_list[i].bias) + concat_weight = torch.concat(weights_list, 0) + use_bias = True if bias_list is None else False + concat_bias = torch.concat(bias_list, 0) if use_bias else None + self.concat_linear = nn.Linear( + concat_weight.shape[1], concat_weight.shape[0], use_bias + ) + self.concat_linear.weight = nn.Parameter(concat_weight) + self.concat_linear.bias = nn.Parameter(concat_bias) if use_bias else None + self._num_concats = len(weights_list) def forward(self, x): output_list = [] diff --git a/intel_extension_for_pytorch/transformers/models/reference/fusions/mha_fusion.py b/intel_extension_for_pytorch/transformers/models/reference/fusions/mha_fusion.py index 5a6d32972..7a45b3d42 100644 --- a/intel_extension_for_pytorch/transformers/models/reference/fusions/mha_fusion.py +++ b/intel_extension_for_pytorch/transformers/models/reference/fusions/mha_fusion.py @@ -124,7 +124,7 @@ def apply_rotary_pos_emb_baichuan(self, x, cos, sin, position_ids): x_embed = (x.float() * cos) + (self.rotate_half(x.float()) * sin) return x_embed.to(x.dtype) - def forward( + def apply_ref_rope( self, x: torch.Tensor, position_ids: torch.Tensor, @@ -203,6 +203,46 @@ def forward( AssertionError(False, "Do not support the optimization of your model yet") return x + def forward( + self, + concat_x: torch.Tensor, + position_ids: torch.Tensor, + num_head: int, + head_dim: int, + offset: int, + rotary_ndims: int, + seq_len: Optional[int] = None, + num_concats: Optional[int] = None, + ): + if num_concats is None: + return self.apply_ref_rope( + concat_x, + position_ids, + num_head, + head_dim, + offset, + rotary_ndims, + seq_len, + ) + else: + hidden_size = concat_x.shape[-1] // num_concats + query = concat_x[..., :hidden_size] + key = concat_x[..., hidden_size : 2 * hidden_size] + value = concat_x[..., 2 * hidden_size :] + query = self.apply_ref_rope( + query, + position_ids, + num_head, + head_dim, + offset, + rotary_ndims, + seq_len, + ) + key = self.apply_ref_rope( + key, position_ids, num_head, head_dim, offset, rotary_ndims, seq_len + ) + return query, key, value + class _IPEXScaleDotProductRef(nn.Module): def __init__(self, module, config): diff --git a/intel_extension_for_pytorch/transformers/models/reference/modules/attentions.py b/intel_extension_for_pytorch/transformers/models/reference/modules/attentions.py index 35957665e..1005128ec 100644 --- a/intel_extension_for_pytorch/transformers/models/reference/modules/attentions.py +++ b/intel_extension_for_pytorch/transformers/models/reference/modules/attentions.py @@ -25,34 +25,53 @@ def _GPTJAttention_forward( Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], ]: + concat_qkv = None if hasattr(self, "concat_qkv"): - query, key, value = self.concat_qkv(hidden_states) + concat_qkv = self.concat_qkv(hidden_states) else: query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) - query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) - key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) + if concat_qkv is not None and type(concat_qkv) is not tuple: + query, key, value = self._IPEXROPE( + concat_qkv, + position_ids.contiguous(), + self.num_attention_heads, + self.head_dim, + 1, # neighbor elements + 64, + None, + self.concat_qkv._num_concats, + ) + else: + if concat_qkv is not None: + query, key, value = concat_qkv + query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) + key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) + + key = self._IPEXROPE( + key, + position_ids.contiguous(), + self.num_attention_heads, + self.head_dim, + 1, # neighbor elements + 64, + ) + query = self._IPEXROPE( + query, + position_ids.contiguous(), + self.num_attention_heads, + self.head_dim, + 1, + 64, + ) + if use_cache: + value = self._split_heads( + value, self.num_attention_heads, self.head_dim, True + ) - key = self._IPEXROPE( - key, - position_ids.contiguous(), - self.num_attention_heads, - self.head_dim, - 1, # neighbor elements - 64, - ) - query = self._IPEXROPE( - query, - position_ids.contiguous(), - self.num_attention_heads, - self.head_dim, - 1, - 64, - ) if use_cache: - value = self._split_heads(value, self.num_attention_heads, self.head_dim, True) ( attn_output, attn_weights, @@ -112,36 +131,53 @@ def _LlamaAttention_forward( use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() + concat_qkv = None if hasattr(self, "concat_qkv"): - query, key, value = self.concat_qkv(hidden_states) + concat_qkv = self.concat_qkv(hidden_states) else: query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) - query = query.view(bsz, q_len, self.num_heads, self.head_dim) - key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + kv_seq_len = ( q_len + past_key_value[0].size(-2) if past_key_value is not None else q_len ) - key = self._IPEXROPE( - key, - position_ids, - self.num_key_value_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) - query = self._IPEXROPE( - query, - position_ids, - self.num_heads, - self.head_dim, - self.head_dim // 2, - self.head_dim, - kv_seq_len, - ) + + if concat_qkv is not None and type(concat_qkv) is not tuple: + query, key, value = self._IPEXROPE( + concat_qkv, + position_ids, + self.num_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + self.concat_qkv._num_concats, + ) + else: + if concat_qkv is not None: + query, key, value = concat_qkv + query = query.view(bsz, q_len, self.num_heads, self.head_dim) + key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + key = self._IPEXROPE( + key, + position_ids, + self.num_key_value_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) + query = self._IPEXROPE( + query, + position_ids, + self.num_heads, + self.head_dim, + self.head_dim // 2, + self.head_dim, + kv_seq_len, + ) if use_cache: (attn_output, attn_weights, past_key_value) = self._IPEXScaleDotProduct(