diff --git a/newrelic/config.py b/newrelic/config.py index cf029bd08..729c0739c 100644 --- a/newrelic/config.py +++ b/newrelic/config.py @@ -4374,6 +4374,11 @@ def _process_module_builtin_defaults(): "newrelic.hooks.application_celery", "instrument_celery_app_task", ) + _process_module_definition( + "celery.app.trace", + "newrelic.hooks.application_celery", + "instrument_celery_app_trace", + ) _process_module_definition("celery.worker", "newrelic.hooks.application_celery", "instrument_celery_worker") _process_module_definition( "celery.concurrency.processes", diff --git a/newrelic/hooks/application_celery.py b/newrelic/hooks/application_celery.py index 25a86a4a6..f3261289d 100644 --- a/newrelic/hooks/application_celery.py +++ b/newrelic/hooks/application_celery.py @@ -28,16 +28,18 @@ from newrelic.api.message_trace import MessageTrace from newrelic.api.pre_function import wrap_pre_function from newrelic.api.transaction import current_transaction -from newrelic.common.object_wrapper import FunctionWrapper, wrap_function_wrapper +from newrelic.common.object_wrapper import FunctionWrapper, wrap_function_wrapper, _NRBoundFunctionWrapper from newrelic.core.agent import shutdown_agent UNKNOWN_TASK_NAME = "" MAPPING_TASK_NAMES = {"celery.starmap", "celery.map"} -def task_name(*args, **kwargs): +def task_info(instance, *args, **kwargs): # Grab the current task, which can be located in either place - if args: + if instance: + task = instance + elif args: task = args[0] elif "task" in kwargs: task = kwargs["task"] @@ -46,6 +48,7 @@ def task_name(*args, **kwargs): # Task can be either a task instance or a signature, which subclasses dict, or an actual dict in some cases. task_name = getattr(task, "name", None) or task.get("task", UNKNOWN_TASK_NAME) + task_source = task # Under mapping tasks, the root task name isn't descriptive enough so we append the # subtask name to differentiate between different mapping tasks @@ -53,20 +56,19 @@ def task_name(*args, **kwargs): try: subtask = kwargs["task"]["task"] task_name = "/".join((task_name, subtask)) + task_source = task.app._tasks[subtask] except Exception: pass - return task_name + return task_name, task_source def CeleryTaskWrapper(wrapped): def wrapper(wrapped, instance, args, kwargs): transaction = current_transaction(active_only=False) - if instance is not None: - _name = task_name(instance, *args, **kwargs) - else: - _name = task_name(*args, **kwargs) + # Grab task name and source + _name, _source = task_info(instance, *args, **kwargs) # A Celery Task can be called either outside of a transaction, or # within the context of an existing transaction. There are 3 @@ -93,11 +95,11 @@ def wrapper(wrapped, instance, args, kwargs): return wrapped(*args, **kwargs) elif transaction: - with FunctionTrace(_name, source=instance): + with FunctionTrace(_name, source=_source): return wrapped(*args, **kwargs) else: - with BackgroundTask(application_instance(), _name, "Celery", source=instance) as transaction: + with BackgroundTask(application_instance(), _name, "Celery", source=_source) as transaction: # Attempt to grab distributed tracing headers try: # Headers on earlier versions of Celery may end up as attributes @@ -200,6 +202,26 @@ def wrap_Celery_send_task(wrapped, instance, args, kwargs): return wrapped(*args, **kwargs) +def wrap_worker_optimizations(wrapped, instance, args, kwargs): + # Attempt to uninstrument BaseTask before stack protection is installed or uninstalled + try: + from celery.app.task import BaseTask + + if isinstance(BaseTask.__call__, _NRBoundFunctionWrapper): + BaseTask.__call__ = BaseTask.__call__.__wrapped__ + except Exception: + BaseTask = None + + # Allow metaprogramming to run + result = wrapped(*args, **kwargs) + + # Rewrap finalized BaseTask + if BaseTask: # Ensure imports succeeded + BaseTask.__call__ = CeleryTaskWrapper(BaseTask.__call__) + + return result + + def instrument_celery_app_base(module): if hasattr(module, "Celery") and hasattr(module.Celery, "send_task"): wrap_function_wrapper(module, "Celery.send_task", wrap_Celery_send_task) @@ -239,3 +261,12 @@ def force_agent_shutdown(*args, **kwargs): if hasattr(module, "Worker"): wrap_pre_function(module, "Worker._do_exit", force_agent_shutdown) + + +def instrument_celery_app_trace(module): + # Uses same wrapper for setup and reset worker optimizations to prevent patching and unpatching from removing wrappers + if hasattr(module, "setup_worker_optimizations"): + wrap_function_wrapper(module, "setup_worker_optimizations", wrap_worker_optimizations) + + if hasattr(module, "reset_worker_optimizations"): + wrap_function_wrapper(module, "reset_worker_optimizations", wrap_worker_optimizations) diff --git a/tests/application_celery/test_task_methods.py b/tests/application_celery/test_task_methods.py index f1d78f32f..509129b09 100644 --- a/tests/application_celery/test_task_methods.py +++ b/tests/application_celery/test_task_methods.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest from _target_application import add, tsum -from celery import chain, chord, group +from testing_support.validators.validate_code_level_metrics import ( + validate_code_level_metrics, +) from testing_support.validators.validate_transaction_count import ( validate_transaction_count, ) @@ -22,20 +25,19 @@ validate_transaction_metrics, ) -FORGONE_TASK_METRICS = [("Function/_target_application.add", None), ("Function/_target_application.tsum", None)] +import celery -def test_task_wrapping_detection(): - """ - Ensure celery detects our monkeypatching properly and will run our instrumentation - on __call__ and runs that instead of micro-optimizing it away to a run() call. +FORGONE_TASK_METRICS = [("Function/_target_application.add", None), ("Function/_target_application.tsum", None)] - If this is not working, most other tests in this file will fail as the different ways - of running celery tasks will not all run our instrumentation. - """ - from celery.app.trace import task_has_custom - assert task_has_custom(add, "__call__") +@pytest.fixture(scope="module", autouse=True, params=[False, True], ids=["unpatched", "patched"]) +def with_worker_optimizations(request, celery_worker_available): + if request.param: + celery.app.trace.setup_worker_optimizations(celery_worker_available.app) + + yield request.param + celery.app.trace.reset_worker_optimizations() @validate_transaction_metrics( @@ -45,6 +47,7 @@ def test_task_wrapping_detection(): rollup_metrics=FORGONE_TASK_METRICS, background_task=True, ) +@validate_code_level_metrics("_target_application", "add") @validate_transaction_count(1) def test_celery_task_call(): """ @@ -61,6 +64,7 @@ def test_celery_task_call(): rollup_metrics=FORGONE_TASK_METRICS, background_task=True, ) +@validate_code_level_metrics("_target_application", "add") @validate_transaction_count(1) def test_celery_task_apply(): """ @@ -78,6 +82,7 @@ def test_celery_task_apply(): rollup_metrics=FORGONE_TASK_METRICS, background_task=True, ) +@validate_code_level_metrics("_target_application", "add") @validate_transaction_count(1) def test_celery_task_delay(): """ @@ -95,6 +100,7 @@ def test_celery_task_delay(): rollup_metrics=FORGONE_TASK_METRICS, background_task=True, ) +@validate_code_level_metrics("_target_application", "add") @validate_transaction_count(1) def test_celery_task_apply_async(): """ @@ -112,6 +118,7 @@ def test_celery_task_apply_async(): rollup_metrics=FORGONE_TASK_METRICS, background_task=True, ) +@validate_code_level_metrics("_target_application", "add") @validate_transaction_count(1) def test_celery_app_send_task(celery_session_app): """ @@ -129,6 +136,7 @@ def test_celery_app_send_task(celery_session_app): rollup_metrics=FORGONE_TASK_METRICS, background_task=True, ) +@validate_code_level_metrics("_target_application", "add") @validate_transaction_count(1) def test_celery_task_signature(): """ @@ -154,6 +162,8 @@ def test_celery_task_signature(): background_task=True, index=-2, ) +@validate_code_level_metrics("_target_application", "add") +@validate_code_level_metrics("_target_application", "add", index=-2) @validate_transaction_count(2) def test_celery_task_link(): """ @@ -179,12 +189,14 @@ def test_celery_task_link(): background_task=True, index=-2, ) +@validate_code_level_metrics("_target_application", "add") +@validate_code_level_metrics("_target_application", "add", index=-2) @validate_transaction_count(2) def test_celery_chain(): """ Executes multiple tasks on worker process and returns an AsyncResult. """ - result = chain(add.s(3, 4), add.s(5))() + result = celery.chain(add.s(3, 4), add.s(5))() result = result.get() assert result == 12 @@ -205,12 +217,14 @@ def test_celery_chain(): background_task=True, index=-2, ) +@validate_code_level_metrics("_target_application", "add") +@validate_code_level_metrics("_target_application", "add", index=-2) @validate_transaction_count(2) def test_celery_group(): """ Executes multiple tasks on worker process and returns an AsyncResult. """ - result = group(add.s(3, 4), add.s(1, 2))() + result = celery.group(add.s(3, 4), add.s(1, 2))() result = result.get() assert result == [7, 3] @@ -238,12 +252,15 @@ def test_celery_group(): background_task=True, index=-3, ) +@validate_code_level_metrics("_target_application", "tsum") +@validate_code_level_metrics("_target_application", "add", index=-2) +@validate_code_level_metrics("_target_application", "add", index=-3) @validate_transaction_count(3) def test_celery_chord(): """ Executes 2 add tasks, followed by a tsum task on the worker process and returns an AsyncResult. """ - result = chord([add.s(3, 4), add.s(1, 2)])(tsum.s()) + result = celery.chord([add.s(3, 4), add.s(1, 2)])(tsum.s()) result = result.get() assert result == 10 @@ -255,6 +272,7 @@ def test_celery_chord(): rollup_metrics=[("Function/_target_application.tsum", 2)], background_task=True, ) +@validate_code_level_metrics("_target_application", "tsum", count=3) @validate_transaction_count(1) def test_celery_task_map(): """ @@ -272,6 +290,7 @@ def test_celery_task_map(): rollup_metrics=[("Function/_target_application.add", 2)], background_task=True, ) +@validate_code_level_metrics("_target_application", "add", count=3) @validate_transaction_count(1) def test_celery_task_starmap(): """ @@ -297,6 +316,8 @@ def test_celery_task_starmap(): background_task=True, index=-2, ) +@validate_code_level_metrics("_target_application", "add", count=2) +@validate_code_level_metrics("_target_application", "add", count=2, index=-2) @validate_transaction_count(2) def test_celery_task_chunks(): """ diff --git a/tests/application_celery/test_wrappers.py b/tests/application_celery/test_wrappers.py new file mode 100644 index 000000000..1bca1b436 --- /dev/null +++ b/tests/application_celery/test_wrappers.py @@ -0,0 +1,46 @@ +# Copyright 2010 New Relic, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from _target_application import add + +import celery + +from newrelic.common.object_wrapper import _NRBoundFunctionWrapper + + +FORGONE_TASK_METRICS = [("Function/_target_application.add", None), ("Function/_target_application.tsum", None)] + + +def test_task_wrapping_detection(): + """ + Ensure celery detects our monkeypatching properly and will run our instrumentation + on __call__ and runs that instead of micro-optimizing it away to a run() call. + + If this is not working, most other tests in this file will fail as the different ways + of running celery tasks will not all run our instrumentation. + """ + assert celery.app.trace.task_has_custom(add, "__call__") + + +def test_worker_optimizations_preserve_instrumentation(celery_worker_available): + is_instrumented = lambda: isinstance(celery.app.task.BaseTask.__call__, _NRBoundFunctionWrapper) + + celery.app.trace.reset_worker_optimizations() + assert is_instrumented(), "Instrumentation not initially applied." + + celery.app.trace.setup_worker_optimizations(celery_worker_available.app) + assert is_instrumented(), "setup_worker_optimizations removed instrumentation." + + celery.app.trace.reset_worker_optimizations() + assert is_instrumented(), "reset_worker_optimizations removed instrumentation." diff --git a/tests/testing_support/validators/validate_transaction_count.py b/tests/testing_support/validators/validate_transaction_count.py index 3ceea7725..ffd4567cf 100644 --- a/tests/testing_support/validators/validate_transaction_count.py +++ b/tests/testing_support/validators/validate_transaction_count.py @@ -17,18 +17,22 @@ def validate_transaction_count(count): - _transactions = [] + transactions = [] @transient_function_wrapper('newrelic.core.stats_engine', 'StatsEngine.record_transaction') def _increment_count(wrapped, instance, args, kwargs): - _transactions.append(getattr(args[0], "name", True)) + transactions.append(getattr(args[0], "name", True)) return wrapped(*args, **kwargs) @function_wrapper def _validate_transaction_count(wrapped, instance, args, kwargs): _new_wrapped = _increment_count(wrapped) result = _new_wrapped(*args, **kwargs) + + _transactions = list(transactions) + del transactions[:] # Clear list for subsequent test runs + assert count == len(_transactions), (count, len(_transactions), _transactions) return result