Skip to content

Commit

Permalink
Add env var to disable redirects again (#10673)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Lok <daniel.lok@databricks.com>
  • Loading branch information
daniellok-db committed Dec 13, 2023
1 parent 65cfc3c commit 8174250
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 5 deletions.
6 changes: 6 additions & 0 deletions mlflow/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,9 @@ def get(self):
MLFLOW_MULTIPART_DOWNLOAD_CHUNK_SIZE = _EnvironmentVariable(
"MLFLOW_MULTIPART_DOWNLOAD_CHUNK_SIZE", int, 100 * 1024**2
)

#: Specifies whether or not to allow the MLflow server to follow redirects when
#: making HTTP requests. If set to False, the server will throw an exception if it
#: encounters a redirect response.
#: (default: ``True``)
MLFLOW_ALLOW_HTTP_REDIRECTS = _BooleanEnvironmentVariable("MLFLOW_ALLOW_HTTP_REDIRECTS", True)
9 changes: 8 additions & 1 deletion mlflow/utils/request_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def _get_http_response_with_retries(
backoff_jitter,
retry_codes,
raise_on_status=True,
allow_redirects=None,
**kwargs,
):
"""
Expand All @@ -187,7 +188,13 @@ def _get_http_response_with_retries(
session = _get_request_session(
max_retries, backoff_factor, backoff_jitter, retry_codes, raise_on_status
)
return session.request(method, url, **kwargs)

# the environment variable is hardcoded here to avoid importing mlflow.
# however, documentation is available in environment_variables.py
env_value = os.getenv("MLFLOW_ALLOW_HTTP_REDIRECTS", "true").lower() in ["true", "1"]
allow_redirects = env_value if allow_redirects is None else allow_redirects

return session.request(method, url, allow_redirects=allow_redirects, **kwargs)


def cloud_storage_http_request(
Expand Down
1 change: 1 addition & 0 deletions tests/projects/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ def confirm_request_params(*args, **kwargs):
headers["Authorization"] = "Basic dXNlcjpwYXNz"
assert args == ("PUT", "host/clusters/list")
assert kwargs == {
"allow_redirects": True,
"headers": headers,
"verify": True,
"json": {"a": "b"},
Expand Down
30 changes: 26 additions & 4 deletions tests/store/artifact/test_databricks_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def test_log_artifact_azure_with_headers(
request_mock.assert_called_with(
"put",
f"{MOCK_AZURE_SIGNED_URI}?comp=blocklist",
allow_redirects=True,
data=ANY,
headers=filtered_azure_headers,
timeout=None,
Expand Down Expand Up @@ -344,33 +345,38 @@ def test_log_artifact_adls_gen2_with_headers(
request_mock.assert_any_call(
"put",
f"{MOCK_ADLS_GEN2_SIGNED_URI}?resource=file",
allow_redirects=True,
headers=filtered_azure_headers,
timeout=None,
)
request_mock.assert_any_call(
"patch",
f"{MOCK_ADLS_GEN2_SIGNED_URI}?action=append&position=0",
allow_redirects=True,
data=ANY,
headers=filtered_azure_headers,
timeout=None,
)
request_mock.assert_any_call(
"patch",
f"{MOCK_ADLS_GEN2_SIGNED_URI}?action=append&position=5",
allow_redirects=True,
data=ANY,
headers=filtered_azure_headers,
timeout=None,
)
request_mock.assert_any_call(
"patch",
f"{MOCK_ADLS_GEN2_SIGNED_URI}?action=append&position=10",
allow_redirects=True,
data=ANY,
headers=filtered_azure_headers,
timeout=None,
)
request_mock.assert_called_with(
"patch",
f"{MOCK_ADLS_GEN2_SIGNED_URI}?action=flush&position=14",
allow_redirects=True,
headers=filtered_azure_headers,
timeout=None,
)
Expand Down Expand Up @@ -402,12 +408,14 @@ def test_log_artifact_adls_gen2_flush_error(databricks_artifact_repo, test_file)
mock.call(
"put",
f"{MOCK_ADLS_GEN2_SIGNED_URI}?resource=file",
allow_redirects=True,
headers={},
timeout=None,
),
mock.call(
"patch",
f"{MOCK_ADLS_GEN2_SIGNED_URI}?action=append&position=0&flush=true",
allow_redirects=True,
data=ANY,
headers={},
timeout=None,
Expand All @@ -434,7 +442,7 @@ def test_log_artifact_aws(databricks_artifact_repo, test_file, artifact_path, ex
GetCredentialsForWrite, MOCK_RUN_ID, [expected_location]
)
request_mock.assert_called_with(
"put", MOCK_AWS_SIGNED_URI, data=ANY, headers={}, timeout=None
"put", MOCK_AWS_SIGNED_URI, allow_redirects=True, data=ANY, headers={}, timeout=None
)


Expand Down Expand Up @@ -462,7 +470,12 @@ def test_log_artifact_aws_with_headers(
GetCredentialsForWrite, MOCK_RUN_ID, [expected_location]
)
request_mock.assert_called_with(
"put", MOCK_AWS_SIGNED_URI, data=ANY, headers=expected_headers, timeout=None
"put",
MOCK_AWS_SIGNED_URI,
allow_redirects=True,
data=ANY,
headers=expected_headers,
timeout=None,
)


Expand Down Expand Up @@ -500,7 +513,7 @@ def test_log_artifact_gcp(databricks_artifact_repo, test_file, artifact_path, ex
GetCredentialsForWrite, MOCK_RUN_ID, [expected_location]
)
request_mock.assert_called_with(
"put", MOCK_GCP_SIGNED_URL, data=ANY, headers={}, timeout=None
"put", MOCK_GCP_SIGNED_URL, allow_redirects=True, data=ANY, headers={}, timeout=None
)


Expand Down Expand Up @@ -528,7 +541,12 @@ def test_log_artifact_gcp_with_headers(
GetCredentialsForWrite, MOCK_RUN_ID, [expected_location]
)
request_mock.assert_called_with(
"put", MOCK_GCP_SIGNED_URL, data=ANY, headers=expected_headers, timeout=None
"put",
MOCK_GCP_SIGNED_URL,
allow_redirects=True,
data=ANY,
headers=expected_headers,
timeout=None,
)


Expand Down Expand Up @@ -1294,6 +1312,7 @@ def test_multipart_upload(databricks_artifact_repo, large_file, mock_chunk_size)
mock.call(
"put",
f"{MOCK_AWS_SIGNED_URI}partNumber={i + 1}",
allow_redirects=True,
data=f.read(mock_chunk_size),
headers={"header": f"part-{i + 1}"},
timeout=None,
Expand Down Expand Up @@ -1383,6 +1402,7 @@ def test_multipart_upload_retry_part_upload(databricks_artifact_repo, large_file
mock.call(
"put",
f"{MOCK_AWS_SIGNED_URI}partNumber={i + 1}",
allow_redirects=True,
data=f.read(mock_chunk_size),
headers={"header": f"part-{i + 1}"},
timeout=None,
Expand Down Expand Up @@ -1441,6 +1461,7 @@ def test_multipart_upload_abort(databricks_artifact_repo, large_file, mock_chunk
mock.call(
"put",
f"{MOCK_AWS_SIGNED_URI}partNumber={i + 1}",
allow_redirects=True,
data=f.read(mock_chunk_size),
headers={"header": f"part-{i + 1}"},
timeout=None,
Expand All @@ -1459,6 +1480,7 @@ def test_multipart_upload_abort(databricks_artifact_repo, large_file, mock_chunk
assert abort_call == mock.call(
"delete",
f"{MOCK_AWS_SIGNED_URI}uploadId=abort",
allow_redirects=True,
headers={"header": "abort"},
timeout=None,
)
1 change: 1 addition & 0 deletions tests/store/tracking/test_rest_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def mock_request(*args, **kwargs):
assert args == ("POST", "https://hello/api/2.0/mlflow/experiments/search")
kwargs = {k: v for k, v in kwargs.items() if v is not None}
assert kwargs == {
"allow_redirects": True,
"json": {"view_type": "ACTIVE_ONLY"},
"headers": DefaultRequestHeaderProvider().request_headers(),
"verify": True,
Expand Down
102 changes: 102 additions & 0 deletions tests/utils/test_request_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,105 @@ def test_download_chunk_incomplete_read(tmp_path):
download_path=download_path,
http_uri="https://example.com",
)


@pytest.mark.parametrize("env_value", ["0", "false", "False", "FALSE"])
def test_redirects_disabled_if_env_var_set(monkeypatch, env_value):
monkeypatch.setenv("MLFLOW_ALLOW_HTTP_REDIRECTS", env_value)

with mock.patch("requests.Session.request") as mock_request:
mock_request.return_value.status_code = 302
mock_request.return_value.text = "mock response"

response = request_utils.cloud_storage_http_request("GET", "http://localhost:5000")

assert response.text == "mock response"
mock_request.assert_called_once_with(
"GET",
"http://localhost:5000",
allow_redirects=False,
timeout=None,
)


@pytest.mark.parametrize("env_value", ["1", "true", "True", "TRUE"])
def test_redirects_enabled_if_env_var_set(monkeypatch, env_value):
monkeypatch.setenv("MLFLOW_ALLOW_HTTP_REDIRECTS", env_value)

with mock.patch("requests.Session.request") as mock_request:
mock_request.return_value.status_code = 302
mock_request.return_value.text = "mock response"

response = request_utils.cloud_storage_http_request(
"GET",
"http://localhost:5000",
)

assert response.text == "mock response"
mock_request.assert_called_once_with(
"GET",
"http://localhost:5000",
allow_redirects=True,
timeout=None,
)


@pytest.mark.parametrize("env_value", ["0", "false", "False", "FALSE"])
def test_redirect_kwarg_overrides_env_value_false(monkeypatch, env_value):
monkeypatch.setenv("MLFLOW_ALLOW_HTTP_REDIRECTS", env_value)

with mock.patch("requests.Session.request") as mock_request:
mock_request.return_value.status_code = 302
mock_request.return_value.text = "mock response"

response = request_utils.cloud_storage_http_request(
"GET", "http://localhost:5000", allow_redirects=True
)

assert response.text == "mock response"
mock_request.assert_called_once_with(
"GET",
"http://localhost:5000",
allow_redirects=True,
timeout=None,
)


@pytest.mark.parametrize("env_value", ["1", "true", "True", "TRUE"])
def test_redirect_kwarg_overrides_env_value_true(monkeypatch, env_value):
monkeypatch.setenv("MLFLOW_ALLOW_HTTP_REDIRECTS", env_value)

with mock.patch("requests.Session.request") as mock_request:
mock_request.return_value.status_code = 302
mock_request.return_value.text = "mock response"

response = request_utils.cloud_storage_http_request(
"GET", "http://localhost:5000", allow_redirects=False
)

assert response.text == "mock response"
mock_request.assert_called_once_with(
"GET",
"http://localhost:5000",
allow_redirects=False,
timeout=None,
)


def test_redirects_enabled_by_default():
with mock.patch("requests.Session.request") as mock_request:
mock_request.return_value.status_code = 302
mock_request.return_value.text = "mock response"

response = request_utils.cloud_storage_http_request(
"GET",
"http://localhost:5000",
)

assert response.text == "mock response"
mock_request.assert_called_once_with(
"GET",
"http://localhost:5000",
allow_redirects=True,
timeout=None,
)
Loading

0 comments on commit 8174250

Please sign in to comment.