Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions clients/python/src/examples/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,12 @@ def will_fail_with_silenced_ignored_exception() -> None:
)
def will_retry_on_deadline_exceeded() -> None:
timed_task(sleep_seconds=2)


@exampletasks.register(name="examples.task_with_headers", pass_headers=True)
def task_with_headers(value: str, headers: dict[str, str]) -> None:
redis = StrictRedis(host="localhost", port=6379, decode_responses=True)
redis.set("task-headers-value", value)
redis.set("task-headers-count", str(len(headers)))
if "x-custom-header" in headers:
redis.set("task-headers-custom", headers["x-custom-header"])
7 changes: 7 additions & 0 deletions clients/python/src/taskbroker_client/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def register(
compression_type: CompressionType = CompressionType.PLAINTEXT,
report_timeout_errors: bool = True,
silenced_exceptions: tuple[type[BaseException], ...] | None = None,
pass_headers: bool = False,
) -> Callable[[Callable[P, R]], Task[P, R]]:
"""
Register a task.
Expand Down Expand Up @@ -121,6 +122,9 @@ def register(
Enable reporting of ProcessingDeadlineExceededError to Sentry.
silenced_exceptions: tuple[type[BaseException], ...] | None
A tuple of exception types that will not be reported by Sentry.
pass_headers: bool
If True, the task function will receive task activation headers
as a keyword argument named `headers` (dict[str, str]).
"""

def wrapped(func: Callable[P, R]) -> Task[P, R]:
Expand All @@ -141,6 +145,7 @@ def wrapped(func: Callable[P, R]) -> Task[P, R]:
compression_type=compression_type,
report_timeout_errors=report_timeout_errors,
silenced_exceptions=silenced_exceptions,
pass_headers=pass_headers,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're going to inject parameters into the task, should we validate that the task doesn't already have a headers parameter with an incompatible type?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another potential solution would be to pass the entire activation into the task,

as in, you have a pass_activation option?

Since we're going to inject parameters into the task, should we validate that the task doesn't already have a headers parameter with an incompatible type?

this feels magical but i can do it

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can enforce that the headers kwarg is there explicitly. def mytask(**kwargs): is then just invalid.

Copy link
Copy Markdown
Member

@markstory markstory May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can enforce that the headers kwarg is there explicitly. def mytask(**kwargs): is then just invalid.

I was more thinking of the scenario where a task has an incompatible headers parameter like:

def some_task(headers: set[str]):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as in, you have a pass_activation option?

Yeah, but as I think about it more I like this idea less, and pass_headers is the better solution.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the current version covers this. it requires the header arg to be there, to have the right type, and to be explicit

)
# TODO(taskworker) tasks should be registered into the registry
# so that we can ensure task names are globally unique
Expand Down Expand Up @@ -234,6 +239,7 @@ def register(
compression_type: CompressionType = CompressionType.PLAINTEXT,
report_timeout_errors: bool = True,
silenced_exceptions: tuple[type[BaseException], ...] | None = None,
pass_headers: bool = False,
) -> Callable[[Callable[P, R]], ExternalTask[P, R]]:
"""
Register an external task stub.
Expand Down Expand Up @@ -285,6 +291,7 @@ def wrapped(func: Callable[P, R]) -> ExternalTask[P, R]:
compression_type=compression_type,
report_timeout_errors=report_timeout_errors,
silenced_exceptions=silenced_exceptions,
pass_headers=pass_headers,
)
self._registered_tasks[name] = task
return task
Expand Down
60 changes: 58 additions & 2 deletions clients/python/src/taskbroker_client/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import base64
import datetime
import inspect
import os
import time
from collections.abc import Callable, Collection, Mapping, MutableMapping
from functools import update_wrapper
from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar
from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar, get_origin
from uuid import uuid4

import msgpack
Expand Down Expand Up @@ -54,6 +55,42 @@ def _get_parameters_format() -> ParametersFormat:
R = TypeVar("R")


def assert_typed_kwarg(
func: Callable[..., Any],
param_name: str,
expected_types: tuple[type, ...],
context: str,
) -> None:
"""
Validate that a function has a keyword argument with a compatible type annotation.

Raises TypeError if:
- The parameter does not exist
- The parameter is positional-only
- The parameter has a type annotation that is not in expected_types
"""
sig = inspect.signature(func, eval_str=True)
if param_name not in sig.parameters:
raise TypeError(f"{context}: function does not have a {param_name!r} parameter")

param = sig.parameters[param_name]
if param.kind == inspect.Parameter.POSITIONAL_ONLY:
raise TypeError(
f"{context}: {param_name!r} parameter is positional-only. "
f"It must be a keyword argument."
)

if param.annotation is not inspect.Parameter.empty:
origin = get_origin(param.annotation)
if origin is None:
origin = param.annotation
if origin not in expected_types:
raise TypeError(
f"{context}: {param_name!r} parameter has type {param.annotation!r}. "
f"Expected one of: {', '.join(t.__name__ for t in expected_types)}."
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any.__name__ crashes error formatting on Python 3.10

Low Severity

The error message in assert_typed_kwarg uses t.__name__ to format all types in expected_types, but typing.Any does not have a __name__ attribute on Python 3.10 (where it's a _SpecialForm instance, not a class). This causes an AttributeError to be raised instead of the intended helpful TypeError when a user provides an incompatible type annotation for headers. The expected_types tuple at the call site includes Any, so the generator expression t.__name__ for t in expected_types will crash when it reaches Any.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 6e002f9. Configure here.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't care. Pick better types.

)
Comment thread
cursor[bot] marked this conversation as resolved.
Comment thread
untitaker marked this conversation as resolved.


class Task(Generic[P, R]):
def __init__(
self,
Expand All @@ -68,6 +105,7 @@ def __init__(
compression_type: CompressionType = CompressionType.PLAINTEXT,
report_timeout_errors: bool = True,
silenced_exceptions: tuple[type[BaseException], ...] | None = None,
pass_headers: bool = False,
):
self.name = name
self._func = func
Expand All @@ -88,6 +126,16 @@ def __init__(
self.compression_type = compression_type
self.report_timeout_errors = report_timeout_errors
self.silenced_exceptions = silenced_exceptions or ()
self.pass_headers = pass_headers

if pass_headers:
assert_typed_kwarg(
func,
"headers",
(dict, Mapping, MutableMapping, Any),
f"Task {name!r} with pass_headers=True",
)

update_wrapper(self, func)

@property
Expand Down Expand Up @@ -154,7 +202,15 @@ def apply_async(

def _call_func(self, *args: Any, **kwargs: Any) -> None:
# Overridden in ExternalTask
self._func(*args, **kwargs)
if self.pass_headers:
if "headers" in kwargs:
raise TypeError(
f"Task '{self.name}' has pass_headers=True, but 'headers' was passed in kwargs. "
"The 'headers' parameter is injected by the worker and cannot be passed by the caller."
)
self._func(*args, headers={}, **kwargs) # type: ignore[arg-type]
Comment thread
untitaker marked this conversation as resolved.
else:
self._func(*args, **kwargs)

def _signal_send(self, task: Task[Any, Any], args: Any, kwargs: Any) -> None:
"""
Expand Down
3 changes: 3 additions & 0 deletions clients/python/src/taskbroker_client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from arroyo.types import BrokerValue, Topic
from sentry_protos.taskbroker.v1.taskbroker_pb2 import TaskActivation, TaskActivationStatus

TaskHeaders = dict[str, str]
"""Headers passed to a task function when pass_headers=True is set."""


class ContextHook(Protocol):
"""
Expand Down
10 changes: 9 additions & 1 deletion clients/python/src/taskbroker_client/worker/workerchild.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,15 @@ def _execute_activation(
):
for hook in context_hooks:
stack.enter_context(hook.on_execute(headers))
task_func(*args, **kwargs)
if task_func.pass_headers:
if "headers" in kwargs:
raise TypeError(
f"Task '{task_func.name}' has pass_headers=True, but 'headers' was passed in kwargs. "
"The 'headers' parameter is injected by the worker and cannot be passed by the caller."
)
task_func(*args, headers=headers, **kwargs)
Comment thread
cursor[bot] marked this conversation as resolved.
else:
task_func(*args, **kwargs)
transaction.set_status(SPANSTATUS.OK)
except Exception:
transaction.set_status(SPANSTATUS.INTERNAL_ERROR)
Expand Down
88 changes: 88 additions & 0 deletions clients/python/tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,3 +465,91 @@ def multi_task() -> None:
activation = multi_task.create_activation([], {})
assert activation.headers["x-test-context"] == "dispatched"
assert activation.headers["x-another"] == "also-here"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for the case when the user passes headers as one of the kwargs to their function?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. testing for a specific error doesn't make a lot of sense if there's no explicit errorhandling for it, so I added more checks to the code too. i think it's a bit bloated now, but probably more user-friendly.


def test_task_pass_headers_attribute(task_namespace: TaskNamespace) -> None:
"""Tasks can opt into receiving headers via pass_headers=True."""

@task_namespace.register(name="test.with_headers", pass_headers=True)
def with_headers(org_id: int, headers: dict[str, str]) -> None:
pass

assert with_headers.pass_headers is True

@task_namespace.register(name="test.without_headers")
def without_headers(org_id: int) -> None:
pass

assert without_headers.pass_headers is False


def test_pass_headers_requires_headers_parameter(task_namespace: TaskNamespace) -> None:
"""Tasks with pass_headers=True must have a 'headers' parameter."""
with pytest.raises(TypeError, match="does not have a 'headers' parameter"):

@task_namespace.register(name="test.missing_headers", pass_headers=True)
def missing_headers(org_id: int) -> None:
pass


def test_pass_headers_rejects_positional_only_headers(task_namespace: TaskNamespace) -> None:
"""Tasks with pass_headers=True cannot have a positional-only 'headers' parameter."""
with pytest.raises(TypeError, match="positional-only"):

@task_namespace.register(name="test.positional_headers", pass_headers=True)
def positional_headers(org_id: int, headers: dict[str, str], /) -> None:
pass


def test_pass_headers_rejects_incompatible_type_annotation(
task_namespace: TaskNamespace,
) -> None:
"""Tasks with pass_headers=True must have a dict-like type annotation for 'headers'."""
with pytest.raises(TypeError, match="Expected one of: dict"):

@task_namespace.register(name="test.wrong_type_headers", pass_headers=True)
def wrong_type_headers(org_id: int, headers: str) -> None:
pass


def test_delay_immediate_mode_with_pass_headers(task_namespace: TaskNamespace) -> None:
"""In ALWAYS_EAGER mode, tasks with pass_headers=True receive empty headers."""
calls: list[dict[str, Any]] = []

@task_namespace.register(name="test.headers_task", pass_headers=True)
def headers_task(value: str, headers: dict[str, str]) -> None:
calls.append({"value": value, "headers": headers})

with patch("taskbroker_client.task.ALWAYS_EAGER", True):
headers_task.delay("test") # type: ignore[call-arg]

assert len(calls) == 1
assert calls[0]["value"] == "test"
assert calls[0]["headers"] == {}


def test_pass_headers_rejects_headers_in_kwargs(task_namespace: TaskNamespace) -> None:
"""Tasks with pass_headers=True reject 'headers' passed in kwargs."""

@task_namespace.register(name="test.headers_collision", pass_headers=True)
def headers_task(value: str, headers: dict[str, str]) -> None:
pass

with patch("taskbroker_client.task.ALWAYS_EAGER", True):
with pytest.raises(TypeError, match="cannot be passed by the caller"):
headers_task.delay("test", headers={"x-custom": "value"})


def test_delay_immediate_mode_without_pass_headers(task_namespace: TaskNamespace) -> None:
"""In ALWAYS_EAGER mode, tasks without pass_headers do not receive headers kwarg."""
calls: list[dict[str, Any]] = []

@task_namespace.register(name="test.no_headers_task")
def no_headers_task(value: str) -> None:
calls.append({"value": value})

with patch("taskbroker_client.task.ALWAYS_EAGER", True):
no_headers_task.delay("test")

assert len(calls) == 1
assert calls[0] == {"value": "test"}
44 changes: 44 additions & 0 deletions clients/python/tests/worker/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,22 @@
),
)

TASK_WITH_HEADERS = InflightTaskActivation(
host="localhost:50051",
receive_timestamp=0,
activation=TaskActivation(
id="headers_task_123",
taskname="examples.task_with_headers",
namespace="examples",
parameters_bytes=msgpack.packb({"args": ["test_value"], "kwargs": {}}, use_bin_type=True),
headers={
"x-custom-header": "custom_value",
"sentry-trace": "trace-id",
},
processing_deadline_duration=2,
),
)


class TestTaskWorker(TestCase):
def test_fetch_task(self) -> None:
Expand Down Expand Up @@ -814,6 +830,34 @@ def test_child_process_record_checkin(mock_capture_checkin: mock.Mock) -> None:
)


def test_child_process_pass_headers() -> None:
"""Task with pass_headers=True receives headers from the activation."""
todo: queue.Queue[InflightTaskActivation] = queue.Queue()
processed: queue.Queue[ProcessingResult] = queue.Queue()
shutdown = Event()

todo.put(TASK_WITH_HEADERS)
child_process(
"examples.app:app",
todo,
processed,
shutdown,
max_task_count=1,
processing_pool_name="test",
process_type="fork",
)

assert todo.empty()
result = processed.get()
assert result.task_id == TASK_WITH_HEADERS.activation.id
assert result.status == TASK_ACTIVATION_STATUS_COMPLETE

redis = StrictRedis(host="localhost", port=6379, decode_responses=True)
assert redis.get("task-headers-value") == "test_value"
assert redis.get("task-headers-custom") == "custom_value"
redis.delete("task-headers-value", "task-headers-count", "task-headers-custom")


@mock.patch("taskbroker_client.worker.workerchild.sentry_sdk.capture_exception")
def test_child_process_terminate_task(mock_capture: mock.Mock) -> None:
todo: queue.Queue[InflightTaskActivation] = queue.Queue()
Expand Down
Loading