diff --git a/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py index 020cdaa4f..91b1a1fdb 100644 --- a/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py @@ -12,7 +12,7 @@ # the specific language governing permissions and limitations under the License. import logging import tensorrt_llm_toolkit -from djl_python.rolling_batch.rolling_batch import RollingBatch +from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception class TRTLLMRollingBatch(RollingBatch): @@ -55,6 +55,7 @@ def translate_triton_params(self, parameters): parameters["streaming"] = parameters.get("streaming", True) return parameters + @stop_on_any_exception def inference(self, input_data, parameters): batch_size = len(input_data) # add pending requests to active requests list