From 8898b761e1bf21cf61b4fdd6dde4f0a37a20d060 Mon Sep 17 00:00:00 2001 From: Justin Kim Date: Tue, 9 Jul 2024 10:30:17 -0700 Subject: [PATCH] Triton deployment improvements for in-framework models (#9600) * add NemoQueryLLMPyTorch class for triton query of in-framework models * nemo_export.py changes to better support in-framework models * separate out in-framework version of triton deploy script * add generate() function to MegatronLLMDeployable to allow for direct use in export tests * use NemoQueryLLMPyTorch in deploy tests * add warning message for when MegatronLLMDeployable overrides transformer_engine * remove enable_streaming argument from deploy_inframework_triton.py since MegatronLLMDeployable does not support streaming add query_inframework.py since original query.py does not work with in-framework deployments * Apply isort and black reformatting Signed-off-by: jukim-nv * skip trtllm support check if in_framework testing * remove unused imports * run_existing_checkpoints was passing wrong prompts argument for in-framework mode * fix unused import in query_inframework.py --------- Signed-off-by: jukim-nv Co-authored-by: jukim-nv Co-authored-by: Onur Yilmaz <35306097+oyilmaz-nvidia@users.noreply.github.com> --- nemo/deploy/nlp/__init__.py | 2 +- nemo/deploy/nlp/megatronllm_deployable.py | 20 ++- nemo/deploy/nlp/query_llm.py | 100 +++++++++++-- .../deploy/nlp/deploy_inframework_triton.py | 103 ++++++++++++++ scripts/deploy/nlp/query_inframework.py | 83 +++++++++++ tests/deploy/nemo_deploy.py | 4 +- tests/export/nemo_export.py | 134 +++++++++++------- 7 files changed, 376 insertions(+), 70 deletions(-) create mode 100755 scripts/deploy/nlp/deploy_inframework_triton.py create mode 100644 scripts/deploy/nlp/query_inframework.py diff --git a/nemo/deploy/nlp/__init__.py b/nemo/deploy/nlp/__init__.py index a2110931c6df..5ebbe6816664 100644 --- a/nemo/deploy/nlp/__init__.py +++ b/nemo/deploy/nlp/__init__.py @@ -15,7 +15,7 @@ use_query_llm = True try: - from nemo.deploy.nlp.query_llm import NemoQueryLLM + from nemo.deploy.nlp.query_llm import NemoQueryLLM, NemoQueryLLMPyTorch except Exception: use_query_llm = False diff --git a/nemo/deploy/nlp/megatronllm_deployable.py b/nemo/deploy/nlp/megatronllm_deployable.py index c27bbbd0102b..1fe029f9fade 100644 --- a/nemo/deploy/nlp/megatronllm_deployable.py +++ b/nemo/deploy/nlp/megatronllm_deployable.py @@ -15,6 +15,7 @@ import logging from enum import IntEnum, auto from pathlib import Path +from typing import List import numpy as np import torch @@ -129,6 +130,12 @@ def _load_from_nemo_checkpoint(self, nemo_checkpoint_filepath: str, num_devices: nemo_checkpoint_filepath, trainer=trainer, return_config=True ) # transformer_engine should always be true according to EricH, but GPT-2B model will fail if it is enabled + if not custom_config.transformer_engine: + LOGGER.warning( + "MegatronLLMDeployable expects model config transformer_engine=True, but this model has it =False. " + "Overriding it to =True, but this may break certain checkpoints converted on older Nemo versions. " + "If your model breaks, please try re-converting the checkpoint on the current Nemo version." + ) custom_config.transformer_engine = True # using multi-gpu for tensor parallelism directly for now, could do pipeline parallel instead or a combination custom_config.tensor_model_parallel_size = num_devices @@ -233,9 +240,7 @@ def _length_params_from_triton_inputs(**inputs: np.ndarray): length_params[length_param_field] = inputs.pop(length_param_field)[0][0] return length_params - @batch - def triton_infer_fn(self, **inputs: np.ndarray): - """Triton server inference function that actually runs the model""" + def generate(self, inputs: List[str], length_params: LengthParam, sampling_params: SamplingParam): if torch.distributed.is_initialized(): distributed_rank = torch.distributed.get_rank() if distributed_rank != 0: @@ -245,13 +250,16 @@ def triton_infer_fn(self, **inputs: np.ndarray): signal_value = ServerSync.SIGNAL.to_long_tensor() torch.distributed.broadcast(signal_value, 0) + return self.model.generate(inputs=inputs, length_params=length_params, sampling_params=sampling_params) + + @batch + def triton_infer_fn(self, **inputs: np.ndarray): + """Triton server inference function that actually runs the model""" input_strings = str_ndarray2list(inputs.pop("prompts")) sampling_params = self._sampling_params_from_triton_inputs(**inputs) length_params = self._length_params_from_triton_inputs(**inputs) - model_output = self.model.generate( - inputs=input_strings, length_params=length_params, sampling_params=sampling_params - ) + model_output = self.generate(input_strings, length_params, sampling_params) ''' model_output['sentences'] will be a list of strings (one per prompt) other fields will either be a list of lists (tokens, for example) diff --git a/nemo/deploy/nlp/query_llm.py b/nemo/deploy/nlp/query_llm.py index 940a927c7a54..71492520bf0a 100644 --- a/nemo/deploy/nlp/query_llm.py +++ b/nemo/deploy/nlp/query_llm.py @@ -30,23 +30,99 @@ def __init__(self, url, model_name): self.url = url self.model_name = model_name - @abstractmethod + +class NemoQueryLLMPyTorch(NemoQueryLLMBase): + """ + Sends a query to Triton for LLM inference + + Example: + from nemo.deploy import NemoTritonQueryLLMPyTorch + + nq = NemoTritonQueryLLMPyTorch(url="localhost", model_name="GPT-2B") + + prompts = ["hello, testing GPT inference", "another GPT inference test?"] + output = nq.query_llm( + prompts=prompts, + max_length=100, + top_k=1, + top_p=0.0, + temperature=0.0, + ) + print("prompts: ", prompts) + """ + + def __init__(self, url, model_name): + super().__init__( + url=url, + model_name=model_name, + ) + + # these arguments are explicitly defined in order to make it clear to user what they can pass + # names and optionality should exactly match the get_triton_input() results for MegatronGPTDeployable def query_llm( self, prompts, - stop_words_list=None, - bad_words_list=None, - no_repeat_ngram_size=None, - max_output_len=512, - top_k=1, - top_p=0.0, - temperature=1.0, - random_seed=None, - task_id=None, - lora_uids=None, + use_greedy: bool = None, + temperature: float = None, + top_k: int = None, + top_p: float = None, + repetition_penalty: float = None, + add_BOS: bool = None, + all_probs: bool = None, + compute_logprob: bool = None, + end_strings=None, + min_length: int = None, + max_length: int = None, init_timeout=60.0, ): - pass + """ + Query the Triton server synchronously and return a list of responses. + + Args: + prompts (List(str)): list of sentences. + use_greedy (bool): use greedy sampling, effectively the same as top_k=1 + temperature (float): A parameter of the softmax function, which is the last layer in the network. + top_k (int): limits us to a certain number (K) of the top tokens to consider. + top_p (float): limits us to the top tokens within a certain probability mass (p). + repetition_penalty (float): penalty applied to repeated sequences, 1.0 means no penalty. + add_BOS (bool): whether or not to add a BOS (beginning of sentence) token. + all_probs (bool): when using compute_logprob, returns probabilities for all tokens in vocabulary. + compute_logprob (bool): get back probabilities of all tokens in the sequence. + end_strings (List(str)): list of strings which will terminate generation when they appear in the output. + min_length (int): min generated tokens. + max_length (int): max generated tokens. + init_timeout (flat): timeout for the connection. + """ + prompts = str_list2numpy(prompts) + inputs = { + "prompts": prompts, + } + if use_greedy is not None: + inputs["use_greedy"] = np.full(prompts.shape, use_greedy, dtype=np.bool_) + if temperature is not None: + inputs["temperature"] = np.full(prompts.shape, temperature, dtype=np.single) + if top_k is not None: + inputs["top_k"] = np.full(prompts.shape, top_k, dtype=np.int_) + if top_p is not None: + inputs["top_p"] = np.full(prompts.shape, top_p, dtype=np.single) + if repetition_penalty is not None: + inputs["repetition_penalty"] = np.full(prompts.shape, repetition_penalty, dtype=np.single) + if add_BOS is not None: + inputs["add_BOS"] = np.full(prompts.shape, add_BOS, dtype=np.bool_) + if all_probs is not None: + inputs["all_probs"] = np.full(prompts.shape, all_probs, dtype=np.bool_) + if compute_logprob is not None: + inputs["compute_logprob"] = np.full(prompts.shape, compute_logprob, dtype=np.bool_) + if end_strings is not None: + inputs["end_strings"] = str_list2numpy(end_strings) + if min_length is not None: + inputs["min_length"] = np.full(prompts.shape, min_length, dtype=np.int_) + if max_length is not None: + inputs["max_length"] = np.full(prompts.shape, max_length, dtype=np.int_) + + with ModelClient(self.url, self.model_name, init_timeout_s=init_timeout) as client: + result_dict = client.infer_batch(**inputs) + return result_dict class NemoQueryLLM(NemoQueryLLMBase): diff --git a/scripts/deploy/nlp/deploy_inframework_triton.py b/scripts/deploy/nlp/deploy_inframework_triton.py new file mode 100755 index 000000000000..b698e4cbacfd --- /dev/null +++ b/scripts/deploy/nlp/deploy_inframework_triton.py @@ -0,0 +1,103 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import sys + +from nemo.deploy import DeployPyTriton + +LOGGER = logging.getLogger("NeMo") + +megatron_llm_supported = True +try: + from nemo.deploy.nlp import MegatronLLMDeployable +except Exception as e: + LOGGER.warning(f"Cannot import MegatronLLMDeployable, it will not be available. {type(e).__name__}: {e}") + megatron_llm_supported = False + + +def get_args(argv): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=f"Deploy nemo models to Triton", + ) + parser.add_argument("-nc", "--nemo_checkpoint", type=str, help="Source .nemo file") + parser.add_argument("-tmn", "--triton_model_name", required=True, type=str, help="Name for the service") + parser.add_argument("-tmv", "--triton_model_version", default=1, type=int, help="Version for the service") + parser.add_argument( + "-trp", "--triton_port", default=8000, type=int, help="Port for the Triton server to listen for requests" + ) + parser.add_argument( + "-tha", "--triton_http_address", default="0.0.0.0", type=str, help="HTTP address for the Triton server" + ) + parser.add_argument("-ng", "--num_gpus", default=1, type=int, help="Number of GPUs for the deployment") + parser.add_argument("-mbs", "--max_batch_size", default=8, type=int, help="Max batch size of the model") + parser.add_argument("-dm", "--debug_mode", default=False, action='store_true', help="Enable debug mode") + args = parser.parse_args(argv) + return args + + +def get_nemo_deployable(args): + if args.nemo_checkpoint is None: + raise ValueError("In-Framework deployment requires a .nemo checkpoint") + + return MegatronLLMDeployable(args.nemo_checkpoint, args.num_gpus) + + +def nemo_deploy(argv): + args = get_args(argv) + + if args.debug_mode: + loglevel = logging.DEBUG + else: + loglevel = logging.INFO + + LOGGER.setLevel(loglevel) + LOGGER.info("Logging level set to {}".format(loglevel)) + LOGGER.info(args) + + if not megatron_llm_supported: + raise ValueError("MegatronLLMDeployable is not supported in this environment.") + triton_deployable = get_nemo_deployable(args) + + try: + nm = DeployPyTriton( + model=triton_deployable, + triton_model_name=args.triton_model_name, + triton_model_version=args.triton_model_version, + max_batch_size=args.max_batch_size, + port=args.triton_port, + address=args.triton_http_address, + ) + + LOGGER.info("Triton deploy function will be called.") + nm.deploy() + except Exception as error: + LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error)) + return + + try: + LOGGER.info("Model serving on Triton is will be started.") + nm.serve() + except Exception as error: + LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error)) + return + + LOGGER.info("Model serving will be stopped.") + nm.stop() + + +if __name__ == '__main__': + nemo_deploy(sys.argv[1:]) diff --git a/scripts/deploy/nlp/query_inframework.py b/scripts/deploy/nlp/query_inframework.py new file mode 100644 index 000000000000..e77ab72a1f04 --- /dev/null +++ b/scripts/deploy/nlp/query_inframework.py @@ -0,0 +1,83 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import sys + +from nemo.deploy.nlp.query_llm import NemoQueryLLMPyTorch + + +def get_args(argv): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=f"Queries Triton server running an in-framework Nemo model", + ) + parser.add_argument("-u", "--url", default="0.0.0.0", type=str, help="url for the triton server") + parser.add_argument("-mn", "--model_name", required=True, type=str, help="Name of the triton model") + prompt_group = parser.add_mutually_exclusive_group(required=True) + prompt_group.add_argument("-p", "--prompt", required=False, type=str, help="Prompt") + prompt_group.add_argument("-pf", "--prompt_file", required=False, type=str, help="File to read the prompt from") + parser.add_argument("-mol", "--max_output_len", default=128, type=int, help="Max output token length") + parser.add_argument("-tk", "--top_k", default=1, type=int, help="top_k") + parser.add_argument("-tpp", "--top_p", default=0.0, type=float, help="top_p") + parser.add_argument("-t", "--temperature", default=1.0, type=float, help="temperature") + parser.add_argument("-it", "--init_timeout", default=60.0, type=float, help="init timeout for the triton server") + + args = parser.parse_args(argv) + return args + + +def query_llm( + url, + model_name, + prompts, + max_output_len=128, + top_k=1, + top_p=0.0, + temperature=1.0, + init_timeout=60.0, +): + nemo_query = NemoQueryLLMPyTorch(url, model_name) + return nemo_query.query_llm( + prompts=prompts, + max_length=max_output_len, + top_k=top_k, + top_p=top_p, + temperature=temperature, + init_timeout=init_timeout, + ) + + +def query(argv): + args = get_args(argv) + + if args.prompt_file is not None: + with open(args.prompt_file, "r") as f: + args.prompt = f.read() + + outputs = query_llm( + url=args.url, + model_name=args.model_name, + prompts=[args.prompt], + max_output_len=args.max_output_len, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + init_timeout=args.init_timeout, + ) + print(outputs["sentences"][0][0]) + + +if __name__ == '__main__': + query(sys.argv[1:]) diff --git a/tests/deploy/nemo_deploy.py b/tests/deploy/nemo_deploy.py index 5ef350b9c34a..5193fe951138 100644 --- a/tests/deploy/nemo_deploy.py +++ b/tests/deploy/nemo_deploy.py @@ -27,7 +27,7 @@ run_export_tests = True try: from nemo.deploy import DeployPyTriton - from nemo.deploy.nlp import NemoQueryLLM + from nemo.deploy.nlp import NemoQueryLLM, NemoQueryLLMPyTorch from nemo.export import TensorRTLLM except Exception as e: run_export_tests = False @@ -140,7 +140,7 @@ def run_in_framework_inference( ) nm.deploy() nm.run() - nq = NemoQueryLLM(url="localhost:8000", model_name=model_name) + nq = NemoQueryLLMPyTorch(url="localhost:8000", model_name=model_name) output_deployed = nq.query_llm( prompts=prompt, diff --git a/tests/export/nemo_export.py b/tests/export/nemo_export.py index 6073cff54423..6a296fdb92eb 100644 --- a/tests/export/nemo_export.py +++ b/tests/export/nemo_export.py @@ -40,7 +40,7 @@ in_framework_supported = True try: - from nemo.deploy.nlp import MegatronLLMDeployable + from nemo.deploy.nlp import MegatronLLMDeployable, NemoQueryLLMPyTorch except Exception as e: LOGGER.warning( f"Cannot import MegatronLLMDeployable, in-framework inference will not be available. {type(e).__name__}: {e}" @@ -101,52 +101,82 @@ def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path): for record in records: prompt = record["text_before_last_word"] expected_output = record["last_word"].strip().lower() - model_output = model.forward( - input_texts=[prompt], - max_output_len=1, - top_k=1, - top_p=0, - temperature=0.1, - task_ids=task_ids, - lora_uids=lora_uids, - ) - model_output = model_output[0][0].strip().lower() - all_expected_outputs.append(expected_output) - all_actual_outputs.append(model_output) + if model is not None: + if isinstance(model, MegatronLLMDeployable): + model_output = model.generate( + inputs=[prompt], + length_params={"min_length": 1, "max_length": 1}, + sampling_params={ + "use_greedy": True, + "temperature": 0.1, + "top_k": 1, + "top_p": 0, + "repetition_penalty": 1.0, + "add_BOS": True, + "all_probs": False, + "compute_logprob": False, + "end_strings": ["<|endoftext|>", ""], + }, + ) + # MegatronLLMDeployable returns prompt + generated output, so need to slice off prompt + model_output = model_output["sentences"][0][len(prompt) :].strip().lower() + else: + model_output = model.forward( + input_texts=[prompt], + max_output_len=1, + top_k=1, + top_p=0, + temperature=0.1, + task_ids=task_ids, + lora_uids=lora_uids, + ) + model_output = model_output[0][0].strip().lower() + all_actual_outputs.append(model_output) + + if expected_output == model_output: + correct_answers += 1 - if expected_output == model_output: - correct_answers += 1 - - if ( - expected_output == model_output - or model_output.startswith(expected_output) - or expected_output.startswith(model_output) - ): - if len(model_output) == 1 and len(expected_output) > 1: - continue - correct_answers_relaxed += 1 + if ( + expected_output == model_output + or model_output.startswith(expected_output) + or expected_output.startswith(model_output) + ): + if len(model_output) == 1 and len(expected_output) > 1: + continue + correct_answers_relaxed += 1 if nq is not None: - trtllm_deployed_output = nq.query_llm( - prompts=[prompt], - max_output_len=1, - top_k=1, - top_p=0, - temperature=0.1, - task_id=task_ids, - ) - trtllm_deployed_output = trtllm_deployed_output[0][0].strip().lower() - - if expected_output == trtllm_deployed_output: + if isinstance(nq, NemoQueryLLMPyTorch): + deployed_output = nq.query_llm( + prompts=[prompt], + max_length=1, + top_k=1, + top_p=0, + temperature=0.1, + ) + # MegatronLLMDeployable returns prompt + generated output, so need to slice off prompt + deployed_output = deployed_output["sentences"][0][0][len(prompt) :].decode().strip().lower() + else: + deployed_output = nq.query_llm( + prompts=[prompt], + max_output_len=1, + top_k=1, + top_p=0, + temperature=0.1, + task_id=task_ids, + ) + deployed_output = deployed_output[0][0].strip().lower() + + if expected_output == deployed_output: correct_answers_deployed += 1 if ( - expected_output == trtllm_deployed_output - or trtllm_deployed_output.startswith(expected_output) - or expected_output.startswith(trtllm_deployed_output) + expected_output == deployed_output + or deployed_output.startswith(expected_output) + or expected_output.startswith(deployed_output) ): - if len(trtllm_deployed_output) == 1 and len(expected_output) > 1: + if len(deployed_output) == 1 and len(expected_output) > 1: continue correct_answers_deployed_relaxed += 1 eval_end = time.monotonic() @@ -459,7 +489,7 @@ def run_existing_checkpoints( if in_framework: return run_in_framework_inference( model_name=model_name, - prompts=model_info["model_type"], + prompts=model_info["prompt_template"], checkpoint_path=model_info["checkpoint"], num_gpus=tp_size, max_output_len=model_info["max_output_len"], @@ -534,14 +564,15 @@ def run_in_framework_inference( ) nm.deploy() nm.run() - nq = NemoQueryLLM(url="localhost:8000", model_name=model_name) + nq = NemoQueryLLMPyTorch(url="localhost:8000", model_name=model_name) output_deployed = nq.query_llm( - prompts=[prompts], - top_k=top_k, - top_p=top_p, - temperature=temperature, + prompts=prompts, top_k=top_k, top_p=top_p, temperature=temperature, max_length=max_output_len ) + output_deployed = output_deployed["sentences"] + # MegatronLLMDeployable will return the prompt + generated output, so cut off the prompt + for i, output in enumerate(output_deployed): + output = output[len(prompts[i]) :] # Unwrap the generator if needed output_deployed = list(output_deployed) @@ -550,7 +581,8 @@ def run_in_framework_inference( accuracy_result = None if run_accuracy: print("Start model accuracy testing ...") - accuracy_result = get_accuracy_with_lambada(None, nq, None, None, test_data_path) + # This script is not written with torch.distributed support in mind, so running non-deployed in-framework models on multiple devices will not work + accuracy_result = get_accuracy_with_lambada(deployed_model, nq, None, None, test_data_path) nm.stop() @@ -736,7 +768,7 @@ def str_to_bool(name: str, s: str) -> bool: def run_inference_tests(args): - if not args.use_vllm and not trt_llm_supported: + if not args.use_vllm and not args.in_framework and not trt_llm_supported: raise UsageError("TensorRT-LLM engine is not supported in this environment.") if args.use_vllm and not vllm_supported: @@ -788,7 +820,7 @@ def run_inference_tests(args): tps = tps * 2 else: - if args.model_dir is None: + if not args.in_framework and args.model_dir is None: raise Exception("When using custom checkpoints, --model_dir is required.") prompts = ["The capital of France is", "Largest animal in the sea is"] @@ -847,6 +879,8 @@ def run_inference_tests(args): accuracy_test_result = "PASS" print_separator = False print("============= Test Summary ============") + # in-framework tests will only return deployed model accuracy results for tps > 1 + deployed_tests_only = args.in_framework and args.max_tps > 1 for num_tps, results in result_dic.items(): functional_result, accuracy_result = results @@ -876,7 +910,9 @@ def optional_bool_to_pass_fail(b: Optional[bool]): print(f"Deployed Model Accuracy: {accuracy_result.deployed_accuracy:.4f}") print(f"Deployed Relaxed Model Accuracy: {accuracy_result.deployed_accuracy_relaxed:.4f}") print(f"Evaluation Time [s]: {accuracy_result.evaluation_time:.2f}") - if accuracy_result.accuracy_relaxed < 0.5: + if (deployed_tests_only and accuracy_result.deployed_accuracy_relaxed < 0.5) or ( + not deployed_tests_only and accuracy_result.accuracy_relaxed < 0.5 + ): accuracy_test_result = "FAIL" print("=======================================")