Skip to content
40 changes: 26 additions & 14 deletions clients/python/src/examples/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,34 @@ def scheduler() -> None:
type=int,
)
def worker(rpc_host: str, concurrency: int, push_mode: bool, grpc_port: int) -> None:
from taskbroker_client.worker import TaskWorker
from taskbroker_client.worker import PushTaskWorker, TaskWorker

click.echo("Starting worker")
worker = TaskWorker(
app_module="examples.app:app",
broker_hosts=[rpc_host],
max_child_task_count=100,
concurrency=concurrency,
child_tasks_queue_maxsize=concurrency * 2,
result_queue_maxsize=concurrency * 2,
rebalance_after=32,
processing_pool_name="examples",
process_type="forkserver",
push_mode=push_mode,
grpc_port=grpc_port,
)
if push_mode:
worker: PushTaskWorker | TaskWorker = PushTaskWorker(
app_module="examples.app:app",
broker_service=rpc_host,
max_child_task_count=100,
concurrency=concurrency,
child_tasks_queue_maxsize=concurrency * 2,
result_queue_maxsize=concurrency * 2,
rebalance_after=32,
processing_pool_name="examples",
process_type="forkserver",
grpc_port=grpc_port,
)
else:
worker = TaskWorker(
app_module="examples.app:app",
broker_hosts=[rpc_host],
max_child_task_count=100,
concurrency=concurrency,
child_tasks_queue_maxsize=concurrency * 2,
result_queue_maxsize=concurrency * 2,
rebalance_after=32,
processing_pool_name="examples",
process_type="forkserver",
)
exitcode = worker.start()
raise SystemExit(exitcode)

Expand Down
4 changes: 2 additions & 2 deletions clients/python/src/taskbroker_client/worker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .worker import TaskWorker
from .worker import PushTaskWorker, TaskWorker

__all__ = ("TaskWorker",)
__all__ = ("TaskWorker", "PushTaskWorker")
103 changes: 103 additions & 0 deletions clients/python/src/taskbroker_client/worker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,3 +486,106 @@ def update_task(
receive_timestamp=time.monotonic(),
)
return None


class PushTaskbrokerClient:
"""
Taskworker RPC client wrapper

Push brokers are a deployment so they don't need to be connected to individually. There is one service provided
that works for all the brokers.
"""

def __init__(
self,
service: str,
application: str,
metrics: MetricsBackend,
health_check_settings: HealthCheckSettings | None = None,
rpc_secret: str | None = None,
grpc_config: str | None = None,
) -> None:
self._application = application
self._service = service
self._rpc_secret = rpc_secret
self._metrics = metrics

self._grpc_options: list[tuple[str, Any]] = [
("grpc.max_receive_message_length", MAX_ACTIVATION_SIZE)
]
if grpc_config:
self._grpc_options.append(("grpc.service_config", grpc_config))
Comment thread
evanh marked this conversation as resolved.

logger.info(
"taskworker.push_client.start",
extra={"service": service, "options": self._grpc_options},
)

self._stub = self._connect_to_host(service)

self._health_check_settings = health_check_settings
self._timestamp_since_touch_lock = threading.Lock()
self._timestamp_since_touch = 0.0

def _emit_health_check(self) -> None:
if self._health_check_settings is None:
return

with self._timestamp_since_touch_lock:
cur_time = time.time()
if (
cur_time - self._timestamp_since_touch
< self._health_check_settings.touch_interval_sec
):
return

self._health_check_settings.file_path.touch()
self._metrics.incr(
"taskworker.client.health_check.touched",
)
self._timestamp_since_touch = cur_time

def _connect_to_host(self, host: str) -> ConsumerServiceStub:
logger.info("taskworker.push_client.connect", extra={"host": host})
channel = grpc.insecure_channel(host, options=self._grpc_options)
secrets = parse_rpc_secret_list(self._rpc_secret)
if secrets:
channel = grpc.intercept_channel(channel, RequestSignatureInterceptor(secrets))
return ConsumerServiceStub(channel)

def update_task(
self,
processing_result: ProcessingResult,
) -> None:
"""
Update the status for a given task activation.
"""
self._emit_health_check()

request = SetTaskStatusRequest(
id=processing_result.task_id,
status=processing_result.status,
fetch_next_task=None,
)

retries = 0
exception = None
while retries < 3:
try:
with self._metrics.timer(
"taskworker.update_task.rpc", tags={"service": self._service}
):
self._stub.SetTaskStatus(request)
exception = None
break
except grpc.RpcError as err:
exception = err
Comment thread
evanh marked this conversation as resolved.
self._metrics.incr(
"taskworker.client.rpc_error",
tags={"method": "SetTaskStatus", "status": err.code().name},
)
finally:
retries += 1

if exception:
raise exception
Loading
Loading