Skip to content

Commit

Permalink
PodSpec should not require primary_container name (#1380)
Browse files Browse the repository at this point in the history
For Pod tasks, if the primary_container_name is not specified, it should default.

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>
  • Loading branch information
kumare3 committed Dec 21, 2022
1 parent 604e9a6 commit 9ed0c18
Showing 1 changed file with 11 additions and 31 deletions.
42 changes: 11 additions & 31 deletions plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Tuple, Union

from flyteidl.core import tasks_pb2 as _core_task
Expand All @@ -18,6 +19,7 @@ def _sanitize_resource_name(resource: _task_models.Resources.ResourceEntry) -> s
return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")


@dataclass
class Pod(object):
"""
Pod is a platform-wide configuration that uses pod templates. By default, every task is launched as a container in a pod.
Expand All @@ -29,39 +31,17 @@ class Pod(object):
:param Optional[Dict[str, str]] annotations: Annotations are key/value pairs that are attached to arbitrary non-identifying metadata to pod spec.
"""

def __init__(
self,
pod_spec: V1PodSpec,
primary_container_name: str,
labels: Optional[Dict[str, str]] = None,
annotations: Optional[Dict[str, str]] = None,
):
if not pod_spec:
pod_spec: V1PodSpec
primary_container_name: str = _PRIMARY_CONTAINER_NAME_FIELD
labels: Optional[Dict[str, str]] = None
annotations: Optional[Dict[str, str]] = None

def __post_init_(self):
if not self.pod_spec:
raise _user_exceptions.FlyteValidationException("A pod spec cannot be undefined")
if not primary_container_name:
if not self.primary_container_name:
raise _user_exceptions.FlyteValidationException("A primary container name cannot be undefined")

self._pod_spec = pod_spec
self._primary_container_name = primary_container_name
self._labels = labels
self._annotations = annotations

@property
def pod_spec(self) -> V1PodSpec:
return self._pod_spec

@property
def primary_container_name(self) -> str:
return self._primary_container_name

@property
def labels(self) -> Optional[Dict[str, str]]:
return self._labels

@property
def annotations(self) -> Optional[Dict[str, str]]:
return self._annotations


class PodFunctionTask(PythonFunctionTask[Pod]):
def __init__(self, task_config: Pod, task_function: Callable, **kwargs):
Expand Down Expand Up @@ -114,7 +94,7 @@ def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]

final_containers.append(container)

self.task_config._pod_spec.containers = final_containers
self.task_config.pod_spec.containers = final_containers

return ApiClient().sanitize_for_serialization(self.task_config.pod_spec)

Expand Down

0 comments on commit 9ed0c18

Please sign in to comment.