diff --git a/cads_processing_api_service/clients.py b/cads_processing_api_service/clients.py index 3c287716..8adfe7d3 100644 --- a/cads_processing_api_service/clients.py +++ b/cads_processing_api_service/clients.py @@ -202,7 +202,9 @@ def post_process_execution( job_kwargs = adaptors.make_system_job_kwargs( resource, execution_content, adaptor.resources ) - compute_sessionmaker = db_utils.get_compute_sessionmaker() + compute_sessionmaker = db_utils.get_compute_sessionmaker( + mode=db_utils.ConnectionMode.write + ) with compute_sessionmaker() as compute_session: job = cads_broker.database.create_request( session=compute_session, @@ -284,7 +286,9 @@ def get_jobs( statement, self.job_table, back, sort_key, sort_dir ) statement = utils.apply_limit(statement, limit) - compute_sessionmaker = db_utils.get_compute_sessionmaker() + compute_sessionmaker = db_utils.get_compute_sessionmaker( + mode=db_utils.ConnectionMode.read + ) catalogue_sessionmaker = db_utils.get_catalogue_sessionmaker() with compute_sessionmaker() as compute_session: job_entries = compute_session.scalars(statement).all() @@ -351,9 +355,22 @@ def get_job( """ user_uid = auth.authenticate_user(auth_header, portal_header) portals = [p.strip() for p in portal_header.split(",")] - compute_sessionmaker = db_utils.get_compute_sessionmaker() - with compute_sessionmaker() as compute_session: - job = utils.get_job_from_broker_db(job_id=job_id, session=compute_session) + try: + compute_sessionmaker = db_utils.get_compute_sessionmaker( + mode=db_utils.ConnectionMode.read + ) + with compute_sessionmaker() as compute_session: + job = utils.get_job_from_broker_db( + job_id=job_id, session=compute_session + ) + except ogc_api_processes_fastapi.exceptions.NoSuchJob: + compute_sessionmaker = db_utils.get_compute_sessionmaker( + mode=db_utils.ConnectionMode.write + ) + with compute_sessionmaker() as compute_session: + job = utils.get_job_from_broker_db( + job_id=job_id, session=compute_session + ) if job["portal"] not in portals: raise ogc_api_processes_fastapi.exceptions.NoSuchJob() auth.verify_permission(user_uid, job) @@ -407,12 +424,34 @@ def get_job_results( structlog.contextvars.bind_contextvars(job_id=job_id) user_uid = auth.authenticate_user(auth_header, portal_header) structlog.contextvars.bind_contextvars(user_id=user_uid) - compute_sessionmaker = db_utils.get_compute_sessionmaker() - with compute_sessionmaker() as compute_session: - job = utils.get_job_from_broker_db(job_id=job_id, session=compute_session) - auth.verify_permission(user_uid, job) - results = utils.get_results_from_broker_db(job=job, session=compute_session) - handle_download_metrics(job, results) + try: + compute_sessionmaker = db_utils.get_compute_sessionmaker( + mode=db_utils.ConnectionMode.read + ) + with compute_sessionmaker() as compute_session: + job = utils.get_job_from_broker_db( + job_id=job_id, session=compute_session + ) + auth.verify_permission(user_uid, job) + results = utils.get_results_from_broker_db( + job=job, session=compute_session + ) + except ( + ogc_api_processes_fastapi.exceptions.NoSuchJob, + ogc_api_processes_fastapi.exceptions.ResultsNotReady, + ): + compute_sessionmaker = db_utils.get_compute_sessionmaker( + mode=db_utils.ConnectionMode.write + ) + with compute_sessionmaker() as compute_session: + job = utils.get_job_from_broker_db( + job_id=job_id, session=compute_session + ) + auth.verify_permission(user_uid, job) + results = utils.get_results_from_broker_db( + job=job, session=compute_session + ) + handle_download_metrics(job, results) return results def delete_job( @@ -441,7 +480,9 @@ def delete_job( structlog.contextvars.bind_contextvars(job_id=job_id) user_uid = auth.authenticate_user(auth_header, portal_header) structlog.contextvars.bind_contextvars(user_id=user_uid) - compute_sessionmaker = db_utils.get_compute_sessionmaker() + compute_sessionmaker = db_utils.get_compute_sessionmaker( + mode=db_utils.ConnectionMode.write + ) with compute_sessionmaker() as compute_session: job = utils.get_job_from_broker_db(job_id=job_id, session=compute_session) auth.verify_permission(user_uid, job) diff --git a/cads_processing_api_service/db_utils.py b/cads_processing_api_service/db_utils.py index e481c898..fed56ac9 100644 --- a/cads_processing_api_service/db_utils.py +++ b/cads_processing_api_service/db_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License +import enum import functools import cads_broker.config @@ -22,18 +23,39 @@ import sqlalchemy.orm +class ConnectionMode(str, enum.Enum): + """Database connection mode.""" + + read = "read" + write = "write" + + @functools.lru_cache() -def get_compute_sessionmaker() -> sqlalchemy.orm.sessionmaker[sqlalchemy.orm.Session]: +def get_compute_sessionmaker( + mode: ConnectionMode = ConnectionMode.write, +) -> sqlalchemy.orm.sessionmaker[sqlalchemy.orm.Session]: """Get an sqlalchemy.orm.sessionmaker object bound to the Broker database. + Parameters + ---------- + mode: ConnectionMode + Connection mode to the database. If ConnectionMode.read, the sessionmaker + will open a connection to a read-only hostname. + Returns ------- sqlalchemy.orm.sessionmaker sqlalchemy.orm.sessionmaker object bound to the Broker database. """ broker_settings = cads_broker.config.ensure_settings() + if mode == ConnectionMode.write: + connection_string = broker_settings.connection_string + elif mode == ConnectionMode.read: + connection_string = broker_settings.connection_string_read + else: + raise ValueError(f"Invalid connection mode: {str(mode)}") broker_engine = sqlalchemy.create_engine( - broker_settings.connection_string, + connection_string, pool_timeout=broker_settings.pool_timeout, pool_recycle=broker_settings.pool_recycle, )