Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add backoff jitter #10486

Merged
merged 10 commits into from Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlflow/environment_variables.py
Expand Up @@ -101,6 +101,12 @@ def get(self):
"MLFLOW_HTTP_REQUEST_BACKOFF_FACTOR", int, 2
)

#: Specifies the backoff jitter between MLflow HTTP request failures
#: (default: ``1.0``)
MLFLOW_HTTP_REQUEST_BACKOFF_JITTER = _EnvironmentVariable(
"MLFLOW_HTTP_REQUEST_BACKOFF_JITTER", float, 1.0
)

#: Specifies the timeout in seconds for MLflow HTTP requests
#: (default: ``120``)
MLFLOW_HTTP_REQUEST_TIMEOUT = _EnvironmentVariable("MLFLOW_HTTP_REQUEST_TIMEOUT", int, 120)
Expand Down
52 changes: 47 additions & 5 deletions mlflow/utils/request_utils.py
Expand Up @@ -2,6 +2,7 @@
# This file is imported by download_cloud_file_chunk.py.
# Importing mlflow is time-consuming and we want to avoid that in artifact download subprocesses.
import os
import random
from functools import lru_cache

import requests
Expand All @@ -26,6 +27,25 @@
)


class JitteredRetry(Retry):
"""
urllib3 < 2 doesn't support `backoff_jitter`. This class is a workaround for that.
"""

def __init__(self, *args, backoff_jitter=0.0, **kwargs):
super().__init__(*args, **kwargs)
self.backoff_jitter = backoff_jitter

def get_backoff_time(self):
"""
Source: https://github.com/urllib3/urllib3/commit/214b184923388328919b0a4b0c15bff603aa51be
"""
backoff_value = super().get_backoff_time()
if self.backoff_jitter != 0.0:
backoff_value += random.random() * self.backoff_jitter
return float(max(0, min(Retry.DEFAULT_BACKOFF_MAX, backoff_value)))


def augmented_raise_for_status(response):
"""Wrap the standard `requests.response.raise_for_status()` method and return reason"""
try:
Expand Down Expand Up @@ -71,6 +91,7 @@ def download_chunk(*, range_start, range_end, headers, download_path, http_uri):
def _cached_get_request_session(
max_retries,
backoff_factor,
backoff_jitter,
retry_codes,
raise_on_status,
# To create a new Session object for each process, we use the process id as the cache key.
Expand All @@ -92,28 +113,35 @@ def _cached_get_request_session(
"status": max_retries,
"status_forcelist": retry_codes,
"backoff_factor": backoff_factor,
"backoff_jitter": backoff_jitter,
"raise_on_status": raise_on_status,
}
if Version(urllib3.__version__) >= Version("1.26.0"):
urllib3_version = Version(urllib3.__version__)
if urllib3_version >= Version("1.26.0"):
retry_kwargs["allowed_methods"] = None
else:
retry_kwargs["method_whitelist"] = None
retry = Retry(**retry_kwargs)

if urllib3_version < Version("2.0"):
retry = JitteredRetry(**retry_kwargs)
else:
retry = Retry(**retry_kwargs)
adapter = HTTPAdapter(max_retries=retry)
session = requests.Session()
session.mount("https://", adapter)
session.mount("http://", adapter)
return session


def _get_request_session(max_retries, backoff_factor, retry_codes, raise_on_status):
def _get_request_session(max_retries, backoff_factor, backoff_jitter, retry_codes, raise_on_status):
"""
Returns a `Requests.Session` object for making an HTTP request.

:param max_retries: Maximum total number of retries.
:param backoff_factor: a time factor for exponential backoff. e.g. value 5 means the HTTP
request will be retried with interval 5, 10, 20... seconds. A value of 0 turns off the
exponential backoff.
:param backoff_jitter: A random jitter to add to the backoff interval.
:param retry_codes: a list of HTTP response error codes that qualifies for retry.
:param raise_on_status: whether to raise an exception, or return a response, if status falls
in retry_codes range and retries have been exhausted.
Expand All @@ -122,14 +150,22 @@ def _get_request_session(max_retries, backoff_factor, retry_codes, raise_on_stat
return _cached_get_request_session(
max_retries,
backoff_factor,
backoff_jitter,
retry_codes,
raise_on_status,
_pid=os.getpid(),
)


def _get_http_response_with_retries(
method, url, max_retries, backoff_factor, retry_codes, raise_on_status=True, **kwargs
method,
url,
max_retries,
backoff_factor,
backoff_jitter,
retry_codes,
raise_on_status=True,
**kwargs,
):
"""
Performs an HTTP request using Python's `requests` module with an automatic retry policy.
Expand All @@ -140,14 +176,17 @@ def _get_http_response_with_retries(
:param backoff_factor: a time factor for exponential backoff. e.g. value 5 means the HTTP
request will be retried with interval 5, 10, 20... seconds. A value of 0 turns off the
exponential backoff.
:param backoff_jitter: A random jitter to add to the backoff interval.
:param retry_codes: a list of HTTP response error codes that qualifies for retry.
:param raise_on_status: whether to raise an exception, or return a response, if status falls
in retry_codes range and retries have been exhausted.
:param kwargs: Additional keyword arguments to pass to `requests.Session.request()`

:return: requests.Response object.
"""
session = _get_request_session(max_retries, backoff_factor, retry_codes, raise_on_status)
session = _get_request_session(
max_retries, backoff_factor, backoff_jitter, retry_codes, raise_on_status
)
return session.request(method, url, **kwargs)


Expand All @@ -156,6 +195,7 @@ def cloud_storage_http_request(
url,
max_retries=5,
backoff_factor=2,
backoff_jitter=1.0,
retry_codes=_TRANSIENT_FAILURE_RESPONSE_CODES,
timeout=None,
**kwargs,
Expand All @@ -169,6 +209,7 @@ def cloud_storage_http_request(
:param backoff_factor: a time factor for exponential backoff. e.g. value 5 means the HTTP
request will be retried with interval 5, 10, 20... seconds. A value of 0 turns off the
exponential backoff.
:param backoff_jitter: A random jitter to add to the backoff interval.
:param retry_codes: a list of HTTP response error codes that qualifies for retry.
:param timeout: wait for timeout seconds for response from remote server for connect and
read request. Default to None owing to long duration operation in read / write.
Expand All @@ -183,6 +224,7 @@ def cloud_storage_http_request(
url,
max_retries,
backoff_factor,
backoff_jitter,
retry_codes,
timeout=timeout,
**kwargs,
Expand Down
15 changes: 12 additions & 3 deletions mlflow/utils/rest_utils.py
Expand Up @@ -5,6 +5,7 @@

from mlflow.environment_variables import (
MLFLOW_HTTP_REQUEST_BACKOFF_FACTOR,
MLFLOW_HTTP_REQUEST_BACKOFF_JITTER,
MLFLOW_HTTP_REQUEST_MAX_RETRIES,
MLFLOW_HTTP_REQUEST_TIMEOUT,
)
Expand All @@ -30,6 +31,7 @@ def http_request(
method,
max_retries=None,
backoff_factor=None,
backoff_jitter=None,
extra_headers=None,
retry_codes=_TRANSIENT_FAILURE_RESPONSE_CODES,
timeout=None,
Expand All @@ -50,6 +52,7 @@ def http_request(
:param backoff_factor: a time factor for exponential backoff. e.g. value 5 means the HTTP
request will be retried with interval 5, 10, 20... seconds. A value of 0 turns off the
exponential backoff.
:param backoff_jitter: A random jitter to add to the backoff interval.
:param extra_headers: a dict of HTTP header name-value pairs to be included in the request.
:param retry_codes: a list of HTTP response error codes that qualifies for retry.
:param timeout: wait for timeout seconds for response from remote server for connect and
Expand All @@ -60,9 +63,14 @@ def http_request(

:return: requests.Response object.
"""
max_retries = max_retries or MLFLOW_HTTP_REQUEST_MAX_RETRIES.get()
backoff_factor = backoff_factor or MLFLOW_HTTP_REQUEST_BACKOFF_FACTOR.get()
timeout = timeout or MLFLOW_HTTP_REQUEST_TIMEOUT.get()
max_retries = MLFLOW_HTTP_REQUEST_MAX_RETRIES.get() if max_retries is None else max_retries
Comment on lines -63 to +66
Copy link
Member

@harupy harupy Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the original code, max_retries = 0 defaults to MLFLOW_HTTP_REQUEST_MAX_RETRIES.get() but it should not.

backoff_factor = (
MLFLOW_HTTP_REQUEST_BACKOFF_FACTOR.get() if backoff_factor is None else backoff_factor
)
backoff_jitter = (
MLFLOW_HTTP_REQUEST_BACKOFF_JITTER.get() if backoff_jitter is None else backoff_jitter
)
timeout = MLFLOW_HTTP_REQUEST_TIMEOUT.get() if timeout is None else timeout
hostname = host_creds.host
auth_str = None
if host_creds.username and host_creds.password:
Expand Down Expand Up @@ -101,6 +109,7 @@ def http_request(
url,
max_retries,
backoff_factor,
backoff_jitter,
retry_codes,
raise_on_status,
headers=headers,
Expand Down
30 changes: 24 additions & 6 deletions tests/gateway/test_integration.py
Expand Up @@ -660,8 +660,14 @@ def test_invalid_response_structure_raises(gateway):
async def mock_chat(self, payload):
return expected_output

def _mock_request_session(max_retries, backoff_factor, retry_codes, raise_on_status):
return _cached_get_request_session(1, 1, retry_codes, True, os.getpid())
def _mock_request_session(
max_retries,
backoff_factor,
backoff_jitter,
retry_codes,
raise_on_status,
):
return _cached_get_request_session(1, 1, 0.5, retry_codes, True, os.getpid())

with patch(
"mlflow.utils.request_utils._get_request_session", _mock_request_session
Expand Down Expand Up @@ -690,8 +696,14 @@ def test_invalid_response_structure_no_raises(gateway):
async def mock_chat(self, payload):
return expected_output

def _mock_request_session(max_retries, backoff_factor, retry_codes, raise_on_status):
return _cached_get_request_session(0, 1, retry_codes, False, os.getpid())
def _mock_request_session(
max_retries,
backoff_factor,
backoff_jitter,
retry_codes,
raise_on_status,
):
return _cached_get_request_session(0, 1, 0.5, retry_codes, False, os.getpid())

with patch(
"mlflow.utils.request_utils._get_request_session", _mock_request_session
Expand Down Expand Up @@ -728,8 +740,14 @@ def test_invalid_query_request_raises(gateway):
async def mock_chat(self, payload):
return expected_output

def _mock_request_session(max_retries, backoff_factor, retry_codes, raise_on_status):
return _cached_get_request_session(2, 1, retry_codes, True, os.getpid())
def _mock_request_session(
max_retries,
backoff_factor,
backoff_jitter,
retry_codes,
raise_on_status,
):
return _cached_get_request_session(2, 1, 0.5, retry_codes, True, os.getpid())

with patch(
"mlflow.utils.request_utils._get_request_session", _mock_request_session
Expand Down
3 changes: 3 additions & 0 deletions tests/utils/test_rest_utils.py
Expand Up @@ -506,6 +506,7 @@ def test_http_request_customize_config(monkeypatch):
mock.ANY,
5,
2,
1.0,
mock.ANY,
True,
headers=mock.ANY,
Expand All @@ -515,13 +516,15 @@ def test_http_request_customize_config(monkeypatch):
mock_get_http_response_with_retries.reset_mock()
monkeypatch.setenv("MLFLOW_HTTP_REQUEST_MAX_RETRIES", "8")
monkeypatch.setenv("MLFLOW_HTTP_REQUEST_BACKOFF_FACTOR", "3")
monkeypatch.setenv("MLFLOW_HTTP_REQUEST_BACKOFF_JITTER", "1.0")
monkeypatch.setenv("MLFLOW_HTTP_REQUEST_TIMEOUT", "300")
http_request(host_only, "/my/endpoint", "GET")
mock_get_http_response_with_retries.assert_called_with(
mock.ANY,
mock.ANY,
8,
3,
1.0,
mock.ANY,
True,
headers=mock.ANY,
Expand Down