Skip to content

Commit

Permalink
Make self.inference_config default to empty dictionary (mlflow#8308)
Browse files Browse the repository at this point in the history
Signed-off-by: harupy <hkawamura0130@gmail.com>
Signed-off-by: Larry O’Brien <larry.obrien@databricks.com>
  • Loading branch information
harupy authored and Larry O’Brien committed May 10, 2023
1 parent 4a13941 commit a7b3ee0
Showing 1 changed file with 5 additions and 17 deletions.
22 changes: 5 additions & 17 deletions mlflow/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,7 +1323,7 @@ class _TransformersWrapper:
def __init__(self, pipeline, flavor_config=None, inference_config=None):
self.pipeline = pipeline
self.flavor_config = flavor_config
self.inference_config = inference_config
self.inference_config = inference_config or {}
self._conversation = None
# NB: Current special-case custom pipeline types that have not been added to
# the native-supported transformers package but require custom parsing:
Expand Down Expand Up @@ -1451,30 +1451,18 @@ def _predict(self, data):
# formatting output), but if `include_prompt` is set to False in the `inference_config`
# option during model saving, excess newline characters and the fed-in prompt will be
# trimmed out from the start of the response.
include_prompt = (
self.inference_config.pop("include_prompt", True) if self.inference_config else True
)
include_prompt = self.inference_config.pop("include_prompt", True)
# Optional stripping out of `\n` for specific generator pipelines.
collapse_whitespace = (
self.inference_config.pop("collapse_whitespace", False)
if self.inference_config
else False
)
collapse_whitespace = self.inference_config.pop("collapse_whitespace", False)

# Generate inference data with the pipeline object
if isinstance(self.pipeline, transformers.ConversationalPipeline):
conversation_output = self.pipeline(self._conversation)
return conversation_output.generated_responses[-1]
elif isinstance(data, dict):
if self.inference_config:
raw_output = self.pipeline(**data, **self.inference_config)
else:
raw_output = self.pipeline(**data)
raw_output = self.pipeline(**data, **self.inference_config)
else:
if self.inference_config:
raw_output = self.pipeline(data, **self.inference_config)
else:
raw_output = self.pipeline(data)
raw_output = self.pipeline(data, **self.inference_config)

# Handle the pipeline outputs
if type(self.pipeline).__name__ in self._supported_custom_generator_types or isinstance(
Expand Down

0 comments on commit a7b3ee0

Please sign in to comment.