Skip to content

Commit

Permalink
Use ModelOpt build_tensorrt_llm for building engines for qnemo checkp…
Browse files Browse the repository at this point in the history
…oints (NVIDIA#9452)

* Enable specyfing alpha for SQ

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

* Enable specifying use_custom_all_reduce for export

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

* Use native TRT-LLM param names in export (partial)

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

* Detect TRT-LLM checkpoint programatically

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

* Pass use_custom_all_reduce in test_nemo_export.py

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

* Paramter parsing bugfix

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

* Revert "Paramter parsing bugfix"

This reverts commit b0a4dd3.

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

* Revert "Enable specifying use_custom_all_reduce for export"

This reverts commit 9e419e3.

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

* Revert "Pass use_custom_all_reduce in test_nemo_export.py"

This reverts commit be70812.

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

* Rename checkpoint detection function

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

* Use ModelOpt build_tensorrt_llm utility for qnemo for performance alignment

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

* Import fix

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

* Apply isort and black reformatting

Signed-off-by: janekl <janekl@users.noreply.github.com>

---------

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>
Signed-off-by: janekl <janekl@users.noreply.github.com>
Co-authored-by: janekl <janekl@users.noreply.github.com>
  • Loading branch information
2 people authored and galv committed Jun 13, 2024
1 parent 3156bea commit c891208
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 47 deletions.
13 changes: 12 additions & 1 deletion nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from nemo.export.trt_llm.nemo_ckpt_loader.nemo_file import get_tokenzier, is_nemo_file, load_nemo_model
from nemo.export.trt_llm.qnemo import qnemo_to_tensorrt_llm
from nemo.export.trt_llm.qnemo.tokenizer_utils import get_nmt_tokenizer
from nemo.export.trt_llm.qnemo.utils import is_qnemo_checkpoint
from nemo.export.trt_llm.tensorrt_llm_build import build_and_save_engine
from nemo.export.trt_llm.tensorrt_llm_run import generate, generate_streaming, load

Expand Down Expand Up @@ -229,7 +230,7 @@ def export(
tmp_dir = tempfile.TemporaryDirectory()
nemo_export_dir = Path(tmp_dir.name)

if nemo_checkpoint_path.endswith("qnemo"):
if is_qnemo_checkpoint(nemo_checkpoint_path):
if os.path.isdir(nemo_checkpoint_path):
nemo_export_dir = nemo_checkpoint_path
else:
Expand All @@ -244,7 +245,17 @@ def export(
max_output_len=max_output_len,
max_batch_size=max_batch_size,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
use_parallel_embedding=use_parallel_embedding,
paged_kv_cache=paged_kv_cache,
remove_input_padding=remove_input_padding,
enable_multi_block_mode=enable_multi_block_mode,
use_lora_plugin=use_lora_plugin,
lora_target_modules=lora_target_modules,
max_lora_rank=max_lora_rank,
max_num_tokens=max_num_tokens,
opt_num_tokens=opt_num_tokens,
)
else:
model, model_configs, self.tokenizer = load_nemo_model(nemo_checkpoint_path, nemo_export_dir)
Expand Down
92 changes: 46 additions & 46 deletions nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import subprocess

import glob
import os
import warnings
from typing import List, Optional

CONFIG_NAME = "config.json"
from modelopt.deploy.llm import build_tensorrt_llm

from nemo.export.trt_llm.qnemo.utils import CONFIG_NAME, WEIGHTS_NAME


def qnemo_to_tensorrt_llm(
Expand All @@ -28,50 +30,48 @@ def qnemo_to_tensorrt_llm(
max_output_len: int,
max_batch_size: int,
max_prompt_embedding_table_size: int,
tensor_parallel_size: int = None,
pipeline_parallel_size: int = None,
use_parallel_embedding: bool = False,
paged_kv_cache: bool = True,
remove_input_padding: bool = True,
enable_multi_block_mode: bool = False,
use_lora_plugin: str = None,
lora_target_modules: Optional[List[str]] = None,
max_lora_rank: int = 64,
max_num_tokens: int = None,
opt_num_tokens: int = None,
):
"""Build TRT-LLM engine via trtllm-build CLI API in a subprocess."""
"""Build TensorRT-LLM engine with ModelOpt build_tensorrt_llm function."""
assert not lora_target_modules, f"LoRA is not supported for quantized checkpoints, got {lora_target_modules}"
print(
"Note that setting n_gpus, tensor_parallel_size and pipeline_parallel_size parameters"
" for quantized models is possible only on export step via nemo.export.quantize module."
" These parameters are ignored when building and running TensorRT-LLM engine below."

warnings.warn(
"Note that setting tensor_parallel_size and pipeline_parallel_size parameters"
" for quantized models should be done on calibration step with nemo.export.quantize module."
" These parameters are ignored when building and running TensorRT-LLM engine below.",
UserWarning,
stacklevel=3,
)
# Load config to explicitly pass selected parameters to trtllm-build command:
with open(os.path.join(nemo_checkpoint_path, CONFIG_NAME), "r") as f:
model_config = json.load(f)
command = [
"trtllm-build",
"--checkpoint_dir",
nemo_checkpoint_path,
"--output_dir",
engine_dir,
"--max_batch_size",
str(max_batch_size),
"--max_input_len",
str(max_input_len),
"--max_output_len",
str(max_output_len),
"--max_prompt_embedding_table_size",
str(max_prompt_embedding_table_size),
"--gemm_plugin",
model_config["dtype"],
"--gpt_attention_plugin",
model_config["dtype"],
"--strongly_typed",
"--use_custom_all_reduce",
"disable",
"--workers",
str(model_config["mapping"]["world_size"]),
]
command_str = " ".join(command)
print(f"Build command is:\n{command_str}")
print("Running trtllm-build, this may take a while...")
result = subprocess.run(command, capture_output=True) # TODO: consider streaming logs
if result.returncode != 0:
print(result.stdout.decode())
print(result.stderr.decode())
raise RuntimeError("Error encountered for trtllm-build command, please check logs.")

print("Building engine done. Full logs are:")
print(result.stdout.decode())
warnings.warn(
"Also use_parallel_embedding, paged_kv_cache, remove_input_padding, enable_multi_block_mode, max_num_tokens"
" and opt_num_tokens parameters are set by ModelOpt build_tensorrt_llm function in the optimal way and are"
" ignored on engine build step.",
UserWarning,
stacklevel=3,
)

num_build_workers = len(glob.glob(os.path.join(nemo_checkpoint_path, WEIGHTS_NAME.format("*"))))
assert num_build_workers, f"No TensorRT-LLM weight files found in {nemo_checkpoint_path}"

build_tensorrt_llm(
pretrained_config=os.path.join(nemo_checkpoint_path, CONFIG_NAME),
engine_dir=engine_dir,
max_input_len=max_input_len,
max_output_len=max_output_len,
max_batch_size=max_batch_size,
max_beam_width=1,
num_build_workers=num_build_workers,
enable_sparsity=False,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
)
18 changes: 18 additions & 0 deletions nemo/export/trt_llm/qnemo/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os
from pathlib import Path

from nemo.export.tarutils import TarPath

CONFIG_NAME = "config.json"
WEIGHTS_NAME = "rank{}.safetensors"


def is_qnemo_checkpoint(path: str) -> bool:
"""Detect if a given path is a TensorRT-LLM a.k.a. "qnemo" checkpoint based on config & tensor data presence."""
if os.path.isdir(path):
path = Path(path)
else:
path = TarPath(path)
config_path = path / CONFIG_NAME
tensor_path = path / WEIGHTS_NAME.format(0)
return config_path.exists() and tensor_path.exists()

0 comments on commit c891208

Please sign in to comment.