Skip to content

Commit

Permalink
Create a new request session in each process (#8331)
Browse files Browse the repository at this point in the history
* Create a new request session in each process

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Add test

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Fix

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Fix docstring

Signed-off-by: harupy <hkawamura0130@gmail.com>

* Fix docstring again

Signed-off-by: harupy <hkawamura0130@gmail.com>

---------

Signed-off-by: harupy <hkawamura0130@gmail.com>
  • Loading branch information
harupy authored and BenWilson2 committed Apr 27, 2023
1 parent b7d8406 commit 2470fd1
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 23 deletions.
2 changes: 1 addition & 1 deletion mlflow/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def download_chunk(request_index, chunk_size, headers, download_path, http_uri):
range_end = range_start + chunk_size - 1
combined_headers = {**headers, "Range": f"bytes={range_start}-{range_end}"}
with cloud_storage_http_request(
"get", http_uri, stream=False, headers=combined_headers, cached_session=False
"get", http_uri, stream=False, headers=combined_headers
) as response:
# File will have been created upstream. Use r+b to ensure chunks
# don't overwrite the entire file.
Expand Down
54 changes: 32 additions & 22 deletions mlflow/utils/rest_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import json
import os

import requests
from requests.adapters import HTTPAdapter
Expand Down Expand Up @@ -40,20 +41,17 @@


@lru_cache(maxsize=64)
def _get_request_session(max_retries, backoff_factor, retry_codes):
return _get_request_session_uncached(max_retries, backoff_factor, retry_codes)


def _get_request_session_uncached(max_retries, backoff_factor, retry_codes):
def _cached_get_request_session(
max_retries,
backoff_factor,
retry_codes,
# To create a new Session object for each process, we use the process id as the cache key.
# This is to avoid sharing the same Session object across processes, which can lead to issues
# such as https://stackoverflow.com/q/3724900.
_pid,
):
"""
Returns a Requests.Session object for making HTTP request.
:param max_retries: Maximum total number of retries.
:param backoff_factor: a time factor for exponential backoff. e.g. value 5 means the HTTP
request will be retried with interval 5, 10, 20... seconds. A value of 0 turns off the
exponential backoff.
:param retry_codes: a list of HTTP response error codes that qualifies for retry.
:return: requests.Session object.
This function should not be called directly. Instead, use `_get_request_session` below.
"""
assert 0 <= max_retries < 10
assert 0 <= backoff_factor < 120
Expand All @@ -80,8 +78,27 @@ def _get_request_session_uncached(max_retries, backoff_factor, retry_codes):
return session


def _get_request_session(max_retries, backoff_factor, retry_codes):
"""
Returns a `Requests.Session` object for making an HTTP request.
:param max_retries: Maximum total number of retries.
:param backoff_factor: a time factor for exponential backoff. e.g. value 5 means the HTTP
request will be retried with interval 5, 10, 20... seconds. A value of 0 turns off the
exponential backoff.
:param retry_codes: a list of HTTP response error codes that qualifies for retry.
:return: requests.Session object.
"""
return _cached_get_request_session(
max_retries,
backoff_factor,
retry_codes,
_pid=os.getpid(),
)


def _get_http_response_with_retries(
method, url, max_retries, backoff_factor, retry_codes, cached_session=True, **kwargs
method, url, max_retries, backoff_factor, retry_codes, **kwargs
):
"""
Performs an HTTP request using Python's `requests` module with an automatic retry policy.
Expand All @@ -93,15 +110,11 @@ def _get_http_response_with_retries(
request will be retried with interval 5, 10, 20... seconds. A value of 0 turns off the
exponential backoff.
:param retry_codes: a list of HTTP response error codes that qualifies for retry.
:param cached_session: Whether to cache session object. False used for multiprocessing contexts.
:param kwargs: Additional keyword arguments to pass to `requests.Session.request()`
:return: requests.Response object.
"""
if cached_session:
session = _get_request_session(max_retries, backoff_factor, retry_codes)
else:
session = _get_request_session_uncached(max_retries, backoff_factor, retry_codes)
session = _get_request_session(max_retries, backoff_factor, retry_codes)
return session.request(method, url, **kwargs)


Expand Down Expand Up @@ -311,7 +324,6 @@ def cloud_storage_http_request(
backoff_factor=2,
retry_codes=_TRANSIENT_FAILURE_RESPONSE_CODES,
timeout=None,
cached_session=True,
**kwargs,
):
"""
Expand All @@ -326,7 +338,6 @@ def cloud_storage_http_request(
:param retry_codes: a list of HTTP response error codes that qualifies for retry.
:param timeout: wait for timeout seconds for response from remote server for connect and
read request. Default to None owing to long duration operation in read / write.
:param cached_session: Whether to cache session object. False used for multiprocessing contexts.
:param kwargs: Additional keyword arguments to pass to `requests.Session.request()`
:return requests.Response object.
Expand All @@ -340,7 +351,6 @@ def cloud_storage_http_request(
backoff_factor,
retry_codes,
timeout=timeout,
cached_session=cached_session,
**kwargs,
)

Expand Down
17 changes: 17 additions & 0 deletions tests/utils/test_rest_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy
import pytest
import requests
from concurrent.futures import ProcessPoolExecutor

from mlflow.environment_variables import MLFLOW_HTTP_REQUEST_TIMEOUT
from mlflow.exceptions import MlflowException, RestException
Expand All @@ -16,6 +17,7 @@
call_endpoints,
augmented_raise_for_status,
_can_parse_as_json_object,
_get_request_session,
)
from mlflow.tracking.request_header.default_request_header_provider import (
DefaultRequestHeaderProvider,
Expand Down Expand Up @@ -523,3 +525,18 @@ def test_augmented_raise_for_status():
assert e.value.response == response
assert e.value.request == response.request
assert response.text in str(e.value)


def call_get_request_session():
return _get_request_session(max_retries=0, backoff_factor=0, retry_codes=(403,))


def test_get_request_session():
sess = call_get_request_session()
another_sess = call_get_request_session()
assert sess is another_sess

with ProcessPoolExecutor(max_workers=2) as e:
futures = [e.submit(call_get_request_session) for _ in range(2)]
sess, another_sess = [f.result() for f in futures]
assert sess is not another_sess

0 comments on commit 2470fd1

Please sign in to comment.