Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make self.inference_config default to empty dictionary #8308

Merged
merged 1 commit into from
Apr 24, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 5 additions & 17 deletions mlflow/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,7 +1323,7 @@ class _TransformersWrapper:
def __init__(self, pipeline, flavor_config=None, inference_config=None):
self.pipeline = pipeline
self.flavor_config = flavor_config
self.inference_config = inference_config
self.inference_config = inference_config or {}
self._conversation = None
# NB: Current special-case custom pipeline types that have not been added to
# the native-supported transformers package but require custom parsing:
Expand Down Expand Up @@ -1451,30 +1451,18 @@ def _predict(self, data):
# formatting output), but if `include_prompt` is set to False in the `inference_config`
# option during model saving, excess newline characters and the fed-in prompt will be
# trimmed out from the start of the response.
include_prompt = (
self.inference_config.pop("include_prompt", True) if self.inference_config else True
)
include_prompt = self.inference_config.pop("include_prompt", True)
# Optional stripping out of `\n` for specific generator pipelines.
collapse_whitespace = (
self.inference_config.pop("collapse_whitespace", False)
if self.inference_config
else False
)
collapse_whitespace = self.inference_config.pop("collapse_whitespace", False)

# Generate inference data with the pipeline object
if isinstance(self.pipeline, transformers.ConversationalPipeline):
conversation_output = self.pipeline(self._conversation)
return conversation_output.generated_responses[-1]
elif isinstance(data, dict):
if self.inference_config:
raw_output = self.pipeline(**data, **self.inference_config)
else:
raw_output = self.pipeline(**data)
raw_output = self.pipeline(**data, **self.inference_config)
else:
if self.inference_config:
raw_output = self.pipeline(data, **self.inference_config)
else:
raw_output = self.pipeline(data)
raw_output = self.pipeline(data, **self.inference_config)

# Handle the pipeline outputs
if type(self.pipeline).__name__ in self._supported_custom_generator_types or isinstance(
Expand Down
Loading