Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
72cb2aa
Add TCP send-side batching to eliminate per-message drain() overhead
kPsarakis May 16, 2026
742e69a
Add ruff noqa suppression for print statements in microbenchmark
kPsarakis May 16, 2026
c695475
Fix test_tcp_networking_coverage to match buffer_message API
kPsarakis May 16, 2026
dba3eb9
Fix all ruff complaints in tcp_networking.py and microbench_messaging.py
kPsarakis May 16, 2026
6910652
ruff format
kPsarakis May 16, 2026
4f0bb04
fix ruff
kPsarakis May 16, 2026
b0488c3
Switch flush loop from timer-based to asyncio.sleep(0)
kPsarakis May 16, 2026
898942a
Add phase-based microbenchmark modelling Styx's round-trip-per-phase …
kPsarakis May 16, 2026
1ed0945
Remove global socket lock from send_message hot path
kPsarakis May 16, 2026
34b876b
Revert TCP send-side batching, keep only the lock-free pool lookup
kPsarakis May 16, 2026
65a5085
TCP perf: client-side TCP_NODELAY + buffers, larger default pool
kPsarakis May 16, 2026
6b7bfca
TCP perf: cache 2-byte framing headers per (msg_type, serializer_id)
kPsarakis May 16, 2026
71303b6
TCP perf: load-aware connection pick in SocketPool
kPsarakis May 16, 2026
e95e11f
Add profile_live_worker.py: py-spy wrapper for live-process profiling
kPsarakis May 16, 2026
be22f6c
Bump socket buffers to 4 MB on both client and server
kPsarakis May 17, 2026
47e8105
Batch RunFunRemote+Ack, inline protocol dispatch, fix deepcopy fallth…
kPsarakis May 17, 2026
e3fcb88
Trim per-call async plumbing on Aria hot path
kPsarakis May 17, 2026
6f220aa
fix ruff
kPsarakis May 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions coordinator/coordinator_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import botocore
from coordinator_metadata import Coordinator
from prometheus_client import Gauge, start_http_server
from styx.common.base_networking import SOCKET_RCV_BUF, SOCKET_SND_BUF
from styx.common.logging import logging
from styx.common.message_types import MessageType
from styx.common.protocols import Protocols
Expand Down Expand Up @@ -87,8 +88,8 @@ def __init__(self) -> None:
struct.pack("ii", 1, 0),
) # Enable LINGER, timeout 0
self.coor_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.coor_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 1024)
self.coor_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 1024)
self.coor_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, SOCKET_SND_BUF)
self.coor_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, SOCKET_RCV_BUF)
self.coor_socket.bind(("0.0.0.0", SERVER_PORT)) # noqa: S104
self.coor_socket.setblocking(False)

Expand All @@ -99,16 +100,8 @@ def __init__(self) -> None:
struct.pack("ii", 1, 0),
) # Enable LINGER, timeout 0
self.protocol_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.protocol_socket.setsockopt(
socket.SOL_SOCKET,
socket.SO_SNDBUF,
1024 * 1024,
)
self.protocol_socket.setsockopt(
socket.SOL_SOCKET,
socket.SO_RCVBUF,
1024 * 1024,
)
self.protocol_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, SOCKET_SND_BUF)
self.protocol_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, SOCKET_RCV_BUF)
self.protocol_socket.bind(("0.0.0.0", SERVER_PORT + 1)) # noqa: S104
self.protocol_socket.setblocking(False)

Expand Down
139 changes: 139 additions & 0 deletions scripts/profile_live_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
#!/usr/bin/env python3
"""Profile a *running* Styx worker with py-spy.

The existing profile_hotpaths.py / profile_tpcc.py / profile_ycsb.py scripts
profile components in isolation (synthetic workloads). This one attaches to a
live worker process during a real TPC-C run and produces a flame graph showing
where wall-clock time actually goes — the only reliable way to find real
bottlenecks that synthetic microbenchmarks miss.

Prerequisite:
pip install py-spy

Usage:
# Local worker, known PID
python scripts/profile_live_worker.py --pid 12345 --duration 30

# Local worker by process name (matches the python script that contains it)
python scripts/profile_live_worker.py --name worker_service --duration 30

# All matching processes (e.g. multiple worker threads)
python scripts/profile_live_worker.py --name worker_service --all --duration 30

# Top-style live view (no SVG, just see-it-now)
python scripts/profile_live_worker.py --pid 12345 --top

Kubernetes (run inside the pod):
kubectl exec -it <worker-pod> -- bash -c 'pip install py-spy && py-spy record --pid 1 --duration 30 -o /tmp/p.svg'
kubectl cp <worker-pod>:/tmp/p.svg ./worker.svg

Output:
profile-<pid>-<timestamp>.svg (flame graph, open in browser)

Notes:
- py-spy is a sampling profiler so overhead is negligible (~1%).
- On Linux it needs CAP_SYS_PTRACE or root for attach. Try:
sudo setcap cap_sys_ptrace=eip $(which py-spy)
- On Windows it just works (admin not required for own processes).
"""

# ruff: noqa: S603, PLC0415

from __future__ import annotations

import argparse
from pathlib import Path
import shutil
import subprocess
import sys
import time


def find_pids_by_name(name: str) -> list[int]:
"""Find Python process PIDs whose command line contains `name`."""
try:
import psutil
except ImportError:
sys.exit("Need psutil for --name lookup. Run: pip install psutil")

pids: list[int] = []
for p in psutil.process_iter(["pid", "name", "cmdline"]):
try:
cmd = " ".join(p.info.get("cmdline") or [])
if name in cmd and "python" in (p.info.get("name") or "").lower():
pids.append(p.info["pid"])
except (psutil.NoSuchProcess, psutil.AccessDenied):
continue
return pids


def require_pyspy() -> str:
path = shutil.which("py-spy")
if path is None:
sys.exit("py-spy not found on PATH. Install with: pip install py-spy")
return path


def record(pyspy: str, pid: int, duration: int, output: Path) -> int:
cmd = [pyspy, "record", "--pid", str(pid), "--duration", str(duration), "-o", str(output)]
print(f" $ {' '.join(cmd)}")
return subprocess.call(cmd)


def top(pyspy: str, pid: int) -> int:
cmd = [pyspy, "top", "--pid", str(pid)]
print(f" $ {' '.join(cmd)}")
return subprocess.call(cmd)


def main() -> None:
ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
target = ap.add_mutually_exclusive_group(required=True)
target.add_argument("--pid", type=int, help="PID of the Styx worker process")
target.add_argument("--name", type=str, help="Process-name substring to locate worker(s)")
ap.add_argument("--duration", type=int, default=30, help="Recording duration in seconds (default: 30)")
ap.add_argument("--all", action="store_true", help="With --name: profile all matches in parallel")
ap.add_argument("--top", action="store_true", help="Show live top view instead of recording SVG")
args = ap.parse_args()

pyspy = require_pyspy()

if args.pid is not None:
pids = [args.pid]
else:
pids = find_pids_by_name(args.name)
if not pids:
sys.exit(f"No process found matching name={args.name!r}")
print(f"Found {len(pids)} match(es) for name={args.name!r}: {pids}")
if not args.all:
pids = pids[:1]
print(f"Using first match: {pids[0]} (pass --all to profile all)")

if args.top:
if len(pids) != 1:
sys.exit("--top requires a single PID (omit --all or use --pid)")
sys.exit(top(pyspy, pids[0]))

ts = time.strftime("%Y%m%d-%H%M%S")
procs: list[tuple[int, subprocess.Popen]] = []
for pid in pids:
out = Path(f"profile-{pid}-{ts}.svg")
cmd = [pyspy, "record", "--pid", str(pid), "--duration", str(args.duration), "-o", str(out)]
print(f"Recording PID {pid} -> {out}")
procs.append((pid, subprocess.Popen(cmd)))

failed = 0
for pid, proc in procs:
rc = proc.wait()
if rc != 0:
print(f" PID {pid}: py-spy exit {rc}", file=sys.stderr)
failed += 1
else:
print(f" PID {pid}: done")

if failed:
sys.exit(failed)


if __name__ == "__main__":
main()
12 changes: 12 additions & 0 deletions styx-package/styx/common/base_networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@
USE_COMPRESSION: bool = bool(strtobool(os.getenv("ENABLE_COMPRESSION", "true")))
COMPRESS_AFTER: int = int(os.getenv("COMPRESS_AFTER", "4096"))

# Per-socket kernel buffer sizes. Applied on both client-side accepted sockets
# (StyxSocketClient) and server-side listening sockets (inherited to accepted).
# A TCP connection's in-flight window = min(sender.SNDBUF, receiver.RCVBUF), so
# both sides must agree or the smaller wins.
#
# Default 4 MB matches what Linux auto-tuning would settle on for typical
# datacenter BDPs (10-25 Gbps at 0.5-1 ms RTT). Note: setting these explicitly
# *disables* Linux kernel auto-tuning for the socket, so we trade dynamic
# growth for a known cap.
SOCKET_SND_BUF: int = int(os.getenv("SOCKET_SND_BUF", str(4 << 20)))
SOCKET_RCV_BUF: int = int(os.getenv("SOCKET_RCV_BUF", str(4 << 20)))


class MessagingMode(IntEnum):
WORKER_COR = 0
Expand Down
2 changes: 2 additions & 0 deletions styx-package/styx/common/message_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,5 @@ class MessageType(IntEnum):
MigrationInitDone = 39
InitDataComplete = 40
UpdateExecutionGraph = 41
RunFunRemoteBatch = 42
AckBatch = 43
25 changes: 15 additions & 10 deletions styx-package/styx/common/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ async def run_function(
await self._send_chain_abort(str(resp), ack_host, ack_port, ack_id)
success = False
elif n_remote_calls == 0:
await self.__send_ack(
# Non-awaiting: buffers the ack; flush happens automatically
# at the next event-loop tick (`NetworkingManager.enqueue_ack`).
self.__send_ack(
ack_host,
ack_port,
ack_id,
Expand Down Expand Up @@ -214,25 +216,28 @@ async def _send_chain_abort(
serializer=Serializer.MSGPACK,
)

async def __send_ack(
def __send_ack(
self,
ack_host: str,
ack_port: int,
ack_id: int,
fraction_str: str,
chain_participants: list[int],
) -> None:
"""Sends an acknowledgement to the worker that holds the root of a distributed chain during normal operation.
"""Records an acknowledgement for the root of a distributed chain.

Same-network: updates the local fraction map directly. Cross-network:
buffers the ack for the next event-loop tick; the networking layer
coalesces all acks for the same peer into a single `AckBatch` send.

Args:
ack_host: Hostname or IP of the next worker.
ack_port: Port number of the next worker.
ack_host: Hostname or IP of the chain root's worker.
ack_port: Port number of the chain root's worker.
ack_id: Acknowledgement ID for the chain.
fraction_str: Fraction of chain progress.
chain_participants: List of worker IDs that participated in the chain.
"""
if self.__networking.in_the_same_network(ack_host, ack_port):
# case when the ack host is the same worker
self.__networking.add_ack_fraction_str(
ack_id,
fraction_str,
Expand All @@ -241,12 +246,12 @@ async def __send_ack(
else:
if self.__networking.worker_id not in chain_participants:
chain_participants.append(self.__networking.worker_id)
await self.__networking.send_message(
self.__networking.enqueue_ack(
ack_host,
ack_port,
msg=(ack_id, fraction_str, chain_participants),
msg_type=MessageType.Ack,
serializer=Serializer.MSGPACK,
ack_id,
fraction_str,
chain_participants,
)

def __materialize_function(
Expand Down
91 changes: 62 additions & 29 deletions styx-package/styx/common/stateful_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,16 +267,22 @@ async def __send_async_calls(
self.__networking.prepare_function_chain(self.__t_id)
elif not self.__networking.in_the_same_network(ack_host, ack_port):
chain_participants.append(self.__networking.worker_id)
# Every call in this chain step shares the same ack metadata; build
# the tuple once instead of allocating one per entry.
ack_payload: tuple[str, int, int, str, list[int]] = (
ack_host,
ack_port,
self.__t_id,
new_share_fraction,
chain_participants,
)
# Bucket remote (off-worker) calls by peer so we can send one
# RunFunRemoteBatch per peer instead of one RunFunRemote per call.
# Local calls still dispatch directly into the in-process protocol.
remote_calls: list[Awaitable] = []
remote_batches: dict[tuple[str, int], list[tuple]] = {}
for entry in self.__async_remote_calls:
operator_name, function_name, partition, key, params, is_local = entry
ack_payload: tuple[str, int, int, str, list[int]] = (
ack_host,
ack_port,
self.__t_id,
new_share_fraction,
chain_participants,
)
if is_local:
payload = RunFuncPayload(
request_id=self.__request_id,
Expand Down Expand Up @@ -306,17 +312,31 @@ async def __send_async_calls(
self.__protocol.run_function(t_id=self.__t_id, payload=payload),
)
else:
remote_calls.append(
self.__call_remote_function_no_response(
operator_name=operator_name,
function_name=function_name,
key=key,
partition=partition,
params=params,
ack_payload=ack_payload,
),
wire_payload, peer_host, peer_port = self.__prepare_message_transmission(
operator_name,
key,
function_name,
partition,
params,
ack_payload,
)
await asyncio.gather(*remote_calls)
remote_batches.setdefault((peer_host, peer_port), []).append(wire_payload)
for (peer_host, peer_port), batch in remote_batches.items():
remote_calls.append(
self.__networking.send_message(
peer_host,
peer_port,
msg=batch,
msg_type=MessageType.RunFunRemoteBatch,
serializer=Serializer.MSGPACK,
),
)
# Many chain stages have exactly one downstream call (e.g. NewOrder's
# warehouse → district hop). Skip the `gather` machinery in that case.
if len(remote_calls) == 1:
await remote_calls[0]
else:
await asyncio.gather(*remote_calls)
return n_remote_calls

def __get_partition(self, operator_name: str, key: K) -> int:
Expand Down Expand Up @@ -408,18 +428,31 @@ def __prepare_message_transmission(
tuple: (payload, operator_host, operator_port)
"""
try:
payload = (
self.__t_id, # __T_ID__
self.__request_id, # __RQ_ID__
operator_name, # __OP_NAME__
function_name, # __FUN_NAME__
key, # __KEY__
partition, # __PARTITION__
self.__fallback_enabled,
params,
) # __PARAMS__
if ack_payload is not None:
payload += (ack_payload,)
# Single-shot tuple build: avoids the `payload += (ack_payload,)`
# re-allocation on every remote call.
if ack_payload is None:
payload = (
self.__t_id, # __T_ID__
self.__request_id, # __RQ_ID__
operator_name, # __OP_NAME__
function_name, # __FUN_NAME__
key, # __KEY__
partition, # __PARTITION__
self.__fallback_enabled,
params, # __PARAMS__
)
else:
payload = (
self.__t_id,
self.__request_id,
operator_name,
function_name,
key,
partition,
self.__fallback_enabled,
params,
ack_payload,
)
operator_host = self.__dns[operator_name][partition][0]
operator_port = self.__dns[operator_name][partition][2]
except KeyError:
Expand Down
Loading
Loading