Skip to content

Commit

Permalink
Add contextvars to sync threaded activities
Browse files Browse the repository at this point in the history
  • Loading branch information
cretz committed Feb 7, 2023
1 parent 24bce5e commit 10e20d2
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 6 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,12 @@ poe test

This runs against [Temporalite](https://github.com/temporalio/temporalite). To run against the time-skipping test
server, pass `--workflow-environment time-skipping`. To run against the `default` namespace of an already-running
server, pass the `host:port` to `--workflow-environment`.
server, pass the `host:port` to `--workflow-environment`. Can also use regular pytest arguments. For example, here's how
to run a single test with debug logs on the console:

```bash
poe test -s --log-cli-level=DEBUG -k test_sync_activity_thread_cancel_caught
```

#### Proto Generation and Testing

Expand Down
18 changes: 14 additions & 4 deletions temporalio/worker/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import asyncio
import concurrent.futures
import contextvars
import inspect
import logging
import multiprocessing
Expand Down Expand Up @@ -668,9 +669,9 @@ async def heartbeat_with_context(*details: Any) -> None:
assert cancelled_event
worker_shutdown_event = self._worker._worker_shutdown_event
assert worker_shutdown_event
return await loop.run_in_executor(
input.executor,
_execute_sync_activity,
# Prepare func and args
func: Callable = _execute_sync_activity
args = [
info,
heartbeat,
self._running_activity.cancel_thread_raiser,
Expand All @@ -679,7 +680,16 @@ async def heartbeat_with_context(*details: Any) -> None:
worker_shutdown_event.thread_event,
input.fn,
*input.args,
)
]
# If we're threaded, we want to pass the context through. We
# have to do this manually, see
# https://github.com/python/cpython/issues/78195.
if isinstance(input.executor, concurrent.futures.ThreadPoolExecutor):
current_context = contextvars.copy_context()
args.insert(0, func)
func = current_context.run
# Invoke
return await loop.run_in_executor(input.executor, func, *args)
finally:
if shared_manager:
await shared_manager.unregister_heartbeater(info.task_token)
Expand Down
44 changes: 43 additions & 1 deletion tests/worker/test_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import time
import uuid
from concurrent.futures.process import BrokenProcessPool
from contextvars import ContextVar
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any, Callable, List, NoReturn, Optional, Sequence
Expand All @@ -32,7 +33,14 @@
TimeoutType,
)
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import SharedStateManager, Worker, WorkerConfig
from temporalio.worker import (
ActivityInboundInterceptor,
ExecuteActivityInput,
Interceptor,
SharedStateManager,
Worker,
WorkerConfig,
)
from tests.helpers.worker import (
ExternalWorker,
KSAction,
Expand Down Expand Up @@ -1072,6 +1080,40 @@ async def test_activity_async_cancel(
assert list(err.value.cause.cause.details) == ["cancel details"]


some_context_var: ContextVar[str] = ContextVar("some_context_var", default="unset")


class ContextVarInterceptor(Interceptor):
def intercept_activity(
self, next: ActivityInboundInterceptor
) -> ActivityInboundInterceptor:
return super().intercept_activity(ContextVarActivityInboundInterceptor(next))


class ContextVarActivityInboundInterceptor(ActivityInboundInterceptor):
async def execute_activity(self, input: ExecuteActivityInput) -> Any:
some_context_var.set("some value!")
return await super().execute_activity(input)


async def test_sync_activity_contextvars(client: Client, worker: ExternalWorker):
@activity.defn
def some_activity() -> str:
return f"context var: {some_context_var.get()}"

with concurrent.futures.ThreadPoolExecutor() as executor:
result = await _execute_workflow_with_activity(
client,
worker,
some_activity,
worker_config={
"activity_executor": executor,
"interceptors": [ContextVarInterceptor()],
},
)
assert result.result == "context var: some value!"


@dataclass
class _ActivityResult:
act_task_queue: str
Expand Down

0 comments on commit 10e20d2

Please sign in to comment.