Skip to content

Commit

Permalink
Env variable support for batch inference (aws#106)
Browse files Browse the repository at this point in the history
* Support env variables for configure batchSize, maxBatchDelay etc. for the single model in torchserve

* Add modified version

* fix flake8

* Edit version

* Correct type

* Add condition to including env variables in model config

* Add version

* Update version and remove env support

* Try converting config to string

* Reverse str and update version

* Fix true

* Experiment with default config

* Complete

* Include load models

* Set max workers to 1

* Set default response timeout to 60, and improve docstring

* Fix flake8

* Add a warning log for single model

* Fix extra spacing in log

* Use string instead of a dict

* Print config

* Fix string

* Fix f-string

* Remove newline

* Adjust f string

* Fix flake8

* Trigger build

* Trigger build

* Trigger build

* Trigger build

Co-authored-by: Nikhil Kulkarni <nikhilsk@amazon.com>
  • Loading branch information
nskool and Nikhil Kulkarni committed Oct 2, 2021
1 parent 6610a41 commit 27b667f
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 1 deletion.
22 changes: 21 additions & 1 deletion src/sagemaker_pytorch_serving_container/torchserve.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from retrying import retry

import sagemaker_pytorch_serving_container
from sagemaker_pytorch_serving_container import ts_environment
from sagemaker_inference import default_handler_service, environment, utils
from sagemaker_inference.environment import code_dir

Expand Down Expand Up @@ -149,14 +150,33 @@ def _create_torchserve_config_file():

def _generate_ts_config_properties():
env = environment.Environment()

user_defined_configuration = {
"default_response_timeout": env.model_server_timeout,
"default_workers_per_model": env.model_server_workers,
"inference_address": "http://0.0.0.0:{}".format(env.inference_http_port),
"management_address": "http://0.0.0.0:{}".format(env.management_http_port),
}

ts_env = ts_environment.TorchServeEnvironment()

if ts_env.is_env_set() and not ENABLE_MULTI_MODEL:
models_string = f'''{{\\
"{DEFAULT_TS_MODEL_NAME}": {{\\
"1.0": {{\\
"defaultVersion": true,\\
"marName": "{DEFAULT_TS_MODEL_NAME}.mar",\\
"minWorkers": {ts_env._min_workers},\\
"maxWorkers": {ts_env._max_workers},\\
"batchSize": {ts_env._batch_size},\\
"maxBatchDelay": {ts_env._max_batch_delay},\\
"responseTimeout": {ts_env._response_timeout}\\
}}\\
}}\\
}}'''
user_defined_configuration["models"] = models_string
logger.warn("Sagemaker TS environment variables have been set and will be used "
"for single model endpoint.")

custom_configuration = str()

for key in user_defined_configuration:
Expand Down
94 changes: 94 additions & 0 deletions src/sagemaker_pytorch_serving_container/ts_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License'). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the 'license' file accompanying this file. This file is
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains functionality that provides access to system
characteristics, environment variables and configuration settings.
"""

from __future__ import absolute_import

from sagemaker_pytorch_serving_container import ts_parameters

import os
import logging

logger = logging.getLogger()

DEFAULT_TS_BATCH_SIZE = 1
DEFAULT_TS_MAX_BATCH_DELAY = 100
DEFAULT_TS_MIN_WORKERS = 1
DEFAULT_TS_MAX_WORKERS = 1
DEFAULT_TS_RESPONSE_TIMEOUT = 60


class TorchServeEnvironment():
"""Provides access to aspects of the torchserve environment relevant to serving containers,
including system characteristics, environment variables and configuration settings.
The Environment is a read-only snapshot of the container environment.
It is a dictionary-like object, allowing any builtin function that works with dictionary.
Attributes:
batch_size (int): This is the maximum batch size in ms that a model is expected to handle
max_batch_delay (int): This is the maximum batch delay time TorchServe waits to receive
batch_size number of requests. If TorchServe doesn’t receive batch_size number of requests
before this timer time’s out, it sends what ever requests that were received to the model handler
min_workers (int): Minimum number of workers that torchserve is allowed to scale down to
max_workers (int): Minimum number of workers that torchserve is allowed to scale up to
response_timeout (int): Time delay after which inference will timeout in absence of a response
"""
def __init__(self):
self._batch_size = int(os.environ.get(ts_parameters.MODEL_SERVER_BATCH_SIZE, DEFAULT_TS_BATCH_SIZE))
self._max_batch_delay = int(os.environ.get(ts_parameters.MODEL_SERVER_MAX_BATCH_DELAY,
DEFAULT_TS_MAX_BATCH_DELAY))
self._min_workers = int(os.environ.get(ts_parameters.MODEL_SERVER_MIN_WORKERS, DEFAULT_TS_MIN_WORKERS))
self._max_workers = int(os.environ.get(ts_parameters.MODEL_SERVER_MAX_WORKERS, DEFAULT_TS_MAX_WORKERS))
self._response_timeout = int(os.environ.get(ts_parameters.MODEL_SERVER_RESPONSE_TIMEOUT,
DEFAULT_TS_RESPONSE_TIMEOUT))

def is_env_set(self): # type: () -> bool
"""bool: whether or not the environment variables have been set"""
ts_env_list = [ts_parameters.MODEL_SERVER_BATCH_SIZE, ts_parameters.MODEL_SERVER_MAX_BATCH_DELAY,
ts_parameters.MODEL_SERVER_MIN_WORKERS, ts_parameters.MODEL_SERVER_MAX_WORKERS,
ts_parameters.MODEL_SERVER_RESPONSE_TIMEOUT]
if any(env in ts_env_list for env in os.environ):
return True

@property
def batch_size(self): # type: () -> int
"""int: number of requests to batch before running inference on the server"""
return self._batch_size

@property
def max_batch_delay(self): # type: () -> int
"""int: time delay in milliseconds, to wait for incoming requests to be batched,
before running inference on the server
"""
return self._max_batch_delay

@property
def min_workers(self): # type:() -> int
"""int: minimum number of worker for model
"""
return self._min_workers

@property
def max_workers(self): # type() -> int
"""int: maximum number of workers for model
"""
return self._max_workers

@property
def response_timeout(self): # type() -> int
"""int: time delay after which inference will timeout in absense of a response
"""
return self._response_timeout
21 changes: 21 additions & 0 deletions src/sagemaker_pytorch_serving_container/ts_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains string constants that define inference toolkit
parameters."""
from __future__ import absolute_import

MODEL_SERVER_BATCH_SIZE = "SAGEMAKER_TS_BATCH_SIZE" # type: str
MODEL_SERVER_MAX_BATCH_DELAY = "SAGEMAKER_TS_MAX_BATCH_DELAY" # type: str
MODEL_SERVER_MIN_WORKERS = "SAGEMAKER_TS_MIN_WORKERS" # type: str
MODEL_SERVER_MAX_WORKERS = "SAGEMAKER_TS_MAX_WORKERS" # type: str
MODEL_SERVER_RESPONSE_TIMEOUT = "SAGEMAKER_TS_RESPONSE_TIMEOUT" # type: str

0 comments on commit 27b667f

Please sign in to comment.