-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for transformers pyfunc #8181
Changes from 13 commits
0451b97
dc12a39
2fad7d0
05a64b0
ddbd04d
036b3eb
8436e16
bf98c94
23614bd
4ef41ee
7ef4498
251b29a
05fdce5
9881baa
6080408
b430212
daeb811
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -106,3 +106,6 @@ a.py | |
|
||
# Log file created by pre-commit hook for black | ||
.black.log | ||
|
||
# Pytest-monitor load testing DB file | ||
*.pymon |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,18 +26,12 @@ | |
except ImportError: | ||
HAS_SCIPY = False | ||
|
||
ModelInputExample = Union[pd.DataFrame, np.ndarray, dict, list, "csr_matrix", "csc_matrix"] | ||
ModelInputExample = Union[pd.DataFrame, np.ndarray, dict, list, "csr_matrix", "csc_matrix", str] | ||
|
||
PyFuncInput = Union[ | ||
pd.DataFrame, | ||
pd.Series, | ||
np.ndarray, | ||
"csc_matrix", | ||
"csr_matrix", | ||
List[Any], | ||
Dict[str, Any], | ||
pd.DataFrame, pd.Series, np.ndarray, "csc_matrix", "csr_matrix", List[Any], Dict[str, Any], str | ||
] | ||
PyFuncOutput = Union[pd.DataFrame, pd.Series, np.ndarray, list] | ||
PyFuncOutput = Union[pd.DataFrame, pd.Series, np.ndarray, list, str] | ||
|
||
|
||
class _Example: | ||
|
@@ -127,6 +121,13 @@ def _handle_dataframe_input(input_ex): | |
if isinstance(input_ex, dict): | ||
if all(_is_scalar(x) for x in input_ex.values()): | ||
input_ex = pd.DataFrame([input_ex]) | ||
elif all(isinstance(x, (str, list)) for x in input_ex.values()): | ||
for value in input_ex.values(): | ||
if isinstance(value, list) and not all(_is_scalar(x) for x in value): | ||
raise TypeError( | ||
"List values within dictionaries must be of scalar type." | ||
) | ||
input_ex = pd.DataFrame(input_ex) | ||
else: | ||
raise TypeError( | ||
"Data in the dictionary must be scalar or of type numpy.ndarray" | ||
|
@@ -141,6 +142,8 @@ def _handle_dataframe_input(input_ex): | |
input_ex = pd.DataFrame([input_ex], columns=range(len(input_ex))) | ||
else: | ||
input_ex = pd.DataFrame(input_ex) | ||
elif isinstance(input_ex, str): | ||
input_ex = pd.DataFrame([input_ex]) | ||
elif not isinstance(input_ex, pd.DataFrame): | ||
try: | ||
import pyspark.sql.dataframe | ||
|
@@ -609,9 +612,16 @@ def _enforce_schema(pf_input: PyFuncInput, input_schema: Schema): | |
if isinstance(pf_input, pd.Series): | ||
pf_input = pd.DataFrame(pf_input) | ||
if not input_schema.is_tensor_spec(): | ||
if isinstance(pf_input, (list, np.ndarray, dict, pd.Series)): | ||
if isinstance(pf_input, (list, np.ndarray, dict, pd.Series, str)): | ||
try: | ||
pf_input = pd.DataFrame(pf_input) | ||
if isinstance(pf_input, dict) and all( | ||
not isinstance(value, (dict, list)) for value in pf_input.values() | ||
): | ||
pf_input = pd.DataFrame(pf_input, index=[0]) | ||
Comment on lines
+617
to
+620
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a behavior change for dataframes with numpy array columns, scalar columns, etc. Can you provide more context here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've adjusted this logic to be very explicit about matching on dict where the keys and values are all strings. Without this logic, we can't cast a scalar string to a DataFrame for serving or do signature validation. By supplying the index, we're able to create the DataFrame without an Exception being thrown. |
||
elif isinstance(pf_input, str): | ||
pf_input = pd.DataFrame({"inputs": pf_input}, index=[0]) | ||
else: | ||
pf_input = pd.DataFrame(pf_input) | ||
except Exception as e: | ||
raise MlflowException( | ||
"This model contains a column-based signature, which suggests a DataFrame" | ||
|
@@ -678,6 +688,7 @@ def validate_schema(data: PyFuncInput, expected_schema: Schema) -> None: | |
- scipy.sparse.csr_matrix | ||
- List[Any] | ||
- Dict[str, Any] | ||
- str | ||
:param expected_schema: Expected :py:class:`Schema <mlflow.types.Schema>` of the input data. | ||
:raises: A :py:class:`mlflow.exceptions.MlflowException`. when the input data does | ||
not match the schema. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this? The requirements field for transformers contains both
accelerate
anddatasets
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, good point. Removed!