Skip to content

Add optional post_predict method to predictor classes #1237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 53 additions & 20 deletions docs/deployments/predictors.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,36 @@ class PythonPredictor:

class PythonPredictor:
def __init__(self, config):
"""Called once before the API becomes available. Performs setup such as downloading/initializing the model or downloading a vocabulary.
"""(Required) Called once before the API becomes available. Performs setup such as downloading/initializing the model or downloading a vocabulary.

Args:
config: Dictionary passed from API configuration (if specified). This may contain information on where to download the model and/or metadata.
config (required): Dictionary passed from API configuration (if specified). This may contain information on where to download the model and/or metadata.
"""
pass

def predict(self, payload, query_params, headers):
"""Called once per request. Preprocesses the request payload (if necessary), runs inference, and postprocesses the inference output (if necessary).
"""(Required) Called once per request. Preprocesses the request payload (if necessary), runs inference, and postprocesses the inference output (if necessary).

Args:
payload: The request payload (see below for the possible payload types) (optional).
query_params: A dictionary of the query parameters used in the request (optional).
headers: A dictionary of the headers sent in the request (optional).
payload (optional): The request payload (see below for the possible payload types).
query_params (optional): A dictionary of the query parameters used in the request.
headers (optional): A dictionary of the headers sent in the request.

Returns:
Prediction or a batch of predictions.
"""
pass

def post_predict(self, response, payload, query_params, headers):
"""(Optional) Called in the background after returning a response. Useful for tasks that the client doesn't need to wait on before receiving a response such as recording metrics or storing results.

Args:
response (optional): The response as returned by the predict method.
payload (optional): The request payload (see below for the possible payload types).
query_params (optional): A dictionary of the query parameters used in the request.
headers (optional): A dictionary of the headers sent in the request.
"""
pass
```

For proper separation of concerns, it is recommended to use the constructor's `config` paramater for information such as from where to download the model and initialization files, or any configurable model parameters. You define `config` in your [API configuration](api-configuration.md), and it is passed through to your Predictor's constructor.
Expand Down Expand Up @@ -204,27 +215,38 @@ If your application requires additional dependencies, you can install additional
```python
class TensorFlowPredictor:
def __init__(self, tensorflow_client, config):
"""Called once before the API becomes available. Performs setup such as downloading/initializing a vocabulary.
"""(Required) Called once before the API becomes available. Performs setup such as downloading/initializing a vocabulary.

Args:
tensorflow_client: TensorFlow client which is used to make predictions. This should be saved for use in predict().
config: Dictionary passed from API configuration (if specified).
tensorflow_client (required): TensorFlow client which is used to make predictions. This should be saved for use in predict().
config (required): Dictionary passed from API configuration (if specified).
"""
self.client = tensorflow_client
# Additional initialization may be done here

def predict(self, payload, query_params, headers):
"""Called once per request. Preprocesses the request payload (if necessary), runs inference (e.g. by calling self.client.predict(model_input)), and postprocesses the inference output (if necessary).
"""(Required) Called once per request. Preprocesses the request payload (if necessary), runs inference (e.g. by calling self.client.predict(model_input)), and postprocesses the inference output (if necessary).

Args:
payload: The request payload (see below for the possible payload types) (optional).
query_params: A dictionary of the query parameters used in the request (optional).
headers: A dictionary of the headers sent in the request (optional).
payload (optional): The request payload (see below for the possible payload types).
query_params (optional): A dictionary of the query parameters used in the request.
headers (optional): A dictionary of the headers sent in the request.

Returns:
Prediction or a batch of predictions.
"""
pass

def post_predict(self, response, payload, query_params, headers):
"""(Optional) Called in the background after returning a response. Useful for tasks that the client doesn't need to wait on before receiving a response such as recording metrics or storing results.

Args:
response (optional): The response as returned by the predict method.
payload (optional): The request payload (see below for the possible payload types).
query_params (optional): A dictionary of the query parameters used in the request.
headers (optional): A dictionary of the headers sent in the request.
"""
pass
```

<!-- CORTEX_VERSION_MINOR -->
Expand Down Expand Up @@ -289,27 +311,38 @@ If your application requires additional dependencies, you can install additional
```python
class ONNXPredictor:
def __init__(self, onnx_client, config):
"""Called once before the API becomes available. Performs setup such as downloading/initializing a vocabulary.
"""(Required) Called once before the API becomes available. Performs setup such as downloading/initializing a vocabulary.

Args:
onnx_client: ONNX client which is used to make predictions. This should be saved for use in predict().
config: Dictionary passed from API configuration (if specified).
onnx_client (required): ONNX client which is used to make predictions. This should be saved for use in predict().
config (required): Dictionary passed from API configuration (if specified).
"""
self.client = onnx_client
# Additional initialization may be done here

def predict(self, payload, query_params, headers):
"""Called once per request. Preprocesses the request payload (if necessary), runs inference (e.g. by calling self.client.predict(model_input)), and postprocesses the inference output (if necessary).
"""(Required) Called once per request. Preprocesses the request payload (if necessary), runs inference (e.g. by calling self.client.predict(model_input)), and postprocesses the inference output (if necessary).

Args:
payload: The request payload (see below for the possible payload types) (optional).
query_params: A dictionary of the query parameters used in the request (optional).
headers: A dictionary of the headers sent in the request (optional).
payload (optional): The request payload (see below for the possible payload types).
query_params (optional): A dictionary of the query parameters used in the request.
headers (optional): A dictionary of the headers sent in the request.

Returns:
Prediction or a batch of predictions.
"""
pass

def post_predict(self, response, payload, query_params, headers):
"""(Optional) Called in the background after returning a response. Useful for tasks that the client doesn't need to wait on before receiving a response such as recording metrics or storing results.

Args:
response (optional): The response as returned by the predict method.
payload (optional): The request payload (see below for the possible payload types).
query_params (optional): A dictionary of the query parameters used in the request.
headers (optional): A dictionary of the headers sent in the request.
"""
pass
```

<!-- CORTEX_VERSION_MINOR -->
Expand Down
37 changes: 33 additions & 4 deletions pkg/workloads/cortex/lib/type/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ def initialize_impl(self, project_dir, client=None):
finally:
refresh_logger()

def class_impl(self, project_dir):
def get_target_and_validations(self):
target_class_name = None
validations = None

if self.type == "tensorflow":
target_class_name = "TensorFlowPredictor"
validations = TENSORFLOW_CLASS_VALIDATION
Expand All @@ -108,6 +111,11 @@ def class_impl(self, project_dir):
target_class_name = "PythonPredictor"
validations = PYTHON_CLASS_VALIDATION

return target_class_name, validations

def class_impl(self, project_dir):
target_class_name, validations = self.get_target_and_validations()

try:
impl = self._load_module("cortex_predictor", os.path.join(project_dir, self.path))
except CortexException as e:
Expand Down Expand Up @@ -171,7 +179,14 @@ def _compute_model_basepath(self, model_path, model_name):
"required_args": ["self"],
"optional_args": ["payload", "query_params", "headers"],
},
]
],
"optional": [
{
"name": "post_predict",
"required_args": ["self"],
"optional_args": ["response", "payload", "query_params", "headers"],
}
],
}

TENSORFLOW_CLASS_VALIDATION = {
Expand All @@ -182,7 +197,14 @@ def _compute_model_basepath(self, model_path, model_name):
"required_args": ["self"],
"optional_args": ["payload", "query_params", "headers"],
},
]
],
"optional": [
{
"name": "post_predict",
"required_args": ["self"],
"optional_args": ["response", "payload", "query_params", "headers"],
}
],
}

ONNX_CLASS_VALIDATION = {
Expand All @@ -193,7 +215,14 @@ def _compute_model_basepath(self, model_path, model_name):
"required_args": ["self"],
"optional_args": ["payload", "query_params", "headers"],
},
]
],
"optional": [
{
"name": "post_predict",
"required_args": ["self"],
"optional_args": ["response", "payload", "query_params", "headers"],
}
],
}


Expand Down
4 changes: 4 additions & 0 deletions pkg/workloads/cortex/lib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
from copy import deepcopy


def has_method(object, method: str):
return callable(getattr(object, method, None))


def extract_zip(zip_path, dest_dir=None, delete_zip_file=False):
if dest_dir is None:
dest_dir = os.path.dirname(zip_path)
Expand Down
46 changes: 36 additions & 10 deletions pkg/workloads/cortex/serve/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,12 @@ async def parse_payload(request: Request, call_next):


def predict(request: Request):
tasks = BackgroundTasks()
api = local_cache["api"]
predictor_impl = local_cache["predictor_impl"]
args = build_predict_args(request)
kwargs = build_predict_kwargs(request)

prediction = predictor_impl.predict(**args)
prediction = predictor_impl.predict(**kwargs)

if isinstance(prediction, bytes):
response = Response(content=prediction, media_type="application/octet-stream")
Expand All @@ -206,27 +207,47 @@ def predict(request: Request):
api.monitoring.model_type == "classification"
and predicted_value not in local_cache["class_set"]
):
tasks = BackgroundTasks()
tasks.add_task(api.upload_class, class_name=predicted_value)
local_cache["class_set"].add(predicted_value)
response.background = tasks
except:
cx_logger().warn("unable to record prediction metric", exc_info=True)

if util.has_method(predictor_impl, "post_predict"):
kwargs = build_post_predict_kwargs(prediction, request)
tasks.add_task(predictor_impl.post_predict, **kwargs)

if len(tasks.tasks) > 0:
response.background = tasks

return response


def build_predict_args(request: Request):
args = {}
def build_predict_kwargs(request: Request):
kwargs = {}

if "payload" in local_cache["predict_fn_args"]:
args["payload"] = request.state.payload
kwargs["payload"] = request.state.payload
if "headers" in local_cache["predict_fn_args"]:
args["headers"] = request.headers
kwargs["headers"] = request.headers
if "query_params" in local_cache["predict_fn_args"]:
args["query_params"] = request.query_params
kwargs["query_params"] = request.query_params

return kwargs


return args
def build_post_predict_kwargs(response, request: Request):
kwargs = {}

if "payload" in local_cache["post_predict_fn_args"]:
kwargs["payload"] = request.state.payload
if "headers" in local_cache["post_predict_fn_args"]:
kwargs["headers"] = request.headers
if "query_params" in local_cache["post_predict_fn_args"]:
kwargs["query_params"] = request.query_params
if "response" in local_cache["post_predict_fn_args"]:
kwargs["response"] = response

return kwargs


def get_summary():
Expand Down Expand Up @@ -295,6 +316,11 @@ def start_fn():
local_cache["client"] = client
local_cache["predictor_impl"] = predictor_impl
local_cache["predict_fn_args"] = inspect.getfullargspec(predictor_impl.predict).args
if util.has_method(predictor_impl, "post_predict"):
local_cache["post_predict_fn_args"] = inspect.getfullargspec(
predictor_impl.post_predict
).args

predict_route = "/"
if provider != "local":
predict_route = "/predict"
Expand Down