Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Oct 4, 2025

📄 18% (0.18x) speedup for WSHandlerComm.write in distributed/comm/ws.py

⏱️ Runtime : 1.55 milliseconds 1.31 milliseconds (best of 170 runs)

📝 Explanation and details

The optimization improves performance by avoiding expensive size calculations for small control messages. The key change is adding a _is_small_control_message() function that quickly identifies common lightweight message types (strings, bytes, small dictionaries/lists with primitive values, None) and skips the costly safe_sizeof() call for them.

What was optimized:

  • Added fast-path detection for small control messages (None, primitives, small dicts/lists with primitive values)
  • When a message is identified as small, the expensive safe_sizeof() calculation is completely bypassed
  • Large or complex messages still use the original logic with safe_sizeof() and potential offloading

Why this leads to speedup:

  • The line profiler shows safe_sizeof() consumed 37% of the original function's time (1.23ms out of 3.32ms total)
  • In the optimized version, this expensive call is eliminated for small messages, reducing the function's total time to 2.43ms
  • The _is_small_control_message() check (515ns per call) is much faster than safe_sizeof() (2671ns per call in the original)

Performance characteristics:

  • Runtime improvement: 18% faster (1.55ms → 1.31ms)
  • Throughput improvement: 0.6% (78,247 → 78,710 operations/second)
  • The optimization is most effective for workloads with many small control messages, which are common in distributed systems for coordination and status updates
  • Large messages still benefit from proper offloading decisions via the original safe_sizeof() logic

This optimization targets the common case where distributed systems exchange many small control messages, avoiding unnecessary computational overhead while preserving the original behavior for complex or large payloads.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 498 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 84.6%
🌀 Generated Regression Tests and Runtime
import asyncio  # used to run async functions

import pytest  # used for our unit tests
from distributed.comm.ws import WSHandlerComm


# Mocks and stubs for dependencies
class DummyWebSocketClosedError(Exception):
    pass

class DummyCommClosedError(Exception):
    pass

class DummyHandler:
    """A dummy handler to simulate write_message for testing."""
    def __init__(self):
        self.messages = []
        self.closed = False

    async def write_message(self, msg, binary=True):
        # Simulate normal write
        if self.closed:
            raise DummyWebSocketClosedError("WebSocket is closed")
        self.messages.append((msg, binary))

class DummyWSHandlerComm:
    """Minimal stub of WSHandlerComm for unit testing."""
    def __init__(self, handler, deserialize=True, allow_offload=True):
        self.handler = handler
        self.allow_offload = allow_offload
        self.local_info = {"id": "local"}
        self.remote_info = {"id": "remote"}
        self.handshake_options = {}
        self.deserialize = deserialize

    async def write(self, msg, serializers=None, on_error=None):
        # Simulate serialization: turn msg into list of bytes frames
        if isinstance(msg, (bytes, bytearray)):
            frames = [bytes(msg)]
        elif isinstance(msg, str):
            frames = [msg.encode()]
        elif isinstance(msg, list):
            frames = [str(item).encode() for item in msg]
        elif isinstance(msg, dict):
            frames = [str(msg).encode()]
        elif msg is None:
            frames = [b"None"]
        else:
            frames = [str(msg).encode()]
        n = len(frames).to_bytes(8, "little")
        nbytes_frames = 0
        try:
            await self.handler.write_message(n, binary=True)
            for frame in frames:
                if type(frame) is not bytes:
                    frame = bytes(frame)
                await self.handler.write_message(frame, binary=True)
                nbytes_frames += len(frame)
        except DummyWebSocketClosedError as e:
            raise DummyCommClosedError(str(e))
        return nbytes_frames

# Alias for clarity in tests
write = DummyWSHandlerComm.write

# ------------------- BASIC TEST CASES -------------------

@pytest.mark.asyncio
async def test_write_basic_bytes():
    """Test writing a simple bytes message."""
    handler = DummyHandler()
    comm = DummyWSHandlerComm(handler)
    msg = b"hello world"
    nbytes = await comm.write(msg)

@pytest.mark.asyncio
async def test_write_basic_str():
    """Test writing a simple string message."""
    handler = DummyHandler()
    comm = DummyWSHandlerComm(handler)
    msg = "hello world"
    nbytes = await comm.write(msg)

@pytest.mark.asyncio
async def test_write_basic_list():
    """Test writing a list of integers."""
    handler = DummyHandler()
    comm = DummyWSHandlerComm(handler)
    msg = [1, 2, 3]
    nbytes = await comm.write(msg)

@pytest.mark.asyncio
async def test_write_basic_dict():
    """Test writing a dict message."""
    handler = DummyHandler()
    comm = DummyWSHandlerComm(handler)
    msg = {"foo": "bar"}
    nbytes = await comm.write(msg)

@pytest.mark.asyncio
async def test_write_basic_none():
    """Test writing None."""
    handler = DummyHandler()
    comm = DummyWSHandlerComm(handler)
    msg = None
    nbytes = await comm.write(msg)

# ------------------- EDGE TEST CASES -------------------

@pytest.mark.asyncio
async def test_write_closed_handler_raises():
    """Test writing to a closed handler raises CommClosedError."""
    handler = DummyHandler()
    comm = DummyWSHandlerComm(handler)
    handler.closed = True
    with pytest.raises(DummyCommClosedError):
        await comm.write(b"test")

@pytest.mark.asyncio
async def test_write_concurrent_execution():
    """Test concurrent writes do not interfere."""
    handler1 = DummyHandler()
    handler2 = DummyHandler()
    comm1 = DummyWSHandlerComm(handler1)
    comm2 = DummyWSHandlerComm(handler2)
    msgs = [b"one", b"two"]
    results = await asyncio.gather(comm1.write(msgs[0]), comm2.write(msgs[1]))

@pytest.mark.asyncio
async def test_write_empty_bytes():
    """Test writing an empty bytes message."""
    handler = DummyHandler()
    comm = DummyWSHandlerComm(handler)
    msg = b""
    nbytes = await comm.write(msg)

@pytest.mark.asyncio
async def test_write_large_message():
    """Test writing a large message."""
    handler = DummyHandler()
    comm = DummyWSHandlerComm(handler)
    msg = b"x" * 1024  # 1KB
    nbytes = await comm.write(msg)

# ------------------- LARGE SCALE TEST CASES -------------------

@pytest.mark.asyncio
async def test_write_many_concurrent_writes():
    """Test many concurrent writes."""
    handlers = [DummyHandler() for _ in range(10)]
    comms = [DummyWSHandlerComm(h) for h in handlers]
    msgs = [f"msg_{i}".encode() for i in range(10)]
    results = await asyncio.gather(*[comm.write(msg) for comm, msg in zip(comms, msgs)])
    for i, handler in enumerate(handlers):
        pass

@pytest.mark.asyncio
async def test_write_large_list_message():
    """Test writing a large list message."""
    handler = DummyHandler()
    comm = DummyWSHandlerComm(handler)
    msg = list(range(100))
    nbytes = await comm.write(msg)

# ------------------- THROUGHPUT TEST CASES -------------------

@pytest.mark.asyncio
async def test_write_throughput_small_load():
    """Test throughput with small load."""
    handler = DummyHandler()
    comm = DummyWSHandlerComm(handler)
    msgs = [b"a", b"b", b"c"]
    results = await asyncio.gather(*[comm.write(msg) for msg in msgs])

@pytest.mark.asyncio
async def test_write_throughput_medium_load():
    """Test throughput with medium load."""
    handler = DummyHandler()
    comm = DummyWSHandlerComm(handler)
    msgs = [bytes([i]) * 10 for i in range(20)]
    results = await asyncio.gather(*[comm.write(msg) for msg in msgs])
    for i, msg in enumerate(msgs):
        pass

@pytest.mark.asyncio
async def test_write_throughput_high_volume():
    """Test throughput with high volume (100 concurrent writes)."""
    handlers = [DummyHandler() for _ in range(100)]
    comms = [DummyWSHandlerComm(h) for h in handlers]
    msgs = [f"high_{i}".encode() for i in range(100)]
    results = await asyncio.gather(*[comm.write(msg) for comm, msg in zip(comms, msgs)])
    for i, handler in enumerate(handlers):
        pass

@pytest.mark.asyncio
async def test_write_throughput_varied_types():
    """Test throughput with varied message types."""
    handler = DummyHandler()
    comm = DummyWSHandlerComm(handler)
    msgs = [b"bytes", "str", 123, None, {"a": 1}]
    results = await asyncio.gather(*[comm.write(msg) for msg in msgs])
    # Results should match frame lengths
    expected_lengths = [6, 3, 3, 4, len(str({"a": 1}).encode())]
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from __future__ import annotations

import asyncio  # used to run async functions
import logging
import struct
# Patch protocol and exceptions for testing
import sys
import threading
import warnings
import weakref
from abc import ABC
from collections.abc import Awaitable, Callable
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, ClassVar, TypeVar

import dask
import pytest  # used for our unit tests
from dask.sizeof import sizeof
from dask.utils import format_bytes, parse_bytes, typename
from distributed import protocol
from distributed.comm import CommClosedError
from distributed.comm.core import Comm, CommClosedError
from distributed.comm.utils import to_frames
from distributed.comm.ws import WSHandlerComm
from distributed.metrics import context_meter
from distributed.sizeof import safe_sizeof
from distributed.utils import offload
from tornado.websocket import WebSocketClosedError
from typing_extensions import ParamSpec


# Mocks and helpers for testing
class DummyWebSocketClosedError(Exception):
    pass

class DummyHandler:
    """
    Dummy handler to simulate the behavior of a Tornado WebSocket handler.
    It records messages written for verification.
    """
    def __init__(self, fail_on_write=False):
        self.messages = []
        self.fail_on_write = fail_on_write
        self.closed = False

    async def write_message(self, message, binary=False):
        # Simulate possible WebSocketClosedError
        if self.fail_on_write or self.closed:
            raise DummyWebSocketClosedError("WebSocket is closed")
        self.messages.append((message, binary))


BIG_BYTES_SHARD_SIZE = dask.utils.parse_bytes(
    dask.config.get("distributed.comm.websockets.shard")
)
from distributed.comm.ws import WSHandlerComm

# Unit tests for WSHandlerComm.write

@pytest.mark.asyncio
async def test_write_basic_bytes():
    """Test writing a simple bytes message."""
    handler = DummyHandler()
    comm = WSHandlerComm(handler)
    msg = b"hello world"
    result = await comm.write(msg)

@pytest.mark.asyncio
async def test_write_basic_str():
    """Test writing a simple string message."""
    handler = DummyHandler()
    comm = WSHandlerComm(handler)
    msg = "foo"
    result = await comm.write(msg)

@pytest.mark.asyncio
async def test_write_basic_dict():
    """Test writing a basic dictionary."""
    handler = DummyHandler()
    comm = WSHandlerComm(handler)
    msg = {"bar": 123}
    result = await comm.write(msg)

@pytest.mark.asyncio
async def test_write_empty_bytes():
    """Test writing an empty bytes message."""
    handler = DummyHandler()
    comm = WSHandlerComm(handler)
    msg = b""
    result = await comm.write(msg)

@pytest.mark.asyncio
async def test_write_empty_dict():
    """Test writing an empty dict."""
    handler = DummyHandler()
    comm = WSHandlerComm(handler)
    msg = {}
    result = await comm.write(msg)

@pytest.mark.asyncio
async def test_write_large_bytes():
    """Test writing a large bytes message."""
    handler = DummyHandler()
    comm = WSHandlerComm(handler)
    msg = b"x" * 100
    result = await comm.write(msg)

@pytest.mark.asyncio
async def test_write_serializers_and_on_error():
    """Test writing with serializers and on_error arguments."""
    handler = DummyHandler()
    comm = WSHandlerComm(handler)
    msg = "baz"
    result = await comm.write(msg, serializers=["msgpack"], on_error="raise")

@pytest.mark.asyncio
async def test_write_concurrent():
    """Test concurrent writes to different handlers."""
    handler1 = DummyHandler()
    handler2 = DummyHandler()
    comm1 = WSHandlerComm(handler1)
    comm2 = WSHandlerComm(handler2)
    msg1 = b"abc"
    msg2 = "def"
    # Run both writes concurrently
    results = await asyncio.gather(
        comm1.write(msg1),
        comm2.write(msg2)
    )

@pytest.mark.asyncio



async def test_write_multiple_types():
    """Test writing multiple types of messages."""
    handler = DummyHandler()
    comm = WSHandlerComm(handler)
    messages = [
        b"hello",
        "world",
        123,
        {"key": "value"},
        [1, 2, 3]
    ]
    for msg in messages:
        result = await comm.write(msg)

@pytest.mark.asyncio
async def test_write_concurrent_same_handler():
    """Test concurrent writes to the same handler."""
    handler = DummyHandler()
    comm = WSHandlerComm(handler)
    msgs = [b"one", b"two", b"three"]
    results = await asyncio.gather(
        comm.write(msgs[0]),
        comm.write(msgs[1]),
        comm.write(msgs[2])
    )

@pytest.mark.asyncio
async def test_write_throughput_small_load():
    """Test throughput under small load (10 messages)."""
    handler = DummyHandler()
    comm = WSHandlerComm(handler)
    msgs = [f"msg{i}" for i in range(10)]
    results = await asyncio.gather(*(comm.write(msg) for msg in msgs))

@pytest.mark.asyncio
async def test_write_throughput_medium_load():
    """Test throughput under medium load (100 messages)."""
    handler = DummyHandler()
    comm = WSHandlerComm(handler)
    msgs = [b"x" * 10 for _ in range(100)]
    results = await asyncio.gather(*(comm.write(msg) for msg in msgs))

@pytest.mark.asyncio
async def test_write_throughput_large_load():
    """Test throughput under large load (250 messages)."""
    handler = DummyHandler()
    comm = WSHandlerComm(handler)
    msgs = [f"large_message_{i}" for i in range(250)]
    results = await asyncio.gather(*(comm.write(msg) for msg in msgs))

@pytest.mark.asyncio
async def test_write_throughput_concurrent_handlers():
    """Test throughput with concurrent writes to multiple handlers."""
    handlers = [DummyHandler() for _ in range(10)]
    comms = [WSHandlerComm(h) for h in handlers]
    msgs = [f"msg_{i}" for i in range(10)]
    coros = [comm.write(msg) for comm, msg in zip(comms, msgs)]
    results = await asyncio.gather(*coros)
    for h in handlers:
        pass

@pytest.mark.asyncio
async def test_write_throughput_large_bytes():
    """Test throughput with large byte messages."""
    handler = DummyHandler()
    comm = WSHandlerComm(handler)
    msgs = [b"x" * 100 for _ in range(20)]
    results = await asyncio.gather(*(comm.write(msg) for msg in msgs))

@pytest.mark.asyncio
async def test_write_throughput_varied_types():
    """Test throughput with varied message types."""
    handler = DummyHandler()
    comm = WSHandlerComm(handler)
    msgs = [b"bytes", "str", 123, {"x": 1}, [1, 2, 3]] * 10
    results = await asyncio.gather(*(comm.write(msg) for msg in msgs))
    for r, msg in zip(results, msgs):
        if isinstance(msg, bytes):
            pass
        else:
            pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-WSHandlerComm.write-mgbrkmn5 and push.

Codeflash

The optimization improves performance by **avoiding expensive size calculations for small control messages**. The key change is adding a `_is_small_control_message()` function that quickly identifies common lightweight message types (strings, bytes, small dictionaries/lists with primitive values, None) and skips the costly `safe_sizeof()` call for them.

**What was optimized:**
- Added fast-path detection for small control messages (None, primitives, small dicts/lists with primitive values)
- When a message is identified as small, the expensive `safe_sizeof()` calculation is completely bypassed
- Large or complex messages still use the original logic with `safe_sizeof()` and potential offloading

**Why this leads to speedup:**
- The line profiler shows `safe_sizeof()` consumed 37% of the original function's time (1.23ms out of 3.32ms total)
- In the optimized version, this expensive call is eliminated for small messages, reducing the function's total time to 2.43ms
- The `_is_small_control_message()` check (515ns per call) is much faster than `safe_sizeof()` (2671ns per call in the original)

**Performance characteristics:**
- **Runtime improvement:** 18% faster (1.55ms → 1.31ms)
- **Throughput improvement:** 0.6% (78,247 → 78,710 operations/second)
- The optimization is most effective for workloads with many small control messages, which are common in distributed systems for coordination and status updates
- Large messages still benefit from proper offloading decisions via the original `safe_sizeof()` logic

This optimization targets the common case where distributed systems exchange many small control messages, avoiding unnecessary computational overhead while preserving the original behavior for complex or large payloads.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 October 4, 2025 04:19
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Oct 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant