Skip to content

Commit

Permalink
Merge pull request #17317 from mvdbeek/close_session_after_task
Browse files Browse the repository at this point in the history
[23.2] Discard connection after task completion
  • Loading branch information
natefoo committed Jan 17, 2024
2 parents 8118f07 + ef26fc4 commit b3d31bf
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 33 deletions.
8 changes: 8 additions & 0 deletions lib/galaxy/celery/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import uuid
from functools import (
lru_cache,
wraps,
Expand Down Expand Up @@ -167,6 +168,10 @@ def wrapper(*args, **kwds):
app = get_galaxy_app()
assert app

# Ensure sqlalchemy session registry scope is specific to this instance of the celery task
scoped_id = str(uuid.uuid4())
app.model.set_request_id(scoped_id)

desc = func.__name__
if action is not None:
desc += f" to {action}"
Expand All @@ -184,6 +189,9 @@ def wrapper(*args, **kwds):
except Exception:
log.warning(f"Celery task execution failed for {desc} {timer}")
raise
finally:
# Close and remove any open session this task has created
app.model.unset_request_id(scoped_id)

return wrapper

Expand Down
15 changes: 15 additions & 0 deletions test/integration/test_celery_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
purge_hda,
)
from galaxy.model import HistoryDatasetAssociation
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.schema import PdfDocumentType
from galaxy.schema.schema import CreatePagePayload
from galaxy.schema.tasks import GeneratePdfDownload
Expand All @@ -31,13 +32,27 @@ def process_page(request: CreatePagePayload):
return f"content_format is {request.content_format} with annotation {request.annotation}"


@galaxy_task
def invalidate_connection(sa_session: galaxy_scoped_session):
sa_session().connection().invalidate()


@galaxy_task
def use_session(sa_session: galaxy_scoped_session):
sa_session().query(HistoryDatasetAssociation).get(1)


class TestCeleryTasksIntegration(IntegrationTestCase):
dataset_populator: DatasetPopulator

def setUp(self):
super().setUp()
self.dataset_populator = DatasetPopulator(self.galaxy_interactor)

def test_recover_from_invalid_connection(self):
invalidate_connection.delay().get()
use_session.delay().get()

def test_random_simple_task_to_verify_framework_for_testing(self):
assert mul.delay(4, 4).get(timeout=10) == 16

Expand Down
51 changes: 18 additions & 33 deletions test/unit/app/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from contextlib import contextmanager
from typing import (
Iterator,
List,
)
from typing import List

from galaxy.celery import set_thread_app
from galaxy.app_unittest_utils.galaxy_mock import MockApp
from galaxy.celery.tasks import clean_object_store_caches
from galaxy.di import Container
from galaxy.objectstore import BaseObjectStore
from galaxy.objectstore.caching import CacheTarget

Expand All @@ -20,34 +15,24 @@ def cache_targets(self) -> List[CacheTarget]:


def test_clean_object_store_caches(tmp_path):
with celery_injected_app_container() as container:
cache_targets: List[CacheTarget] = []
container[BaseObjectStore] = MockObjectStore(cache_targets) # type: ignore[assignment]
container = MockApp()
cache_targets: List[CacheTarget] = []
container[BaseObjectStore] = MockObjectStore(cache_targets) # type: ignore[assignment]

# similar code used in object store unit tests
cache_dir = tmp_path
path = cache_dir / "a_file_0"
path.write_text("this is an example file")
# similar code used in object store unit tests
cache_dir = tmp_path
path = cache_dir / "a_file_0"
path.write_text("this is an example file")

# works fine on an empty list of cache targets...
clean_object_store_caches()
# works fine on an empty list of cache targets...
clean_object_store_caches()

assert path.exists()
assert path.exists()

# place the file in mock object store's cache targets and
# run the task again and the above file should be gone.
cache_targets.append(CacheTarget(cache_dir, 1, 0.000000001))
# works fine on an empty list of cache targets...
clean_object_store_caches()
# place the file in mock object store's cache targets and
# run the task again and the above file should be gone.
cache_targets.append(CacheTarget(cache_dir, 1, 0.000000001))
# works fine on an empty list of cache targets...
clean_object_store_caches()

assert not path.exists()


@contextmanager
def celery_injected_app_container() -> Iterator[Container]:
container = Container()
set_thread_app(container)
try:
yield container
finally:
set_thread_app(None)
assert not path.exists()

0 comments on commit b3d31bf

Please sign in to comment.