Skip to content

Commit

Permalink
Enable overwrites of default environment variables (#874)
Browse files Browse the repository at this point in the history
* Enable overwrites of default environment variables

* Black formatting

* Include test for additional worker group; test overriding of environment variables

* Black

---------

Co-authored-by: Jonas Dedden <jonas.dedden@deepl.com>
  • Loading branch information
jonded94 and Jonas Dedden committed Apr 2, 2024
1 parent b668cc6 commit 93fe171
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 32 deletions.
32 changes: 18 additions & 14 deletions dask_kubernetes/operator/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,21 +153,25 @@ def build_worker_deployment_spec(
"metadata": metadata,
"spec": spec,
}
env = [
{
"name": "DASK_WORKER_NAME",
"value": worker_name,
},
{
"name": "DASK_SCHEDULER_ADDRESS",
"value": f"tcp://{cluster_name}-scheduler.{namespace}.svc.cluster.local:8786",
},
]
worker_env = {
"name": "DASK_WORKER_NAME",
"value": worker_name,
}
scheduler_env = {
"name": "DASK_SCHEDULER_ADDRESS",
"value": f"tcp://{cluster_name}-scheduler.{namespace}.svc.cluster.local:8786",
}
for container in deployment_spec["spec"]["template"]["spec"]["containers"]:
if "env" in container:
container["env"].extend(env)
else:
container["env"] = env
if "env" not in container:
container["env"] = [worker_env, scheduler_env]
continue

container_env_names = [env_item["name"] for env_item in container["env"]]

if "DASK_WORKER_NAME" not in container_env_names:
container["env"].append(worker_env)
if "DASK_SCHEDULER_ADDRESS" not in container_env_names:
container["env"].append(scheduler_env)
return deployment_spec


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ metadata:
spec:
cluster: simple
worker:
replicas: 2
replicas: 1
spec:
containers:
- name: worker
Expand All @@ -23,3 +23,5 @@ spec:
env:
- name: WORKER_ENV
value: hello-world # We dont test the value, just the name
- name: DASK_WORKER_NAME
value: test-worker
89 changes: 72 additions & 17 deletions dask_kubernetes/operator/controller/tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

DIR = pathlib.Path(__file__).parent.absolute()


_EXPECTED_ANNOTATIONS = {"test-annotation": "annotation-value"}
_EXPECTED_LABELS = {"test-label": "label-value"}
DEFAULT_CLUSTER_NAME = "simple"
Expand All @@ -47,7 +46,6 @@ def gen_cluster(k8s_cluster, ns, gen_cluster_manifest):

@asynccontextmanager
async def cm(cluster_name=DEFAULT_CLUSTER_NAME):

cluster_path = gen_cluster_manifest(cluster_name)
# Create cluster resource
k8s_cluster.kubectl("apply", "-n", ns, "-f", cluster_path)
Expand Down Expand Up @@ -95,6 +93,36 @@ async def cm(job_file):
yield cm


@pytest.fixture()
def gen_worker_group(k8s_cluster, ns):
"""Yields an instantiated context manager for creating/deleting a worker group."""

@asynccontextmanager
async def cm(worker_group_file):
worker_group_path = os.path.join(DIR, "resources", worker_group_file)
with open(worker_group_path) as f:
worker_group_name = yaml.load(f, yaml.Loader)["metadata"]["name"]

# Create cluster resource
k8s_cluster.kubectl("apply", "-n", ns, "-f", worker_group_path)
while worker_group_name not in k8s_cluster.kubectl(
"get", "daskworkergroups.kubernetes.dask.org", "-n", ns
):
await asyncio.sleep(0.1)

try:
yield worker_group_name, ns
finally:
# Test: remove the wait=True, because I think this is blocking the operator
k8s_cluster.kubectl("delete", "-n", ns, "-f", worker_group_path)
while worker_group_name in k8s_cluster.kubectl(
"get", "daskworkergroups.kubernetes.dask.org", "-n", ns
):
await asyncio.sleep(0.1)

yield cm


def test_customresources(k8s_cluster):
assert "daskclusters.kubernetes.dask.org" in k8s_cluster.kubectl("get", "crd")
assert "daskworkergroups.kubernetes.dask.org" in k8s_cluster.kubectl("get", "crd")
Expand Down Expand Up @@ -671,32 +699,59 @@ async def test_object_dask_cluster(k8s_cluster, kopf_runner, gen_cluster):


@pytest.mark.anyio
async def test_object_dask_worker_group(k8s_cluster, kopf_runner, gen_cluster):
async def test_object_dask_worker_group(
k8s_cluster, kopf_runner, gen_cluster, gen_worker_group
):
with kopf_runner:
async with gen_cluster() as (cluster_name, ns):
async with (
gen_cluster() as (cluster_name, ns),
gen_worker_group("simpleworkergroup.yaml") as (
additional_workergroup_name,
_,
),
):
cluster = await DaskCluster.get(cluster_name, namespace=ns)
additional_workergroup = await DaskWorkerGroup.get(
additional_workergroup_name, namespace=ns
)

worker_groups = []
while not worker_groups:
worker_groups = await cluster.worker_groups()
await asyncio.sleep(0.1)
assert len(worker_groups) == 1 # Just the default worker group
wg = worker_groups[0]
assert isinstance(wg, DaskWorkerGroup)
worker_groups = worker_groups + [additional_workergroup]

pods = []
while not pods:
pods = await wg.pods()
await asyncio.sleep(0.1)
assert all([isinstance(p, Pod) for p in pods])
for wg in worker_groups:
assert isinstance(wg, DaskWorkerGroup)

deployments = []
while not deployments:
deployments = await wg.deployments()
await asyncio.sleep(0.1)
assert all([isinstance(d, Deployment) for d in deployments])
deployments = []
while not deployments:
deployments = await wg.deployments()
await asyncio.sleep(0.1)
assert all([isinstance(d, Deployment) for d in deployments])

assert (await wg.cluster()).name == cluster.name
pods = []
while not pods:
pods = await wg.pods()
await asyncio.sleep(0.1)
assert all([isinstance(p, Pod) for p in pods])

assert (await wg.cluster()).name == cluster.name

for deployment in deployments:
assert deployment.labels["dask.org/cluster-name"] == cluster.name
for env in deployment.spec["template"]["spec"]["containers"][0][
"env"
]:
if env["name"] == "DASK_WORKER_NAME":
if wg.name == additional_workergroup_name:
assert env["value"] == "test-worker"
else:
assert env["value"] == deployment.name
if env["name"] == "DASK_SCHEDULER_ADDRESS":
scheduler_service = await cluster.scheduler_service()
assert f"{scheduler_service.name}.{ns}" in env["value"]


@pytest.mark.anyio
Expand Down

0 comments on commit 93fe171

Please sign in to comment.