Skip to content

Commit

Permalink
feat: Changed resource config to allow custom values
Browse files Browse the repository at this point in the history
For fixing error with labels for nvidia gpus/tpus being not
propagated/transmitted to environments properly
  • Loading branch information
Lasica committed Nov 21, 2022
1 parent 222b9d2 commit 34b3245
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 40 deletions.
24 changes: 18 additions & 6 deletions kedro_kubeflow/config.py
Expand Up @@ -178,9 +178,13 @@ def __getitem__(self, key):
)


# class ResourceConfig(BaseModel):
# cpu: Optional[str]
# memory: Optional[str]
class ResourceConfig(dict):
def __getitem__(self, key):
defaults: dict = super().__getitem__("__default__")
this: dict = super().get(key, {})
updated_defaults = defaults.copy()
updated_defaults.update(this)
return updated_defaults


class TolerationConfig(BaseModel):
Expand Down Expand Up @@ -284,9 +288,17 @@ def _create_default_dict_with(
default_value = (value := value or {}).get("__default__", default)
return dict_cls(lambda: default_value, value)

# @validator("resources", always=True)
@validator("resources", always=True)
def _validate_resources(cls, value):
return {"cpu":"500m", "memory":"1024Mi"}.update(value)
default = ResourceConfig(
{"__default__": {"cpu": "500m", "memory": "1024Mi"}}
)
if isinstance(value, dict):
default.update(value)
# else:
# # throw some error?
# logger.error(value, "Unknown type")
return default

@validator("retry_policy", always=True)
def _validate_retry_policy(cls, value):
Expand All @@ -307,7 +319,7 @@ def _validate_extra_volumes(cls, value):
run_name: str
scheduled_run_name: Optional[str]
description: Optional[str] = None
resources: Optional[Dict[str, Dict[str,str]]]
resources: Optional[Dict[str, ResourceConfig]]
tolerations: Optional[Dict[str, List[TolerationConfig]]]
retry_policy: Optional[Dict[str, Optional[RetryPolicyConfig]]]
volume: Optional[VolumeConfig] = None
Expand Down
10 changes: 6 additions & 4 deletions kedro_kubeflow/generators/utils.py
Expand Up @@ -130,10 +130,12 @@ def customize_op(op, image_pull_policy, run_config: RunConfig):
op.container.set_security_context(
k8s.V1SecurityContext(run_as_user=run_config.volume.owner)
)

import IPython
IPython.embed()
resources = run_config.resources[op.name]#.dict(exclude_none=True)
resources = run_config.resources.get(
op.name, run_config.resources["__default__"]
)
for k, v in run_config.resources["__default__"].items():
if k not in resources:
resources[k] = v
op.container.resources = k8s.V1ResourceRequirements(
limits=resources,
requests=resources,
Expand Down
63 changes: 33 additions & 30 deletions tests/test_config.py
Expand Up @@ -28,27 +28,28 @@
class TestPluginConfig(unittest.TestCase, MinimalConfigMixin):
def test_plugin_config(self):
cfg = PluginConfig(**yaml.safe_load(CONFIG_YAML))
assert cfg.host == "https://example.com"
assert cfg.run_config.image == "gcr.io/project-image/test"
assert cfg.run_config.image_pull_policy == "Always"
assert cfg.run_config.experiment_name == "Test Experiment"
assert cfg.run_config.run_name == "test run"
assert cfg.run_config.scheduled_run_name == "scheduled run"
assert cfg.run_config.wait_for_completion
assert cfg.run_config.volume.storageclass == "default"
assert cfg.run_config.volume.size == "3Gi"
assert cfg.run_config.volume.keep is True
assert cfg.run_config.volume.access_modes == ["ReadWriteOnce"]
assert cfg.run_config.resources["node1"] is not None
assert cfg.run_config.description == "My awesome pipeline"
assert cfg.run_config.ttl == 300
self.assertEqual(cfg.host, "https://example.com")
self.assertEqual(cfg.run_config.image, "gcr.io/project-image/test")
self.assertEqual(cfg.run_config.image_pull_policy, "Always")
self.assertEqual(cfg.run_config.experiment_name, "Test Experiment")
self.assertEqual(cfg.run_config.run_name, "test run")
self.assertEqual(cfg.run_config.scheduled_run_name, "scheduled run")
self.assertTrue(cfg.run_config.wait_for_completion)
self.assertEqual(cfg.run_config.volume.storageclass, "default")
self.assertEqual(cfg.run_config.volume.size, "3Gi")
self.assertTrue(cfg.run_config.volume.keep)
self.assertEqual(cfg.run_config.volume.access_modes, ["ReadWriteOnce"])
self.assertIsNotNone(cfg.run_config.resources["node1"])
self.assertIsNotNone(cfg.run_config.resources["__default__"])
self.assertEqual(cfg.run_config.description, "My awesome pipeline")
self.assertEqual(cfg.run_config.ttl, 300)

def test_defaults(self):
cfg = PluginConfig(**self.minimal_config())
assert cfg.run_config.image_pull_policy == "IfNotPresent"
self.assertEqual(cfg.run_config.image_pull_policy, "IfNotPresent")
assert cfg.run_config.description is None
SECONDS_IN_ONE_WEEK = 3600 * 24 * 7
assert cfg.run_config.ttl == SECONDS_IN_ONE_WEEK
self.assertEqual(cfg.run_config.ttl, SECONDS_IN_ONE_WEEK)
assert cfg.run_config.volume is None

def test_missing_required_config(self):
Expand All @@ -61,19 +62,19 @@ def test_resources_default_only(self):
{"run_config": {"resources": {"__default__": {"cpu": "100m"}}}}
)
)
assert cfg.run_config.resources["node2"].cpu == "100m"
assert cfg.run_config.resources["node3"].cpu == "100m"
self.assertEqual(cfg.run_config.resources["node2"]["cpu"], "100m")
self.assertEqual(cfg.run_config.resources["node3"]["cpu"], "100m")

def test_resources_no_default(self):
cfg = PluginConfig(
**self.minimal_config(
{"run_config": {"resources": {"node2": {"cpu": "100m"}}}}
)
)
assert cfg.run_config.resources["node2"].cpu == "100m"
self.assertEqual(cfg.run_config.resources["node2"]["cpu"], "100m")
self.assertDictEqual(
cfg.run_config.resources["node3"].dict(),
cfg.run_config.resources["__default__"].dict(),
cfg.run_config.resources["node3"],
cfg.run_config.resources["__default__"],
)

def test_resources_default_and_node_specific(self):
Expand All @@ -90,14 +91,14 @@ def test_resources_default_and_node_specific(self):
)
)
self.assertDictEqual(
cfg.run_config.resources["node2"].dict(),
cfg.run_config.resources["node2"],
{
"cpu": "100m",
"memory": "64Mi",
},
)
self.assertDictEqual(
cfg.run_config.resources["node3"].dict(),
cfg.run_config.resources["node3"],
{
"cpu": "200m",
"memory": "64Mi",
Expand Down Expand Up @@ -149,13 +150,15 @@ def test_tolerations_no_default(self):
cfg.run_config.tolerations["node2"][0].dict(), toleration_config[0]
)

assert (
self.assertEqual(
isinstance(cfg.run_config.tolerations["node2"], list)
and len(cfg.run_config.tolerations["node2"]) == 1
and len(cfg.run_config.tolerations["node2"]),
1,
)
assert (
self.assertEqual(
isinstance(cfg.run_config.tolerations["node3"], list)
and len(cfg.run_config.tolerations["node3"]) == 0
and len(cfg.run_config.tolerations["node3"]),
0,
)

def test_tolerations_default_and_node_specific(self):
Expand Down Expand Up @@ -206,8 +209,8 @@ def test_reuse_run_name_for_scheduled_run_name(self):
cfg = PluginConfig(
**self.minimal_config({"run_config": {"run_name": "some run"}})
)
assert cfg.run_config.run_name == "some run"
assert cfg.run_config.scheduled_run_name == "some run"
self.assertEqual(cfg.run_config.run_name, "some run")
self.assertEqual(cfg.run_config.scheduled_run_name, "some run")

def test_retry_policy_default_and_node_specific(self):
cfg = PluginConfig(
Expand Down Expand Up @@ -275,4 +278,4 @@ def test_retry_policy_no_default(self):
},
)

assert cfg.run_config.retry_policy["node2"] is None
self.assertIsNone(cfg.run_config.retry_policy["node2"])

0 comments on commit 34b3245

Please sign in to comment.