Skip to content

Commit

Permalink
[inference] Fix running time of test_continuous_batching (#5750)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanheng-zhao committed May 24, 2024
1 parent 5f8c0a0 commit b96c639
Showing 1 changed file with 26 additions and 58 deletions.
84 changes: 26 additions & 58 deletions tests/test_infer/test_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM

import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn

Expand All @@ -28,69 +28,37 @@ def generate_inputs(num_sequences, min_length, max_length):
return sequences


@parameterize(
"test_config",
[
{
"max_batch_size": 8,
"max_output_len": 512,
"max_input_len": 64,
"do_sample": False,
}
],
)
def check_inference_engine(test_config, use_engine=False, prompt_template=None):
@parameterize("n_multiple", [10])
@parameterize("max_batch_size", [8])
@parameterize("max_input_len", [128])
@parameterize("max_output_len", [128])
def check_inference_engine(n_multiple, max_batch_size, max_input_len, max_output_len):
setup_seed(20)
max_batch_size = test_config["max_batch_size"]
max_input_len = test_config["max_input_len"]
max_output_len = test_config["max_output_len"]
do_sample = test_config["do_sample"]
top_p = 0.5
top_k = 50
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()

tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = LlamaForCausalLM(LlamaConfig(num_hidden_layers=2)).cuda()
model = model.eval()

inputs_token_ids = generate_inputs(10 * max_batch_size, min_length=10, max_length=max_input_len)

if use_engine:
inference_config = InferenceConfig(
max_batch_size=max_batch_size, max_output_len=max_output_len, prompt_template=prompt_template
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == max_output_len
inference_engine.add_request(prompts_token_ids=inputs_token_ids)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
# apply prompt template
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=max_output_len,
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
assert len(outputs) == 10 * max_batch_size


@parameterize("prompt_template", [None, "llama"])
def check_continuous_batching(prompt_template):
check_inference_engine(use_engine=True, prompt_template=prompt_template)
inputs_token_ids = generate_inputs(
n_multiple * max_batch_size, min_length=max_input_len // 2, max_length=max_input_len
)
inference_config = InferenceConfig(
max_batch_size=max_batch_size, max_input_len=max_input_len, max_output_len=max_output_len
)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == max_output_len

inference_engine.add_request(prompts_token_ids=inputs_token_ids)
assert inference_engine.request_handler._has_waiting()

outputs = inference_engine.generate()
assert not inference_engine.request_handler._has_waiting()
assert len(outputs) == n_multiple * max_batch_size


def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_continuous_batching()
check_inference_engine()


@pytest.mark.dist
Expand Down

0 comments on commit b96c639

Please sign in to comment.