Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ljcornel committed Dec 20, 2022
2 parents 6c760b0 + 7859a47 commit 096e66c
Show file tree
Hide file tree
Showing 80 changed files with 217 additions and 162 deletions.
63 changes: 59 additions & 4 deletions geti_sdk/rest_clients/training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
TaskConfiguration,
)
from geti_sdk.data_models.containers import AlgorithmList
from geti_sdk.data_models.enums import JobState
from geti_sdk.data_models.enums import JobState, JobType
from geti_sdk.data_models.project import Dataset
from geti_sdk.http_session import GetiRequestException, GetiSession
from geti_sdk.platform_versions import GETI_11_VERSION
Expand Down Expand Up @@ -143,7 +143,8 @@ def train_task(
enable_pot_optimization: bool = False,
hyper_parameters: Optional[TaskConfiguration] = None,
hpo_parameters: Optional[Dict[str, Any]] = None,
timeout: int = 300,
await_running_jobs: bool = True,
timeout: int = 3600,
) -> Job:
"""
Start training of a specific task in the project.
Expand All @@ -163,10 +164,16 @@ def train_task(
:param hyper_parameters: Optional hyper parameters to use for training
:param hpo_parameters: Optional set of parameters to use for automatic hyper
parameter optimization. Only supported for version 1.1 and up
:param await_running_jobs: True to wait for currently running jobs to
complete. This will guarantee that the training request can be submitted
successfully. Setting this to False will cause an error to be raised when
a training request is submitted for a task for which a training job is
already in progress.
:param timeout: Timeout (in seconds) to wait for the Job to be created. If a
training request is submitted successfully, a training job should be
instantiated on the Geti server. If the Job does not appear on the server
job list within the `timeout`, an error will be raised.
job list within the `timeout`, an error will be raised. This parameter only
takes effect when `await_running_jobs` is set to True.
:return: The training job that has been created
"""
if isinstance(task, int):
Expand Down Expand Up @@ -201,6 +208,32 @@ def train_task(
else:
data = {"training_parameters": [request_data]}

task_jobs = self.get_jobs_for_task(task=task)
task_training_jobs = [job for job in task_jobs if job.type == JobType.TRAIN]
log_warning_msg = ""
if len(task_training_jobs) >= 1:
if len(task_training_jobs) == 1:
msg_start = f"A training job for task '{task.title}' is"
else:
msg_start = f"Multiple training jobs for task '{task.title}' are"
log_warning_msg = msg_start + " already in progress on the server."

if len(task_training_jobs) >= 1:
if not await_running_jobs:
raise RuntimeError(
log_warning_msg + " Unable to submit training request. Please "
"wait for the current training job to finish or "
"set the `await_running_jobs` parameter in this "
"method to `True`."
)
else:
logging.info(
log_warning_msg + f" Awaiting completion of currently running jobs "
f"before a new train request can be submitted. "
f"Maximum waiting time set to {timeout} seconds."
)
self.monitor_jobs(jobs=task_training_jobs, timeout=timeout)

response = self.session.get_rest_response(
url=f"{self.base_url}/train", method="POST", data=data
)
Expand All @@ -216,7 +249,7 @@ def train_task(
logging.info(f"Training job with ID {job_id} submitted successfully.")
else:
t_start = time.time()
while job is None and (time.time() - t_start < timeout):
while job is None and (time.time() - t_start < 10):
logging.info(
f"Training request was submitted but the training job status could "
f"not be retrieved from the platform yet. Re-attempting to fetch "
Expand Down Expand Up @@ -291,3 +324,25 @@ def monitor_jobs(
f"Monitoring stopped after {t_elapsed:.1f} seconds due to timeout."
)
return jobs

def get_jobs_for_task(self, task: Task, running_only: bool = True) -> List[Job]:
"""
Return a list of current jobs for the task, if any
:param task: Task to retrieve the jobs for
:param running_only: True to return only jobs that are currently running,
False to return all jobs (including cancelled, finished or errored jobs)
:return: List of Jobs running on the server for this particular task
"""
project_jobs = self.get_jobs(project_only=True)
task_jobs: List[Job] = []
for job in project_jobs:
if job.metadata is not None:
if job.metadata.task is not None:
if job.metadata.task.task_id == task.id:
if running_only:
if job.status.state == JobState.RUNNING:
task_jobs.append(job)
else:
task_jobs.append(job)
return task_jobs
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
4 changes: 2 additions & 2 deletions tests/fixtures/cassettes/TestGetiSession.test_logout.cassette
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
4 changes: 2 additions & 2 deletions tests/fixtures/cassettes/geti.cassette
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Loading

0 comments on commit 096e66c

Please sign in to comment.