Skip to content

Commit

Permalink
[Dask] Enhance extending env vars to avoid memory leak (#5490)
Browse files Browse the repository at this point in the history
  • Loading branch information
liranbg committed May 2, 2024
1 parent f3c0222 commit cb3ba86
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 7 deletions.
10 changes: 5 additions & 5 deletions mlrun/runtimes/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,12 +1086,12 @@ def enrich_runtime_spec(

def _set_env(self, name, value=None, value_from=None):
new_var = k8s_client.V1EnvVar(name=name, value=value, value_from=value_from)
i = 0
for v in self.spec.env:
if get_item_name(v) == name:
self.spec.env[i] = new_var

# ensure we don't have duplicate env vars with the same name
for env_index, value_item in enumerate(self.spec.env):
if get_item_name(value_item) == name:
self.spec.env[env_index] = new_var
return self
i += 1
self.spec.env.append(new_var)
return self

Expand Down
28 changes: 26 additions & 2 deletions server/api/runtime_handlers/daskjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,32 @@ def enrich_dask_cluster(
# TODO: we might never enter here, since running a function requires defining an image
or "daskdev/dask:latest"
)
env = spec.env
env.extend(function.generate_runtime_k8s_env())
env = function.generate_runtime_k8s_env()

# filter any spec.env that already exists in env
# in other words, dont let spec.env override env (or not even duplicate it)
# we dont want to override env to ensure k8s runtime envs are enforced and correct
# leaving no room for human mistakes
def get_env_name(env_: Union[client.V1EnvVar, dict]) -> str:
if isinstance(env_, client.V1EnvVar):
return env_.name
return env_.get("name", "")

env.extend(
filter(
lambda spec_env: not any(
[
True
for _env in env
# spec_env might be V1EnvVar or a dict
# _env is just a dict
if get_env_name(spec_env) == get_env_name(_env)
]
),
spec.env,
)
)

namespace = meta.namespace or config.namespace
if spec.extra_pip:
env.append(spec.extra_pip)
Expand Down
80 changes: 80 additions & 0 deletions tests/api/runtimes/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@

from dask import distributed
from fastapi.testclient import TestClient
from kubernetes import client as k8s_client
from sqlalchemy.orm import Session

import mlrun
import mlrun.common.schemas
import server.api.api.endpoints.functions
import server.api.runtime_handlers.daskjob
from mlrun import mlconf
from mlrun.platforms import auto_mount
from mlrun.runtimes.utils import generate_resources
Expand Down Expand Up @@ -432,6 +434,84 @@ def test_dask_with_security_context(self, db: Session, client: TestClient):
_ = runtime.client
self.assert_security_context(other_security_context)

def test_enrich_dask_cluster(self):
function = mlrun.runtimes.DaskCluster(
metadata=dict(
name="test",
project="project",
labels={"label1": "val1"},
annotations={"annotation1": "val1"},
),
spec=dict(
nthreads=1,
worker_resources={"limits": {"memory": "1Gi"}},
scheduler_resources={"limits": {"memory": "1Gi"}},
env=[
{"name": "MLRUN_NAMESPACE", "value": "other-namespace"},
k8s_client.V1EnvVar(name="MLRUN_TAG", value="latest"),
],
),
)

function.generate_runtime_k8s_env = unittest.mock.Mock(
return_value=[
{"name": "MLRUN_DEFAULT_PROJECT", "value": "project"},
{"name": "MLRUN_NAMESPACE", "value": "test-namespace"},
]
)

# add default envvars that expected to be on enriched pods
# do it to verify later on it is not duplicated and appears only once
function.spec.env.extend(function.generate_runtime_k8s_env())

expected_resources = {
"limits": {"memory": "1Gi"},
"requests": {},
}
expected_env = [
{"name": "MLRUN_DEFAULT_PROJECT", "value": "project"},
{"name": "MLRUN_NAMESPACE", "value": "test-namespace"},
k8s_client.V1EnvVar(name="MLRUN_TAG", value="latest"),
]
expected_labels = {
"mlrun/project": "project",
"mlrun/class": "dask",
"mlrun/function": "test",
"label1": "val1",
"mlrun/scrape-metrics": "True",
"mlrun/tag": "latest",
}

secrets = []
client_version = "1.6.0"
client_python_version = "3.9"
scheduler_pod, worker_pod, function, namespace = (
server.api.runtime_handlers.daskjob.enrich_dask_cluster(
function, secrets, client_version, client_python_version
)
)

assert scheduler_pod.metadata.namespace == namespace
assert worker_pod.metadata.namespace == namespace
assert scheduler_pod.metadata.labels == expected_labels
assert worker_pod.metadata.labels == expected_labels
assert scheduler_pod.spec.containers[0].args == ["dask", "scheduler"]
assert worker_pod.spec.containers[0].args == [
"dask",
"worker",
"--nthreads",
"1",
"--memory-limit",
"1Gi",
]
assert worker_pod.spec.containers[0].resources == expected_resources
assert scheduler_pod.spec.containers[0].resources == expected_resources
assert worker_pod.spec.containers[0].env == expected_env
assert scheduler_pod.spec.containers[0].env == expected_env

# used once by test, once by enrich_dask_cluster
assert function.generate_runtime_k8s_env.call_count == 2

def test_deploy_dask_function_with_enriched_security_context(
self, db: Session, client: TestClient, k8s_secrets_mock: K8sSecretsMock
):
Expand Down

0 comments on commit cb3ba86

Please sign in to comment.