diff --git a/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py b/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py index e81728ddb4..c38ad33834 100644 --- a/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py +++ b/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Any, Callable, Dict, Optional, Tuple, Union from flyteidl.core import tasks_pb2 as _core_task @@ -18,6 +19,7 @@ def _sanitize_resource_name(resource: _task_models.Resources.ResourceEntry) -> s return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-") +@dataclass class Pod(object): """ Pod is a platform-wide configuration that uses pod templates. By default, every task is launched as a container in a pod. @@ -29,39 +31,17 @@ class Pod(object): :param Optional[Dict[str, str]] annotations: Annotations are key/value pairs that are attached to arbitrary non-identifying metadata to pod spec. """ - def __init__( - self, - pod_spec: V1PodSpec, - primary_container_name: str, - labels: Optional[Dict[str, str]] = None, - annotations: Optional[Dict[str, str]] = None, - ): - if not pod_spec: + pod_spec: V1PodSpec + primary_container_name: str = _PRIMARY_CONTAINER_NAME_FIELD + labels: Optional[Dict[str, str]] = None + annotations: Optional[Dict[str, str]] = None + + def __post_init_(self): + if not self.pod_spec: raise _user_exceptions.FlyteValidationException("A pod spec cannot be undefined") - if not primary_container_name: + if not self.primary_container_name: raise _user_exceptions.FlyteValidationException("A primary container name cannot be undefined") - self._pod_spec = pod_spec - self._primary_container_name = primary_container_name - self._labels = labels - self._annotations = annotations - - @property - def pod_spec(self) -> V1PodSpec: - return self._pod_spec - - @property - def primary_container_name(self) -> str: - return self._primary_container_name - - @property - def labels(self) -> Optional[Dict[str, str]]: - return self._labels - - @property - def annotations(self) -> Optional[Dict[str, str]]: - return self._annotations - class PodFunctionTask(PythonFunctionTask[Pod]): def __init__(self, task_config: Pod, task_function: Callable, **kwargs): @@ -114,7 +94,7 @@ def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any] final_containers.append(container) - self.task_config._pod_spec.containers = final_containers + self.task_config.pod_spec.containers = final_containers return ApiClient().sanitize_for_serialization(self.task_config.pod_spec)