Skip to content

Commit

Permalink
[API] Enable to enrich function with node selection attributes on job…
Browse files Browse the repository at this point in the history
… submission endpoint (#935)
  • Loading branch information
Hedingber committed May 17, 2021
1 parent 33d6192 commit 42a084f
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 18 deletions.
36 changes: 18 additions & 18 deletions mlrun/runtimes/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
self.image_pull_secret = image_pull_secret
self.node_name = node_name
self.node_selector = node_selector
self.affinity = affinity
self._affinity = affinity

@property
def volumes(self) -> list:
Expand All @@ -105,18 +105,20 @@ def volume_mounts(self, volume_mounts):
for volume_mount in volume_mounts:
self._set_volume_mount(volume_mount)

@property
def affinity(self) -> client.V1Affinity:
return self._affinity

@affinity.setter
def affinity(self, affinity):
self._affinity = self._transform_affinity_to_k8s_class_instance(affinity)

def to_dict(self, fields=None, exclude=None):
struct = super().to_dict(fields, exclude=["affinity"])
api = client.ApiClient()
struct["affinity"] = api.sanitize_for_serialization(self.affinity)
return struct

@classmethod
def from_dict(cls, struct=None, fields=None):
new_instance = super().from_dict(struct, fields)
new_instance.affinity = new_instance._get_affinity_as_k8s_class_instance()
return new_instance

def update_vols_and_mounts(self, volumes, volume_mounts):
if volumes:
for vol in volumes:
Expand All @@ -127,14 +129,16 @@ def update_vols_and_mounts(self, volumes, volume_mounts):
self._set_volume_mount(volume_mount)

def _get_affinity_as_k8s_class_instance(self):
if not self.affinity:
pass

def _transform_affinity_to_k8s_class_instance(self, affinity):
if not affinity:
return None
affinity = self.affinity
if isinstance(affinity, dict):
api = client.ApiClient()
# not ideal to use their private method, but looks like that's the only option
# Taken from https://github.com/kubernetes-client/python/issues/977
affinity = api._ApiClient__deserialize(self.affinity, "V1Affinity")
affinity = api._ApiClient__deserialize(affinity, "V1Affinity")
return affinity

def _get_sanitized_affinity(self):
Expand All @@ -146,8 +150,8 @@ def _get_sanitized_affinity(self):
if not self.affinity:
return {}
if isinstance(self.affinity, dict):
# if node_affinity is part of the dict it means to_dict on the kubernetes object performed, there's nothing
# we can do at that point to transform it to the sanitized version
# heuristic - if node_affinity is part of the dict it means to_dict on the kubernetes object performed,
# there's nothing we can do at that point to transform it to the sanitized version
if "node_affinity" in self.affinity:
raise mlrun.errors.MLRunInvalidArgumentError(
"Affinity must be instance of kubernetes' V1Affinity class"
Expand All @@ -156,8 +160,7 @@ def _get_sanitized_affinity(self):
# then it's already the sanitized version
return self.affinity
api = client.ApiClient()
affinity = self._get_affinity_as_k8s_class_instance()
return api.sanitize_for_serialization(affinity)
return api.sanitize_for_serialization(self.affinity)

def _set_volume_mount(self, volume_mount):
# calculate volume mount hash
Expand Down Expand Up @@ -410,15 +413,12 @@ def _add_vault_params_to_spec(self, runobj=None, project=None):
def kube_resource_spec_to_pod_spec(
kube_resource_spec: KubeResourceSpec, container: client.V1Container
):
affinity = kube_resource_spec.affinity
if kube_resource_spec.affinity and isinstance(kube_resource_spec.affinity, dict):
affinity = kube_resource_spec._get_affinity_as_k8s_class_instance()
return client.V1PodSpec(
containers=[container],
restart_policy="Never",
volumes=kube_resource_spec.volumes,
service_account=kube_resource_spec.service_account,
node_name=kube_resource_spec.node_name,
node_selector=kube_resource_spec.node_selector,
affinity=affinity,
affinity=kube_resource_spec.affinity,
)
3 changes: 3 additions & 0 deletions mlrun/runtimes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,9 @@ def enrich_function_from_dict(function, function_dict):
"resources",
"image_pull_policy",
"replicas",
"node_name",
"node_selector",
"affinity",
]:
override_value = getattr(override_function.spec, attribute, None)
if override_value:
Expand Down
95 changes: 95 additions & 0 deletions tests/api/api/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,81 @@ def test_parse_submit_job_body_override_values(db: Session, client: TestClient):
},
"image_pull_policy": "Always",
"replicas": "3",
"node_name": "k8s-node1",
"node_selector": {"kubernetes.io/hostname": "k8s-node1"},
"affinity": {
"nodeAffinity": {
"preferredDuringSchedulingIgnoredDuringExecution": [
{
"preference": {
"matchExpressions": [
{
"key": "some_node_label",
"operator": "In",
"values": [
"possible-label-value-1",
"possible-label-value-2",
],
}
]
},
"weight": 1,
}
],
"requiredDuringSchedulingIgnoredDuringExecution": {
"nodeSelectorTerms": [
{
"matchExpressions": [
{
"key": "some_node_label",
"operator": "In",
"values": [
"required-label-value-1",
"required-label-value-2",
],
}
]
}
]
},
},
"podAffinity": {
"requiredDuringSchedulingIgnoredDuringExecution": [
{
"labelSelector": {
"matchLabels": {
"some-pod-label-key": "some-pod-label-value"
}
},
"namespaces": ["namespace-a", "namespace-b"],
"topologyKey": "key-1",
}
]
},
"podAntiAffinity": {
"preferredDuringSchedulingIgnoredDuringExecution": [
{
"podAffinityTerm": {
"labelSelector": {
"matchExpressions": [
{
"key": "some_pod_label",
"operator": "NotIn",
"values": [
"forbidden-label-value-1",
"forbidden-label-value-2",
],
}
]
},
"namespaces": ["namespace-c"],
"topologyKey": "key-2",
},
"weight": 1,
}
]
},
},
}
},
}
Expand Down Expand Up @@ -90,6 +165,26 @@ def test_parse_submit_job_body_override_values(db: Session, client: TestClient):
parsed_function_object, submit_job_body, original_function
)
_assert_env_vars(parsed_function_object, submit_job_body, original_function)
assert (
parsed_function_object.spec.node_name
== submit_job_body["function"]["spec"]["node_name"]
)
assert (
DeepDiff(
parsed_function_object.spec.node_selector,
submit_job_body["function"]["spec"]["node_selector"],
ignore_order=True,
)
== {}
)
assert (
DeepDiff(
parsed_function_object.spec._get_sanitized_affinity(),
submit_job_body["function"]["spec"]["affinity"],
ignore_order=True,
)
== {}
)


def test_parse_submit_job_body_keep_resources(db: Session, client: TestClient):
Expand Down

0 comments on commit 42a084f

Please sign in to comment.