Skip to content

Commit

Permalink
Enable BF16 Concat linear for GPTJ/LLAMA models (#2349)
Browse files Browse the repository at this point in the history
* Add tpp concat linear

* refinement

* Update linear_fusion.py

* refinement

* fix woq

* enable concat linear with concat rope

* code refine and pass ut

* flake8 format

* code refiement

* minor fix

---------

Co-authored-by: liangan1 <liangang.zhang@intel.com>
  • Loading branch information
jianan-gu and liangan1 committed Dec 18, 2023
1 parent 1364dc5 commit d6d5919
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 73 deletions.
4 changes: 2 additions & 2 deletions examples/cpu/inference/python/llm/README.md
Expand Up @@ -209,9 +209,9 @@ OMP_NUM_THREADS=<physical cores num> numactl -m <node N> -C <physical cores list
OMP_NUM_THREADS=56 numactl -m 0 -C 0-55 python run.py --benchmark -m meta-llama/Llama-2-7b-hf --ipex-smooth-quant --qconfig-summary-file <path to "llama-2-7b_qconfig.json"> --output-dir "saved_results" --int8-bf16-mixed
```

- We provide the downloading links of tuned static quantization qconfig summary files with good quality: ["meta-llama/Llama-2-7b-hf"](https://intel-extension-for-pytorch.s3.amazonaws.com/miscellaneous/llm/llama-2-7b_qconfig.json), ["meta-llama/Llama-2-7b-chat-hf"](https://intel-extension-for-pytorch.s3.amazonaws.com/miscellaneous/llm/llama-2-7b-chat_qconfig.json), ["meta-llama/Llama-2-13b-hf"](https://intel-extension-for-pytorch.s3.amazonaws.com/miscellaneous/llm/llama-2-13b_qconfig.json) and ["EleutherAI/gpt-j-6b"](https://intel-extension-for-pytorch.s3.amazonaws.com/miscellaneous/llm/gpt-j-6b_qconfig.json).
- We provide the downloading links of tuned static quantization qconfig summary files with good quality: ["meta-llama/Llama-2-7b-hf"](https://intel-extension-for-pytorch.s3.amazonaws.com/miscellaneous/llm/llama-2-7b_qconfig.json), ["meta-llama/Llama-2-7b-chat-hf"](https://intel-extension-for-pytorch.s3.amazonaws.com/miscellaneous/llm/llama-2-7b-chat_qconfig.json), ["meta-llama/Llama-2-13b-hf"](https://intel-extension-for-pytorch.s3.amazonaws.com/miscellaneous/llm/llama-2-13b_qconfig.json) and ["EleutherAI/gpt-j-6b"](https://intel-extension-for-pytorch.s3.amazonaws.com/miscellaneous/llm/gpt-j-6b_qconfig.json). We verify these qconfig files with the certain codebase of Transformers/PyTorch/IPEX to make sure their ease of use, while their compatibility may change due to the changes on codebases. If you meet any failure when you adopt these existing qconfig files, please refer to [Intel® Neural Compressor scripts](https://github.com/intel/neural-compressor/tree/master/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/ipex) to generate them based on your environment of codebases.

- For other models' qconfig recipes, you can just try to run your model_id and use IPEX default recipes by removing "--qconfig-summary-file <path to specific model qconfig>". If IPEX default recipes are not good enough for accuracy requirements, please refer to the [Intel® Neural Compressor tutorial](https://github.com/intel/neural-compressor/blob/master/docs/source/smooth_quant.md#validated-models) for more tuned recipes.
- For other models' qconfig recipes, you can just try to run your model_id and use IPEX default recipes by removing "--qconfig-summary-file <path to specific model qconfig>". If IPEX default recipes are not good enough for accuracy requirements, please refer to the [Intel® Neural Compressor tutorial](https://github.com/intel/neural-compressor/blob/master/docs/source/smooth_quant.md#validated-models) and [scripts](https://github.com/intel/neural-compressor/tree/master/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/ipex) for more tuned recipes.

#### Weight-only quantization:

Expand Down
Expand Up @@ -2,7 +2,6 @@
from torch import nn
import math
import warnings
import copy
from intel_extension_for_pytorch.nn.modules import IpexWoqLinear
from intel_extension_for_pytorch.quantization import (
get_weight_only_quant_qconfig_mapping,
Expand Down Expand Up @@ -221,12 +220,15 @@ def __init__(self, module, tpp=False, woq=False):
super().__init__(module.linear_0, tpp=tpp, woq=woq)
assert hasattr(module, "num_concat")
self.num_concat = module.num_concat
self.linear_list = []
for i in range(self.num_concat):
attr_name = f"linear_{i}"
assert hasattr(module, attr_name)
self.linear_list.append(getattr(module, attr_name))
self.concat_linear = None
self.linear_list = []
self.woq = woq
self.tpp = tpp
if woq:
for i in range(self.num_concat):
attr_name = f"linear_{i}"
assert hasattr(module, attr_name)
self.linear_list.append(getattr(module, attr_name))
if woq and all(
isinstance(linear, IpexWoqLinear) for linear in self.linear_list
):
Expand Down Expand Up @@ -303,6 +305,7 @@ def __init__(self, module, tpp=False, woq=False):
mod.bias = nn.Parameter(concat_bias) if use_bias else None
mod.qconfig = qconfig
mod._num_concats = len(weights_list)
self._num_concats = mod._num_concats
if w_dtype == torch.quint4x2:
self.concat_linear = IpexWoqLinear.from_float_and_int4_weight(
mod,
Expand All @@ -317,20 +320,32 @@ def __init__(self, module, tpp=False, woq=False):
mod, concat_scales, concat_zeros
)
else:
for i in range(self.num_concat):
attr_name = f"linear_{i}"
setattr(self, attr_name, copy.deepcopy(getattr(module, attr_name)))
self._num_concats = module._num_concats
if (
self.tpp
and hasattr(module, "concat_linear")
and module.concat_linear is not None
):
self.concat_linear = module.concat_linear
else:
for i in range(self.num_concat):
attr_name = f"linear_{i}"
setattr(self, attr_name, getattr(module, attr_name))

def forward(self, x):
if self.concat_linear is not None:
num_concats = self.concat_linear._num_concats
concat_output = self.concat_linear(x)
hidden_size = concat_output.shape[-1] // num_concats
concat_output = concat_output.view(num_concats, -1, hidden_size)
expected_shape = list(x.shape)[:-1] + [hidden_size]
return tuple(
[concat_output[i].view(expected_shape) for i in range(num_concats)]
)
if self.woq:
num_concats = self._num_concats
hidden_size = concat_output.shape[-1] // num_concats
concat_output = concat_output.view(num_concats, -1, hidden_size)
expected_shape = list(x.shape)[:-1] + [hidden_size]
return tuple(
[concat_output[i].view(expected_shape) for i in range(num_concats)]
)
else:
return concat_output

output_list = []
for i in range(self.num_concat):
assert hasattr(self, f"linear_{i}")
Expand Down
Expand Up @@ -26,21 +26,32 @@ def forward(
offset: int,
rotary_ndims: int,
seq_len: Optional[int] = None,
num_concats: Optional[int] = None,
):
position_ids = position_ids.contiguous()
sin_cos, _, _ = self.embed_positions(seq_len)
# ToDo: when the input is concat_qkv, the output will be (query, key, value)
x = x.contiguous()
query, _, _ = torch.ops.torch_ipex.rotary_position_embedding(
x,
sin_cos,
position_ids,
num_head,
head_dim,
offset,
rotary_ndims,
)
return query
if num_concats is None:
x, _, _ = torch.ops.torch_ipex.rotary_position_embedding(
x,
sin_cos,
position_ids,
num_head,
head_dim,
offset,
rotary_ndims,
)
return x
else:
query, key, value = torch.ops.torch_ipex.rotary_position_embedding(
x,
sin_cos,
position_ids,
num_head,
head_dim,
offset,
rotary_ndims,
)
return query, key, value


class _IPEXScaleDotProductCPU(nn.Module):
Expand Down
Expand Up @@ -2,6 +2,7 @@
from torch import nn
import math
import copy
from intel_extension_for_pytorch.nn.modules import IpexWoqLinear


class _IPEXlinearSiluRef(nn.Module):
Expand Down Expand Up @@ -85,6 +86,24 @@ def __init__(self, linear_list: list):
for i in range(self.num_concat):
attr_name = f"linear_{i}"
setattr(self, attr_name, copy.deepcopy(linear_list[i]))
self.concat_linear = None
self._num_concats = None
if all(not isinstance(linear, IpexWoqLinear) for linear in linear_list):
weights_list = []
bias_list = []
for i in range(self.num_concat):
weights_list.append(linear_list[i].weight)
if linear_list[i].bias is not None:
bias_list.append(linear_list[i].bias)
concat_weight = torch.concat(weights_list, 0)
use_bias = True if bias_list is None else False
concat_bias = torch.concat(bias_list, 0) if use_bias else None
self.concat_linear = nn.Linear(
concat_weight.shape[1], concat_weight.shape[0], use_bias
)
self.concat_linear.weight = nn.Parameter(concat_weight)
self.concat_linear.bias = nn.Parameter(concat_bias) if use_bias else None
self._num_concats = len(weights_list)

def forward(self, x):
output_list = []
Expand Down
Expand Up @@ -124,7 +124,7 @@ def apply_rotary_pos_emb_baichuan(self, x, cos, sin, position_ids):
x_embed = (x.float() * cos) + (self.rotate_half(x.float()) * sin)
return x_embed.to(x.dtype)

def forward(
def apply_ref_rope(
self,
x: torch.Tensor,
position_ids: torch.Tensor,
Expand Down Expand Up @@ -203,6 +203,46 @@ def forward(
AssertionError(False, "Do not support the optimization of your model yet")
return x

def forward(
self,
concat_x: torch.Tensor,
position_ids: torch.Tensor,
num_head: int,
head_dim: int,
offset: int,
rotary_ndims: int,
seq_len: Optional[int] = None,
num_concats: Optional[int] = None,
):
if num_concats is None:
return self.apply_ref_rope(
concat_x,
position_ids,
num_head,
head_dim,
offset,
rotary_ndims,
seq_len,
)
else:
hidden_size = concat_x.shape[-1] // num_concats
query = concat_x[..., :hidden_size]
key = concat_x[..., hidden_size : 2 * hidden_size]
value = concat_x[..., 2 * hidden_size :]
query = self.apply_ref_rope(
query,
position_ids,
num_head,
head_dim,
offset,
rotary_ndims,
seq_len,
)
key = self.apply_ref_rope(
key, position_ids, num_head, head_dim, offset, rotary_ndims, seq_len
)
return query, key, value


class _IPEXScaleDotProductRef(nn.Module):
def __init__(self, module, config):
Expand Down
Expand Up @@ -25,34 +25,53 @@ def _GPTJAttention_forward(
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
concat_qkv = None
if hasattr(self, "concat_qkv"):
query, key, value = self.concat_qkv(hidden_states)
concat_qkv = self.concat_qkv(hidden_states)
else:
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)

query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
if concat_qkv is not None and type(concat_qkv) is not tuple:
query, key, value = self._IPEXROPE(
concat_qkv,
position_ids.contiguous(),
self.num_attention_heads,
self.head_dim,
1, # neighbor elements
64,
None,
self.concat_qkv._num_concats,
)
else:
if concat_qkv is not None:
query, key, value = concat_qkv
query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)

key = self._IPEXROPE(
key,
position_ids.contiguous(),
self.num_attention_heads,
self.head_dim,
1, # neighbor elements
64,
)
query = self._IPEXROPE(
query,
position_ids.contiguous(),
self.num_attention_heads,
self.head_dim,
1,
64,
)
if use_cache:
value = self._split_heads(
value, self.num_attention_heads, self.head_dim, True
)

key = self._IPEXROPE(
key,
position_ids.contiguous(),
self.num_attention_heads,
self.head_dim,
1, # neighbor elements
64,
)
query = self._IPEXROPE(
query,
position_ids.contiguous(),
self.num_attention_heads,
self.head_dim,
1,
64,
)
if use_cache:
value = self._split_heads(value, self.num_attention_heads, self.head_dim, True)
(
attn_output,
attn_weights,
Expand Down Expand Up @@ -112,36 +131,53 @@ def _LlamaAttention_forward(
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
concat_qkv = None
if hasattr(self, "concat_qkv"):
query, key, value = self.concat_qkv(hidden_states)
concat_qkv = self.concat_qkv(hidden_states)
else:
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
query = query.view(bsz, q_len, self.num_heads, self.head_dim)
key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim)

kv_seq_len = (
q_len + past_key_value[0].size(-2) if past_key_value is not None else q_len
)
key = self._IPEXROPE(
key,
position_ids,
self.num_key_value_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
query = self._IPEXROPE(
query,
position_ids,
self.num_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)

if concat_qkv is not None and type(concat_qkv) is not tuple:
query, key, value = self._IPEXROPE(
concat_qkv,
position_ids,
self.num_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
self.concat_qkv._num_concats,
)
else:
if concat_qkv is not None:
query, key, value = concat_qkv
query = query.view(bsz, q_len, self.num_heads, self.head_dim)
key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
key = self._IPEXROPE(
key,
position_ids,
self.num_key_value_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)
query = self._IPEXROPE(
query,
position_ids,
self.num_heads,
self.head_dim,
self.head_dim // 2,
self.head_dim,
kv_seq_len,
)

if use_cache:
(attn_output, attn_weights, past_key_value) = self._IPEXScaleDotProduct(
Expand Down

0 comments on commit d6d5919

Please sign in to comment.