diff --git a/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py index ccbf3f1a2..020cdaa4f 100644 --- a/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py @@ -38,8 +38,8 @@ def reset(self): super().reset() def translate_triton_params(self, parameters): - parameters["request_output_len"] = int( - parameters.get("max_new_tokens", 128)) + parameters["max_new_tokens"] = parameters.get("max_new_tokens", 128) + parameters["request_output_len"] = parameters.pop("max_new_tokens") if "top_k" in parameters.keys(): parameters["runtime_top_k"] = parameters.pop("top_k") if "top_p" in parameters.keys():