Skip to content

Commit

Permalink
Support starcoder MHA fusion (#448)
Browse files Browse the repository at this point in the history
  • Loading branch information
changwangss committed Oct 12, 2023
1 parent 900ebf4 commit 841b29a
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ cd intel-extension-for-transformers
pip install -r requirements.txt
python setup.py install
```

Here is how to install intel-extension-for-pytorch from source.
```shell
# gcc version >= 11
Expand Down Expand Up @@ -65,9 +66,15 @@ python run_generation.py \
```

## 2. Performance

```bash
export KMP_BLOCKTIME=1
export KMP_SETTINGS=1
export KMP_AFFINITY=granularity=fine,compact,1,0
export LD_PRELOAD=${CONDA_PREFIX}/lib/libiomp5.so
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so
# --int8 is used for int8 model
python run_generation.py \
OMP_NUM_THREADS=<physical cores num> numactl -m <node N> -C <cpu list> python run_generation.py \
--model bigcode/starcoderbase \
--output_dir "./saved_results" \
--int8 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ def calib_func(prepared_model):


if args.benchmark:
from numpy import mean
print("---- Prompt size:", args.prompt_size)

normalized_config = NormalizedConfigManager.get_normalized_config_class(
user_model.config.model_type
)(user_model.config)
Expand All @@ -319,53 +319,47 @@ def calib_func(prepared_model):

num_iter = args.iters
num_warmup = args.num_warmup
total_time = 0.0
first_token_time = []
second_token_time = []
for i in range(num_iter):
print("Interation index:", i)
input_ids = torch.randint(1, tokenizer.vocab_size, size = (args.batch_size, args.prompt_size))

total_latency = 0
for j in range(args.max_new_tokens):
total_time = 0.0
with torch.inference_mode(), torch.no_grad():
for j in range(args.max_new_tokens):

for i in range(num_iter):
tic = time.time()
if j == 0:
if j==0:
#input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = torch.randint(1, tokenizer.vocab_size, size = (args.batch_size , args.prompt_size))
attention_mask = torch.ones(input_ids.shape)
new_shape = [input_ids.shape[0], 0, d_k*2]
dummy_tensor = torch.empty(size=new_shape)
dummy_tensor = torch.ones(size=new_shape)
past_key_values = tuple([dummy_tensor] * num_layers)
input_bs, input_len = input_ids.shape
attention_mask = torch.ones(input_bs, input_len)

inp = {"input_ids": input_ids,
"past_key_values": past_key_values,
"attention_mask": attention_mask}

out = user_model(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask)
out = user_model(**inp)
gen_id = torch.argmax(out[0][:, -1:, :], axis = -1)
gen_text = tokenizer.batch_decode(gen_id, skip_special_tokens=True)
toc = time.time()
#print(gen_text, flush=True)
if i >= num_warmup:
total_time += toc - tic
if i >= num_warmup and j == 0:
first_token_latency = toc - tic
print("The first token inference latency: %.5f sec." % first_token_latency)
first_token_time.append(first_token_latency)
if i >= num_warmup and j == 1:
second_token_latency = toc - tic
print("The second token inference latency: %.5f sec." % second_token_latency)
second_token_time.append(second_token_latency)

input_ids = gen_id
past_key_values = out[1]
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1] + 1)


print("\n", "-" * 10, "Summary:", "-" * 10)
print("The first token inference average latency: %.3f sec." % mean(first_token_time))
print("The second token inference average latency: %.3f sec." % mean(second_token_time))
latency = total_time / (num_iter - num_warmup)
print("Inference latency: %.3f sec." % latency)
throughput = (num_iter - num_warmup) / total_time
print("Throughput: {} samples/sec".format(throughput))

print("\n", "-" * 10, "Summary:", "-" * 10)
print("Generated token index:", j+1)
latency = total_time / (num_iter - num_warmup)
print("Inference latency: %.5f sec." % latency)
throughput = (num_iter - num_warmup) / total_time
print("Throughput: {} samples/sec".format(throughput))

total_latency += latency
average_latency = total_latency / args.max_new_tokens
print("Average inference latency: %.5f sec." % latency)
average_throughput = args.max_new_tokens / total_latency
print("Average throughput: {} samples/sec".format(throughput))


if args.accuracy:
from intel_extension_for_transformers.llm.evaluation.lm_code_eval import evaluate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
# limitations under the License.

# coding=utf-8
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Copyright 2023 The Bigcode team and HuggingFace Inc. team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -158,14 +157,17 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
query_shape = query.shape
batch_size = query_shape[0]
key_length = key.size(-1)
value_shape = value.shape
if self.multi_query:
# (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
# -> (batch_size, query_length, num_heads, key_length)
query_length = query_shape[1]
attn_shape = (batch_size, query_length, self.num_heads, key_length)
attn_view = (batch_size, query_length * self.num_heads, key_length)
# No copy needed for MQA 2, or when layer_past is provided.
query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
query = query.reshape(batch_size, query_length, self.num_heads, self.head_dim)
key = key.reshape(batch_size, 1, self.head_dim, key_length)
value = value.reshape(batch_size, 1, value_shape[-2], value_shape[-1])
else:
# (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length)
# -> (batch_size, num_heads, query_length, key_length)
Expand All @@ -177,7 +179,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# No copy when layer_past is provided.
key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length)

attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype)
attn_weights = torch.empty(attn_shape, device=query.device, dtype=query.dtype)
if query.device.type == "cpu":
# This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588.
# The bug was fixed in https://github.com/pytorch/pytorch/pull/96086,
Expand All @@ -186,8 +188,9 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
beta = 1
else:
beta = 0
attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=scale_factor).view(attn_shape)


attn_weights = scale_factor * torch.matmul(query, key) # + beta * attn_weights (not needed, it is 0)

if upcast:
# Use a fused kernel to prevent a large overhead from casting and scaling.
# Sub-optimal when the key length is not a multiple of 8.
Expand All @@ -214,7 +217,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = attn_weights * head_mask

if self.multi_query:
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
attn_output = torch.matmul(attn_weights, value).view(query_shape)
else:
attn_output = torch.matmul(attn_weights, value)

Expand Down Expand Up @@ -402,8 +405,8 @@ def _init_weights(self, module):
if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)):
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth.
# > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the
# > of residual layers.
# > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of
# > residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
Expand All @@ -425,8 +428,8 @@ def _init_weights(self, module):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

# Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with
# GPT2->GPTBigCode
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing
# with GPT2->GPTBigCode
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, GPTBigCodeModel):
module.gradient_checkpointing = value
Expand Down Expand Up @@ -497,15 +500,15 @@ def _set_gradient_checkpointing(self, module, value=False):
- 0 indicates the head is **masked**.
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
`past_key_values`).
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
Expand Down Expand Up @@ -737,8 +740,8 @@ def custom_forward(*inputs):

@add_start_docstrings(
"""
The GPT_BIGCODE Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings).
The GPT_BIGCODE Model transformer with a language modeling head on top (linear layer with weights tied to the input
embeddings).
""",
GPT_BIGCODE_START_DOCSTRING,
)
Expand Down Expand Up @@ -888,10 +891,10 @@ def _reorder_cache(
models (e.g. GPT-1) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row.
If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess
the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value
in each row of the batch).
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
GPT_BIGCODE_START_DOCSTRING,
)
Expand Down

0 comments on commit 841b29a

Please sign in to comment.