Skip to content

Commit

Permalink
improve code-genereation benchmark (#89)
Browse files Browse the repository at this point in the history
Signed-off-by: changwangss <chang1.wang@intel.com>
  • Loading branch information
changwangss committed Aug 2, 2023
1 parent ff2074d commit 0b3450b
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ python run_generation.py \
--int8 \
--ipex \
--benchmark \
--prompt_size 32 \
--batch_size 1
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ build:
test:
docker run -v $(CURDIR):$(CURDIR) \
-it $(IMAGE_NAME) python3 run_generation.py --model $(CURDIR)/starcoder-3b --quantize --sq --alpha 0.7 --ipex \
--calib_iters 5 --calib_batch_size 1 --dataset "mbpp" --calib_split "test" --output_dir "$(CURDIR)/saved_results" \
--calib_iters 10 --calib_batch_size 1 --dataset "mbpp" --calib_split "test" --output_dir "$(CURDIR)/saved_results" \
--int8 --accuracy --tasks multiple-lua --batch_size 20 --n_samples 20 --allow_code_execution \
--do_sample --temperature 0.2 --limit 2
--do_sample --temperature 0.2 --limit 10

@echo "If pass@1 is 0.25 then your configuration for standard benchmarks is correct"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
parser.add_argument("--benchmark", action="store_true")
parser.add_argument("--iters", default=100, type=int, help="num iter")
parser.add_argument("--num_warmup", default=10, type=int, help="num warmup")
parser.add_argument("--prompt_size", default=32, type=int, help="generate dummy input_ids size")
parser.add_argument("--accuracy", action="store_true")
parser.add_argument("--batch_size", default=1, type=int,
help="batch size num.")
Expand Down Expand Up @@ -294,31 +295,62 @@ def calib_func(prepared_model):


if args.benchmark:
prompt = "def print_hello_world():"
from numpy import mean
print("---- Prompt size:", args.prompt_size)
normalized_config = NormalizedConfigManager.get_normalized_config_class(
user_model.config.model_type
)(user_model.config)

input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1)
print("---- Prompt size:", input_size)
num_layers = normalized_config.num_layers
num_attention_heads = normalized_config.num_attention_heads
hidden_size = normalized_config.hidden_size
d_k = hidden_size // num_attention_heads

# start
total_time = 0.0
num_iter = args.iters
num_warmup = args.num_warmup
prompt = [prompt] * args.batch_size

with torch.inference_mode(), torch.no_grad():
for i in range(num_iter):
tic = time.time()
input_ids = tokenizer.encode(prompt, return_tensors="pt")
output = user_model.generate(
input_ids, min_new_tokens=args.max_new_tokens, max_new_tokens=args.max_new_tokens
)
gen_code = tokenizer.batch_decode(output[0])
toc = time.time()
print(gen_code)
if i >= num_warmup:
total_time += toc - tic
total_time = 0.0
first_token_time = []
second_token_time = []
for i in range(num_iter):
print("Interation index:", i)
input_ids = torch.randint(1, tokenizer.vocab_size, size = (args.batch_size, args.prompt_size))
with torch.inference_mode(), torch.no_grad():
for j in range(args.max_new_tokens):
tic = time.time()
if j == 0:
new_shape = [input_ids.shape[0], 0, d_k*2]
dummy_tensor = torch.empty(size=new_shape)
past_key_values = tuple([dummy_tensor] * num_layers)
input_bs, input_len = input_ids.shape
attention_mask = torch.ones(input_bs, input_len)

inp = {"input_ids": input_ids,
"past_key_values": past_key_values,
"attention_mask": attention_mask}

out = user_model(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask)
gen_id = torch.argmax(out[0][:, -1:, :], axis = -1)
gen_text = tokenizer.batch_decode(gen_id, skip_special_tokens=True)
toc = time.time()
if i >= num_warmup:
total_time += toc - tic
if i >= num_warmup and j == 0:
first_token_latency = toc - tic
print("The first token inference latency: %.5f sec." % first_token_latency)
first_token_time.append(first_token_latency)
if i >= num_warmup and j == 1:
second_token_latency = toc - tic
print("The second token inference latency: %.5f sec." % second_token_latency)
second_token_time.append(second_token_latency)

input_ids = gen_id
past_key_values = out[1]
attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1] + 1)


print("\n", "-" * 10, "Summary:", "-" * 10)
print("The first token inference average latency: %.3f sec." % mean(first_token_time))
print("The second token inference average latency: %.3f sec." % mean(second_token_time))
latency = total_time / (num_iter - num_warmup)
print("Inference latency: %.3f sec." % latency)
throughput = (num_iter - num_warmup) / total_time
Expand Down

0 comments on commit 0b3450b

Please sign in to comment.