Skip to content

Commit

Permalink
[fix] Device and format and implementation optimization (#1055)
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng authored Sep 8, 2023
1 parent ef9712a commit ed03dac
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 20 deletions.
7 changes: 5 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ libs
.classpath
bin/

#node
# node
node_modules/

#vscode
# vscode
.vscode

# dir
tests/integration/models/
engines/python/setup/djl_python/tests/resources*
2 changes: 1 addition & 1 deletion engines/python/setup/djl_python/rolling_batch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

from .scheduler_rolling_batch import SchedulerRollingBatch
from .scheduler_rolling_batch import SchedulerRollingBatch
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,9 @@ def postprocess_results(self, batch_size):
res = {"data": req.get_next_token(), "last": req.is_last_token()}
results.append(res)

for i in range(1, batch_size + 1):
if self.pending_requests[batch_size - i].is_last_token():
self.pending_requests.pop(batch_size - i)
self.pending_requests = [
req for req in self.pending_requests if not req.is_last_token()
]

if len(self.pending_requests) == 0:
self.req_id_counter = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,22 @@ def __init__(self, model_id_or_path, device, properties, **kwargs):
"""

super().__init__(device, **kwargs)
self._init_model_and_tokenizer(kwargs, model_id_or_path)
self._init_model_and_tokenizer(model_id_or_path,
device=device,
**kwargs)
self._init_scheduler(properties)

@stop_on_any_exception
def inference(self, input_data, parameters):
def inference(self, input_text, parameters):
"""
Performs prefill and decode operations for the batch.
:param input_data: List of input texts for each request in a batch
:param input_text: List of input texts for each request in a batch
:param parameters: List of kwargs for each request in a batch
:return: generated batch decoded tokens
"""
batch_size = len(input_data)
new_requests = self.get_new_requests(input_data, parameters,
batch_size = len(input_text)
new_requests = self.get_new_requests(input_text, parameters,
batch_size)

preprocessed_new_requests = self.preprocess_requests(new_requests)
Expand Down Expand Up @@ -86,18 +88,21 @@ def preprocess_requests(self, requests):

return new_requests

def _init_model_and_tokenizer(self, kwargs, model_id_or_path):
def _init_model_and_tokenizer(self,
model_id_or_path,
device=None,
**kwargs):
self.config = AutoConfig.from_pretrained(model_id_or_path, **kwargs)
architectures = self.config.architectures
if architectures and architectures[0].endswith(
"ForConditionalGeneration"):
raise ValueError('Seq2Seq model is not supported by scheduler')
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_id_or_path, **kwargs)

if self.device:
self.model.to(self.device)
model_id_or_path,
device_map="auto"
if device and device.type == "cuda" else "cpu",
**kwargs)

self.tokenizer = AutoTokenizer.from_pretrained(model_id_or_path,
padding_side="left")
Expand Down Expand Up @@ -155,7 +160,7 @@ def _prefill_and_decode(self, new_requests):
for request_id, generated_token, request in zip(
request_ids, generated_tokens, self.pending_requests):
is_last_token = (request_id in exit_req_ids)
request.set_next_token(generated_token,
request.set_next_token(f" {generated_token}",
self.output_formatter,
last_token=is_last_token)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from collections import defaultdict
import torch
from djl_python.rolling_batch import SchedulerRollingBatch
import torch.distributed as dist


def print_rank0(content):
rank = 0
if dist.is_initialized():
rank = dist.get_rank()
if rank == 0:
print(content)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

properties = {
"tensor_parallel_degree": 1,
"dtype": "fp16",
"max_rolling_batch_size": 8,
"model_loading_timeout": 7200,
"max_rolling_batch_prefill_tokens": 10000,
"paged_attention": "True"
}

model_id = "huggyllama/llama-7b"
"""
{"inputs":"write a program to add two numbers in python","parameters":{"max_new_tokens":1000, "do_sample":true, "temperature":0.7}}
"""
input_str = [
"write a program to add two numbers in python",
"write a program to add two numbers in python\n"
]

params = [{
"max_new_tokens": 50,
"do_sample": False,
"temperature": 0.7
}, {
"max_new_tokens": 50,
"do_sample": False,
"temperature": 0.7
}]

# ===================== scheduler ============================
print("=========== scheduler ==========")
rolling_batch = SchedulerRollingBatch(model_id, device, properties)
rolling_batch.output_formatter = None
print("reach here")

output_all = defaultdict(list)
result = rolling_batch.inference(input_str, params)
for i, res in enumerate(result):
output_all[i].append(res['data'])

for _ in range(50):
result = rolling_batch.inference([], [])
for i, res in enumerate(result):
output_all[i].append(res['data'])

for i, out in enumerate(output_all.values()):
print_rank0(input_str[i] + ''.join(out))
print_rank0('\n====')
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def get_rolling_batch_class_from_str(rolling_batch_type: str):
raise ValueError(f"Invalid rolling batch type: {rolling_batch_type}")


def init_rolling_batch(rolling_batch_type: str,
model_id: str,
def init_rolling_batch(rolling_batch_type: str, model_id: str,
properties: dict):
rolling_batch_type = rolling_batch_type.lower()
device = 0
Expand Down Expand Up @@ -140,7 +139,11 @@ def simulator(batcher,
if args.properties:
properties = json.loads(args.properties)
else:
properties = {"tensor_parallel_degree": 1, "trust_remote_code": True, "engine": "Python"}
properties = {
"tensor_parallel_degree": 1,
"trust_remote_code": True,
"engine": "Python"
}
if args.rollingbatch == "lmi-dist":
dist.init_process_group("nccl")
properties["engine"] = "MPI"
Expand Down

0 comments on commit ed03dac

Please sign in to comment.