Skip to content

Commit

Permalink
feat:add retry config & ignore exit code setting (#361)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
xiaoyeqiannian and pre-commit-ci[bot] committed Aug 23, 2023
1 parent 4408389 commit f8470e8
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 18 deletions.
27 changes: 24 additions & 3 deletions dpdispatcher/dp_cloud_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def __init__(self, context):
phone = context.remote_profile.get("phone", None)
username = context.remote_profile.get("username", None)
password = context.remote_profile.get("password", None)
self.retry_count = context.remote_profile.get("retry_count", 3)
self.ignore_exit_code = context.remote_profile.get("ignore_exit_code", True)

ticket = os.environ.get("BOHR_TICKET", None)
if ticket:
Expand Down Expand Up @@ -110,7 +112,6 @@ def do_submit(self, job):
# oss_task_zip = 'indicate/' + job.job_hash + '/' + zip_filename
oss_task_zip = self._gen_oss_path(job, zip_filename)
job_resources = ALI_OSS_BUCKET_URL + oss_task_zip

input_data = self.input_data.copy()

if not input_data.get("job_resources"):
Expand Down Expand Up @@ -187,7 +188,9 @@ def check_status(self, job):
f"cannot find job information in bohrium for job {job.job_id} {check_return} {retry_return}"
)

job_state = self.map_dp_job_state(dp_job_status)
job_state = self.map_dp_job_state(
dp_job_status, check_return.get("exitCode", 0), self.ignore_exit_code
)
if job_state == JobStatus.finished:
job_log = self.api.get_log(job_id)
if self.input_data.get("output_log"):
Expand Down Expand Up @@ -232,7 +235,7 @@ def check_if_recover(self, submission):
# pass

@staticmethod
def map_dp_job_state(status):
def map_dp_job_state(status, exit_code, ignore_exit_code=True):
if isinstance(status, JobStatus):
return status
map_dict = {
Expand All @@ -244,10 +247,13 @@ def map_dp_job_state(status):
4: JobStatus.running,
5: JobStatus.terminated,
6: JobStatus.running,
9: JobStatus.waiting,
}
if status not in map_dict:
dlog.error(f"unknown job status {status}")
return JobStatus.unknown
if status == -1 and exit_code != 0 and ignore_exit_code:
return JobStatus.finished
return map_dict[status]

def kill(self, job):
Expand All @@ -261,6 +267,21 @@ def kill(self, job):
job_id = job.job_id
self.api.kill(job_id)

def get_exit_code(self, job) -> int:
job_id = self._parse_job_id(job.job_id)
if job_id <= 0:
raise RuntimeError(f"cannot parse job id {job.job_id}")

check_return = self._get_job_detail(job_id, self.group_id)
return check_return.get("exitCode", -999) # type: ignore

def _parse_job_id(self, str_job_id: str) -> int:
job_id = 0
if "job_group_id" in str_job_id:
ids = str_job_id.split(":job_group_id:")
job_id, _ = int(ids[0]), int(ids[1])
return job_id


DpCloudServer = Bohrium
Lebesgue = Bohrium
24 changes: 21 additions & 3 deletions dpdispatcher/dp_cloud_server_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from dpdispatcher import dlog
from dpdispatcher.base_context import BaseContext
from dpdispatcher.dpcloudserver.config import ALI_STS_BUCKET_NAME, ALI_STS_ENDPOINT

# from dpdispatcher.submission import Machine
# from . import dlog
Expand All @@ -20,8 +21,6 @@
DP_CLOUD_SERVER_HOME_DIR = os.path.join(
os.path.expanduser("~"), ".dpdispatcher/", "dp_cloud_server/"
)
ENDPOINT = "http://oss-cn-shenzhen.aliyuncs.com"
BUCKET_NAME = os.environ.get("BUCKET_NAME", "dpcloudserver")


class BohriumContext(BaseContext):
Expand Down Expand Up @@ -124,7 +123,9 @@ def upload_job(self, job, common_files=None):
upload_zip = zip_file.zip_file_list(
self.local_root, zip_task_file, file_list=upload_file_list
)
result = self.api.upload(oss_task_zip, upload_zip, ENDPOINT, BUCKET_NAME)
result = self.api.upload(
oss_task_zip, upload_zip, ALI_STS_ENDPOINT, ALI_STS_BUCKET_NAME
)
retry_count = 0
self._backup(self.local_root, upload_zip)

Expand Down Expand Up @@ -285,6 +286,9 @@ def machine_subfields(cls) -> List[Argument]:
doc_remote_profile = (
"The information used to maintain the connection with remote machine."
)
doc_retry_count = "The retry count when a job is terminated"
doc_ignore_exit_code = """The job state will be marked as finished if the exit code is non-zero when set to True. Otherwise,
the job state will be designated as terminated."""
return [
Argument(
"remote_profile",
Expand All @@ -299,6 +303,20 @@ def machine_subfields(cls) -> List[Argument]:
alias=["project_id"],
doc="Program ID",
),
Argument(
"retry_count",
[int, type(None)],
optional=True,
default=3,
doc=doc_retry_count,
),
Argument(
"ignore_exit_code",
bool,
optional=True,
default=True,
doc=doc_ignore_exit_code,
),
Argument(
"keep_backup",
bool,
Expand Down
1 change: 0 additions & 1 deletion dpdispatcher/dpcloudserver/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,3 @@
"DPDISPATCHER_LEBESGUE_ALI_OSS_BUCKET_URL",
"https://dpcloudserver.oss-cn-shenzhen.aliyuncs.com/",
)
# ALI_OSS_BUCKET_URL = 'https://dpcloudserver.oss-cn-shenzhen.aliyuncs.com/
12 changes: 12 additions & 0 deletions dpdispatcher/machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,15 @@ def kill(self, job):
job
"""
dlog.warning("Job %s should be manually killed" % job.job_id)

def get_exit_code(self, job):
"""Get exit code of the job.
Parameters
----------
job : Job
job
"""
raise NotImplementedError(
"abstract method get_exit_code should be implemented by derived class"
)
33 changes: 25 additions & 8 deletions dpdispatcher/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def __init__(self, context):
self.remote_profile = context.remote_profile.copy()

self.grouped = self.remote_profile.get("grouped", True)
self.retry_count = self.remote_profile.get("retry_count", 3)
self.ignore_exit_code = context.remote_profile.get("ignore_exit_code", True)
self.client = Client()
self.job = Job(client=self.client)
self.storage = Storage(client=self.client)
Expand Down Expand Up @@ -80,8 +82,9 @@ def do_submit(self, job):
"out_files": self._gen_backward_files_list(job),
"platform": self.remote_profile.get("platform", "ali"),
"image_address": self.remote_profile.get("image_address", ""),
"job_id": job.job_id,
}
if job.job_state == JobStatus.unsubmitted:
openapi_params["job_id"] = job.job_id

data = self.job.insert(**openapi_params)

Expand Down Expand Up @@ -126,12 +129,13 @@ def check_status(self, job):
f"cannot find job information in bohrium for job {job.job_id} {check_return} {retry_return}"
)

job_state = self.map_dp_job_state(dp_job_status)
job_state = self.map_dp_job_state(
dp_job_status, check_return.get("exitCode", 0), self.ignore_exit_code # type: ignore
)
if job_state == JobStatus.finished:
job_log = self.job.log(job_id)
if self.remote_profile.get("output_log"):
print(job_log, end="")
# print(job.job_id)
self._download_job(job)
elif self.remote_profile.get("output_log") and job_state == JobStatus.running:
job_log = self.job.log(job_id)
Expand All @@ -140,7 +144,6 @@ def check_status(self, job):

def _download_job(self, job):
data = self.job.detail(job.job_id)
# print(data)
job_url = data["jobFiles"]["outFiles"][0]["url"] # type: ignore
if not job_url:
return
Expand Down Expand Up @@ -174,7 +177,7 @@ def check_if_recover(self, submission):
# pass

@staticmethod
def map_dp_job_state(status):
def map_dp_job_state(status, exit_code, ignore_exit_code=True):
if isinstance(status, JobStatus):
return status
map_dict = {
Expand All @@ -191,6 +194,8 @@ def map_dp_job_state(status):
if status not in map_dict:
dlog.error(f"unknown job status {status}")
return JobStatus.unknown
if status == -1 and exit_code != 0 and ignore_exit_code:
return JobStatus.finished
return map_dict[status]

def kill(self, job):
Expand All @@ -204,6 +209,18 @@ def kill(self, job):
job_id = job.job_id
self.job.kill(job_id)

# def check_finish_tag(self, job):
# job_tag_finished = job.job_hash + '_job_tag_finished'
# return self.context.check_file_exists(job_tag_finished)
def get_exit_code(self, job):
"""Get exit code of the job.
Parameters
----------
job : Job
job
Returns
-------
int
exit code
"""
check_return = self.job.detail(job.job_id)
return check_return.get("exitCode", -999) # type: ignore
1 change: 0 additions & 1 deletion dpdispatcher/ssh_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ def arginfo():
doc_look_for_keys = (
"enable searching for discoverable private key files in ~/.ssh/"
)

ssh_remote_profile_args = [
Argument("hostname", str, optional=False, doc=doc_hostname),
Argument("username", str, optional=False, doc=doc_username),
Expand Down
7 changes: 5 additions & 2 deletions dpdispatcher/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,6 @@ def __init__(
# self.job_work_base = job_work_base
self.resources = resources
self.machine = machine

self.job_state = None # JobStatus.unsubmitted
self.job_id = ""
self.fail_count = 0
Expand Down Expand Up @@ -839,7 +838,11 @@ def handle_unexpected_job_state(self):
f"job: {self.job_hash} {self.job_id} terminated;"
f"fail_cout is {self.fail_count}; resubmitting job"
)
if (self.fail_count) > 0 and (self.fail_count % 3 == 0):
retry_count = 3
assert self.machine is not None
if hasattr(self.machine, "retry_count") and self.machine.retry_count > 0:
retry_count = self.machine.retry_count
if (self.fail_count) > 0 and (self.fail_count % retry_count == 0):
raise RuntimeError(
f"job:{self.job_hash} {self.job_id} failed {self.fail_count} times.job_detail:{self}"
)
Expand Down

0 comments on commit f8470e8

Please sign in to comment.