Skip to content

Commit

Permalink
simplify to one trace on predict fn
Browse files Browse the repository at this point in the history
Signed-off-by: Jesse Chan <jesse.chan@databricks.com>
  • Loading branch information
jessechancy committed May 6, 2024
1 parent 96d1dce commit 4dbd76b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 38 deletions.
48 changes: 22 additions & 26 deletions mlflow/models/evaluation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,8 +1316,7 @@ def _get_model_from_function(fn):

class ModelFromFunction(mlflow.pyfunc.PythonModel):
def predict(self, context, model_input: pd.DataFrame):
traced_fn = mlflow.trace(fn)
return traced_fn(model_input)
return fn(model_input)

python_model = ModelFromFunction()
return _PythonModelPyfuncWrapper(python_model, None, None)
Expand Down Expand Up @@ -1357,30 +1356,27 @@ def predict(self, context, model_input: pd.DataFrame):
)
input_column = model_input.columns[0]

with mlflow.start_span() as span:
span.set_inputs(model_input)
predictions = []
for data in model_input[input_column]:
if isinstance(data, str):
# If the input data is a string, we will construct the request
# payload from it.
prediction = _call_deployments_api(self.endpoint, data, self.params)
elif isinstance(data, dict):
# If the input data is a dictionary, we will directly use it as the request
# payload, with adding the inference parameters if provided.
prediction = _call_deployments_api(
self.endpoint, data, self.params, wrap_payload=False
)
else:
raise MlflowException(
f"Invalid input column type: {type(data)}. The input data "
"must be either a string or a dictionary contains the request "
"payload for evaluating an MLflow Deployments endpoint.",
error_code=INVALID_PARAMETER_VALUE,
)

predictions.append(prediction)
span.set_outputs(predictions)
predictions = []
for data in model_input[input_column]:
if isinstance(data, str):
# If the input data is a string, we will construct the request
# payload from it.
prediction = _call_deployments_api(self.endpoint, data, self.params)
elif isinstance(data, dict):
# If the input data is a dictionary, we will directly use it as the request
# payload, with adding the inference parameters if provided.
prediction = _call_deployments_api(
self.endpoint, data, self.params, wrap_payload=False
)
else:
raise MlflowException(
f"Invalid input column type: {type(data)}. The input data "
"must be either a string or a dictionary contains the request "
"payload for evaluating an MLflow Deployments endpoint.",
error_code=INVALID_PARAMETER_VALUE,
)

predictions.append(prediction)

return pd.Series(predictions)

Expand Down
2 changes: 1 addition & 1 deletion mlflow/models/evaluation/default_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _extract_predict_fn(model, raw_model):
except ImportError:
pass

return predict_fn, predict_proba_fn
return mlflow.trace(predict_fn), predict_proba_fn


def _get_regressor_metrics(y, y_pred, sample_weights):
Expand Down
18 changes: 7 additions & 11 deletions mlflow/pyfunc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,17 +965,13 @@ def predict(self, data, params=None):
Returns:
Model predictions.
"""
with mlflow.start_span() as span:
if inspect.signature(self._client.invoke).parameters.get("params"):
span.set_inputs({"data": data, "params": params})
result = self._client.invoke(data, params=params).get_predictions()
else:
span.set_inputs(data)
_log_warning_if_params_not_in_predict_signature(_logger, params)
result = self._client.invoke(data).get_predictions()
if isinstance(result, pandas.DataFrame):
result = result[result.columns[0]]
span.set_outputs(result)
if inspect.signature(self._client.invoke).parameters.get("params"):
result = self._client.invoke(data, params=params).get_predictions()
else:
_log_warning_if_params_not_in_predict_signature(_logger, params)
result = self._client.invoke(data).get_predictions()
if isinstance(result, pandas.DataFrame):
result = result[result.columns[0]]
return result

@property
Expand Down

0 comments on commit 4dbd76b

Please sign in to comment.