Skip to content

Commit

Permalink
[API] Close DB session (#5623)
Browse files Browse the repository at this point in the history
  • Loading branch information
alonmr committed May 26, 2024
1 parent 48485ec commit f7240c7
Showing 1 changed file with 42 additions and 37 deletions.
79 changes: 42 additions & 37 deletions server/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,51 +278,55 @@ async def _verify_log_collection_started_on_startup(
:param start_logs_limit: Semaphore which limits the number of concurrent log collection tasks
"""
db_session = await fastapi.concurrency.run_in_threadpool(create_session)
logger.debug(
"Getting all runs which are in non terminal state and require logs collection"
)
runs = await fastapi.concurrency.run_in_threadpool(
get_db().list_distinct_runs_uids,
db_session,
requested_logs_modes=[None, False],
only_uids=False,
states=mlrun.common.runtimes.constants.RunStates.non_terminal_states(),
)
logger.debug(
"Getting all runs which might have reached terminal state while the API was down",
api_downtime_grace_period=config.log_collector.api_downtime_grace_period,
)
runs.extend(
await fastapi.concurrency.run_in_threadpool(
try:
logger.debug(
"Getting all runs which are in non terminal state and require logs collection"
)
runs = await fastapi.concurrency.run_in_threadpool(
get_db().list_distinct_runs_uids,
db_session,
requested_logs_modes=[None, False],
only_uids=False,
# We take the minimum between the api_downtime_grace_period and the runtime_resources_deletion_grace_period
# because we want to make sure that we don't miss any runs which might have reached terminal state while the
# API was down, and their runtime resources are not deleted
last_update_time_from=datetime.datetime.now(datetime.timezone.utc)
- datetime.timedelta(
seconds=min(
int(config.log_collector.api_downtime_grace_period),
int(config.runtime_resources_deletion_grace_period),
)
),
states=mlrun.common.runtimes.constants.RunStates.terminal_states(),
states=mlrun.common.runtimes.constants.RunStates.non_terminal_states(),
)
)
if runs:
logger.debug(
"Found runs which require logs collection on startup",
runs_uids=[run.get("metadata", {}).get("uid", None) for run in runs],
"Getting all runs which might have reached terminal state while the API was down",
api_downtime_grace_period=config.log_collector.api_downtime_grace_period,
)

# we're using best_effort=True so the api will mark the runs as requested logs collection even in cases
# where the log collection failed (e.g. when the pod is not found for runs that might have reached terminal
# state while the API was down)
await _start_log_and_update_runs(
start_logs_limit, db_session, runs, best_effort=True
runs.extend(
await fastapi.concurrency.run_in_threadpool(
get_db().list_distinct_runs_uids,
db_session,
requested_logs_modes=[None, False],
only_uids=False,
# We take the minimum between the api_downtime_grace_period and the
# runtime_resources_deletion_grace_period because we want to make sure that we don't miss any runs
# which might have reached terminal state while the API was down, and their runtime resources
# are not deleted
last_update_time_from=datetime.datetime.now(datetime.timezone.utc)
- datetime.timedelta(
seconds=min(
int(config.log_collector.api_downtime_grace_period),
int(config.runtime_resources_deletion_grace_period),
)
),
states=mlrun.common.runtimes.constants.RunStates.terminal_states(),
)
)
if runs:
logger.debug(
"Found runs which require logs collection on startup",
runs_uids=[run.get("metadata", {}).get("uid", None) for run in runs],
)

# we're using best_effort=True so the api will mark the runs as requested logs collection even in cases
# where the log collection failed (e.g. when the pod is not found for runs that might have reached terminal
# state while the API was down)
await _start_log_and_update_runs(
start_logs_limit, db_session, runs, best_effort=True
)
finally:
await fastapi.concurrency.run_in_threadpool(close_session, db_session)


async def _initiate_logs_collection(start_logs_limit: asyncio.Semaphore):
Expand Down Expand Up @@ -417,6 +421,7 @@ async def _start_log_and_update_runs(
logger.debug(
"Updating runs to indicate that we requested logs collection for them",
runs_uids=runs_to_mark_as_requested_logs,
start_log_request_counters_len=len(_run_uid_start_log_request_counters),
)
# update the runs to indicate that we have requested log collection for them
await fastapi.concurrency.run_in_threadpool(
Expand Down

0 comments on commit f7240c7

Please sign in to comment.