Skip to content

Commit

Permalink
Enable specifying use_custom_all_reduce for export
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Lasek <janek.lasek@gmail.com>
  • Loading branch information
janekl committed Jun 6, 2024
1 parent 4f925d3 commit 9e419e3
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 1 deletion.
4 changes: 4 additions & 0 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def export(
use_inflight_batching: bool = False,
enable_context_fmha: bool = True,
paged_kv_cache: bool = False,
use_custom_all_reduce: bool = True,
dtype: str = "bfloat16",
load_model: bool = True,
enable_multi_block_mode: bool = False,
Expand All @@ -145,6 +146,7 @@ def export(
use_inflight_batching (bool): if True, enables inflight batching for TensorRT-LLM Triton backend.
enable_context_fmha (bool): if True, use fused Context MultiHeadedAttention.
paged_kv_cache (bool): if True, uses kv cache feature of the TensorRT-LLM.
use_custom_all_reduce (bool): if True, uses latency-optimized AR plugin instead of native NCCL operator.
dtype (str): Floating point type for model weights (Supports BFloat16/Float16).
load_model (bool): load TensorRT-LLM model after the export.
enable_multi_block_mode (bool): enable faster decoding in multihead attention. Required for long context.
Expand Down Expand Up @@ -208,6 +210,7 @@ def export(
max_input_len=max_input_token,
max_output_len=max_output_token,
max_batch_size=max_batch_size,
use_custom_all_reduce=use_custom_all_reduce,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
lora_target_modules=lora_target_modules,
)
Expand Down Expand Up @@ -238,6 +241,7 @@ def export(
lora_target_modules=lora_target_modules,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
enable_multi_block_mode=enable_multi_block_mode,
use_custom_all_reduce=use_custom_all_reduce,
)

tokenizer_path = os.path.join(nemo_export_dir, "tokenizer.model")
Expand Down
3 changes: 2 additions & 1 deletion nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def qnemo_to_tensorrt_llm(
max_input_len: int,
max_output_len: int,
max_batch_size: int,
use_custom_all_reduce: bool,
max_prompt_embedding_table_size: int,
lora_target_modules: Optional[List[str]] = None,
):
Expand Down Expand Up @@ -60,7 +61,7 @@ def qnemo_to_tensorrt_llm(
model_config["dtype"],
"--strongly_typed",
"--use_custom_all_reduce",
"disable",
"enable" if use_custom_all_reduce else "disable",
"--workers",
str(model_config["mapping"]["world_size"]),
]
Expand Down
2 changes: 2 additions & 0 deletions nemo/export/trt_llm/tensorrt_llm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def build_and_save_engine(
lora_target_modules=None,
max_prompt_embedding_table_size=0,
enable_multi_block_mode: bool = False,
use_custom_all_reduce: bool = True,
):
try:
model_cls = getattr(tensorrt_llm.models, model_config.architecture)
Expand All @@ -397,6 +398,7 @@ def build_and_save_engine(
'gather_generation_logits': False,
'strongly_typed': False,
'builder_opt': None,
'use_custom_all_reduce': use_custom_all_reduce,
}
build_config = BuildConfig.from_dict(build_dict, plugin_config=plugin_config)

Expand Down

0 comments on commit 9e419e3

Please sign in to comment.