From a7b3ee0dacafea127ee38005d01954ea643abd83 Mon Sep 17 00:00:00 2001 From: Harutaka Kawamura Date: Mon, 24 Apr 2023 23:26:33 +0900 Subject: [PATCH] Make self.inference_config default to empty dictionary (#8308) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: harupy Signed-off-by: Larry O’Brien --- mlflow/transformers.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/mlflow/transformers.py b/mlflow/transformers.py index c6620aa543d6c..1497d231af5b5 100644 --- a/mlflow/transformers.py +++ b/mlflow/transformers.py @@ -1323,7 +1323,7 @@ class _TransformersWrapper: def __init__(self, pipeline, flavor_config=None, inference_config=None): self.pipeline = pipeline self.flavor_config = flavor_config - self.inference_config = inference_config + self.inference_config = inference_config or {} self._conversation = None # NB: Current special-case custom pipeline types that have not been added to # the native-supported transformers package but require custom parsing: @@ -1451,30 +1451,18 @@ def _predict(self, data): # formatting output), but if `include_prompt` is set to False in the `inference_config` # option during model saving, excess newline characters and the fed-in prompt will be # trimmed out from the start of the response. - include_prompt = ( - self.inference_config.pop("include_prompt", True) if self.inference_config else True - ) + include_prompt = self.inference_config.pop("include_prompt", True) # Optional stripping out of `\n` for specific generator pipelines. - collapse_whitespace = ( - self.inference_config.pop("collapse_whitespace", False) - if self.inference_config - else False - ) + collapse_whitespace = self.inference_config.pop("collapse_whitespace", False) # Generate inference data with the pipeline object if isinstance(self.pipeline, transformers.ConversationalPipeline): conversation_output = self.pipeline(self._conversation) return conversation_output.generated_responses[-1] elif isinstance(data, dict): - if self.inference_config: - raw_output = self.pipeline(**data, **self.inference_config) - else: - raw_output = self.pipeline(**data) + raw_output = self.pipeline(**data, **self.inference_config) else: - if self.inference_config: - raw_output = self.pipeline(data, **self.inference_config) - else: - raw_output = self.pipeline(data) + raw_output = self.pipeline(data, **self.inference_config) # Handle the pipeline outputs if type(self.pipeline).__name__ in self._supported_custom_generator_types or isinstance(