Skip to content

Commit

Permalink
Merge pull request #202 from getindata/resource-cfg-rework
Browse files Browse the repository at this point in the history
WIP feat: Changed resource config to allow custom values
  • Loading branch information
Lasica committed Nov 23, 2022
2 parents 4cb366c + 17e80b2 commit ab70986
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 93 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -2,6 +2,8 @@

## [Unreleased]

- Removed field validation from resources configuration field - now it can take any custom parameters such as "nvidia.com/gpu":1

## [0.7.3] - 2022-09-23

- Fixed plugin config provider so it respects environment provided by the user
Expand Down
20 changes: 15 additions & 5 deletions kedro_kubeflow/config.py
Expand Up @@ -178,9 +178,13 @@ def __getitem__(self, key):
)


class ResourceConfig(BaseModel):
cpu: Optional[str]
memory: Optional[str]
class ResourceConfig(dict):
def __getitem__(self, key):
defaults: dict = super().get("__default__")
this: dict = super().get(key, {})
updated_defaults = defaults.copy()
updated_defaults.update(this)
return updated_defaults


class TolerationConfig(BaseModel):
Expand Down Expand Up @@ -286,9 +290,15 @@ def _create_default_dict_with(

@validator("resources", always=True)
def _validate_resources(cls, value):
return RunConfig._create_default_dict_with(
value, ResourceConfig(cpu="500m", memory="1024Mi")
default = ResourceConfig(
{"__default__": {"cpu": "500m", "memory": "1024Mi"}}
)
if isinstance(value, dict):
default.update(value)
elif value is not None:
logger.error(f"Unknown type for resource config {type(value)}")
raise TypeError(f"Unknown type for resource config {type(value)}")
return default

@validator("retry_policy", always=True)
def _validate_retry_policy(cls, value):
Expand Down
2 changes: 1 addition & 1 deletion kedro_kubeflow/generators/utils.py
Expand Up @@ -131,7 +131,7 @@ def customize_op(op, image_pull_policy, run_config: RunConfig):
k8s.V1SecurityContext(run_as_user=run_config.volume.owner)
)

resources = run_config.resources[op.name].dict(exclude_none=True)
resources = run_config.resources[op.name]
op.container.resources = k8s.V1ResourceRequirements(
limits=resources,
requests=resources,
Expand Down
86 changes: 56 additions & 30 deletions tests/test_config.py
Expand Up @@ -28,27 +28,28 @@
class TestPluginConfig(unittest.TestCase, MinimalConfigMixin):
def test_plugin_config(self):
cfg = PluginConfig(**yaml.safe_load(CONFIG_YAML))
assert cfg.host == "https://example.com"
assert cfg.run_config.image == "gcr.io/project-image/test"
assert cfg.run_config.image_pull_policy == "Always"
assert cfg.run_config.experiment_name == "Test Experiment"
assert cfg.run_config.run_name == "test run"
assert cfg.run_config.scheduled_run_name == "scheduled run"
assert cfg.run_config.wait_for_completion
assert cfg.run_config.volume.storageclass == "default"
assert cfg.run_config.volume.size == "3Gi"
assert cfg.run_config.volume.keep is True
assert cfg.run_config.volume.access_modes == ["ReadWriteOnce"]
assert cfg.run_config.resources["node1"] is not None
assert cfg.run_config.description == "My awesome pipeline"
assert cfg.run_config.ttl == 300
self.assertEqual(cfg.host, "https://example.com")
self.assertEqual(cfg.run_config.image, "gcr.io/project-image/test")
self.assertEqual(cfg.run_config.image_pull_policy, "Always")
self.assertEqual(cfg.run_config.experiment_name, "Test Experiment")
self.assertEqual(cfg.run_config.run_name, "test run")
self.assertEqual(cfg.run_config.scheduled_run_name, "scheduled run")
self.assertTrue(cfg.run_config.wait_for_completion)
self.assertEqual(cfg.run_config.volume.storageclass, "default")
self.assertEqual(cfg.run_config.volume.size, "3Gi")
self.assertTrue(cfg.run_config.volume.keep)
self.assertEqual(cfg.run_config.volume.access_modes, ["ReadWriteOnce"])
self.assertIsNotNone(cfg.run_config.resources["node1"])
self.assertIsNotNone(cfg.run_config.resources["__default__"])
self.assertEqual(cfg.run_config.description, "My awesome pipeline")
self.assertEqual(cfg.run_config.ttl, 300)

def test_defaults(self):
cfg = PluginConfig(**self.minimal_config())
assert cfg.run_config.image_pull_policy == "IfNotPresent"
self.assertEqual(cfg.run_config.image_pull_policy, "IfNotPresent")
assert cfg.run_config.description is None
SECONDS_IN_ONE_WEEK = 3600 * 24 * 7
assert cfg.run_config.ttl == SECONDS_IN_ONE_WEEK
self.assertEqual(cfg.run_config.ttl, SECONDS_IN_ONE_WEEK)
assert cfg.run_config.volume is None

def test_missing_required_config(self):
Expand All @@ -61,19 +62,42 @@ def test_resources_default_only(self):
{"run_config": {"resources": {"__default__": {"cpu": "100m"}}}}
)
)
assert cfg.run_config.resources["node2"].cpu == "100m"
assert cfg.run_config.resources["node3"].cpu == "100m"
self.assertEqual(cfg.run_config.resources["node2"]["cpu"], "100m")
self.assertEqual(cfg.run_config.resources["node3"]["cpu"], "100m")

def test_resources_gpu_label(self):
cfg = PluginConfig(
**self.minimal_config(
{
"run_config": {
"resources": {
"__default__": {
"cpu": "100m",
"nvidia.com/gpu": "1",
"nvidia.com/tpu": "1",
}
}
}
}
)
)
self.assertEqual(
cfg.run_config.resources["__default__"]["nvidia.com/gpu"], "1"
)
self.assertEqual(
cfg.run_config.resources["node3"]["nvidia.com/tpu"], "1"
)

def test_resources_no_default(self):
cfg = PluginConfig(
**self.minimal_config(
{"run_config": {"resources": {"node2": {"cpu": "100m"}}}}
)
)
assert cfg.run_config.resources["node2"].cpu == "100m"
self.assertEqual(cfg.run_config.resources["node2"]["cpu"], "100m")
self.assertDictEqual(
cfg.run_config.resources["node3"].dict(),
cfg.run_config.resources["__default__"].dict(),
cfg.run_config.resources["node3"],
cfg.run_config.resources["__default__"],
)

def test_resources_default_and_node_specific(self):
Expand All @@ -90,14 +114,14 @@ def test_resources_default_and_node_specific(self):
)
)
self.assertDictEqual(
cfg.run_config.resources["node2"].dict(),
cfg.run_config.resources["node2"],
{
"cpu": "100m",
"memory": "64Mi",
},
)
self.assertDictEqual(
cfg.run_config.resources["node3"].dict(),
cfg.run_config.resources["node3"],
{
"cpu": "200m",
"memory": "64Mi",
Expand Down Expand Up @@ -149,13 +173,15 @@ def test_tolerations_no_default(self):
cfg.run_config.tolerations["node2"][0].dict(), toleration_config[0]
)

assert (
self.assertEqual(
isinstance(cfg.run_config.tolerations["node2"], list)
and len(cfg.run_config.tolerations["node2"]) == 1
and len(cfg.run_config.tolerations["node2"]),
1,
)
assert (
self.assertEqual(
isinstance(cfg.run_config.tolerations["node3"], list)
and len(cfg.run_config.tolerations["node3"]) == 0
and len(cfg.run_config.tolerations["node3"]),
0,
)

def test_tolerations_default_and_node_specific(self):
Expand Down Expand Up @@ -206,8 +232,8 @@ def test_reuse_run_name_for_scheduled_run_name(self):
cfg = PluginConfig(
**self.minimal_config({"run_config": {"run_name": "some run"}})
)
assert cfg.run_config.run_name == "some run"
assert cfg.run_config.scheduled_run_name == "some run"
self.assertEqual(cfg.run_config.run_name, "some run")
self.assertEqual(cfg.run_config.scheduled_run_name, "some run")

def test_retry_policy_default_and_node_specific(self):
cfg = PluginConfig(
Expand Down Expand Up @@ -275,4 +301,4 @@ def test_retry_policy_no_default(self):
},
)

assert cfg.run_config.retry_policy["node2"] is None
self.assertIsNone(cfg.run_config.retry_policy["node2"])

0 comments on commit ab70986

Please sign in to comment.