diff --git a/plugins/flytekit-skypilot/flytekitplugins/skypilot/__init__.py b/plugins/flytekit-skypilot/flytekitplugins/skypilot/__init__.py index a5ae6b7de9..99e77ce39f 100644 --- a/plugins/flytekit-skypilot/flytekitplugins/skypilot/__init__.py +++ b/plugins/flytekit-skypilot/flytekitplugins/skypilot/__init__.py @@ -1,3 +1,3 @@ from .agent import SkyPilotAgent -from .task import SkyPilot, SkyPilotFunctionTask # noqa from .metadata import SkyPilotMetadata # noqa +from .task import SkyPilot, SkyPilotFunctionTask # noqa diff --git a/plugins/flytekit-skypilot/flytekitplugins/skypilot/agent.py b/plugins/flytekit-skypilot/flytekitplugins/skypilot/agent.py index 98c1bcfc9f..df5b37077e 100644 --- a/plugins/flytekit-skypilot/flytekitplugins/skypilot/agent.py +++ b/plugins/flytekit-skypilot/flytekitplugins/skypilot/agent.py @@ -1,36 +1,44 @@ -from typing import Optional, List, Dict, Any, Tuple, Callable -from dataclasses import dataclass, asdict import asyncio +import functools +import multiprocessing +import os +import traceback +from datetime import datetime, timezone +from typing import Callable, Dict, List, Optional + import sky -import sky.cli as sky_cli import sky.core import sky.exceptions -from sky.skylet import constants as skylet_constants import sky.resources -import os -import pdb -from flytekit.models.literals import LiteralMap -from flytekit import logger -from flytekit.models.task import TaskTemplate -from flytekit.extend.backend.base_agent import AsyncAgentBase, AgentRegistry, Resource, ResourceMeta -from flytekitplugins.skypilot.utils import skypilot_status_to_flyte_phase, execute_cmd_to_path, setup_cloud_credential -from flytekitplugins.skypilot.utils import LAUNCH_TYPE_TO_SKY_STATUS, COROUTINE_INTERVAL -from flytekitplugins.skypilot.utils import EventHandler, TaskStatus, SkyPathSetting, TaskRemotePathSetting +from flytekitplugins.skypilot.metadata import JobLaunchType, SkyPilotMetadata from flytekitplugins.skypilot.task_utils import get_sky_task_config -from flytekit.core.data_persistence import FileAccessProvider -from flytekitplugins.skypilot.metadata import SkyPilotMetadata, JobLaunchType -import multiprocessing -import functools -import traceback -from datetime import datetime, timezone +from flytekitplugins.skypilot.utils import ( + COROUTINE_INTERVAL, + LAUNCH_TYPE_TO_SKY_STATUS, + EventHandler, + SkyPathSetting, + TaskRemotePathSetting, + TaskStatus, + execute_cmd_to_path, + setup_cloud_credential, + skypilot_status_to_flyte_phase, +) +from sky.skylet import constants as skylet_constants +from flytekit import logger +from flytekit.core.data_persistence import FileAccessProvider +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate TASK_TYPE = "skypilot" + class WrappedProcess(multiprocessing.Process): - ''' + """ Wrapper for multiprocessing.Process to catch exceptions in the target function - ''' + """ + def __init__(self, *args, **kwargs) -> None: multiprocessing.Process.__init__(self, *args, **kwargs) self._pconn, self._cconn = multiprocessing.Pipe() @@ -50,14 +58,14 @@ def exception(self): if self._pconn.poll(): self._exception = self._pconn.recv() return self._exception - + class BlockingProcessHandler: def __init__(self, fn: Callable) -> None: self._process = WrappedProcess(target=fn) self._process.start() self._check_interval = COROUTINE_INTERVAL - + async def status_poller(self, event_handler: EventHandler): while self._process.exitcode is None: await asyncio.sleep(self._check_interval) @@ -72,11 +80,11 @@ async def status_poller(self, event_handler: EventHandler): self.clean_up() if launch_exception is not None: raise Exception(launch_exception) - + def get_task(self, event_handler: EventHandler): task = asyncio.create_task(self.status_poller(event_handler)) return task - + def clean_up(self): self._process.terminate() self._process.join() @@ -84,7 +92,6 @@ def clean_up(self): class SkyTaskFuture(object): - _job_id: int = -1 # not executed yet _task_kwargs: TaskTemplate = None _launch_coro: asyncio.Task = None _status_check_coro: asyncio.Task = None @@ -95,6 +102,7 @@ class SkyTaskFuture(object): _event_handler: EventHandler = None _launched_process: BlockingProcessHandler = None _task_status: TaskStatus = TaskStatus.INIT + def __init__(self, task_template: TaskTemplate): self._task_kwargs = task_template args = execute_cmd_to_path(task_template.container.args) @@ -102,26 +110,25 @@ def __init__(self, task_template: TaskTemplate): self._cluster_name = sky.jobs.utils.JOB_CONTROLLER_NAME else: self._cluster_name = self.task_template.custom["cluster_name"] - + self._sky_path_setting = TaskRemotePathSetting( - file_access=FileAccessProvider( - local_sandbox_dir="/tmp", - raw_output_prefix=args["raw_output_data_prefix"] - ), + file_access=FileAccessProvider(local_sandbox_dir="/tmp", raw_output_prefix=args["raw_output_data_prefix"]), job_type=task_template.custom["job_launch_type"], - cluster_name=self._cluster_name + cluster_name=self._cluster_name, ).from_task_prefix(task_template.custom["task_name"]) self._task_name = self.sky_path_setting.task_name def launch(self): setup_cloud_credential(show_check=True) - + cluster_name: str = self.task_template.custom["cluster_name"] # sky_resources: List[Dict[str, str]] = task_template.custom["resource_config"] sky_resources = get_sky_task_config(self.task_template) - sky_resources.update({ - "name": self.task_name, - }) + sky_resources.update( + { + "name": self.task_name, + } + ) task = sky.Task.from_yaml_config(sky_resources) logger.warning(f"Launching task... \nSetup: \n{task.setup}\nRun: \n{task.run}") # task.set_resources(sky_resources).set_file_mounts(self.task_template.custom["file_mounts"]) @@ -129,21 +136,17 @@ def launch(self): job_id = -1 if self.task_template.custom["job_launch_type"] == JobLaunchType.MANAGED: - sky.jobs.launch( - task=task, - detach_run=True, - stream_logs=False - ) + sky.jobs.launch(task=task, detach_run=True, stream_logs=False) else: job_id, _ = sky.launch( - task=task, - cluster_name=cluster_name, - backend=backend, + task=task, + cluster_name=cluster_name, + backend=backend, idle_minutes_to_autostop=self.task_template.custom["stop_after"], down=self.task_template.custom["auto_down"], detach_run=True, detach_setup=True, - stream_logs=False + stream_logs=False, ) return job_id @@ -154,11 +157,7 @@ def sky_path_setting(self) -> TaskRemotePathSetting: @property def task_template(self) -> TaskTemplate: return self._task_kwargs - - @property - def job_id(self) -> int: - return self._job_id - + @property def task_name(self) -> str: return self._task_name @@ -172,12 +171,8 @@ def task_status(self) -> TaskStatus: return self._task_status def launch_failed_callback( - self, - task: asyncio.Task, - cause: str, - clean_ups: List[Callable]=None, - cancel_on_done: bool=False - ): + self, task: asyncio.Task, cause: str, clean_ups: List[Callable] = None, cancel_on_done: bool = False + ): if clean_ups is None: clean_ups = [] error_cause = None @@ -186,8 +181,8 @@ def launch_failed_callback( if task.exception() is not None: try: error_cause = "Exception" - result = task.result() - except Exception as e: + task.result() + except Exception: # re-raise the exception error = traceback.format_exc() self.sky_path_setting.put_error_log(error) @@ -196,7 +191,7 @@ def launch_failed_callback( f"Cause: \n{error}\n" f"error_log is at {self.sky_path_setting.remote_error_log}" ) - + self._event_handler.failed_event.set() else: # normal exit, if this marks the end of all coroutines, cancel them all @@ -212,44 +207,32 @@ def launch_failed_callback( self._task_status = TaskStatus.DONE if self._cancel_callback is not None: self._cancel_callback(self) - + def launch_process_wrapper(self, fn: Callable): self._launched_process = BlockingProcessHandler(fn) return self._launched_process.get_task(self._event_handler) - def start(self, callback: Callable=None): - ''' + def start(self, callback: Callable = None): + """ create launch coroutine. create status check coroutine. create remote deleted check coroutine. - ''' + """ self._cancel_callback = callback self._event_handler = EventHandler() self._launch_coro = self.launch_process_wrapper(self.launch) self._launch_coro.add_done_callback( - functools.partial( - self.launch_failed_callback, - cause="launch", - cancel_on_done=False - ) + functools.partial(self.launch_failed_callback, cause="launch", cancel_on_done=False) ) self._status_check_coro = asyncio.create_task(self.sky_path_setting.deletion_status(self._event_handler)) self._status_check_coro.add_done_callback( - functools.partial( - self.launch_failed_callback, - cause="deletion status", - cancel_on_done=True - ) + functools.partial(self.launch_failed_callback, cause="deletion status", cancel_on_done=True) ) self._status_upload_coro = asyncio.create_task(self.sky_path_setting.put_task_status(self._event_handler)) self._status_upload_coro.add_done_callback( - functools.partial( - self.launch_failed_callback, - cause="upload status", - cancel_on_done=True - ) + functools.partial(self.launch_failed_callback, cause="upload status", cancel_on_done=True) ) - + def cancel(self): self._event_handler.cancel_event.set() @@ -259,28 +242,22 @@ class SkyTaskTracker(object): _zip_coro: asyncio.Task = None _sky_path_setting: SkyPathSetting = None _hostname: str = sky.utils.common_utils.get_user_hash() + @classmethod def try_first_register(cls, task_template: TaskTemplate): - ''' + """ sets up coroutine for sky.zip upload - ''' + """ if cls._JOB_RESIGTRY: return setup_cloud_credential() args = execute_cmd_to_path(task_template.container.args) - file_access = FileAccessProvider( - local_sandbox_dir="/tmp", - raw_output_prefix=args["raw_output_data_prefix"] - ) - - cls._sky_path_setting = SkyPathSetting( - task_level_prefix=file_access.raw_output_prefix, - unique_id=cls._hostname - ) + file_access = FileAccessProvider(local_sandbox_dir="/tmp", raw_output_prefix=args["raw_output_data_prefix"]) + + cls._sky_path_setting = SkyPathSetting(task_level_prefix=file_access.raw_output_prefix, unique_id=cls._hostname) cls._zip_coro = asyncio.create_task(cls._sky_path_setting.zip_and_upload()) cls._zip_coro.add_done_callback(cls.zip_failed_callback) - @classmethod def zip_failed_callback(self, task: asyncio.Task): try: @@ -292,26 +269,28 @@ def zip_failed_callback(self, task: asyncio.Task): @classmethod def on_task_deleted(cls, deleted_task: SkyTaskFuture): - ''' + """ executed on the agent when task on its pod cancelled if no tasks on the cluster running, stop the cluster the sky.stop part can be disabled if we force autostop - ''' - running_tasks_on_cluster = list(filter( - lambda task: task.task_status == TaskStatus.INIT and task.cluster_name == deleted_task.cluster_name, - cls._JOB_RESIGTRY.values() - )) + """ + running_tasks_on_cluster = list( + filter( + lambda task: task.task_status == TaskStatus.INIT and task.cluster_name == deleted_task.cluster_name, + cls._JOB_RESIGTRY.values(), + ) + ) if not running_tasks_on_cluster: logger.warning(f"Stopping cluster {deleted_task.cluster_name}") try: # FIXME: this is a blocking call, delete needs long timeout sky.stop(deleted_task.cluster_name) - except sky.exceptions.NotSupportedError as e: + except sky.exceptions.NotSupportedError: logger.warning(f"Cluster {deleted_task.cluster_name} is not supported for stopping.") - + cls._sky_path_setting.remote_path_setting.delete_task(deleted_task.sky_path_setting.unique_id) # del cls._JOB_RESIGTRY[deleted_task.task_name] - + @classmethod def register_sky_task(cls, task_template: TaskTemplate): cls.try_first_register(task_template) @@ -326,10 +305,10 @@ def register_sky_task(cls, task_template: TaskTemplate): cls._JOB_RESIGTRY[new_task.task_name] = new_task return new_task + def remote_setup(remote_meta: SkyPilotMetadata, wrapped, **kwargs): sky_path_setting = SkyPathSetting( - task_level_prefix=remote_meta.task_metadata_prefix, - unique_id=remote_meta.tracker_hostname + task_level_prefix=remote_meta.task_metadata_prefix, unique_id=remote_meta.tracker_hostname ) sky_path_setting.download_and_unzip() home_sky_dir = sky_path_setting.local_path_setting.home_sky_dir @@ -341,12 +320,10 @@ def remote_setup(remote_meta: SkyPilotMetadata, wrapped, **kwargs): sky.authentication.PUBLIC_SSH_KEY_PATH = os.path.join(home_key_dir, public_key_base) # mock db path sky.global_user_state._DB = sky.utils.db_utils.SQLiteConn( - os.path.join(home_sky_dir, "state.db"), - sky.global_user_state.create_table + os.path.join(home_sky_dir, "state.db"), sky.global_user_state.create_table ) sky.skylet.job_lib._DB = sky.utils.db_utils.SQLiteConn( - os.path.join(home_sky_dir, "skylet.db"), - sky.skylet.job_lib.create_table + os.path.join(home_sky_dir, "skylet.db"), sky.skylet.job_lib.create_table ) sky.skylet.job_lib._CURSOR = sky.skylet.job_lib._DB.cursor sky.skylet.job_lib._CONN = sky.skylet.job_lib._DB.conn @@ -357,29 +334,24 @@ def remote_setup(remote_meta: SkyPilotMetadata, wrapped, **kwargs): # run the wrapped function wrapped_result = wrapped(**kwargs) return wrapped_result - + def query_job_status(resource_meta: SkyPilotMetadata): # task on another agent pod may fail to launch, check for launch error log sky_path_setting = TaskRemotePathSetting( - file_access=FileAccessProvider( - local_sandbox_dir="/tmp", - raw_output_prefix=resource_meta.task_metadata_prefix - ), + file_access=FileAccessProvider(local_sandbox_dir="/tmp", raw_output_prefix=resource_meta.task_metadata_prefix), job_type=resource_meta.job_launch_type, cluster_name=resource_meta.cluster_name, - task_name=resource_meta.job_name + task_name=resource_meta.job_name, ) - # check job status - return LAUNCH_TYPE_TO_SKY_STATUS[resource_meta.job_launch_type](sky_path_setting.get_task_status().task_status) - - # return sky.JobStatus.INIT # job not found, in most cases this is in setup stage + task_status = sky_path_setting.get_task_status().task_status + return LAUNCH_TYPE_TO_SKY_STATUS[resource_meta.job_launch_type](task_status) + def check_remote_agent_alive(resource_meta: SkyPilotMetadata): sky_path_setting = SkyPathSetting( - task_level_prefix=resource_meta.task_metadata_prefix, - unique_id=resource_meta.tracker_hostname + task_level_prefix=resource_meta.task_metadata_prefix, unique_id=resource_meta.tracker_hostname ) utc_time = datetime.now(timezone.utc) last_upload_time = sky_path_setting.last_upload_time() @@ -387,33 +359,27 @@ def check_remote_agent_alive(resource_meta: SkyPilotMetadata): time_diff = utc_time - last_upload_time if time_diff.total_seconds() > skylet_constants.CONTROLLER_IDLE_MINUTES_TO_AUTOSTOP * 60: return False - + return True - + def remote_deletion(resource_meta: SkyPilotMetadata): # this part can be removed if sky job controller down is supported # if the zip is not updated for a long time, the agent pod is considered down, so we need to delete the controller if not check_remote_agent_alive(resource_meta): with multiprocessing.Pool(1) as p: - starmap_results = p.starmap( - functools.partial(remote_setup, cluster_name=resource_meta.cluster_name), - [(resource_meta, sky.down)] + p.starmap( + functools.partial(remote_setup, cluster_name=resource_meta.cluster_name), [(resource_meta, sky.down)] ) sky_task_settings = TaskRemotePathSetting( - file_access=FileAccessProvider( - local_sandbox_dir="/tmp", - raw_output_prefix=resource_meta.task_metadata_prefix - ), + file_access=FileAccessProvider(local_sandbox_dir="/tmp", raw_output_prefix=resource_meta.task_metadata_prefix), job_type=resource_meta.job_launch_type, cluster_name=resource_meta.cluster_name, - task_name=resource_meta.job_name + task_name=resource_meta.job_name, ) sky_task_settings.to_proto_and_upload(resource_meta, sky_task_settings.remote_delete_proto) - - class SkyPilotAgent(AsyncAgentBase): def __init__(self): super().__init__(task_type_name=TASK_TYPE, metadata_type=SkyPilotMetadata) @@ -424,12 +390,10 @@ async def create( inputs: Optional[LiteralMap] = None, **kwargs, ) -> SkyPilotMetadata: - logger.warning(f"Creating... SkyPilot {task_template.container.args} | {task_template.container.image}") # pdb.set_trace() task = SkyTaskTracker.register_sky_task(task_template=task_template) logger.warning(f"Created SkyPilot {task.task_name}") - # await SkyTaskTracker._JOB_RESIGTRY[job_id]._launch_coro meta = SkyPilotMetadata( job_name=task.task_name, cluster_name=task.cluster_name, @@ -437,10 +401,7 @@ async def create( tracker_hostname=SkyTaskTracker._hostname, job_launch_type=task_template.custom["job_launch_type"], ) - return meta - - async def get(self, resource_meta: SkyPilotMetadata, **kwargs) -> Resource: # pdb.set_trace() @@ -448,7 +409,7 @@ async def get(self, resource_meta: SkyPilotMetadata, **kwargs) -> Resource: job_status = None outputs = None job_status = query_job_status(resource_meta) - + logger.warning(f"Getting... {job_status}, took {(datetime.now(timezone.utc) - received_time).total_seconds()}") phase = skypilot_status_to_flyte_phase(job_status) return Resource(phase=phase, outputs=outputs, message=None) @@ -457,10 +418,10 @@ async def delete(self, resource_meta: SkyPilotMetadata, **kwargs): # pdb.set_trace() if resource_meta.job_name not in SkyTaskTracker._JOB_RESIGTRY: remote_deletion(resource_meta) - else: existed_task = SkyTaskTracker._JOB_RESIGTRY[resource_meta.job_name] existed_task.cancel() - + + # To register the skypilot agent AgentRegistry.register(SkyPilotAgent()) diff --git a/plugins/flytekit-skypilot/flytekitplugins/skypilot/cloud_registry.py b/plugins/flytekit-skypilot/flytekitplugins/skypilot/cloud_registry.py index fa1e06de08..4838462244 100644 --- a/plugins/flytekit-skypilot/flytekitplugins/skypilot/cloud_registry.py +++ b/plugins/flytekit-skypilot/flytekitplugins/skypilot/cloud_registry.py @@ -1,15 +1,18 @@ -from typing import Dict, Optional, List, Type -import flytekit -from flytekit import FlyteContext, PythonFunctionTask, logger -from asyncio.subprocess import PIPE +import os import subprocess +from asyncio.subprocess import PIPE from dataclasses import dataclass -import os +from typing import Dict, List, Optional, Type + +import flytekit +from flytekit import logger + class CloudNotInstalledError(ValueError): """ This is the base error for cloud credential errors. """ + pass @@ -17,32 +20,34 @@ class CloudCredentialError(ValueError): """ This is the base error for cloud credential errors. """ + pass + @dataclass class CloudCredentialMount(object): vm_path: str container_path: str -class BaseCloudCredentialProvider: - _CLOUD_TYPE: str = "base cloud", +class BaseCloudCredentialProvider: + _CLOUD_TYPE: str = ("base cloud",) _SECRET_GROUP: Optional[str] = None def __init__( - self, + self, ): self._secret_manager = flytekit.current_context().secrets self.check_cloud_dependency() - + def check_cloud_dependency(self) -> None: raise NotImplementedError - + def setup_cloud_credential( self, ) -> None: raise NotImplementedError - + @property def secrets(self): return self._secret_manager @@ -51,6 +56,7 @@ def secrets(self): def get_mount_envs() -> Dict[str, CloudCredentialMount]: return {} + class CloudRegistry(object): """ This is the registry for all agents. @@ -85,10 +91,10 @@ def __init__( self, ): super().__init__() - + def check_cloud_dependency(self) -> None: try: - version_check = subprocess.run( + subprocess.run( [ "aws", "--version", @@ -100,7 +106,6 @@ def check_cloud_dependency(self) -> None: raise CloudNotInstalledError( f"AWS CLI not found. Please install it with 'pip install skypilot[aws]' and try again. Error: \n{type(e)}\n{e}" ) - def setup_cloud_credential( self, @@ -117,7 +122,7 @@ def setup_cloud_credential( key="aws_secret_access_key", ), } - + for key, secret in aws_config_dict.items(): configure_result = subprocess.run( [ @@ -130,8 +135,10 @@ def setup_cloud_credential( stdout=PIPE, stderr=PIPE, ) - if configure_result.returncode!= 0: - raise CloudCredentialError(f"Failed to configure AWS credentials for {key}: {configure_result.stderr.decode('utf-8')}") + if configure_result.returncode != 0: + raise CloudCredentialError( + f"Failed to configure AWS credentials for {key}: {configure_result.stderr.decode('utf-8')}" + ) @staticmethod def get_mount_envs(): @@ -143,7 +150,7 @@ def get_mount_envs(): "AWS_SHARED_CREDENTIALS_FILE": CloudCredentialMount( vm_path=("~/.aws/credentials"), container_path="/tmp/aws/credentials", - ) + ), } @@ -156,10 +163,10 @@ def __init__( self, ): super().__init__() - + def check_cloud_dependency(self) -> None: try: - version_check = subprocess.run( + subprocess.run( [ "gcloud", "--version", @@ -171,7 +178,6 @@ def check_cloud_dependency(self) -> None: raise CloudNotInstalledError( f"AWS CLI not found. Please install it with 'pip install skypilot[aws]' and try again. Error: \n{type(e)}\n{e}" ) - def setup_cloud_credential( self, @@ -192,13 +198,15 @@ def setup_cloud_credential( key="project_id", ), } - gcp_config_dict.update({ - "type": "service_account", - "token_uri": "https://oauth2.googleapis.com/token", - }) - - import tempfile + gcp_config_dict.update( + { + "type": "service_account", + "token_uri": "https://oauth2.googleapis.com/token", + } + ) + import json + # FIXME: it looks insecure since the key file is on agent pod with open(self._GCLOUD_KEY_FILE, "w") as f: f.write(json.dumps(gcp_config_dict, indent=4)) @@ -214,8 +222,10 @@ def setup_cloud_credential( stdout=PIPE, stderr=PIPE, ) - if configure_result.returncode!= 0: - raise CloudCredentialError(f"Failed to configure GCP credentials: {configure_result.stderr.decode('utf-8')}") + if configure_result.returncode != 0: + raise CloudCredentialError( + f"Failed to configure GCP credentials: {configure_result.stderr.decode('utf-8')}" + ) os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self._GCLOUD_KEY_FILE @@ -228,7 +238,6 @@ def get_mount_envs(): ) } + CloudRegistry.register(AWSCredentialProvider._CLOUD_TYPE, AWSCredentialProvider) CloudRegistry.register(GCPCredentialProvider._CLOUD_TYPE, GCPCredentialProvider) - - diff --git a/plugins/flytekit-skypilot/flytekitplugins/skypilot/metadata.py b/plugins/flytekit-skypilot/flytekitplugins/skypilot/metadata.py index f4f1d3aebd..3e0a02ab3f 100644 --- a/plugins/flytekit-skypilot/flytekitplugins/skypilot/metadata.py +++ b/plugins/flytekit-skypilot/flytekitplugins/skypilot/metadata.py @@ -1,18 +1,22 @@ +import enum from dataclasses import dataclass + from flytekit.extend.backend.base_agent import ResourceMeta -import enum + + class JobLaunchType(int, enum.Enum): NORMAL = 0 # sky launch MANAGED = 1 # sky jobs launch - + @dataclass class SkyPilotMetadata(ResourceMeta): """ This is the metadata for the job. """ + job_name: str cluster_name: str task_metadata_prefix: str tracker_hostname: str - job_launch_type: JobLaunchType \ No newline at end of file + job_launch_type: JobLaunchType diff --git a/plugins/flytekit-skypilot/flytekitplugins/skypilot/task.py b/plugins/flytekit-skypilot/flytekitplugins/skypilot/task.py index 6168708502..b76a5bf5a0 100644 --- a/plugins/flytekit-skypilot/flytekitplugins/skypilot/task.py +++ b/plugins/flytekit-skypilot/flytekitplugins/skypilot/task.py @@ -1,22 +1,17 @@ -from typing import Any, Dict, Optional, Union, Callable, List, Set -from dataclasses import dataclass, asdict -import os -from typing import Any, Callable, Dict, Optional, Union, cast - -from google.protobuf.json_format import MessageToDict -import enum -from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger -from flytekit.configuration import DefaultImages, SerializationSettings +from dataclasses import asdict, dataclass +from typing import Any, Callable, Dict, Optional, Union + +from flytekitplugins.skypilot.metadata import JobLaunchType +from flytekitplugins.skypilot.task_utils import ContainerRunType + +from flytekit import FlyteContextManager, PythonFunctionTask, logger +from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import ExecutionParameters from flytekit.core.python_auto_container import get_registerable_container_image -from flytekit.extend import ExecutionState, TaskPlugins +from flytekit.extend import TaskPlugins from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.image_spec import ImageSpec -from flytekitplugins.skypilot.metadata import JobLaunchType -from flytekitplugins.skypilot.task_utils import ContainerRunType -import sky -from sky import resources as resources_lib from flytekit.models.literals import LiteralMap FLYTE_LOCAL_CONFIG = { @@ -26,7 +21,6 @@ } - @dataclass class SkyPilot(object): cluster_name: str @@ -41,17 +35,17 @@ class SkyPilot(object): job_launch_type: JobLaunchType = JobLaunchType.NORMAL auto_down: bool = False stop_after: int = None - + def __post_init__(self): if self.resource_config is None: self.resource_config = {} if self.local_config is None: self.local_config = {"local_envs": {}} + class SkyPilotFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask[SkyPilot]): - _TASK_TYPE = "skypilot" - + def __init__( self, task_config: SkyPilot, @@ -59,7 +53,6 @@ def __init__( container_image: Optional[Union[str, ImageSpec]] = None, **kwargs, ): - # for local testing and remote cloud # container_image = replace_local_registry(container_image) super(SkyPilotFunctionTask, self).__init__( @@ -72,31 +65,31 @@ def __init__( def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: return asdict(self.task_config) - + # deprecated def pre_execute(self, user_params: ExecutionParameters | None) -> ExecutionParameters | None: import sys + print("Pre executing...", sys.argv) return super().pre_execute(user_params) - + def execute(self: PythonTask, **kwargs) -> LiteralMap: if isinstance(self.task_config, SkyPilot): # Use the Skypilot agent to run it by default. try: ctx = FlyteContextManager.current_context() if not ctx.file_access.is_remote(ctx.file_access.raw_output_prefix): - pass - # raise ValueError( - # "To submit a Skypilot job locally," - # " please set --raw-output-data-prefix to a remote path. e.g. s3://, gcs//, etc." - # ) + raise ValueError( + "To submit a Skypilot job locally," + " please set --raw-output-data-prefix to a remote path. e.g. s3://, gcs//, etc." + ) if ctx.execution_state and ctx.execution_state.is_local_execution(): return AsyncAgentExecutorMixin.execute(self, **kwargs) except Exception as e: logger.error(f"Agent failed to run the task with error: {e}") logger.info("Falling back to local execution") return PythonFunctionTask.execute(self, **kwargs) - + def get_image(self, settings: SerializationSettings) -> str: if isinstance(self.container_image, ImageSpec): # Ensure that the code is always copied into the image, even during fast-registration. @@ -105,5 +98,4 @@ def get_image(self, settings: SerializationSettings) -> str: return get_registerable_container_image(self.container_image, settings.image_config) - TaskPlugins.register_pythontask_plugin(SkyPilot, SkyPilotFunctionTask) diff --git a/plugins/flytekit-skypilot/flytekitplugins/skypilot/task_utils.py b/plugins/flytekit-skypilot/flytekitplugins/skypilot/task_utils.py index 6cb1d98ed4..725c2c8e33 100644 --- a/plugins/flytekit-skypilot/flytekitplugins/skypilot/task_utils.py +++ b/plugins/flytekit-skypilot/flytekitplugins/skypilot/task_utils.py @@ -1,11 +1,13 @@ -import textwrap +import enum +import os import shlex -from flytekit.models.task import TaskTemplate +import textwrap from typing import Any, Dict -import enum + import sky -import os -from flytekitplugins.skypilot.cloud_registry import CloudRegistry, CloudCredentialError, CloudNotInstalledError +from flytekitplugins.skypilot.cloud_registry import CloudRegistry + +from flytekit.models.task import TaskTemplate class ContainerRunType(int, enum.Enum): @@ -17,47 +19,59 @@ def parse_sky_resources(task_template: TaskTemplate) -> Dict[str, Any]: sky_task_config = {} resources: Dict[str, Any] = task_template.custom["resource_config"] container_image: str = task_template.container.image - if resources.get('image_id', None) is None and task_template.custom["container_run_type"] == ContainerRunType.RUNTIME: + if ( + resources.get("image_id", None) is None + and task_template.custom["container_run_type"] == ContainerRunType.RUNTIME + ): resources["image_id"] = f"docker:{container_image}" - - sky_task_config.update({ - "resources": resources, - "file_mounts": task_template.custom["file_mounts"], - }) + + sky_task_config.update( + { + "resources": resources, + "file_mounts": task_template.custom["file_mounts"], + } + ) return sky_task_config - + class SetupCommand(object): docker_pull: str = None flytekit_pip: str = None full_setup: str = None + def __init__(self, task_template: TaskTemplate) -> None: task_setup = task_template.custom["setup"] if task_template.custom["container_run_type"] == ContainerRunType.APP: self.docker_pull = f"docker pull {task_template.container.image}" else: # HACK, change back to normal flytekit - self.flytekit_pip = textwrap.dedent("""\ + self.flytekit_pip = textwrap.dedent( + """\ python -m pip uninstall flytekit -y python -m pip install -e /flytekit - """) - + """ + ) + self.full_setup = "\n".join(filter(None, [task_setup, self.docker_pull, self.flytekit_pip])).strip() - + class RunCommand(object): full_task_command: str = None _use_gpu: bool = False + def __init__(self, task_template: TaskTemplate) -> None: raw_task_command = shlex.join(task_template.container.args) - local_env_prefix = "\n".join([f"export {k}='{v}'" for k, v in task_template.custom["local_config"]["local_envs"].items()]) + local_env_prefix = "\n".join( + [f"export {k}='{v}'" for k, v in task_template.custom["local_config"]["local_envs"].items()] + ) self.check_resource(task_template) if task_template.custom["container_run_type"] == ContainerRunType.RUNTIME: python_path_command = f"export PYTHONPATH=$PYTHONPATH:$HOME/{sky.backends.docker_utils.SKY_DOCKER_WORKDIR}" - self.full_task_command = "\n".join(filter(None, [local_env_prefix, python_path_command, raw_task_command])).strip() + self.full_task_command = "\n".join( + filter(None, [local_env_prefix, python_path_command, raw_task_command]) + ).strip() else: container_entrypoint, container_args = task_template.container.args[0], task_template.container.args[1:] - docker_run_prefix = f"docker run {'--gpus=all' if self.use_gpu else ''} --entrypoint {container_entrypoint}" volume_setups, cloud_cred_envs = [], [] for cloud in CloudRegistry.list_clouds(): @@ -68,18 +82,20 @@ def __init__(self, task_template: TaskTemplate) -> None: cloud_cred_envs.append(f"-e {env_key}={path_mapping.container_path}") volume_command = " ".join(volume_setups) cloud_cred_env_command = " ".join(cloud_cred_envs) - self.full_task_command = " ".join([ - docker_run_prefix, - volume_command, - cloud_cred_env_command, - task_template.container.image, - *container_args - ]) + self.full_task_command = " ".join( + [ + docker_run_prefix, + volume_command, + cloud_cred_env_command, + task_template.container.image, + *container_args, + ] + ) @property def use_gpu(self) -> bool: return self._use_gpu - + def check_resource(self, task_template: TaskTemplate) -> None: gpu_config = parse_sky_resources(task_template) gpu_task = sky.Task.from_yaml_config(gpu_config) @@ -87,13 +103,16 @@ def check_resource(self, task_template: TaskTemplate) -> None: if resource.accelerators is not None: self._use_gpu = True break - + + def get_sky_task_config(task_template: TaskTemplate) -> Dict[str, Any]: sky_task_config = parse_sky_resources(task_template) - sky_task_config.update({ - # build setup commands - "setup": SetupCommand(task_template).full_setup, - # build run commands - "run": RunCommand(task_template).full_task_command - }) - return sky_task_config \ No newline at end of file + sky_task_config.update( + { + # build setup commands + "setup": SetupCommand(task_template).full_setup, + # build run commands + "run": RunCommand(task_template).full_task_command, + } + ) + return sky_task_config diff --git a/plugins/flytekit-skypilot/flytekitplugins/skypilot/utils.py b/plugins/flytekit-skypilot/flytekitplugins/skypilot/utils.py index 2374096ef2..c66ed4fa08 100644 --- a/plugins/flytekit-skypilot/flytekitplugins/skypilot/utils.py +++ b/plugins/flytekit-skypilot/flytekitplugins/skypilot/utils.py @@ -1,28 +1,22 @@ -from asyncio.subprocess import PIPE -from decimal import ROUND_CEILING, Decimal -from typing import Optional, Tuple, Any, Dict, Union import asyncio -from flyteidl.core.execution_pb2 import TaskExecution -from typing import List -from flytekit import logger -import flytekit import enum -from flytekit.core.resources import Resources -from flytekit.tools.fast_registration import download_distribution as _download_distribution -from flytekitplugins.skypilot.metadata import SkyPilotMetadata, JobLaunchType -from flytekitplugins.skypilot.cloud_registry import CloudRegistry, CloudCredentialError, CloudNotInstalledError -import sky -import pdb -import pathlib -from datetime import datetime import os -import rich_click as _click -from dataclasses import dataclass, field, asdict -from google.protobuf.struct_pb2 import Struct -from google.protobuf.json_format import MessageToDict +import shutil import tarfile import tempfile -import shutil +from dataclasses import asdict, dataclass, field +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple, Union + +import rich_click as _click +import sky +from flyteidl.core.execution_pb2 import TaskExecution +from flytekitplugins.skypilot.cloud_registry import CloudCredentialError, CloudNotInstalledError, CloudRegistry +from flytekitplugins.skypilot.metadata import JobLaunchType, SkyPilotMetadata +from google.protobuf.json_format import MessageToDict +from google.protobuf.struct_pb2 import Struct + +from flytekit import logger from flytekit.core.data_persistence import FileAccessProvider SKYPILOT_STATUS_TO_FLYTE_PHASE = { @@ -33,7 +27,6 @@ "SUCCEEDED": TaskExecution.SUCCEEDED, "FAILED": TaskExecution.FAILED, "FAILED_SETUP": TaskExecution.FAILED, - "CANCELLED": TaskExecution.FAILED, "STARTING": TaskExecution.INITIALIZING, "RECOVERING": TaskExecution.WAITING_FOR_RESOURCES, "CANCELLING": TaskExecution.WAITING_FOR_RESOURCES, @@ -58,6 +51,7 @@ def skypilot_status_to_flyte_phase(status: sky.JobStatus) -> TaskExecution.Phase """ return SKYPILOT_STATUS_TO_FLYTE_PHASE[status.value] + # use these commands from entrypoint to help resolve the task_template.container.args @_click.group() def _pass_through(): @@ -111,9 +105,8 @@ def fast_execute_task_cmd(additional_distribution: str, dest_dir: str, task_exec if arg == "--resolver": cmd.extend(["--dynamic-addl-distro", additional_distribution, "--dynamic-dest-dir", dest_dir]) cmd.append(arg) - + return cmd - @_pass_through.command("pyflyte-map-execute") @@ -154,6 +147,7 @@ def map_execute_task_cmd( map_execute_task_cmd.name: map_execute_task_cmd, } + def execute_cmd_to_path(cmd: List[str]) -> Dict[str, Any]: assert len(cmd) > 0 args = {} @@ -164,39 +158,41 @@ def execute_cmd_to_path(cmd: List[str]) -> Dict[str, Any]: if cmd_entrypoint.name == fast_execute_task_cmd.name: args = {} pyflyte_args = fast_execute_task_cmd.invoke(ctx) - pyflyte_ctx = ENTRYPOINT_MAP[pyflyte_args[0]].make_context( - info_name="", - args=list(pyflyte_args)[1:] - ) + pyflyte_ctx = ENTRYPOINT_MAP[pyflyte_args[0]].make_context(info_name="", args=list(pyflyte_args)[1:]) args.update(pyflyte_ctx.params) # args["full-command"] = pyflyte_args break - + # raise error if args is empty or cannot find raw_output_data_prefix if not args or args.get("raw_output_data_prefix", None) is None: raise ValueError(f"Bad command for {cmd}") return args - - + + class RemoteDeletedError(ValueError): """ This is the base error for cloud credential errors. """ + pass + @dataclass class TaskFutureStatus(object): """ This is the status for the task future. """ + job_type: int cluster_status: str = sky.ClusterStatus.INIT.value task_status: str = sky.JobStatus.PENDING.value + class TaskStatus(int, enum.Enum): INIT = 0 DONE = 1 + class EventHandler(object): def __init__(self) -> None: self.cancel_event = asyncio.Event() @@ -205,21 +201,22 @@ def __init__(self) -> None: self.finished_event = asyncio.Event() def is_terminal(self): - return self.cancel_event.is_set()\ - or self.failed_event.is_set()\ - or self.finished_event.is_set() - + return self.cancel_event.is_set() or self.failed_event.is_set() or self.finished_event.is_set() + def __repr__(self) -> str: - return f"EventHandler(cancel_event={self.cancel_event.is_set()},"\ - f"failed_event={self.failed_event.is_set()},"\ - f", launch_done_event={self.launch_done_event.is_set()},"\ - f"finished_event={self.finished_event.is_set()})"\ + return ( + f"EventHandler(cancel_event={self.cancel_event.is_set()}," + f"failed_event={self.failed_event.is_set()}," + f", launch_done_event={self.launch_done_event.is_set()}," + f"finished_event={self.finished_event.is_set()})" + ) @dataclass class BaseRemotePathSetting: file_access: FileAccessProvider remote_sky_dir: str = field(init=False) + def __post_init__(self): self.remote_sky_dir = self.file_access.join(self.file_access.raw_output_prefix, ".skys") self.file_access.raw_output_fs.makedirs(self.remote_sky_dir, exist_ok=True) @@ -229,34 +226,35 @@ def remote_exists(self): def remote_failed(self): raise NotImplementedError - + def find_delete_proto(self) -> Dict[str, Any]: raise NotImplementedError - + def touch_task(self, task_id: str): raise NotImplementedError - + + @dataclass class TrackerRemotePathSetting(BaseRemotePathSetting): # raw_output_prefix: s3://{bucket}/{SKY_DIRNAME}/{hostname} unique_id: str _HOME_SKY = "home_sky.tar.gz" _KEY_SKY = "sky_key.tar.gz" + def __post_init__(self): super().__post_init__() self.remote_sky_zip = self.file_access.join(self.remote_sky_dir, self._HOME_SKY) self.remote_key_zip = self.file_access.join(self.remote_sky_dir, self._KEY_SKY) def remote_exists(self): - return self.file_access.exists(self.remote_sky_zip)\ - or self.file_access.exists(self.remote_key_zip) - + return self.file_access.exists(self.remote_sky_zip) or self.file_access.exists(self.remote_key_zip) + def delete_task(self, task_id: str): self.file_access.raw_output_fs.rm_file(self.file_access.join(self.remote_sky_dir, task_id)) - + def touch_task(self, task_id: str): self.file_access.raw_output_fs.touch(self.file_access.join(self.remote_sky_dir, task_id)) - + def task_exists(self, task_id: str): file_system = self.file_access.raw_output_fs remote_sky_parts = self.remote_sky_dir.rstrip(file_system.sep).split(file_system.sep) @@ -268,8 +266,9 @@ def task_exists(self, task_id: str): for task in registered_tasks: if task_id in task: return True - return False + + @dataclass class TaskRemotePathSetting(BaseRemotePathSetting): # raw_output_prefix: s3://{bucket}/data/.../{id} @@ -278,17 +277,12 @@ class TaskRemotePathSetting(BaseRemotePathSetting): unique_id: str = None task_id: int = None task_name: str = None - - + def from_task_prefix(self, task_prefix: str): task_name = f"{task_prefix}.{self.unique_id}" return TaskRemotePathSetting( - file_access=self.file_access, - job_type=self.job_type, - cluster_name=self.cluster_name, - task_name=task_name + file_access=self.file_access, job_type=self.job_type, cluster_name=self.cluster_name, task_name=task_name ) - def __post_init__(self): super().__post_init__() @@ -300,7 +294,6 @@ def __post_init__(self): # get last part of the path as unique_id self.unique_id = self.file_access.raw_output_prefix.rstrip(sep).split(sep)[-1] logger.warning(self) - def remote_failed(self): return self.file_access.exists(self.remote_error_log) @@ -310,16 +303,16 @@ def find_delete_proto(self) -> SkyPilotMetadata: return None temp_proto = self.file_access.get_random_local_path() self.file_access.get_data(self.remote_delete_proto, temp_proto) - with open(temp_proto, 'rb') as f: + with open(temp_proto, "rb") as f: meta_proto = Struct() meta_proto.ParseFromString(f.read()) - + os.remove(temp_proto) return SkyPilotMetadata(**MessageToDict(meta_proto)) - + def put_error_log(self, error_log: Exception): # open file for write and read - with tempfile.NamedTemporaryFile(mode='w+', delete=False) as log_file: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as log_file: log_file.write(str(error_log)) log_file.flush() log_file.seek(0) @@ -329,7 +322,7 @@ def delete_job(self): status = self.get_task_status() job_status = LAUNCH_TYPE_TO_SKY_STATUS[status.job_type](status.task_status) if not self.task_id: - # task hasn't been submitted, try query again + # task hasn't been submitted, try query again (this rarely happens) try: loop = asyncio.get_event_loop() job_list = asyncio.run_coroutine_threadsafe(self.get_job_list(), loop).result() @@ -342,16 +335,18 @@ def delete_job(self): if not job_status.is_terminal(): try: if self.job_type == JobLaunchType.NORMAL: - sky.cancel(cluster_name=self.cluster_name, job_ids=[self.task_id], _try_cancel_if_cluster_is_init=True) + sky.cancel( + cluster_name=self.cluster_name, job_ids=[self.task_id], _try_cancel_if_cluster_is_init=True + ) else: sky.jobs.cancel(name=self.task_name) except sky.exceptions.ClusterNotUpError: pass - + def to_proto_and_upload(self, target, remote_path): proto_struct = Struct() proto_struct.update(asdict(target)) - with tempfile.NamedTemporaryFile(mode='wb') as f: + with tempfile.NamedTemporaryFile(mode="wb") as f: f.write(proto_struct.SerializeToString()) f.flush() f.seek(0) @@ -360,7 +355,7 @@ def to_proto_and_upload(self, target, remote_path): def get_task_status(self): temp_proto = self.file_access.get_random_local_path() self.file_access.get_data(self.remote_status_proto, temp_proto) - with open(temp_proto, 'rb') as f: + with open(temp_proto, "rb") as f: status_proto = Struct() status_proto.ParseFromString(f.read()) os.remove(temp_proto) @@ -368,7 +363,7 @@ def get_task_status(self): if self.remote_failed(): local_log_file = self.file_access.get_random_local_path() self.file_access.get_data(self.remote_error_log, local_log_file) - with open(local_log_file, 'r') as f: + with open(local_log_file, "r") as f: logger.error(f.read()) os.remove(local_log_file) return TaskFutureStatus(status.job_type, status.cluster_status, sky.JobStatus.FAILED.value) @@ -388,9 +383,9 @@ def get_status_from_list(self, job_list) -> Tuple[sky.JobStatus, Optional[int]]: async def put_task_status(self, event_handler: EventHandler): init_status = TaskFutureStatus(job_type=self.job_type) self.to_proto_and_upload(init_status, self.remote_status_proto) - # pdb.set_trace() # FIXME too long while True: + # other coroutines cancelled / failed if event_handler.is_terminal(): self.handle_return(event_handler) return @@ -398,9 +393,11 @@ async def put_task_status(self, event_handler: EventHandler): prev_status = self.get_task_status() status = prev_status logger.warning(event_handler) + # task finished if LAUNCH_TYPE_TO_SKY_STATUS[self.job_type](prev_status.task_status).is_terminal(): event_handler.finished_event.set() return + # task submitted, can get task_id now if event_handler.launch_done_event.is_set(): job_list = await self.get_job_list() current_status, task_id = self.get_status_from_list(job_list) @@ -409,10 +406,10 @@ async def put_task_status(self, event_handler: EventHandler): status = TaskFutureStatus( job_type=self.job_type, cluster_status=prev_status.cluster_status, - task_status=current_status or prev_status.task_status + task_status=current_status or prev_status.task_status, ) logger.warning(status) - + # task not submitted, query cluster status else: cluster_status = sky.status(self.cluster_name) if not cluster_status: @@ -420,23 +417,25 @@ async def put_task_status(self, event_handler: EventHandler): cluster_status = [{"status": sky.ClusterStatus(init_status.cluster_status)}] status = TaskFutureStatus( job_type=self.job_type, - cluster_status=cluster_status[0]['status'].value, - task_status=prev_status.task_status + cluster_status=cluster_status[0]["status"].value, + task_status=prev_status.task_status, ) - + self.task_id = task_id or self.task_id self.to_proto_and_upload(status, self.remote_status_proto) await asyncio.sleep(COROUTINE_INTERVAL) - + def handle_return(self, event_handler: EventHandler): + # task submitted, need to cancel if event_handler.launch_done_event.is_set(): self.delete_job() + # task failed, put the failed status if event_handler.failed_event.is_set(): task_status = self.get_task_status() task_status.task_status = LAUNCH_TYPE_TO_SKY_STATUS[task_status.job_type]("FAILED").value self.to_proto_and_upload(task_status, self.remote_status_proto) return - + async def deletion_status(self, event_handler: EventHandler): # checks if the task has been deleted from another pod while True: @@ -449,6 +448,7 @@ async def deletion_status(self, event_handler: EventHandler): return await asyncio.sleep(COROUTINE_INTERVAL) + @dataclass class LocalPathSetting: file_access: FileAccessProvider @@ -458,6 +458,7 @@ class LocalPathSetting: sky_key_zip: str = None home_sky_dir: str = None home_key_dir: str = None + def __post_init__(self): self.local_sky_prefix = os.path.join(self.file_access.local_sandbox_dir, self.execution_id, ".skys") self.home_sky_zip = os.path.join(self.local_sky_prefix, "home_sky") @@ -465,18 +466,18 @@ def __post_init__(self): self.home_sky_dir = os.path.join(self.local_sky_prefix, ".sky") self.home_key_dir = os.path.join(self.local_sky_prefix, ".ssh") - def zip_sky_info(self): # compress ~/.sky to home_sky.tar.gz - sky_zip = shutil.make_archive(self.home_sky_zip, 'gztar', os.path.expanduser("~/.sky")) + sky_zip = shutil.make_archive(self.home_sky_zip, "gztar", os.path.expanduser("~/.sky")) # tar ~/.ssh/sky-key* to sky_key.tar.gz local_key_dir = os.path.expanduser("~/.ssh") archived_keys = [file for file in os.listdir(local_key_dir) if file.startswith("sky-key")] - with tarfile.open(self.sky_key_zip, 'w:gz') as key_tar: + with tarfile.open(self.sky_key_zip, "w:gz") as key_tar: for key_file in archived_keys: key_tar.add(os.path.join(local_key_dir, key_file), arcname=os.path.basename(key_file)) return sky_zip + @dataclass class SkyPathSetting: task_level_prefix: str # for the filesystem parsing @@ -485,15 +486,18 @@ class SkyPathSetting: local_path_setting: LocalPathSetting = None remote_path_setting: TrackerRemotePathSetting = None file_access: FileAccessProvider = None + def __post_init__(self): file_provider = FileAccessProvider(local_sandbox_dir="/tmp", raw_output_prefix=self.task_level_prefix) - bucket_name = file_provider.raw_output_fs._strip_protocol(file_provider.raw_output_prefix)\ - .split(file_provider.raw_output_fs.sep)[0] + bucket_name = file_provider.raw_output_fs._strip_protocol(file_provider.raw_output_prefix).split( + file_provider.raw_output_fs.sep + )[0] + # s3://{bucket}/{SKY_DIRNAME}/{host_id} working_dir = file_provider.join( - file_provider.raw_output_fs.unstrip_protocol(bucket_name), - SKY_DIRNAME, + file_provider.raw_output_fs.unstrip_protocol(bucket_name), + SKY_DIRNAME, ) - + self.working_dir = file_provider.join(working_dir, self.unique_id) self.file_access = FileAccessProvider(local_sandbox_dir="/tmp", raw_output_prefix=self.working_dir) self.file_access.raw_output_fs.makedirs(self.working_dir, exist_ok=True) @@ -510,25 +514,22 @@ async def zip_and_upload(self): self.file_access.put_data(self.local_path_setting.sky_key_zip, self.remote_path_setting.remote_key_zip) await asyncio.sleep(COROUTINE_INTERVAL) - def download_and_unzip(self): local_sky_zip = os.path.join( - os.path.dirname(self.local_path_setting.home_sky_zip), - self.remote_path_setting._HOME_SKY + os.path.dirname(self.local_path_setting.home_sky_zip), self.remote_path_setting._HOME_SKY ) self.file_access.get_data(self.remote_path_setting.remote_sky_zip, local_sky_zip) self.file_access.get_data(self.remote_path_setting.remote_key_zip, self.local_path_setting.sky_key_zip) shutil.unpack_archive(local_sky_zip, self.local_path_setting.home_sky_dir) - with tarfile.open(self.local_path_setting.sky_key_zip, 'r:gz') as key_tar: + with tarfile.open(self.local_path_setting.sky_key_zip, "r:gz") as key_tar: key_tar.extractall(self.local_path_setting.home_key_dir) return - - + def last_upload_time(self) -> Optional[datetime]: if self.file_access.exists(self.remote_path_setting.remote_sky_zip): self.file_access.raw_output_fs.ls(self.remote_path_setting.remote_sky_zip, refresh=True) - return self.file_access.raw_output_fs.info(self.remote_path_setting.remote_sky_zip)['LastModified'] - + return self.file_access.raw_output_fs.modified(self.remote_path_setting.remote_sky_zip) + def setup_cloud_credential(show_check: bool = False): cloud_provider_types = CloudRegistry.list_clouds() @@ -540,16 +541,14 @@ def setup_cloud_credential(show_check: bool = False): try: provider.setup_cloud_credential() installed_cloud_providers.append(provider._CLOUD_TYPE) - except CloudCredentialError as e: + except CloudCredentialError: cred_not_provided_clouds.append(provider._CLOUD_TYPE) continue - except CloudNotInstalledError as e: + except CloudNotInstalledError: continue - + logger.warning(f"Installed cloud providers: {installed_cloud_providers}") if show_check: sky.check.check() return - - diff --git a/plugins/flytekit-skypilot/flytekitplugins/skypilot/workflows.py b/plugins/flytekit-skypilot/flytekitplugins/skypilot/workflows.py index f198276fd9..fef0f60fa5 100644 --- a/plugins/flytekit-skypilot/flytekitplugins/skypilot/workflows.py +++ b/plugins/flytekit-skypilot/flytekitplugins/skypilot/workflows.py @@ -1,29 +1,17 @@ -from typing import Any, Dict, Optional, Union, Callable, List, Set -from dataclasses import dataclass, asdict import os -from typing import Any, Callable, Dict, Optional, Union, cast - -from google.protobuf.json_format import MessageToDict - -from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger, Workflow, workflow, task -from flytekit.configuration import DefaultImages, SerializationSettings -from flytekit.core.base_task import PythonTask -from flytekit.core.context_manager import ExecutionParameters -from flytekit.core.python_auto_container import get_registerable_container_image -from flytekit.extend import ExecutionState, TaskPlugins -from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin -from flytekit.image_spec import ImageSpec -from flytekit.types.file import FlyteFile -from flytekit.core.data_persistence import FileAccessProvider, flyte_tmp_dir -import flytekit +from typing import Callable, Dict, List + import sky -from sky import resources as resources_lib -from flytekit.models.literals import LiteralMap from flytekitplugins.skypilot.task import SkyPilot -from flytekitplugins.skypilot.utils import parse_sky_resources, setup_cloud_credential, LocalPathSetting -import tempfile +from flytekitplugins.skypilot.utils import LocalPathSetting, setup_cloud_credential +from sky import resources as resources_lib + +import flytekit +from flytekit import FlyteContextManager, PythonFunctionTask, logger, task +from flytekit.types.file import FlyteFile + -def sky_config_to_resource(sky_config: SkyPilot, container_image: str=None) -> resources_lib.Resources: +def sky_config_to_resource(sky_config: SkyPilot, container_image: str = None) -> resources_lib.Resources: resources: List[Dict[str, str]] = sky_config.resource_config new_resource_list = [] for resource in resources: @@ -41,7 +29,7 @@ def sky_config_to_resource(sky_config: SkyPilot, container_image: str=None) -> r logger.info(resource) new_resource = sky.resources.Resources(**resource) new_resource_list.append(new_resource) - + if not new_resource_list: new_resource_list.append(sky.resources.Resources(image_id=f"docker:{container_image}")) return new_resource_list @@ -56,30 +44,27 @@ def clean_up(user_hash: str, controller_name: str) -> None: sky.jobs.utils.JOB_CONTROLLER_NAME = controller_name sky.down(controller_name) + def create_cluster(user_hash: str, cluster_name: str) -> tuple[FlyteFile, FlyteFile, str]: # get job controller file config_url = os.environ.get("SKYPILOT_CONFIG_URL", None) if config_url: download_and_set_sky_config(config_url) - + setup_cloud_credential() - sample_task_config = { - 'resources': { - 'cpu': '1', - 'memory': '1', - 'use_spot': True - } - } + sample_task_config = {"resources": {"cpu": "1", "memory": "1", "use_spot": True}} sample_task = sky.Task.from_yaml_config(sample_task_config) sky.utils.common_utils.get_user_hash = lambda: user_hash sky.jobs.utils.JOB_CONTROLLER_NAME = cluster_name sky.jobs.launch(sample_task) - path_setting = LocalPathSetting(file_access=FlyteContextManager.current_context().file_access, execution_id=flytekit.current_context().task_id.version) + path_setting = LocalPathSetting( + file_access=FlyteContextManager.current_context().file_access, + execution_id=flytekit.current_context().task_id.version, + ) path_setting.zip_sky_info() return FlyteFile(path_setting.home_sky_zip), FlyteFile(path_setting.sky_key_zip), cluster_name - # TODO: Trying to separate tasks, but I don't think this would be any better given skypilot's slow api. def load_sky_config(): secret_manager = flytekit.current_context().secrets @@ -88,10 +73,10 @@ def load_sky_config(): group="sky", key="config", ) - except ValueError as e: - logger.warning(f"sky config not set, will use default controller setting") + except ValueError: + logger.warning("sky config not set, will use default controller setting") return - + download_and_set_sky_config(config_url) @@ -101,16 +86,11 @@ def download_and_set_sky_config(config_url: str): file_access.get_data(config_url, os.path.expanduser(sky.skypilot_config.CONFIG_PATH)) sky.skypilot_config._try_load_config() + # write a decorator for function, the decorator must be able to take in the task_config and return a new function -def sky_pilot_task( - task_config: SkyPilot, - **kwargs -) -> Callable: +def sky_pilot_task(task_config: SkyPilot, **kwargs) -> Callable: def wrapper(func: Callable) -> Callable: - create_cluster_func = task( - create_cluster, - **kwargs - ) - pass - - return wrapper \ No newline at end of file + create_cluster_func = task(create_cluster, **kwargs) + assert isinstance(create_cluster_func, PythonFunctionTask) + + return wrapper diff --git a/plugins/flytekit-skypilot/setup.py b/plugins/flytekit-skypilot/setup.py index 7278321ec8..707ddc03dd 100644 --- a/plugins/flytekit-skypilot/setup.py +++ b/plugins/flytekit-skypilot/setup.py @@ -33,4 +33,4 @@ "Topic :: Software Development :: Libraries :: Python Modules", ], entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, -) \ No newline at end of file +) diff --git a/plugins/flytekit-skypilot/tests/test_skypilot.py b/plugins/flytekit-skypilot/tests/test_skypilot.py index e34ea2c5c8..0da5ee4dbb 100644 --- a/plugins/flytekit-skypilot/tests/test_skypilot.py +++ b/plugins/flytekit-skypilot/tests/test_skypilot.py @@ -1,50 +1,38 @@ -import mock import asyncio -import pytest -import sky -import textwrap -import fsspec -import grpc import functools -import pdb -from flyteidl.core.execution_pb2 import TaskExecution -import time +import os import shutil -import multiprocessing +import tempfile +import textwrap +import time from collections import OrderedDict -from flytekit import Resources, task -import flytekit + +import grpc +import mock +import pytest +import sky +from flyteidl.core.execution_pb2 import TaskExecution +from flytekitplugins.skypilot import SkyPilot, SkyPilotAgent +from flytekitplugins.skypilot.agent import SkyTaskFuture, SkyTaskTracker +from flytekitplugins.skypilot.task_utils import ContainerRunType, get_sky_task_config from flytekitplugins.skypilot.utils import COROUTINE_INTERVAL, SkyPathSetting -from flytekitplugins.skypilot import SkyPilotAgent, SkyPilot -from flytekit.core.data_persistence import FileAccessProvider -from flytekitplugins.skypilot.agent import SkyTaskTracker, SkyTaskFuture -from flytekit.extend.backend.base_agent import AgentRegistry -from flytekit.core.data_persistence import FileAccessProvider -from flytekitplugins.skypilot.task_utils import ContainerRunType, get_sky_task_config, RunCommand, SetupCommand + +from flytekit import Resources, task from flytekit.configuration import DefaultImages, ImageConfig, SerializationSettings +from flytekit.core.data_persistence import FileAccessProvider from flytekit.extend import get_serializable -import os -import tempfile +from flytekit.extend.backend.base_agent import AgentRegistry -@pytest.mark.parametrize( - "container_run_type", - [(ContainerRunType.APP), - (ContainerRunType.RUNTIME)] -) +@pytest.mark.parametrize("container_run_type", [(ContainerRunType.APP), (ContainerRunType.RUNTIME)]) def test_skypilot_task(container_run_type): from flytekitplugins.skypilot import SkyPilot, SkyPilotFunctionTask task_config = SkyPilot( cluster_name="mock_cluster", setup="echo 'Hello, World!'", - resource_config={ - "ordered": [ - {"instance_type": "e2-small"}, - {"cloud": "aws"} - ] - }, - container_run_type=container_run_type + resource_config={"ordered": [{"instance_type": "e2-small"}, {"cloud": "aws"}]}, + container_run_type=container_run_type, ) requests = Resources(cpu="2", mem="4Gi") limits = Resources(cpu="4") @@ -69,37 +57,49 @@ def say_hello(name: str) -> str: task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello) template = task_spec.template container = template.container - + config = get_sky_task_config(template) sky_task = sky.Task.from_yaml_config(config) if container_run_type == ContainerRunType.APP: - assert sky_task.setup == textwrap.dedent(f"""\ + assert ( + sky_task.setup + == textwrap.dedent( + f"""\ {task_config.setup} docker pull {container.image} - """).strip() - assert sky_task.run.startswith(f"docker run") - + """ + ).strip() + ) + assert sky_task.run.startswith("docker run") + else: - assert sky_task.setup == textwrap.dedent(f"""\ + assert ( + sky_task.setup + == textwrap.dedent( + f"""\ {task_config.setup} python -m pip uninstall flytekit -y python -m pip install -e /flytekit - """).strip() + """ + ).strip() + ) assert sky_task.run.startswith("export PYTHONPATH") assert container.args[0] in sky_task.run + def check_task_all_done(task: SkyTaskFuture): assert task._launch_coro.done() assert task._status_upload_coro.done() assert task._status_check_coro.done() + def check_task_not_done(task: SkyTaskFuture): assert not task._launch_coro.done() assert not task._status_upload_coro.done() assert not task._status_check_coro.done() - + def mock_launch(obj, sleep_time=5): time.sleep(sleep_time) @@ -111,11 +111,13 @@ async def stop_sky_path(): except asyncio.exceptions.CancelledError: pass + def mock_sky_queues(): - sky.status = mock.MagicMock(return_value=[{'status': sky.ClusterStatus.INIT}]) + sky.status = mock.MagicMock(return_value=[{"status": sky.ClusterStatus.INIT}]) sky.stop = mock.MagicMock() sky.down = mock.MagicMock() + @pytest.fixture def mock_fs(): random_dir = tempfile.mkdtemp() @@ -130,14 +132,10 @@ def mock_fs(): task_config = SkyPilot( cluster_name="mock_cluster", setup="echo 'Hello, World!'", - resource_config={ - "image_id": "a/b:c", - "ordered": [ - {"instance_type": "e2-small"}, - {"cloud": "aws"} - ] - }, + resource_config={"image_id": "a/b:c", "ordered": [{"instance_type": "e2-small"}, {"cloud": "aws"}]}, ) + + @task( task_config=task_config, container_image=DefaultImages.default_image(), @@ -174,14 +172,19 @@ def get_container_args(mock_fs): "say_hello0", ] + @pytest.fixture def mock_provider(mock_fs): with mock.patch( - "flytekitplugins.skypilot.agent.SkyPathSetting", autospec=True, return_value=SkyPathSetting(task_level_prefix=str(mock_fs.local_sandbox_dir), unique_id="sky_mock") + "flytekitplugins.skypilot.agent.SkyPathSetting", + autospec=True, + return_value=SkyPathSetting(task_level_prefix=str(mock_fs.local_sandbox_dir), unique_id="sky_mock"), ) as mock_path, mock.patch( "flytekitplugins.skypilot.agent.setup_cloud_credential", autospec=True ) as cloud_setup, mock.patch( - "flytekitplugins.skypilot.agent.SkyTaskFuture.launch", autospec=True, side_effect=functools.partial(mock_launch, sleep_time=5) + "flytekitplugins.skypilot.agent.SkyTaskFuture.launch", + autospec=True, + side_effect=functools.partial(mock_launch, sleep_time=5), ) as mock_provider: yield (mock_path, cloud_setup, mock_provider) sky_path_fs = SkyTaskTracker._sky_path_setting.file_access.raw_output_fs @@ -189,34 +192,26 @@ def mock_provider(mock_fs): SkyTaskTracker._JOB_RESIGTRY.clear() - @pytest.mark.asyncio async def test_async_agent(mock_provider, mock_fs): - ( - mock_path, - cloud_setup, - skypath_mock - ) = mock_provider + (mock_path, cloud_setup, skypath_mock) = mock_provider # mock_provider.return_value = mock.MagicMock() serialization_settings = SerializationSettings(image_config=ImageConfig()) context = mock.MagicMock(spec=grpc.ServicerContext) mock_sky_queues() - + task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello0) agent = AgentRegistry.get_agent(task_spec.template.type) assert isinstance(agent, SkyPilotAgent) task_spec.template.container._args = get_container_args(mock_fs) - create_task_response = await agent.create( - context=context, - task_template=task_spec.template - ) + create_task_response = await agent.create(context=context, task_template=task_spec.template) resource_meta = create_task_response await asyncio.sleep(0) remote_task = SkyTaskTracker._JOB_RESIGTRY.get(resource_meta.job_name) check_task_not_done(remote_task) - get_task_response = await (agent.get(context=context, resource_meta=resource_meta)) + get_task_response = await agent.get(context=context, resource_meta=resource_meta) phase = get_task_response.phase await asyncio.sleep(0) assert phase == TaskExecution.INITIALIZING @@ -226,14 +221,9 @@ async def test_async_agent(mock_provider, mock_fs): await stop_sky_path() - @pytest.mark.asyncio async def test_agent_coro_failed(mock_provider, mock_fs): - ( - mock_path, - cloud_setup, - skypath_mock - ) = mock_provider + (mock_path, cloud_setup, skypath_mock) = mock_provider serialization_settings = SerializationSettings(image_config=ImageConfig()) context = mock.MagicMock(spec=grpc.ServicerContext) @@ -242,14 +232,11 @@ async def test_agent_coro_failed(mock_provider, mock_fs): agent = AgentRegistry.get_agent(task_spec.template.type) assert isinstance(agent, SkyPilotAgent) task_spec.template.container._args = get_container_args(mock_fs) - create_task_response = await agent.create( - context=context, - task_template=task_spec.template - ) + create_task_response = await agent.create(context=context, task_template=task_spec.template) resource_meta = create_task_response await asyncio.sleep(0) remote_task = SkyTaskTracker._JOB_RESIGTRY.get(resource_meta.job_name) - get_task_response = await (agent.get(context=context, resource_meta=resource_meta)) + get_task_response = await agent.get(context=context, resource_meta=resource_meta) phase = get_task_response.phase await asyncio.sleep(0) # causing task to fail diff --git a/plugins/flytekit-skypilot/tests/test_utils.py b/plugins/flytekit-skypilot/tests/test_utils.py index 8475ec81f7..192940fb4c 100644 --- a/plugins/flytekit-skypilot/tests/test_utils.py +++ b/plugins/flytekit-skypilot/tests/test_utils.py @@ -1,9 +1,10 @@ -import unittest from flytekitplugins.skypilot import utils + def test_execute_cmd_to_path(): random_dir = "/tmp/abc" - args = ["pyflyte-fast-execute", + args = [ + "pyflyte-fast-execute", "--additional-distribution", "{{ .remote_package_path }}", "--dest-dir", @@ -28,6 +29,6 @@ def test_execute_cmd_to_path(): "task-name", "say_hello0", ] - + new_args = utils.execute_cmd_to_path(args) - assert new_args["raw_output_data_prefix"] == random_dir \ No newline at end of file + assert new_args["raw_output_data_prefix"] == random_dir