Skip to content

Commit

Permalink
Add Content-Type header validation (#10526)
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
  • Loading branch information
B-Step62 committed Nov 30, 2023
1 parent 4b8bb73 commit 28ff3f9
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/source/auth/rest-api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ The API is hosted under the ``/api`` route on the MLflow tracking server. For ex
experiments on a tracking server hosted at ``http://localhost:5000``, access
``http://localhost:5000/api/2.0/mlflow/users/create``.

.. important::
The MLflow REST API requires content type ``application/json`` for all POST requests.

.. contents:: Table of Contents
:local:
Expand Down
4 changes: 3 additions & 1 deletion docs/source/rest-api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ The API is hosted under the ``/api`` route on the MLflow tracking server. For ex
experiments on a tracking server hosted at ``http://localhost:5000``, make a POST request to
``http://localhost:5000/api/2.0/mlflow/experiments/search``.

.. important::
The MLflow REST API requires content type ``application/json`` for all POST requests.

.. contents:: Table of Contents
:local:
:depth: 1

===========================



.. _mlflowMlflowServicecreateExperiment:

Create Experiment
Expand Down
6 changes: 5 additions & 1 deletion mlflow/server/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,11 @@ def create_user():
user = store.create_user(username, password)
return make_response({"user": user.to_json()})
else:
return make_response(f"Invalid content type: '{content_type}'", 400)
message = (
"Invalid content type. Must be one of: "
"application/x-www-form-urlencoded, application/json"
)
return make_response(message, 400)


@catch_mlflow_exception
Expand Down
5 changes: 5 additions & 0 deletions mlflow/server/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
UpdateExperiment,
UpdateRun,
)
from mlflow.server.validation import _validate_content_type
from mlflow.store.artifact.artifact_repo import MultipartUploadMixin
from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
from mlflow.store.db.db_types import DATABASE_ENGINES
Expand Down Expand Up @@ -403,6 +404,7 @@ def _validate_param_against_schema(schema, param, value, proto_parsing_succeeded


def _get_request_json(flask_request=request):
_validate_content_type(flask_request, ["application/json"])
return flask_request.get_json(force=True, silent=True)


Expand Down Expand Up @@ -1112,6 +1114,7 @@ def _default_history_bulk_impl():
@_disable_if_artifacts_only
def search_datasets_handler():
MAX_EXPERIMENT_IDS_PER_REQUEST = 20
_validate_content_type(request, ["application/json"])
experiment_ids = request.json.get("experiment_ids", [])
if not experiment_ids:
raise MlflowException(
Expand Down Expand Up @@ -1179,6 +1182,8 @@ def assert_arg_exists(arg_name, arg):
error_code=INVALID_PARAMETER_VALUE,
)

_validate_content_type(request, ["application/json"])

args = request.json
experiment_id = args.get("experiment_id")
assert_arg_exists("experiment_id", experiment_id)
Expand Down
31 changes: 31 additions & 0 deletions mlflow/server/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List

from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE


def _validate_content_type(flask_request, allowed_content_types: List[str]):
"""
Validates that the request content type is one of the allowed content types.
:param flask_request: Flask request object (flask.request)
:param allowed_content_types: A list of allowed content types
"""
if flask_request.method not in ["POST", "PUT"]:
return

if flask_request.content_type is None:
raise MlflowException(
message="Bad Request. Content-Type header is missing.",
error_code=INVALID_PARAMETER_VALUE,
)

# Remove any parameters e.g. "application/json; charset=utf-8" -> "application/json"
content_type = flask_request.content_type.split(";")[0]
if content_type not in allowed_content_types:
message = f"Bad Request. Content-Type must be one of {allowed_content_types}."

raise MlflowException(
message=message,
error_code=INVALID_PARAMETER_VALUE,
)
33 changes: 33 additions & 0 deletions tests/server/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def test_all_model_registry_endpoints_available():
def test_can_parse_json():
request = mock.MagicMock()
request.method = "POST"
request.content_type = "application/json"
request.get_json = mock.MagicMock()
request.get_json.return_value = {"name": "hello"}
msg = _get_request_message(CreateExperiment(), flask_request=request)
Expand All @@ -165,12 +166,23 @@ def test_can_parse_json():
def test_can_parse_post_json_with_unknown_fields():
request = mock.MagicMock()
request.method = "POST"
request.content_type = "application/json"
request.get_json = mock.MagicMock()
request.get_json.return_value = {"name": "hello", "WHAT IS THIS FIELD EVEN": "DOING"}
msg = _get_request_message(CreateExperiment(), flask_request=request)
assert msg.name == "hello"


def test_can_parse_post_json_with_content_type_params():
request = mock.MagicMock()
request.method = "POST"
request.content_type = "application/json; charset=utf-8"
request.get_json = mock.MagicMock()
request.get_json.return_value = {"name": "hello"}
msg = _get_request_message(CreateExperiment(), flask_request=request)
assert msg.name == "hello"


def test_can_parse_get_json_with_unknown_fields():
request = mock.MagicMock()
request.method = "GET"
Expand All @@ -184,12 +196,33 @@ def test_can_parse_get_json_with_unknown_fields():
def test_can_parse_json_string():
request = mock.MagicMock()
request.method = "POST"
request.content_type = "application/json"
request.get_json = mock.MagicMock()
request.get_json.return_value = '{"name": "hello2"}'
msg = _get_request_message(CreateExperiment(), flask_request=request)
assert msg.name == "hello2"


def test_can_block_post_request_with_invalid_content_type():
request = mock.MagicMock()
request.method = "POST"
request.content_type = "text/plain"
request.get_json = mock.MagicMock()
request.get_json.return_value = {"name": "hello"}
with pytest.raises(MlflowException, match=r"Bad Request. Content-Type"):
_get_request_message(CreateExperiment(), flask_request=request)


def test_can_block_post_request_with_missing_content_type():
request = mock.MagicMock()
request.method = "POST"
request.content_type = None
request.get_json = mock.MagicMock()
request.get_json.return_value = {"name": "hello"}
with pytest.raises(MlflowException, match=r"Bad Request. Content-Type"):
_get_request_message(CreateExperiment(), flask_request=request)


def test_search_runs_default_view_type(mock_get_request_message, mock_tracking_store):
"""
Search Runs default view type is filled in as ViewType.ACTIVE_ONLY
Expand Down

0 comments on commit 28ff3f9

Please sign in to comment.