diff --git a/mlflow/transformers.py b/mlflow/transformers.py index c6620aa543d6c..1497d231af5b5 100644 --- a/mlflow/transformers.py +++ b/mlflow/transformers.py @@ -1323,7 +1323,7 @@ class _TransformersWrapper: def __init__(self, pipeline, flavor_config=None, inference_config=None): self.pipeline = pipeline self.flavor_config = flavor_config - self.inference_config = inference_config + self.inference_config = inference_config or {} self._conversation = None # NB: Current special-case custom pipeline types that have not been added to # the native-supported transformers package but require custom parsing: @@ -1451,30 +1451,18 @@ def _predict(self, data): # formatting output), but if `include_prompt` is set to False in the `inference_config` # option during model saving, excess newline characters and the fed-in prompt will be # trimmed out from the start of the response. - include_prompt = ( - self.inference_config.pop("include_prompt", True) if self.inference_config else True - ) + include_prompt = self.inference_config.pop("include_prompt", True) # Optional stripping out of `\n` for specific generator pipelines. - collapse_whitespace = ( - self.inference_config.pop("collapse_whitespace", False) - if self.inference_config - else False - ) + collapse_whitespace = self.inference_config.pop("collapse_whitespace", False) # Generate inference data with the pipeline object if isinstance(self.pipeline, transformers.ConversationalPipeline): conversation_output = self.pipeline(self._conversation) return conversation_output.generated_responses[-1] elif isinstance(data, dict): - if self.inference_config: - raw_output = self.pipeline(**data, **self.inference_config) - else: - raw_output = self.pipeline(**data) + raw_output = self.pipeline(**data, **self.inference_config) else: - if self.inference_config: - raw_output = self.pipeline(data, **self.inference_config) - else: - raw_output = self.pipeline(data) + raw_output = self.pipeline(data, **self.inference_config) # Handle the pipeline outputs if type(self.pipeline).__name__ in self._supported_custom_generator_types or isinstance(