Skip to content

Commit

Permalink
Add env variable to disable redirects (#10655)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Lok <daniel.lok@databricks.com>
  • Loading branch information
daniellok-db committed Dec 8, 2023
1 parent 6a4b1be commit f01767a
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 5 deletions.
6 changes: 6 additions & 0 deletions mlflow/environment_variables.py
Expand Up @@ -487,3 +487,9 @@ def get(self):
MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE = _EnvironmentVariable(
"MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE", int, 100 * 1024**2
)

#: Specifies whether or not to allow the MLflow server to follow redirects when
#: making HTTP requests. If set to False, the server will throw an exception if it
#: encounters a redirect response.
#: (default: ``True``)
MLFLOW_ALLOW_HTTP_REDIRECTS = _BooleanEnvironmentVariable("MLFLOW_ALLOW_HTTP_REDIRECTS", True)
18 changes: 17 additions & 1 deletion mlflow/utils/request_utils.py
Expand Up @@ -187,7 +187,23 @@ def _get_http_response_with_retries(
session = _get_request_session(
max_retries, backoff_factor, backoff_jitter, retry_codes, raise_on_status
)
return session.request(method, url, **kwargs)

# the environment variable is hardcoded here to avoid importing mlflow.
# however, documentation is available in environment_variables.py
env_value = os.getenv("MLFLOW_ALLOW_HTTP_REDIRECTS", "1").lower()
allow_redirects = env_value in ["true", "1"]

response = session.request(method, url, allow_redirects=allow_redirects, **kwargs)

if not allow_redirects and (response.is_redirect or 300 <= response.status_code < 400):
raise HTTPError(
"HTTP redirects are disabled through the MLFLOW_DISABLE_HTTP_REDIRECTS "
"environment variable, but the server responded with a redirect. Response text: "
f"{response.text}",
response=response,
)

return response


def cloud_storage_http_request(
Expand Down
1 change: 1 addition & 0 deletions tests/projects/test_databricks.py
Expand Up @@ -449,6 +449,7 @@ def confirm_request_params(*args, **kwargs):
headers["Authorization"] = "Basic dXNlcjpwYXNz"
assert args == ("PUT", "host/clusters/list")
assert kwargs == {
"allow_redirects": True,
"headers": headers,
"verify": True,
"json": {"a": "b"},
Expand Down
30 changes: 26 additions & 4 deletions tests/store/artifact/test_databricks_artifact_repo.py
Expand Up @@ -258,6 +258,7 @@ def test_log_artifact_azure_with_headers(
request_mock.assert_called_with(
"put",
f"{MOCK_AZURE_SIGNED_URI}?comp=blocklist",
allow_redirects=True,
data=ANY,
headers=filtered_azure_headers,
timeout=None,
Expand Down Expand Up @@ -346,33 +347,38 @@ def test_log_artifact_adls_gen2_with_headers(
request_mock.assert_any_call(
"put",
f"{MOCK_ADLS_GEN2_SIGNED_URI}?resource=file",
allow_redirects=True,
headers=filtered_azure_headers,
timeout=None,
)
request_mock.assert_any_call(
"patch",
f"{MOCK_ADLS_GEN2_SIGNED_URI}?action=append&position=0",
allow_redirects=True,
data=ANY,
headers=filtered_azure_headers,
timeout=None,
)
request_mock.assert_any_call(
"patch",
f"{MOCK_ADLS_GEN2_SIGNED_URI}?action=append&position=5",
allow_redirects=True,
data=ANY,
headers=filtered_azure_headers,
timeout=None,
)
request_mock.assert_any_call(
"patch",
f"{MOCK_ADLS_GEN2_SIGNED_URI}?action=append&position=10",
allow_redirects=True,
data=ANY,
headers=filtered_azure_headers,
timeout=None,
)
request_mock.assert_called_with(
"patch",
f"{MOCK_ADLS_GEN2_SIGNED_URI}?action=flush&position=14",
allow_redirects=True,
headers=filtered_azure_headers,
timeout=None,
)
Expand Down Expand Up @@ -404,12 +410,14 @@ def test_log_artifact_adls_gen2_flush_error(databricks_artifact_repo, test_file)
mock.call(
"put",
f"{MOCK_ADLS_GEN2_SIGNED_URI}?resource=file",
allow_redirects=True,
headers={},
timeout=None,
),
mock.call(
"patch",
f"{MOCK_ADLS_GEN2_SIGNED_URI}?action=append&position=0&flush=true",
allow_redirects=True,
data=ANY,
headers={},
timeout=None,
Expand All @@ -436,7 +444,7 @@ def test_log_artifact_aws(databricks_artifact_repo, test_file, artifact_path, ex
GetCredentialsForWrite, MOCK_RUN_ID, [expected_location]
)
request_mock.assert_called_with(
"put", MOCK_AWS_SIGNED_URI, data=ANY, headers={}, timeout=None
"put", MOCK_AWS_SIGNED_URI, allow_redirects=True, data=ANY, headers={}, timeout=None
)


Expand Down Expand Up @@ -464,7 +472,12 @@ def test_log_artifact_aws_with_headers(
GetCredentialsForWrite, MOCK_RUN_ID, [expected_location]
)
request_mock.assert_called_with(
"put", MOCK_AWS_SIGNED_URI, data=ANY, headers=expected_headers, timeout=None
"put",
MOCK_AWS_SIGNED_URI,
allow_redirects=True,
data=ANY,
headers=expected_headers,
timeout=None,
)


Expand Down Expand Up @@ -502,7 +515,7 @@ def test_log_artifact_gcp(databricks_artifact_repo, test_file, artifact_path, ex
GetCredentialsForWrite, MOCK_RUN_ID, [expected_location]
)
request_mock.assert_called_with(
"put", MOCK_GCP_SIGNED_URL, data=ANY, headers={}, timeout=None
"put", MOCK_GCP_SIGNED_URL, allow_redirects=True, data=ANY, headers={}, timeout=None
)


Expand Down Expand Up @@ -530,7 +543,12 @@ def test_log_artifact_gcp_with_headers(
GetCredentialsForWrite, MOCK_RUN_ID, [expected_location]
)
request_mock.assert_called_with(
"put", MOCK_GCP_SIGNED_URL, data=ANY, headers=expected_headers, timeout=None
"put",
MOCK_GCP_SIGNED_URL,
allow_redirects=True,
data=ANY,
headers=expected_headers,
timeout=None,
)


Expand Down Expand Up @@ -1298,6 +1316,7 @@ def test_multipart_upload(databricks_artifact_repo, large_file, mock_chunk_size)
mock.call(
"put",
f"{MOCK_AWS_SIGNED_URI}partNumber={i + 1}",
allow_redirects=True,
data=f.read(mock_chunk_size),
headers={"header": f"part-{i + 1}"},
timeout=None,
Expand Down Expand Up @@ -1387,6 +1406,7 @@ def test_multipart_upload_retry_part_upload(databricks_artifact_repo, large_file
mock.call(
"put",
f"{MOCK_AWS_SIGNED_URI}partNumber={i + 1}",
allow_redirects=True,
data=f.read(mock_chunk_size),
headers={"header": f"part-{i + 1}"},
timeout=None,
Expand Down Expand Up @@ -1445,6 +1465,7 @@ def test_multipart_upload_abort(databricks_artifact_repo, large_file, mock_chunk
mock.call(
"put",
f"{MOCK_AWS_SIGNED_URI}partNumber={i + 1}",
allow_redirects=True,
data=f.read(mock_chunk_size),
headers={"header": f"part-{i + 1}"},
timeout=None,
Expand All @@ -1463,6 +1484,7 @@ def test_multipart_upload_abort(databricks_artifact_repo, large_file, mock_chunk
assert abort_call == mock.call(
"delete",
f"{MOCK_AWS_SIGNED_URI}uploadId=abort",
allow_redirects=True,
headers={"header": "abort"},
timeout=None,
)
1 change: 1 addition & 0 deletions tests/store/tracking/test_rest_store.py
Expand Up @@ -70,6 +70,7 @@ def mock_request(*args, **kwargs):
assert args == ("POST", "https://hello/api/2.0/mlflow/experiments/search")
kwargs = {k: v for k, v in kwargs.items() if v is not None}
assert kwargs == {
"allow_redirects": True,
"json": {"view_type": "ACTIVE_ONLY"},
"headers": DefaultRequestHeaderProvider().request_headers(),
"verify": True,
Expand Down
43 changes: 43 additions & 0 deletions tests/utils/test_request_utils.py
Expand Up @@ -58,3 +58,46 @@ def test_download_chunk_incomplete_read(tmp_path):
download_path=download_path,
http_uri="https://example.com",
)


@pytest.mark.parametrize("env_value", ["0", "false", "False", "FALSE"])
def test_redirects_disabled_if_env_var_set(monkeypatch, env_value):
from requests.exceptions import HTTPError

monkeypatch.setenv("MLFLOW_ALLOW_HTTP_REDIRECTS", env_value)

with mock.patch("requests.Session.request") as mock_request:
mock_request.return_value.status_code = 302
mock_request.return_value.text = "mock response"

with pytest.raises(HTTPError, match="HTTP redirects are disabled"):
request_utils.cloud_storage_http_request("GET", "http://localhost:5000")


@pytest.mark.parametrize("env_value", ["1", "true", "True", "TRUE"])
def test_redirects_enabled_if_env_var_set(monkeypatch, env_value):
monkeypatch.setenv("MLFLOW_ALLOW_HTTP_REDIRECTS", env_value)

with mock.patch("requests.Session.request") as mock_request:
mock_request.return_value.status_code = 302
mock_request.return_value.text = "mock response"

response = request_utils.cloud_storage_http_request(
"GET",
"http://localhost:5000",
)

assert response.text == "mock response"


def test_redirects_enabled_by_default():
with mock.patch("requests.Session.request") as mock_request:
mock_request.return_value.status_code = 302
mock_request.return_value.text = "mock response"

response = request_utils.cloud_storage_http_request(
"GET",
"http://localhost:5000",
)

assert response.text == "mock response"
15 changes: 15 additions & 0 deletions tests/utils/test_rest_utils.py
Expand Up @@ -116,6 +116,7 @@ def test_http_request_hostonly(request):
request.assert_called_with(
"GET",
"http://my-host/my/endpoint",
allow_redirects=True,
verify=True,
headers=DefaultRequestHeaderProvider().request_headers(),
timeout=120,
Expand All @@ -133,6 +134,7 @@ def test_http_request_cleans_hostname(request):
request.assert_called_with(
"GET",
"http://my-host/my/endpoint",
allow_redirects=True,
verify=True,
headers=DefaultRequestHeaderProvider().request_headers(),
timeout=120,
Expand All @@ -151,6 +153,7 @@ def test_http_request_with_basic_auth(request):
request.assert_called_with(
"GET",
"http://my-host/my/endpoint",
allow_redirects=True,
verify=True,
headers=headers,
timeout=120,
Expand Down Expand Up @@ -183,6 +186,7 @@ def __eq__(self, other):
request.assert_called_once_with(
"GET",
"http://my-host/my/endpoint",
allow_redirects=True,
verify=mock.ANY,
headers=mock.ANY,
timeout=mock.ANY,
Expand All @@ -207,6 +211,7 @@ def test_http_request_with_auth(fetch_auth, request):
request.assert_called_with(
"GET",
"http://my-host/my/endpoint",
allow_redirects=True,
verify=mock.ANY,
headers=mock.ANY,
timeout=mock.ANY,
Expand All @@ -226,6 +231,7 @@ def test_http_request_with_token(request):
request.assert_called_with(
"GET",
"http://my-host/my/endpoint",
allow_redirects=True,
verify=True,
headers=headers,
timeout=120,
Expand All @@ -242,6 +248,7 @@ def test_http_request_with_insecure(request):
request.assert_called_with(
"GET",
"http://my-host/my/endpoint",
allow_redirects=True,
verify=False,
headers=DefaultRequestHeaderProvider().request_headers(),
timeout=120,
Expand All @@ -258,6 +265,7 @@ def test_http_request_client_cert_path(request):
request.assert_called_with(
"GET",
"http://my-host/my/endpoint",
allow_redirects=True,
verify=True,
cert="/some/path",
headers=DefaultRequestHeaderProvider().request_headers(),
Expand All @@ -275,6 +283,7 @@ def test_http_request_server_cert_path(request):
request.assert_called_with(
"GET",
"http://my-host/my/endpoint",
allow_redirects=True,
verify="/some/path",
headers=DefaultRequestHeaderProvider().request_headers(),
timeout=120,
Expand All @@ -295,6 +304,7 @@ def test_http_request_with_content_type_header(request):
request.assert_called_with(
"GET",
"http://my-host/my/endpoint",
allow_redirects=True,
verify=True,
headers=headers,
timeout=120,
Expand All @@ -320,6 +330,7 @@ def test_http_request_request_headers(request):
request.assert_called_with(
"GET",
"http://my-host/my/endpoint",
allow_redirects=True,
verify="/some/path",
headers={**DefaultRequestHeaderProvider().request_headers(), "test": "header"},
timeout=120,
Expand Down Expand Up @@ -356,6 +367,7 @@ def test_http_request_request_headers_user_agent(request):
request.assert_called_with(
"GET",
"http://my-host/my/endpoint",
allow_redirects=True,
verify="/some/path",
headers=expected_headers,
timeout=120,
Expand Down Expand Up @@ -393,6 +405,7 @@ def test_http_request_request_headers_user_agent_and_extra_header(request):
request.assert_called_with(
"GET",
"http://my-host/my/endpoint",
allow_redirects=True,
verify="/some/path",
headers=expected_headers,
timeout=120,
Expand Down Expand Up @@ -440,6 +453,7 @@ def test_http_request_wrapper(request):
request.assert_called_with(
"GET",
"http://my-host/my/endpoint",
allow_redirects=True,
verify=False,
headers=DefaultRequestHeaderProvider().request_headers(),
timeout=120,
Expand All @@ -450,6 +464,7 @@ def test_http_request_wrapper(request):
request.assert_called_with(
"GET",
"http://my-host/my/endpoint",
allow_redirects=True,
verify=False,
headers=DefaultRequestHeaderProvider().request_headers(),
timeout=120,
Expand Down

0 comments on commit f01767a

Please sign in to comment.