Skip to content

Commit

Permalink
[Runtimes] Add validation on inputs types (#1550)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tankilevitch committed Dec 7, 2021
1 parent c187e5c commit 3f11240
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 1 deletion.
4 changes: 4 additions & 0 deletions mlrun/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ class MLRunInvalidArgumentError(MLRunHTTPStatusError, ValueError):
error_status_code = HTTPStatus.BAD_REQUEST.value


class MLRunInvalidArgumentTypeError(MLRunHTTPStatusError, TypeError):
error_status_code = HTTPStatus.BAD_REQUEST.value


class MLRunConflictError(MLRunHTTPStatusError):
error_status_code = HTTPStatus.CONFLICT.value

Expand Down
3 changes: 2 additions & 1 deletion mlrun/runtimes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def run(
name: str = "",
project: str = "",
params: dict = None,
inputs: dict = None,
inputs: Dict[str, str] = None,
out_path: str = "",
workdir: str = "",
artifact_path: str = "",
Expand Down Expand Up @@ -286,6 +286,7 @@ def run(
:return: run context object (RunObject) with run metadata, results and status
"""
mlrun.utils.helpers.verify_dict_items_type("Inputs", inputs, [str], [str])

if self.spec.mode and self.spec.mode not in run_modes:
raise ValueError(f'run mode can only be {",".join(run_modes)}')
Expand Down
41 changes: 41 additions & 0 deletions mlrun/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,47 @@ def verify_field_list_of_type(
verify_field_of_type(field_name, element, expected_element_type)


def verify_dict_items_type(
name: str,
dictionary: dict,
expected_keys_types: list = None,
expected_values_types: list = None,
):
if dictionary:
if type(dictionary) != dict:
raise mlrun.errors.MLRunInvalidArgumentTypeError(
f"{name} expected to be of type dict, got type : {type(dictionary)}"
)
try:
verify_list_items_type(dictionary.keys(), expected_keys_types)
verify_list_items_type(dictionary.values(), expected_values_types)
except mlrun.errors.MLRunInvalidArgumentTypeError as exc:
raise mlrun.errors.MLRunInvalidArgumentTypeError(
f"{name} should be of type Dict[{get_pretty_types_names(expected_keys_types)},"
f"{get_pretty_types_names(expected_values_types)}]."
) from exc


def verify_list_items_type(list_, expected_types: list = None):
if list_ and expected_types:
list_items_types = set(map(type, list_))
expected_types = set(expected_types)

if not list_items_types.issubset(expected_types):
raise mlrun.errors.MLRunInvalidArgumentTypeError(
f"Found unexpected types in list items. expected: {expected_types},"
f" found: {list_items_types} in : {list_}"
)


def get_pretty_types_names(types):
if len(types) == 0:
return ""
if len(types) > 1:
return "Union[" + ",".join([ty.__name__ for ty in types]) + "]"
return types[0].__name__


def now_date():
return datetime.now(timezone.utc)

Expand Down
25 changes: 25 additions & 0 deletions tests/api/runtimes/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session

import mlrun.errors
from mlrun.runtimes.base import BaseRuntime
from tests.api.runtimes.base import TestRuntimeBase


class TestBaseRunTime(TestRuntimeBase):
def custom_setup_after_fixtures(self):
self._mock_create_namespaced_pod()

@pytest.mark.parametrize(
"inputs", [{"input1": 123}, {"input1": None}, {"input1": None, "input2": 2}]
)
def test_run_with_invalid_inputs(self, db: Session, client: TestClient, inputs):
runtime = BaseRuntime()
with pytest.raises(mlrun.errors.MLRunInvalidArgumentTypeError):
self._execute_run(runtime, inputs=inputs)

def test_run_with_valid_inputs(self, db: Session, client: TestClient):
inputs = {"input1": "mlrun"}
runtime = BaseRuntime()
self._execute_run(runtime, inputs=inputs)
44 changes: 44 additions & 0 deletions tests/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
extend_hub_uri_if_needed,
fill_artifact_path_template,
get_parsed_docker_registry,
get_pretty_types_names,
verify_field_regex,
verify_list_items_type,
)
from mlrun.utils.regex import run_name

Expand Down Expand Up @@ -355,3 +357,45 @@ def test_fill_artifact_path_template():
case["artifact_path"], case.get("project")
)
assert case["expected_artifact_path"] == filled_artifact_path


@pytest.mark.parametrize("actual_list", [[1], [1, "asd"], [None], ["asd", 23]])
@pytest.mark.parametrize("expected_types", [[str]])
def test_verify_list_types_failure(actual_list, expected_types):
with pytest.raises(mlrun.errors.MLRunInvalidArgumentTypeError):
verify_list_items_type(actual_list, expected_types)


@pytest.mark.parametrize(
"actual_list", [[1.0, 8, "test"], ["test", 0.0], [None], [[["test"], 23]]]
)
@pytest.mark.parametrize("expected_types", [[str, int]])
def test_verify_list_multiple_types_failure(actual_list, expected_types):
with pytest.raises(mlrun.errors.MLRunInvalidArgumentTypeError):
verify_list_items_type(actual_list, expected_types)


@pytest.mark.parametrize("actual_list", [[], ["test"], ["test", "test1"]])
@pytest.mark.parametrize("expected_types", [[str]])
def test_verify_list_types_success(actual_list, expected_types):
verify_list_items_type(actual_list, expected_types)


@pytest.mark.parametrize(
"actual_list",
[[1, 8, "test"], ["test", 0], [], ["test", 23, "test"], ["test"], [1], [123, 123]],
)
@pytest.mark.parametrize("expected_types", [[str, int]])
def test_verify_list_multiple_types_success(actual_list, expected_types):
verify_list_items_type(actual_list, expected_types)


def test_get_pretty_types_names():
cases = [
([], ""),
([str], "str"),
([str, int], "Union[str,int]"),
]
for types, expected in cases:
pretty_result = get_pretty_types_names(types)
assert pretty_result == expected

0 comments on commit 3f11240

Please sign in to comment.