Skip to content

Commit

Permalink
[Serving] Fix node selectors and priority class name support (#1334)
Browse files Browse the repository at this point in the history
  • Loading branch information
alonmr committed Sep 22, 2021
1 parent 8115bee commit 57704da
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 50 deletions.
8 changes: 8 additions & 0 deletions mlrun/runtimes/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,11 @@ def __init__(
track_models=None,
secret_sources=None,
default_content_type=None,
node_name=None,
node_selector=None,
affinity=None,
disable_auto_mount=False,
priority_class_name=None,
):

super().__init__(
Expand All @@ -134,7 +138,11 @@ def __init__(
service_account=service_account,
readiness_timeout=readiness_timeout,
build=build,
node_name=node_name,
node_selector=node_selector,
affinity=affinity,
disable_auto_mount=disable_auto_mount,
priority_class_name=priority_class_name,
)

self.models = models or {}
Expand Down
123 changes: 77 additions & 46 deletions tests/api/runtimes/test_nuclio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@


class TestNuclioRuntime(TestRuntimeBase):
@property
def runtime_kind(self):
# enables extending classes to run the same tests with different runtime
return "nuclio"

@property
def class_name(self):
# enables extending classes to run the same tests with different class
return "remote"

def custom_setup_after_fixtures(self):
self._mock_nuclio_deploy_config()

Expand All @@ -36,6 +46,11 @@ def custom_setup(self):
os.environ["V3IO_ACCESS_KEY"] = self.v3io_access_key = "1111-2222-3333-4444"
os.environ["V3IO_USERNAME"] = self.v3io_user = "test-user"

def _serialize_and_deploy_nuclio_function(self, function):
# simulating sending to API - serialization through dict
function = function.from_dict(function.to_dict())
deploy_nuclio_function(function)

@staticmethod
def _mock_nuclio_deploy_config():
nuclio.deploy.deploy_config = unittest.mock.Mock(return_value="some-server")
Expand Down Expand Up @@ -258,7 +273,7 @@ def test_enrich_with_ingress_no_overriding(self, db: Session, client: TestClient
Expect no ingress template to be created, thought its mode is "always",
since the function already have a pre-configured ingress
"""
function = self._generate_runtime("nuclio")
function = self._generate_runtime(self.runtime_kind)

# both ingress and node port
ingress_host = "something.com"
Expand All @@ -278,7 +293,7 @@ def test_enrich_with_ingress_always(self, db: Session, client: TestClient):
"""
Expect ingress template to be created as the configuration templated ingress mode is "always"
"""
function = self._generate_runtime("nuclio")
function = self._generate_runtime(self.runtime_kind)
function_name, project_name, config = compile_function_config(function)
service_type = "NodePort"
enrich_function_with_ingress(
Expand All @@ -292,7 +307,7 @@ def test_enrich_with_ingress_on_cluster_ip(self, db: Session, client: TestClient
Expect ingress template to be created as the configuration templated ingress mode is "onClusterIP" while the
function service type is ClusterIP
"""
function = self._generate_runtime("nuclio")
function = self._generate_runtime(self.runtime_kind)
function_name, project_name, config = compile_function_config(function)
service_type = "ClusterIP"
enrich_function_with_ingress(
Expand All @@ -305,7 +320,7 @@ def test_enrich_with_ingress_never(self, db: Session, client: TestClient):
"""
Expect no ingress to be created automatically as the configuration templated ingress mode is "never"
"""
function = self._generate_runtime("nuclio")
function = self._generate_runtime(self.runtime_kind)
function_name, project_name, config = compile_function_config(function)
service_type = "DoesNotMatter"
enrich_function_with_ingress(
Expand All @@ -315,23 +330,25 @@ def test_enrich_with_ingress_never(self, db: Session, client: TestClient):
assert ingresses == []

def test_deploy_basic_function(self, db: Session, client: TestClient):
function = self._generate_runtime("nuclio")
function = self._generate_runtime(self.runtime_kind)

deploy_nuclio_function(function)
self._assert_deploy_called_basic_config()
self._serialize_and_deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(expected_class=self.class_name)

def test_deploy_function_with_labels(self, db: Session, client: TestClient):
labels = {
"key": "value",
"key-2": "value-2",
}
function = self._generate_runtime("nuclio", labels)
function = self._generate_runtime(self.runtime_kind, labels)

deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(expected_labels=labels)
self._serialize_and_deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(
expected_labels=labels, expected_class=self.class_name
)

def test_deploy_with_triggers(self, db: Session, client: TestClient):
function = self._generate_runtime("nuclio")
function = self._generate_runtime(self.runtime_kind)

http_trigger = {
"workers": 2,
Expand All @@ -353,32 +370,32 @@ def test_deploy_with_triggers(self, db: Session, client: TestClient):
function.with_http(**http_trigger)
function.add_v3io_stream_trigger(**v3io_trigger)

deploy_nuclio_function(function)
self._assert_deploy_called_basic_config()
self._serialize_and_deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(expected_class=self.class_name)
self._assert_triggers(http_trigger, v3io_trigger)

def test_deploy_with_v3io(self, db: Session, client: TestClient):
function = self._generate_runtime("nuclio")
function = self._generate_runtime(self.runtime_kind)
local_path = "/local/path"
remote_path = "/container/and/path"
function.with_v3io(local_path, remote_path)

deploy_nuclio_function(function)
self._assert_deploy_called_basic_config()
self._serialize_and_deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(expected_class=self.class_name)
self._assert_nuclio_v3io_mount(local_path, remote_path)

def test_deploy_with_node_selection(self, db: Session, client: TestClient):
mlconf.nuclio_version = "1.6.10"
function = self._generate_runtime("nuclio")
function = self._generate_runtime(self.runtime_kind)

node_name = "some-node-name"
function.with_node_selection(node_name=node_name)

deploy_nuclio_function(function)
self._assert_deploy_called_basic_config()
self._serialize_and_deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(expected_class=self.class_name)
self._assert_node_selections(function.spec, expected_node_name=node_name)

function = self._generate_runtime("nuclio")
function = self._generate_runtime(self.runtime_kind)

node_selector = {
"label-1": "val1",
Expand All @@ -388,37 +405,45 @@ def test_deploy_with_node_selection(self, db: Session, client: TestClient):
json.dumps(node_selector).encode("utf-8")
)
function.with_node_selection(node_selector=node_selector)
deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(call_count=2)
self._serialize_and_deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(
call_count=2, expected_class=self.class_name
)
self._assert_node_selections(
function.spec, expected_node_selector=node_selector
)

function = self._generate_runtime("nuclio")
function = self._generate_runtime(self.runtime_kind)

node_selector = {
"label-3": "val3",
"label-4": "val4",
}
function.with_node_selection(node_selector=node_selector)
deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(call_count=3)
self._serialize_and_deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(
call_count=3, expected_class=self.class_name
)
self._assert_node_selections(
function.spec, expected_node_selector=node_selector
)

function = self._generate_runtime("nuclio")
function = self._generate_runtime(self.runtime_kind)
affinity = self._generate_affinity()

function.with_node_selection(affinity=affinity)
deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(call_count=4)
self._serialize_and_deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(
call_count=4, expected_class=self.class_name
)
self._assert_node_selections(function.spec, expected_affinity=affinity)

function = self._generate_runtime("nuclio")
function = self._generate_runtime(self.runtime_kind)
function.with_node_selection(node_name, node_selector, affinity)
deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(call_count=5)
self._serialize_and_deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(
call_count=5, expected_class=self.class_name
)
self._assert_node_selections(
function.spec,
expected_node_name=node_name,
Expand All @@ -432,37 +457,41 @@ def test_deploy_with_priority_class_name(self, db: Session, client: TestClient):
default_priority_class_name = "default-priority"
mlrun.mlconf.default_function_priority_class_name = default_priority_class_name
mlrun.mlconf.valid_function_priority_class_names = default_priority_class_name
function = self._generate_runtime("nuclio")
function = self._generate_runtime(self.runtime_kind)

deploy_nuclio_function(function)
self._assert_deploy_called_basic_config()
self._serialize_and_deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(expected_class=self.class_name)
args, _ = nuclio.deploy.deploy_config.call_args
deploy_spec = args[0]["spec"]

assert "priorityClassName" not in deploy_spec

mlconf.nuclio_version = "1.6.18"
mlrun.mlconf.valid_function_priority_class_names = ""
function = self._generate_runtime("nuclio")
function = self._generate_runtime(self.runtime_kind)

deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(call_count=2)
self._serialize_and_deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(
call_count=2, expected_class=self.class_name
)
args, _ = nuclio.deploy.deploy_config.call_args
deploy_spec = args[0]["spec"]

assert "priorityClassName" not in deploy_spec

mlrun.mlconf.valid_function_priority_class_names = default_priority_class_name
function = self._generate_runtime("nuclio")
function = self._generate_runtime(self.runtime_kind)

deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(call_count=3)
self._serialize_and_deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(
call_count=3, expected_class=self.class_name
)
args, _ = nuclio.deploy.deploy_config.call_args
deploy_spec = args[0]["spec"]

assert deploy_spec["priorityClassName"] == default_priority_class_name

function = self._generate_runtime()
function = self._generate_runtime(self.runtime_kind)
medium_priority_class_name = "medium-priority"
mlrun.mlconf.valid_function_priority_class_names = medium_priority_class_name
mlconf.nuclio_version = "1.5.20"
Expand All @@ -476,8 +505,10 @@ def test_deploy_with_priority_class_name(self, db: Session, client: TestClient):
mlconf.nuclio_version = "1.6.18"
function.with_priority_class(medium_priority_class_name)

deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(call_count=4)
self._serialize_and_deploy_nuclio_function(function)
self._assert_deploy_called_basic_config(
call_count=4, expected_class=self.class_name
)
args, _ = nuclio.deploy.deploy_config.call_args
deploy_spec = args[0]["spec"]

Expand Down Expand Up @@ -540,7 +571,7 @@ def success():
success()

def test_load_function_with_source_archive_git(self):
fn = self._generate_runtime("nuclio")
fn = self._generate_runtime(self.runtime_kind)
fn.with_source_archive(
"git://github.com/org/repo#my-branch",
handler="path/inside/repo#main:handler",
Expand Down Expand Up @@ -572,7 +603,7 @@ def test_load_function_with_source_archive_git(self):
}

def test_load_function_with_source_archive_s3(self):
fn = self._generate_runtime("nuclio")
fn = self._generate_runtime(self.runtime_kind)
fn.with_source_archive(
"s3://my-bucket/path/in/bucket/my-functions-archive",
handler="path/inside/functions/archive#main:Handler",
Expand Down Expand Up @@ -610,7 +641,7 @@ def test_load_function_with_source_archive_s3(self):
}

def test_load_function_with_source_archive_v3io(self):
fn = self._generate_runtime("nuclio")
fn = self._generate_runtime(self.runtime_kind)
fn.with_source_archive(
"v3ios://host.com/container/my-functions-archive.zip",
handler="path/inside/functions/archive#main:handler",
Expand Down
18 changes: 14 additions & 4 deletions tests/api/runtimes/test_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@


class TestServingRuntime(TestNuclioRuntime):
@property
def runtime_kind(self):
# enables extending classes to run the same tests with different runtime
return "serving"

@property
def class_name(self):
# enables extending classes to run the same tests with different class
return "serving"

def custom_setup_after_fixtures(self):
self._mock_nuclio_deploy_config()
self._mock_vault_functionality()
Expand Down Expand Up @@ -60,7 +70,7 @@ def _remote_db_mock_function(func, with_mlrun):
SQLDB.get_builder_status = unittest.mock.Mock(return_value=("text", "last_log"))

def _create_serving_function(self):
function = self._generate_runtime("serving")
function = self._generate_runtime(self.runtime_kind)
graph = function.set_topology("flow", exist_ok=True, engine="sync")

graph.add_step(name="s1", class_name="Chain", secret="inline_secret1")
Expand Down Expand Up @@ -168,7 +178,7 @@ def test_remote_deploy_with_secrets(self, db: Session, client: TestClient):
function = self._create_serving_function()

function.deploy(verbose=True)
self._assert_deploy_called_basic_config(expected_class="serving")
self._assert_deploy_called_basic_config(expected_class=self.class_name)

self._assert_deploy_spec_has_secrets_config(
expected_secret_sources=self._generate_expected_secret_sources()
Expand Down Expand Up @@ -211,7 +221,7 @@ def test_serving_with_secrets_remote_build(self, db: Session, client: TestClient

assert response.status_code == HTTPStatus.OK.value

self._assert_deploy_called_basic_config(expected_class="serving")
self._assert_deploy_called_basic_config(expected_class=self.class_name)

def test_child_functions_with_secrets(self, db: Session, client: TestClient):
function = self._create_serving_function()
Expand Down Expand Up @@ -250,7 +260,7 @@ def test_child_functions_with_secrets(self, db: Session, client: TestClient):
]

self._assert_deploy_called_basic_config(
expected_class="serving",
expected_class=self.class_name,
call_count=2,
expected_params=expected_deploy_params,
)
Expand Down

0 comments on commit 57704da

Please sign in to comment.