Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 115 additions & 25 deletions src/launchpad/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import multiprocessing
import os
import sys
import threading

from dataclasses import dataclass
from functools import partial
Expand Down Expand Up @@ -53,9 +54,30 @@ def _process_in_subprocess(decoded_message: Any, log_queue: multiprocessing.Queu
sys.exit(1)


def _kill_process(process: multiprocessing.Process, artifact_id: str) -> None:
"""Gracefully terminate, then force kill a subprocess."""
process.terminate()
process.join(timeout=5)
if process.is_alive():
logger.warning(
"Process did not terminate gracefully, force killing",
extra={"artifact_id": artifact_id},
)
process.kill()
process.join(timeout=1) # Brief timeout to reap zombie, avoid infinite block
if process.is_alive():
logger.error(
"Process could not be killed, may become zombie",
extra={"artifact_id": artifact_id},
)


def process_kafka_message_with_service(
msg: Message[KafkaPayload],
log_queue: multiprocessing.Queue[Any],
process_registry: dict[int, tuple[multiprocessing.Process, str]],
registry_lock: threading.Lock,
factory: LaunchpadStrategyFactory,
) -> Any:
"""Process a Kafka message by spawning a fresh subprocess with timeout protection."""
timeout = int(os.getenv("KAFKA_TASK_TIMEOUT_SECONDS", "720")) # 12 minutes default
Expand All @@ -71,32 +93,49 @@ def process_kafka_message_with_service(
# Spawn actual processing in a subprocess
process = multiprocessing.Process(target=_process_in_subprocess, args=(decoded, log_queue))
process.start()
process.join(timeout=timeout)

if process.is_alive():
logger.error(
"Launchpad task killed after exceeding timeout",
extra={"timeout_seconds": timeout, "artifact_id": artifact_id},
)
process.terminate()
process.join(timeout=5) # Give it 5s to terminate gracefully
# Register the process for tracking (PID is always set after start())
with registry_lock:
process_registry[process.pid] = (process, artifact_id) # type: ignore[index]

try:
process.join(timeout=timeout)

# Handle timeout (process still alive after full timeout)
if process.is_alive():
logger.warning(
"Process did not terminate gracefully, force killing",
extra={"artifact_id": artifact_id},
logger.error(
"Launchpad task killed after exceeding timeout",
extra={"timeout_seconds": timeout, "artifact_id": artifact_id},
)
process.kill()
process.join()
return None # type: ignore[return-value]

if process.exitcode != 0:
logger.error(
"Process exited with non-zero code",
extra={"exit_code": process.exitcode, "artifact_id": artifact_id},
)
return None # type: ignore[return-value]
_kill_process(process, artifact_id)
return None # type: ignore[return-value]

if process.exitcode != 0:
# Check if we killed it during rebalance - if so, don't commit offset
pid = process.pid
if pid is not None:
with registry_lock:
was_killed_by_rebalance = pid in factory._killed_during_rebalance
factory._killed_during_rebalance.discard(pid)

if was_killed_by_rebalance:
logger.warning(
"Process killed during rebalance, message will be reprocessed",
extra={"artifact_id": artifact_id},
)
raise TimeoutError("Subprocess killed during rebalance")

# All other failures - skip message
logger.error(
"Process exited with non-zero code",
extra={"exit_code": process.exitcode, "artifact_id": artifact_id},
)
return None # type: ignore[return-value]

return decoded # type: ignore[no-any-return]
return decoded # type: ignore[no-any-return]
finally:
with registry_lock:
process_registry.pop(process.pid, None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this also handles removing PIDs of successful message processes, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes!



def create_kafka_consumer() -> LaunchpadKafkaConsumer:
Expand Down Expand Up @@ -155,6 +194,32 @@ def create_kafka_consumer() -> LaunchpadKafkaConsumer:
return LaunchpadKafkaConsumer(processor, strategy_factory, healthcheck_path)


class ShutdownAwareStrategy(ProcessingStrategy[KafkaPayload]):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

id love to not have to do this.. 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but i dont think there is another way to intercept the _close_strategy call that we care about..

"""Wrapper that kills active subprocesses during rebalance."""

def __init__(self, inner: ProcessingStrategy[KafkaPayload], factory: LaunchpadStrategyFactory):
self._inner = inner
self._factory = factory

def submit(self, message: Message[KafkaPayload]) -> None:
self._inner.submit(message)

def poll(self) -> None:
self._inner.poll()

def close(self) -> None:
# Kill all active subprocesses BEFORE closing inner strategy
self._factory.kill_active_processes()
self._inner.close()

def terminate(self) -> None:
self._factory.kill_active_processes()
self._inner.terminate()

def join(self, timeout: float | None = None) -> None:
self._inner.join(timeout)


class LaunchpadKafkaConsumer:
processor: StreamProcessor[KafkaPayload]
strategy_factory: LaunchpadStrategyFactory
Expand Down Expand Up @@ -214,6 +279,10 @@ def __init__(
self._queue_listener = self._setup_queue_listener()
self._queue_listener.start()

self._active_processes: dict[int, tuple[multiprocessing.Process, str]] = {}
self._processes_lock = threading.Lock()
self._killed_during_rebalance: set[int] = set()

self.concurrency = concurrency
self.max_pending_futures = max_pending_futures
self.healthcheck_file = healthcheck_file
Expand All @@ -225,6 +294,21 @@ def _setup_queue_listener(self) -> QueueListener:

return QueueListener(self._log_queue, *handlers, respect_handler_level=True)

def kill_active_processes(self) -> None:
"""Kill all active subprocesses. Called during rebalancing."""
with self._processes_lock:
if self._active_processes:
logger.info(
"Killing %d active subprocess(es) during rebalance",
len(self._active_processes),
)
for pid, (process, artifact_id) in list(self._active_processes.items()):
if process.is_alive():
self._killed_during_rebalance.add(pid)
logger.info("Terminating subprocess with PID %d", pid)
_kill_process(process, artifact_id)
self._active_processes.clear()

def create_with_partitions(
self,
commit: Commit,
Expand All @@ -235,15 +319,21 @@ def create_with_partitions(
assert self.healthcheck_file
next_step = Healthcheck(self.healthcheck_file, next_step)

processing_function = partial(process_kafka_message_with_service, log_queue=self._log_queue)
strategy = RunTaskInThreads(
processing_function = partial(
process_kafka_message_with_service,
log_queue=self._log_queue,
process_registry=self._active_processes,
registry_lock=self._processes_lock,
factory=self,
)
inner_strategy = RunTaskInThreads(
processing_function=processing_function,
concurrency=self.concurrency,
max_pending_futures=self.max_pending_futures,
next_step=next_step,
)

return strategy
return ShutdownAwareStrategy(inner_strategy, self)

def close(self) -> None:
"""Clean up the logging queue and listener."""
Expand Down
Loading