From 0fe1a3bcdeb7d25936d1014201ad74b66f250b49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20Zab=C5=82ocki?= Date: Fri, 10 Jun 2022 09:28:23 +0200 Subject: [PATCH] Allow specifying extra volumes (#129) Allow specifying extra volumes (#129) + Pydantic refactor Co-authored-by: marcin.zablocki --- .github/workflows/prepare-release.yml | 2 +- .github/workflows/publish.yml | 2 +- .github/workflows/python-package.yml | 4 +- CHANGELOG.md | 2 + README.md | 2 +- .../02_installation/02_configuration.md | 41 +++ kedro_kubeflow/config.py | 331 ++++++++---------- kedro_kubeflow/context_helper.py | 2 +- kedro_kubeflow/generators/utils.py | 32 +- kedro_kubeflow/kfpclient.py | 13 +- setup.py | 1 - tests/common.py | 16 + tests/test_cli.py | 4 +- tests/test_config.py | 237 ++++++++----- tests/test_config.yml | 2 + tests/test_context_helper.py | 7 +- tests/test_extra_volumes.py | 60 ++++ tests/test_kfpclient.py | 55 ++- tests/test_one_pod_pipeline_generator.py | 38 +- tests/test_pod_per_node_pipeline_generator.py | 61 +++- tox.ini | 2 +- 21 files changed, 604 insertions(+), 310 deletions(-) create mode 100644 tests/common.py create mode 100644 tests/test_extra_volumes.py diff --git a/.github/workflows/prepare-release.yml b/.github/workflows/prepare-release.yml index 4422e7a..0b4f327 100644 --- a/.github/workflows/prepare-release.yml +++ b/.github/workflows/prepare-release.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.7] + python-version: [3.8] env: PYTHON_PACKAGE: kedro_kubeflow steps: diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index cbb926f..1682055 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.7] + python-version: [3.8] env: PYTHON_PACKAGE: kedro_kubeflow steps: diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index e206b91..22f3acb 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -16,7 +16,7 @@ jobs: - name: Setup python uses: actions/setup-python@v2.2.1 with: - python-version: 3.7 + python-version: 3.8 - name: Setup virtualenv run: | @@ -33,7 +33,7 @@ jobs: - name: Test with tox run: | pip install tox-pip-version - tox -v -e py37 + tox -v -e py38 - name: Report coverage uses: paambaati/codeclimate-action@v2.7.5 diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a1c4ec..1e88b1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,8 @@ # Changelog ## [Unreleased] +- Added support for extra volumes per node +- Refactored configuration classes to Pydantic ## [0.6.4] - 2022-06-01 diff --git a/README.md b/README.md index f04d409..86aa07a 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Kedro Kubeflow Plugin -[![Python Version](https://img.shields.io/badge/python-3.7%20%7C%203.8-blue.svg)](https://github.com/getindata/kedro-kubeflow) +[![Python Version](https://img.shields.io/badge/python-3.8-blue.svg)](https://github.com/getindata/kedro-kubeflow) [![License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![SemVer](https://img.shields.io/badge/semver-2.0.0-green)](https://semver.org/) [![PyPI version](https://badge.fury.io/py/kedro-kubeflow.svg)](https://pypi.org/project/kedro-kubeflow/) diff --git a/docs/source/02_installation/02_configuration.md b/docs/source/02_installation/02_configuration.md index 94955d2..bd53831 100644 --- a/docs/source/02_installation/02_configuration.md +++ b/docs/source/02_installation/02_configuration.md @@ -104,6 +104,18 @@ run_config: value: "gpu_workload" effect: "NoSchedule" + # Optional section to allow mounting additional volumes (such as EmptyDir) + # to specific nodes + extra_volumes: + tensorflow_step: + - mount_path: /dev/shm + volume: + name: shared_memory + empty_dir: + cls: V1EmptyDirVolumeSource + params: + medium: Memory + # Optional section allowing adjustment of the resources # reservations and limits for the nodes resources: @@ -151,3 +163,32 @@ can later inject in configuration file using `${name}` syntax. There are two special variables `KEDRO_CONFIG_COMMIT_ID`, `KEDRO_CONFIG_BRANCH_NAME` with support specifying default when variable is not set, e.g. `${commit_id|dirty}` + +## Extra volumes +You can mount additional volumes (such as `emptyDir`) to specific nodes by using `extra_volumes` config node. +The syntax of the configuration allows to define k8s SDK compatible class hierarchy similar to the way you would define it in the KFP DSL, e.g: +```python +# KFP DSL +volume = dsl.PipelineVolume(volume=k8s.client.V1Volume( + name="shared_memory", + empty_dir=k8s.client.V1EmptyDirVolumeSource(medium='Memory'))) + +training_op.add_pvolumes({'/dev/shm': volume}) +``` +will translate to the following Kedro-Kubeflow config: +```yaml +extra_volumes: + training_op: + - mount_path: /dev/shm + volume: + name: shared_memory + empty_dir: + cls: V1EmptyDirVolumeSource + params: + medium: Memory +``` + +In general, the `volume` key accepts a dictionary with the keys being the named parameters for the [V1Volume](https://github.com/kubernetes-client/python/blob/be9a47e57358e3701ad079c98e223d3437ba1f46/kubernetes/docs/V1Volume.md) and values being one of: +* dictionary with `cls` and `params` keys (to define nested objects) - see `kedro_kubeflow.config.ObjectKwargs` +* list of values / list of dictionaries (`kedro_kubeflow.config.ObjectKwargs`) as described above +* values (`str`, `int` etc.) diff --git a/kedro_kubeflow/config.py b/kedro_kubeflow/config.py index eb94d50..fd9d00e 100644 --- a/kedro_kubeflow/config.py +++ b/kedro_kubeflow/config.py @@ -1,6 +1,15 @@ +import logging import os +from collections import defaultdict +from enum import Enum +from importlib import import_module +from typing import Any, Dict, List, Optional, Type, Union -from kedro.config import MissingConfigException +from kubernetes import client as k8s_client +from kubernetes.client import V1Volume +from pydantic import BaseModel, validator + +logger = logging.getLogger(__name__) DEFAULT_CONFIG_TEMPLATE = """ # Base url of the Kubeflow Pipelines, should include the schema (http/https) @@ -12,7 +21,7 @@ # Name of the image to run as the pipeline steps image: {image} - # Pull pilicy to be used for the steps. Use Always if you push the images + # Pull policy to be used for the steps. Use Always if you push the images # on the same tag, or Never if you use only local images image_pull_policy: IfNotPresent @@ -114,6 +123,18 @@ cpu: 200m memory: 64Mi + # Optional section to allow mounting additional volumes (such as EmptyDir) + # to specific nodes + extra_volumes: + tensorflow_step: + - mount_path: /dev/shm + volume: + name: shared_memory + empty_dir: + cls: V1EmptyDirVolumeSource + params: + medium: Memory + # Optional section to provide retry policy for the steps # and default policy for steps with no policy specified retry_policy: @@ -146,211 +167,169 @@ """ -class Config(object): - def __init__(self, raw): - self._raw = raw - - def _get_or_default(self, prop, default): - return self._raw.get(prop, default) - - def _get_or_fail(self, prop): - if prop in self._raw.keys(): - return self._raw[prop] - else: - raise MissingConfigException( - f"Missing required configuration: '{self._get_prefix()}{prop}'." - ) - - def _get_prefix(self): - return "" - - def __eq__(self, other): - return self._raw == other._raw - - -class VolumeConfig(Config): - @property - def storageclass(self): - return self._get_or_default("storageclass", None) - - @property - def size(self): - return self._get_or_default("size", "1Gi") - - @property - def access_modes(self): - return self._get_or_default("access_modes", ["ReadWriteOnce"]) - - @property - def skip_init(self): - return self._get_or_default("skip_init", False) - - @property - def keep(self): - return self._get_or_default("keep", False) - - @property - def owner(self): - return self._get_or_default("owner", 0) - - def _get_prefix(self): - return "run_config.volume." - - -class NodeResources(Config): - def is_set_for(self, node_name): - return self.get_for(node_name) != {} - - def get_for(self, node_name): - defaults = self._get_or_default("__default__", {}) - node_specific = self._get_or_default(node_name, {}) - return {**defaults, **node_specific} - - -class Tolerations(Config): - def is_set_for(self, node_name): - return bool(self.get_for(node_name)) +class DefaultConfigDict(defaultdict): + def __getitem__(self, key): + defaults: BaseModel = super().__getitem__("__default__") + this: BaseModel = super().__getitem__(key) + return ( + defaults.copy(update=this.dict(exclude_none=True)) + if defaults + else this + ) - def get_for(self, node_name): - node_values = self._get_or_default(node_name, []) - if node_values: - return node_values - return self._get_or_default("__default__", []) +class ResourceConfig(BaseModel): + cpu: Optional[str] + memory: Optional[str] -class RetryPolicy(Config): - def is_set_for(self, node_name): - return self.get_for(node_name) != {} - def get_for(self, node_name): - defaults = self._get_or_default("__default__", {}) - node_specific = self._get_or_default(node_name, {}) - values = {**defaults, **node_specific} - if values == {}: - return {} - values["num_retries"] = int(values.get("num_retries", 0)) - values["backoff_factor"] = ( - float(values["backoff_factor"]) - if "backoff_factor" in values - else None - ) - values["backoff_duration"] = ( - str(values["backoff_duration"]) - if "backoff_duration" in values - else None - ) - return values +class TolerationConfig(BaseModel): + key: str + operator: str + value: Optional[str] = None + effect: str -class RunConfig(Config): - @property - def image(self): - return self._get_or_fail("image") +class RetryPolicyConfig(BaseModel): + num_retries: int + backoff_duration: str + backoff_factor: float - @property - def image_pull_policy(self): - return self._get_or_default("image_pull_policy", "IfNotPresent") - @property - def root(self): - return self._get_or_fail("root") +class VolumeConfig(BaseModel): + storageclass: Optional[str] = None + size: str = "1Gi" + access_modes: List[str] = ["ReadWriteOnce"] + skip_init: bool = False + keep: bool = False + owner: int = 0 - @property - def experiment_name(self): - return self._get_or_fail("experiment_name") - @property - def run_name(self): - return self._get_or_fail("run_name") +class NodeMergeStrategyEnum(str, Enum): + none = "none" + full = "full" - @property - def scheduled_run_name(self): - return self._get_or_default( - "scheduled_run_name", self._get_or_fail("run_name") - ) - @property - def description(self): - return self._get_or_default("description", None) +class ObjectKwargs(BaseModel): + cls: str + params: Dict[str, Union["ObjectKwargs", Any]] - @property - def resources(self): - return NodeResources(self._get_or_default("resources", {})) - @property - def tolerations(self): - return Tolerations(self._get_or_default("tolerations", {})) +class ExtraVolumeConfig(BaseModel): + volume: Dict[str, Union[ObjectKwargs, List[ObjectKwargs], Any]] + mount_path: str - @property - def retry_policy(self): - return RetryPolicy(self._get_or_default("retry_policy", {})) + def as_v1volume(self) -> V1Volume: + return self._construct_v1_volume(self.volume) - @property - def volume(self): - if "volume" in self._raw.keys(): - cfg = self._get_or_fail("volume") - return VolumeConfig(cfg) + @staticmethod + def _resolve_cls(cls_name): + if hasattr(k8s_client, cls_name): + return getattr(k8s_client, cls_name, None) else: - return None + module_name, class_name = cls_name.rsplit(".", 1) + module = import_module(module_name) + return getattr(module, class_name, None) - @property - def wait_for_completion(self): - return bool(self._get_or_default("wait_for_completion", False)) + @staticmethod + def _construct(value: Union[ObjectKwargs, Any]): + if isinstance(value, ObjectKwargs): + assert ( + actual_cls := ExtraVolumeConfig._resolve_cls(value.cls) + ) is not None, f"Cannot import class {value.cls}" + return actual_cls( + **{ + k: ExtraVolumeConfig._construct(v) + for k, v in value.params.items() + } + ) + elif isinstance(value, list): + return [ + ExtraVolumeConfig._construct(ObjectKwargs.parse_obj(v)) + for v in value + ] + else: + return value - @property - def store_kedro_outputs_as_kfp_artifacts(self): - return bool( - self._get_or_default("store_kedro_outputs_as_kfp_artifacts", True) + @classmethod + def _construct_v1_volume(cls, value: dict): + return V1Volume( + **{k: ExtraVolumeConfig._construct(v) for k, v in value.items()} ) - @property - def max_cache_staleness(self): - return str(self._get_or_default("max_cache_staleness", None)) - - @property - def ttl(self): - return int(self._get_or_default("ttl", 3600 * 24 * 7)) - - @property - def on_exit_pipeline(self): - return self._get_or_default("on_exit_pipeline", None) - - @property - def node_merge_strategy(self): - strategy = str(self._get_or_default("node_merge_strategy", "none")) - if strategy not in ["none", "full"]: - raise ValueError( - f"Invalid {self._get_prefix()}node_merge_strategy: {strategy}" + @validator("volume") + def volume_validator(cls, value): + try: + cls._construct_v1_volume(value) + except Exception as ex: + logger.exception( + "Cannot construct kubernetes.client.models.v1_volume.V1Volume " + "from the passed `volume` field", ) - else: - return strategy + raise ex + return value - def _get_prefix(self): - return "run_config." +class RunConfig(BaseModel): + def __init__(self, **kwargs): + super().__init__(**kwargs) -class PluginConfig(Config): - @property - def host(self): - return self._get_or_fail("host") + if "scheduled_run_name" not in kwargs: + self.scheduled_run_name = kwargs["run_name"] - @property - def run_config(self): - cfg = self._get_or_fail("run_config") - return RunConfig(cfg) + @staticmethod + def _create_default_dict_with( + value: dict, default, dict_cls: Type = DefaultConfigDict + ): + default_value = (value := value or {}).get("__default__", default) + return dict_cls(lambda: default_value, value) + + @validator("resources", always=True) + def _validate_resources(cls, value): + return RunConfig._create_default_dict_with( + value, ResourceConfig(cpu="500m", memory="1024Mi") + ) + + @validator("retry_policy", always=True) + def _validate_retry_policy(cls, value): + return RunConfig._create_default_dict_with(value, None) + + @validator("tolerations", always=True) + def _validate_tolerations(cls, value): + return RunConfig._create_default_dict_with(value, [], defaultdict) + + @validator("extra_volumes", always=True) + def _validate_extra_volumes(cls, value): + return RunConfig._create_default_dict_with(value, [], defaultdict) + + image: str + image_pull_policy: str = "IfNotPresent" + root: Optional[str] + experiment_name: str + run_name: str + scheduled_run_name: Optional[str] + description: Optional[str] = None + resources: Optional[Dict[str, ResourceConfig]] + tolerations: Optional[Dict[str, List[TolerationConfig]]] + retry_policy: Optional[Dict[str, Optional[RetryPolicyConfig]]] + volume: Optional[VolumeConfig] = None + extra_volumes: Optional[Dict[str, List[ExtraVolumeConfig]]] = None + wait_for_completion: bool = False + store_kedro_outputs_as_kfp_artifacts: bool = True + max_cache_staleness: Optional[str] = None + ttl: int = 3600 * 24 * 7 + on_exit_pipeline: Optional[str] = None + node_merge_strategy: NodeMergeStrategyEnum = NodeMergeStrategyEnum.none + + +class PluginConfig(BaseModel): + host: str + run_config: RunConfig @staticmethod def sample_config(**kwargs): return DEFAULT_CONFIG_TEMPLATE.format(**kwargs) - @property - def project_id(self): - return self._get_or_fail("project_id") - - @property - def region(self): - return self._get_or_fail("region") - @staticmethod def initialize_github_actions(project_name, where, templates_dir): os.makedirs(where / ".github/workflows", exist_ok=True) diff --git a/kedro_kubeflow/context_helper.py b/kedro_kubeflow/context_helper.py index faeded5..b6755e0 100644 --- a/kedro_kubeflow/context_helper.py +++ b/kedro_kubeflow/context_helper.py @@ -67,7 +67,7 @@ def config(self) -> PluginConfig: raw = EnvTemplatedConfigLoader( self.context.config_loader.conf_paths ).get(self.CONFIG_FILE_PATTERN) - return PluginConfig(raw) + return PluginConfig(**raw) @property @lru_cache() diff --git a/kedro_kubeflow/generators/utils.py b/kedro_kubeflow/generators/utils.py index 5f9feac..5e6841d 100644 --- a/kedro_kubeflow/generators/utils.py +++ b/kedro_kubeflow/generators/utils.py @@ -10,6 +10,7 @@ from kfp.compiler._k8s_helper import sanitize_k8s_name from ..auth import IAP_CLIENT_ID +from ..config import RunConfig def ensure_json_serializable(value): @@ -121,23 +122,30 @@ def create_pipeline_exit_handler( ) -def customize_op(op, image_pull_policy, run_config): +def customize_op(op, image_pull_policy, run_config: RunConfig): op.container.set_image_pull_policy(image_pull_policy) if run_config.volume and run_config.volume.owner is not None: op.container.set_security_context( k8s.V1SecurityContext(run_as_user=run_config.volume.owner) ) - if run_config.resources.is_set_for(op.name): - op.container.resources = k8s.V1ResourceRequirements( - limits=run_config.resources.get_for(op.name), - requests=run_config.resources.get_for(op.name), - ) - if run_config.retry_policy.is_set_for(op.name): - op.set_retry( - policy="Always", **run_config.retry_policy.get_for(op.name) + resources = run_config.resources[op.name].dict(exclude_none=True) + op.container.resources = k8s.V1ResourceRequirements( + limits=resources, + requests=resources, + ) + + if retry_policy := run_config.retry_policy[op.name]: + op.set_retry(policy="Always", **retry_policy.dict()) + + for toleration in run_config.tolerations[op.name]: + op.add_toleration(k8s.V1Toleration(**toleration.dict())) + + if extra_volumes := run_config.extra_volumes[op.name]: + op.add_pvolumes( + { + ev.mount_path: dsl.PipelineVolume(volume=ev.as_v1volume()) + for ev in extra_volumes + } ) - if run_config.tolerations.is_set_for(op.name): - for toleration in run_config.tolerations.get_for(op.name): - op.add_toleration(k8s.V1Toleration(**toleration)) return op diff --git a/kedro_kubeflow/kfpclient.py b/kedro_kubeflow/kfpclient.py index 239cfb8..7585914 100644 --- a/kedro_kubeflow/kfpclient.py +++ b/kedro_kubeflow/kfpclient.py @@ -14,6 +14,7 @@ ) from .auth import AuthHandler +from .config import NodeMergeStrategyEnum, PluginConfig from .utils import clean_name WAIT_TIMEOUT = 24 * 60 * 60 @@ -23,7 +24,7 @@ class KubeflowClient(object): log = logging.getLogger(__name__) - def __init__(self, config, project_name, context): + def __init__(self, config: PluginConfig, project_name, context): client_params = {} token = AuthHandler().obtain_id_token() if token is not None: @@ -40,18 +41,16 @@ def __init__(self, config, project_name, context): self.project_name = project_name self.pipeline_description = config.run_config.description - if config.run_config.node_merge_strategy == "none": + if config.run_config.node_merge_strategy == NodeMergeStrategyEnum.none: self.generator = PodPerNodePipelineGenerator( config, project_name, context ) - elif config.run_config.node_merge_strategy == "full": + elif ( + config.run_config.node_merge_strategy == NodeMergeStrategyEnum.full + ): self.generator = OnePodPipelineGenerator( config, project_name, context ) - else: - raise Exception( - f"Invalid `node_merge_strategy`: {config.run_config.node_merge_strategy}" - ) def list_pipelines(self): pipelines = self.client.list_pipelines(page_size=30).pipelines diff --git a/setup.py b/setup.py index e73c925..2f7cf9f 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,6 @@ python_requires=">=3", classifiers=[ "Development Status :: 4 - Beta", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", ], keywords="kedro kubeflow plugin", diff --git a/tests/common.py b/tests/common.py new file mode 100644 index 0000000..40baca8 --- /dev/null +++ b/tests/common.py @@ -0,0 +1,16 @@ +from pydantic.utils import deep_update + + +class MinimalConfigMixin: + def minimal_config(self, override=None): + minimal = { + "run_config": { + "image": "asd", + "experiment_name": "exp", + "run_name": "unit tests", + }, + "host": "localhost:8080", + } + if override: + minimal = deep_update(minimal, override) + return minimal diff --git a/tests/test_cli.py b/tests/test_cli.py index 08b8e09..d15009c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -24,7 +24,7 @@ from kedro_kubeflow.context_helper import ContextHelper test_config = PluginConfig( - { + **{ "host": "https://example.com", "run_config": { "image": "gcr.io/project-image/test", @@ -35,7 +35,7 @@ "volume": { "storageclass": "default", "size": "3Gi", - "access_modes": "[ReadWriteOnce]", + "access_modes": ["ReadWriteOnce"], }, }, } diff --git a/tests/test_config.py b/tests/test_config.py index d41fbf0..bd33360 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,13 +1,13 @@ import unittest import yaml -from kedro.config.config import MissingConfigException +from pydantic import ValidationError from kedro_kubeflow.config import PluginConfig +from tests.common import MinimalConfigMixin CONFIG_YAML = """ host: https://example.com - run_config: image: "gcr.io/project-image/test" image_pull_policy: "Always" @@ -25,9 +25,9 @@ """ -class TestPluginConfig(unittest.TestCase): +class TestPluginConfig(unittest.TestCase, MinimalConfigMixin): def test_plugin_config(self): - cfg = PluginConfig(yaml.safe_load(CONFIG_YAML)) + cfg = PluginConfig(**yaml.safe_load(CONFIG_YAML)) assert cfg.host == "https://example.com" assert cfg.run_config.image == "gcr.io/project-image/test" assert cfg.run_config.image_pull_policy == "Always" @@ -39,12 +39,12 @@ def test_plugin_config(self): assert cfg.run_config.volume.size == "3Gi" assert cfg.run_config.volume.keep is True assert cfg.run_config.volume.access_modes == ["ReadWriteOnce"] - assert cfg.run_config.resources.is_set_for("node1") is False + assert cfg.run_config.resources["node1"] is not None assert cfg.run_config.description == "My awesome pipeline" assert cfg.run_config.ttl == 300 def test_defaults(self): - cfg = PluginConfig({"run_config": {}}) + cfg = PluginConfig(**self.minimal_config()) assert cfg.run_config.image_pull_policy == "IfNotPresent" assert cfg.run_config.description is None SECONDS_IN_ONE_WEEK = 3600 * 24 * 7 @@ -52,48 +52,57 @@ def test_defaults(self): assert cfg.run_config.volume is None def test_missing_required_config(self): - cfg = PluginConfig({}) - with self.assertRaises(MissingConfigException): - print(cfg.host) + with self.assertRaises(ValidationError): + PluginConfig(**{}) def test_resources_default_only(self): cfg = PluginConfig( - {"run_config": {"resources": {"__default__": {"cpu": "100m"}}}} + **self.minimal_config( + {"run_config": {"resources": {"__default__": {"cpu": "100m"}}}} + ) ) - assert cfg.run_config.resources.is_set_for("node2") - assert cfg.run_config.resources.get_for("node2") == {"cpu": "100m"} - assert cfg.run_config.resources.is_set_for("node3") - assert cfg.run_config.resources.get_for("node3") == {"cpu": "100m"} + assert cfg.run_config.resources["node2"].cpu == "100m" + assert cfg.run_config.resources["node3"].cpu == "100m" def test_resources_no_default(self): cfg = PluginConfig( - {"run_config": {"resources": {"node2": {"cpu": "100m"}}}} + **self.minimal_config( + {"run_config": {"resources": {"node2": {"cpu": "100m"}}}} + ) + ) + assert cfg.run_config.resources["node2"].cpu == "100m" + self.assertDictEqual( + cfg.run_config.resources["node3"].dict(), + cfg.run_config.resources["__default__"].dict(), ) - assert cfg.run_config.resources.is_set_for("node2") - assert cfg.run_config.resources.get_for("node2") == {"cpu": "100m"} - assert cfg.run_config.resources.is_set_for("node3") is False def test_resources_default_and_node_specific(self): cfg = PluginConfig( - { - "run_config": { - "resources": { - "__default__": {"cpu": "200m", "memory": "64Mi"}, - "node2": {"cpu": "100m"}, + **self.minimal_config( + { + "run_config": { + "resources": { + "__default__": {"cpu": "200m", "memory": "64Mi"}, + "node2": {"cpu": "100m"}, + } } } - } + ) + ) + self.assertDictEqual( + cfg.run_config.resources["node2"].dict(), + { + "cpu": "100m", + "memory": "64Mi", + }, + ) + self.assertDictEqual( + cfg.run_config.resources["node3"].dict(), + { + "cpu": "200m", + "memory": "64Mi", + }, ) - assert cfg.run_config.resources.is_set_for("node2") - assert cfg.run_config.resources.get_for("node2") == { - "cpu": "100m", - "memory": "64Mi", - } - assert cfg.run_config.resources.is_set_for("node3") - assert cfg.run_config.resources.get_for("node3") == { - "cpu": "200m", - "memory": "64Mi", - } def test_tolerations_default_only(self): toleration_config = [ @@ -105,12 +114,21 @@ def test_tolerations_default_only(self): } ] cfg = PluginConfig( - {"run_config": {"tolerations": {"__default__": toleration_config}}} + **self.minimal_config( + { + "run_config": { + "tolerations": {"__default__": toleration_config} + } + } + ) + ) + + self.assertDictEqual( + cfg.run_config.tolerations["node2"][0].dict(), toleration_config[0] + ) + self.assertDictEqual( + cfg.run_config.tolerations["node3"][0].dict(), toleration_config[0] ) - assert cfg.run_config.tolerations.is_set_for("node2") - assert cfg.run_config.tolerations.get_for("node2") == toleration_config - assert cfg.run_config.tolerations.is_set_for("node3") - assert cfg.run_config.tolerations.get_for("node3") == toleration_config def test_tolerations_no_default(self): toleration_config = [ @@ -122,11 +140,23 @@ def test_tolerations_no_default(self): } ] cfg = PluginConfig( - {"run_config": {"tolerations": {"node2": toleration_config}}} + **self.minimal_config( + {"run_config": {"tolerations": {"node2": toleration_config}}} + ) + ) + + self.assertDictEqual( + cfg.run_config.tolerations["node2"][0].dict(), toleration_config[0] + ) + + assert ( + isinstance(cfg.run_config.tolerations["node2"], list) + and len(cfg.run_config.tolerations["node2"]) == 1 + ) + assert ( + isinstance(cfg.run_config.tolerations["node3"], list) + and len(cfg.run_config.tolerations["node3"]) == 0 ) - assert cfg.run_config.tolerations.is_set_for("node2") - assert cfg.run_config.tolerations.get_for("node2") == toleration_config - assert cfg.run_config.tolerations.is_set_for("node3") is False def test_tolerations_default_and_node_specific(self): toleration_config = [ @@ -146,60 +176,103 @@ def test_tolerations_default_and_node_specific(self): } ] cfg = PluginConfig( - { - "run_config": { - "tolerations": { - "__default__": default_toleration_config, - "node2": toleration_config, + **self.minimal_config( + { + "run_config": { + "tolerations": { + "__default__": default_toleration_config, + "node2": toleration_config, + } } } - } + ) ) - assert cfg.run_config.tolerations.is_set_for("node2") - assert cfg.run_config.tolerations.get_for("node2") == toleration_config - assert cfg.run_config.tolerations.is_set_for("node3") - assert ( - cfg.run_config.tolerations.get_for("node3") - == default_toleration_config + + self.assertDictEqual( + cfg.run_config.tolerations["node2"][0].dict(), toleration_config[0] + ) + self.assertDictEqual( + cfg.run_config.tolerations["node3"][0].dict(), + default_toleration_config[0], ) def test_do_not_keep_volume_by_default(self): - cfg = PluginConfig({"run_config": {"volume": {}}}) + cfg = PluginConfig( + **self.minimal_config(override={"run_config": {"volume": {}}}) + ) assert cfg.run_config.volume.keep is False def test_reuse_run_name_for_scheduled_run_name(self): - cfg = PluginConfig({"run_config": {"run_name": "some run"}}) + cfg = PluginConfig( + **self.minimal_config({"run_config": {"run_name": "some run"}}) + ) assert cfg.run_config.run_name == "some run" assert cfg.run_config.scheduled_run_name == "some run" def test_retry_policy_default_and_node_specific(self): cfg = PluginConfig( + **self.minimal_config( + { + "run_config": { + "retry_policy": { + "__default__": { + "num_retries": 4, + "backoff_duration": "60s", + "backoff_factor": 2, + }, + "node3": { + "num_retries": "100", + "backoff_duration": "5m", + "backoff_factor": 1, + }, + } + } + } + ) + ) + + self.assertDictEqual( + cfg.run_config.retry_policy["node2"].dict(), { - "run_config": { - "retry_policy": { - "__default__": { - "num_retries": 4, - "backoff_duration": "60s", - "backoff_factor": 2, - }, - "node3": { - "num_retries": "100", - "backoff_duration": "5m", - "backoff_factor": 1, - }, + "backoff_duration": "60s", + "backoff_factor": 2, + "num_retries": 4, + }, + ) + + self.assertDictEqual( + cfg.run_config.retry_policy["node3"].dict(), + { + "backoff_duration": "5m", + "backoff_factor": 1, + "num_retries": 100, + }, + ) + + def test_retry_policy_no_default(self): + cfg = PluginConfig( + **self.minimal_config( + { + "run_config": { + "retry_policy": { + "node3": { + "num_retries": "100", + "backoff_duration": "5m", + "backoff_factor": 1, + }, + } } } - } + ) ) - assert cfg.run_config.retry_policy.is_set_for("node2") - assert cfg.run_config.retry_policy.get_for("node2") == { - "backoff_duration": "60s", - "backoff_factor": 2, - "num_retries": 4, - } - assert cfg.run_config.retry_policy.is_set_for("node3") - assert cfg.run_config.retry_policy.get_for("node3") == { - "backoff_duration": "5m", - "backoff_factor": 1, - "num_retries": 100, - } + + self.assertDictEqual( + cfg.run_config.retry_policy["node3"].dict(), + { + "num_retries": 100, + "backoff_duration": "5m", + "backoff_factor": 1.0, + }, + ) + + assert cfg.run_config.retry_policy["node2"] is None diff --git a/tests/test_config.yml b/tests/test_config.yml index c698dcb..2bec1ff 100644 --- a/tests/test_config.yml +++ b/tests/test_config.yml @@ -2,3 +2,5 @@ run_config: image: "gcr.io/project-image/${commit_id|dirty}" experiment_name: "[Test] ${branch_name|local}" run_name: "${xyz123|dirty}" + host: http://localhost:1234 + diff --git a/tests/test_context_helper.py b/tests/test_context_helper.py index 61af7ab..7df7292 100644 --- a/tests/test_context_helper.py +++ b/tests/test_context_helper.py @@ -11,10 +11,11 @@ EnvTemplatedConfigLoader, ) +from .common import MinimalConfigMixin from .utils import environment -class TestContextHelper(unittest.TestCase): +class TestContextHelper(unittest.TestCase, MinimalConfigMixin): def test_init_different_kedro_versions(self): with patch("kedro_kubeflow.context_helper.kedro_version", "0.16.0"): @@ -48,9 +49,9 @@ def test_config(self): with patch.object(KedroSession, "create", context), patch( "kedro_kubeflow.context_helper.EnvTemplatedConfigLoader" ) as config_loader: - config_loader.return_value.get.return_value = {} + config_loader.return_value.get.return_value = self.minimal_config() helper = ContextHelper.init(metadata, "test") - assert helper.config == PluginConfig({}) + assert helper.config == PluginConfig(**self.minimal_config()) class TestEnvTemplatedConfigLoader(unittest.TestCase): diff --git a/tests/test_extra_volumes.py b/tests/test_extra_volumes.py new file mode 100644 index 0000000..9cb06ee --- /dev/null +++ b/tests/test_extra_volumes.py @@ -0,0 +1,60 @@ +import unittest +from io import StringIO + +import yaml +from kubernetes.client import V1EmptyDirVolumeSource, V1KeyToPath, V1Volume + +from kedro_kubeflow.config import ExtraVolumeConfig + + +class TestExtraVolumes(unittest.TestCase): + def test_can_construct_volumes_object_from_yaml(self): + volumes_yaml = """ +mount_path: /dev/shm +volume: + name: unit_tests_volume + empty_dir: + cls: V1EmptyDirVolumeSource + params: + medium: Memory + """.strip() + + volumes_dict = yaml.safe_load(StringIO(volumes_yaml)) + volumes_cfg: ExtraVolumeConfig = ExtraVolumeConfig.parse_obj( + volumes_dict + ) + volume_def: V1Volume = volumes_cfg.as_v1volume() + assert volumes_cfg is not None and volume_def is not None + assert ( + volume_def.empty_dir is not None + and isinstance(volume_def.empty_dir, V1EmptyDirVolumeSource) + and volume_def.empty_dir.medium == "Memory" + ) + + def test_can_construct_volume_from_arbitrary_class(self): + volumes_yaml = """ +mount_path: /dev/shm +volume: + name: unit_tests_volume + config_map: + cls: V1ConfigMapVolumeSource + params: + default_mode: 644 + items: + - cls: kubernetes.client.models.v1_key_to_path.V1KeyToPath + params: {"key": "abc", "path": "./myfile"} + name: asdf + """.strip() + + volumes_cfg: ExtraVolumeConfig = ExtraVolumeConfig.parse_obj( + yaml.safe_load(StringIO(volumes_yaml)) + ) + + volume_def: V1Volume = volumes_cfg.as_v1volume() + assert volumes_cfg is not None and volume_def is not None + assert volume_def.config_map is not None + assert ( + volume_def.config_map.default_mode == 644 + and isinstance(volume_def.config_map.items, list) + and isinstance(volume_def.config_map.items[0], V1KeyToPath) + ) diff --git a/tests/test_kfpclient.py b/tests/test_kfpclient.py index aeef841..382d9b7 100644 --- a/tests/test_kfpclient.py +++ b/tests/test_kfpclient.py @@ -8,11 +8,15 @@ from kfp import dsl from kedro_kubeflow.config import PluginConfig +from kedro_kubeflow.generators.one_pod_pipeline_generator import ( + OnePodPipelineGenerator, +) from kedro_kubeflow.kfpclient import KubeflowClient from kedro_kubeflow.utils import strip_margin +from tests.common import MinimalConfigMixin -class TestKubeflowClient(unittest.TestCase): +class TestKubeflowClient(unittest.TestCase, MinimalConfigMixin): def create_experiment(self, id="123"): return type("obj", (object,), {"id": id}) @@ -182,7 +186,11 @@ def test_should_use_jwt_token_in_kfp_client( # when self.client_under_test = KubeflowClient( - PluginConfig({"host": "http://unittest", "run_config": {}}), + PluginConfig( + **self.minimal_config( + {"host": "http://unittest", "run_config": {}} + ) + ), None, None, ) @@ -206,7 +214,11 @@ def test_should_use_dex_session_in_kfp_client( # when self.client_under_test = KubeflowClient( - PluginConfig({"host": "http://unittest", "run_config": {}}), + PluginConfig( + **self.minimal_config( + {"host": "http://unittest", "run_config": {}} + ) + ), None, None, ) @@ -336,6 +348,24 @@ def test_should_upload_new_pipeline(self): ) assert kwargs["description"] == "Very Important Pipeline" + @patch("kedro_kubeflow.kfpclient.Client") + @patch("kedro.framework.context.context.KedroContext") + def test_can_create_client_with_node_strategy_full(self, context, _): + client = KubeflowClient( + PluginConfig( + **self.minimal_config( + { + "host": "http://unittest", + "run_config": {"node_merge_strategy": "full"}, + } + ) + ), + "unit-test-project", + context, + ) + + assert isinstance(client.generator, OnePodPipelineGenerator) + def test_should_truncated_the_pipeline_name_to_100_characters_on_upload( self, ): @@ -380,24 +410,31 @@ def test_should_upload_new_version_of_existing_pipeline(self): def test_should_raise_error_if_invalid_node_merge_strategy( self, kfp_client_mock ): - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as raises: KubeflowClient( PluginConfig( - { - "host": "http://unittest", - "run_config": {"node_merge_strategy": "other"}, - } + **self.minimal_config( + { + "host": "http://unittest", + "run_config": {"node_merge_strategy": "other"}, + } + ) ), None, None, ) + assert "validation error" in str(raises.exception) @patch("kedro_kubeflow.kfpclient.PodPerNodePipelineGenerator") @patch("kedro_kubeflow.kfpclient.Client") def create_client(self, config, kfp_client_mock, pipeline_generator_mock): project_name = "my-awesome-project" self.client_under_test = KubeflowClient( - PluginConfig({"host": "http://unittest", "run_config": config}), + PluginConfig( + **self.minimal_config( + {"host": "http://unittest", "run_config": config} + ) + ), project_name, None, # context, ) diff --git a/tests/test_one_pod_pipeline_generator.py b/tests/test_one_pod_pipeline_generator.py index 07da567..a8c6624 100644 --- a/tests/test_one_pod_pipeline_generator.py +++ b/tests/test_one_pod_pipeline_generator.py @@ -13,13 +13,14 @@ from kedro_kubeflow.generators.one_pod_pipeline_generator import ( OnePodPipelineGenerator, ) +from tests.common import MinimalConfigMixin def identity(input1: str): return input1 # pragma: no cover -class TestGenerator(unittest.TestCase): +class TestGenerator(unittest.TestCase, MinimalConfigMixin): def test_support_modification_of_pull_policy(self): # given self.create_generator() @@ -69,7 +70,7 @@ def test_should_support_params_and_inject_them_to_the_node(self): "{{pipelineparam:op=;name=param3}}", ] - def test_should_not_add_resources_spec_if_not_requested(self): + def test_should_use_default_resources_spec_if_not_requested(self): # given self.create_generator(config={}) @@ -80,7 +81,11 @@ def test_should_not_add_resources_spec_if_not_requested(self): )() # then - assert dsl_pipeline.ops["pipeline"].container.resources is None + assert dsl_pipeline.ops["pipeline"].container.resources is not None + assert dsl_pipeline.ops["pipeline"].container.resources.limits["cpu"] + assert dsl_pipeline.ops["pipeline"].container.resources.limits[ + "memory" + ] def test_should_add_resources_spec(self): # given @@ -275,6 +280,29 @@ def test_should_generate_exit_handler_if_requested(self): ) ) + def test_should_generate_exit_handler_with_max_staleness(self): + # given + self.create_generator( + config={ + "on_exit_pipeline": "notify_via_slack", + "max_cache_staleness": "P0D", + } + ) + + # when + with kfp.dsl.Pipeline(None) as dsl_pipeline: + pipeline = self.generator_under_test.generate_pipeline( + "pipeline", "unittest-image", "Always" + ) + pipeline() + + assert ( + dsl_pipeline.ops[ + "on-exit" + ].execution_options.caching_strategy.max_cache_staleness + == "P0D" + ) + def create_generator(self, config=None, params=None, catalog=None): if config is None: config = {} @@ -303,7 +331,9 @@ def create_generator(self, config=None, params=None, catalog=None): ) self.generator_under_test = OnePodPipelineGenerator( config=PluginConfig( - {"host": "http://unittest", "run_config": config} + **self.minimal_config( + {"host": "http://unittest", "run_config": config} + ) ), project_name="my-awesome-project", context=context, diff --git a/tests/test_pod_per_node_pipeline_generator.py b/tests/test_pod_per_node_pipeline_generator.py index b52c430..0e20ad3 100644 --- a/tests/test_pod_per_node_pipeline_generator.py +++ b/tests/test_pod_per_node_pipeline_generator.py @@ -12,13 +12,14 @@ from kedro_kubeflow.generators.pod_per_node_pipeline_generator import ( PodPerNodePipelineGenerator, ) +from tests.common import MinimalConfigMixin def identity(input1: str): return input1 # pragma: no cover -class TestGenerator(unittest.TestCase): +class TestGenerator(unittest.TestCase, MinimalConfigMixin): def test_support_modification_of_pull_policy(self): # given self.create_generator() @@ -264,7 +265,7 @@ def test_should_support_params_and_inject_them_to_the_nodes(self): "{{pipelineparam:op=;name=param2}}", ] - def test_should_not_add_resources_spec_if_not_requested(self): + def test_should_fallbackto_default_resources_spec_if_not_requested(self): # given self.create_generator(config={}) @@ -278,7 +279,7 @@ def test_should_not_add_resources_spec_if_not_requested(self): # then for node_name in ["node1", "node2"]: spec = dsl_pipeline.ops[node_name].container - assert spec.resources is None + assert spec.resources is not None def test_should_add_resources_spec(self): # given @@ -306,6 +307,35 @@ def test_should_add_resources_spec(self): assert node2_spec.limits == {"cpu": "100m"} assert node2_spec.requests == {"cpu": "100m"} + def test_can_add_extra_volumes(self): + self.create_generator( + config={ + "extra_volumes": { + "node1": [ + { + "mount_path": "/my/volume", + "volume": { + "name": "my_volume", + "empty_dir": { + "cls": "V1EmptyDirVolumeSource", + "params": {"medium": "Memory"}, + }, + }, + } + ] + } + } + ) + + pipeline = self.generator_under_test.generate_pipeline( + "pipeline", "unittest-image", "Always" + ) + with kfp.dsl.Pipeline(None) as dsl_pipeline: + pipeline() + + volume_mounts = dsl_pipeline.ops["node1"].container.volume_mounts + assert len(volume_mounts) == 1 + def test_should_not_add_retry_policy_if_not_requested(self): # given self.create_generator(config={}) @@ -364,6 +394,19 @@ def test_should_add_retry_policy(self): assert op2.backoff_duration == "60s" assert op2.backoff_max_duration is None + def test_should_add_max_cache_staleness(self): + self.create_generator(config={"max_cache_staleness": "P0D"}) + + with kfp.dsl.Pipeline(None) as dsl_pipeline: + self.generator_under_test.generate_pipeline( + "pipeline", "unittest-image", "Always" + )() + + op1 = dsl_pipeline.ops["node1"] + assert ( + op1.execution_options.caching_strategy.max_cache_staleness == "P0D" + ) + def test_should_set_description(self): # given self.create_generator(config={"description": "DESC"}) @@ -465,16 +508,16 @@ def test_should_pass_kedro_config_env_to_nodes(self): del os.environ["KEDRO_CONFIG_MY_KEY"] del os.environ["SOME_VALUE"] - def create_generator(self, config={}, params={}, catalog={}): + def create_generator(self, config=None, params=None, catalog=None): project_name = "my-awesome-project" config_loader = MagicMock() - config_loader.get.return_value = catalog + config_loader.get.return_value = catalog or {} context = type( "obj", (object,), { "env": "unittests", - "params": params, + "params": params or {}, "config_loader": config_loader, "pipelines": { "pipeline": Pipeline( @@ -487,7 +530,11 @@ def create_generator(self, config={}, params={}, catalog={}): }, ) self.generator_under_test = PodPerNodePipelineGenerator( - PluginConfig({"host": "http://unittest", "run_config": config}), + PluginConfig( + **self.minimal_config( + {"host": "http://unittest", "run_config": config or {}} + ) + ), project_name, context, ) diff --git a/tox.ini b/tox.ini index aaaa653..ab3af5f 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py37 +envlist = py38 [testenv] pip_version =