Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions buckaroo/dataflow/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,27 @@ def _get_summary_sd(self, df:pd.DataFrame) -> Tuple[SDType, TAny]:

_summary_sd_cache_key = (None, None)

# Spike (rows-first WS protocol): when set, the next firing of the
# ``_summary_sd`` observer is short-circuited so the cascade lands
# ``processed_df`` + ``df_meta`` (the cheap parts) without paying the
# analysis-pipeline cost. Caller is responsible for clearing the flag
# and re-running ``_summary_sd`` via an explicit ``recompute_summary_sd()``
# call once the row-data has been shipped. **Not for production paths.**
_defer_summary_sd = False

def recompute_summary_sd(self):
"""Force a (re-)compute of ``summary_sd`` against the current
``processed_df``. Used by the spike state-change handler to lift
the ``_defer_summary_sd`` short-circuit after the row-data has
been pushed to the client.
"""
df = self.processed_df
if df is None:
return
result_summary_sd, errs = self._get_summary_sd(df)
self.summary_sd = result_summary_sd
self.errs = errs

@observe('processed_result', 'analysis_klasses')
@exception_protect('summary_sd-protector')
def _summary_sd(self, change):
Expand All @@ -259,6 +280,11 @@ def _summary_sd(self, change):
# construction even when processed_df identity is unchanged.
# Skip when neither the dataframe nor analysis_klasses has actually
# changed since the last run. See issue #709.
if self._defer_summary_sd:
# Spike: caller will trigger recompute after sending the row-data
# message. Leaves ``self.summary_sd`` at its prior value so
# downstream cache observer still wires up a coherent state.
return
df = self.processed_df
klasses = self.analysis_klasses
if (id(df), id(klasses)) == self._summary_sd_cache_key:
Expand Down
127 changes: 101 additions & 26 deletions buckaroo/server/websocket_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from urllib.parse import urlparse

import tornado.websocket
from tornado.ioloop import IOLoop

from buckaroo.server.data_loading import (handle_infinite_request, handle_infinite_request_buckaroo, handle_infinite_request_lazy, get_buckaroo_display_state)
from buckaroo.server.session import build_state_message
Expand All @@ -20,6 +21,24 @@ def _handle_infinite_request_xorq(xorq_dataflow, payload_args):

_BUCKAROO_DEBUG = os.environ.get("BUCKAROO_DEBUG", "").lower() in ("1", "true")

# Spike: opt-in two-message "rows first" state-change response. When set,
# the state-change handler sends a fast initial_state with deferred stats,
# then an IOLoop.add_callback fires the stats compute and sends a second
# initial_state with fresh stats. Off by default so existing tests
# (which assume a single rebroadcast frame) keep working.
_ROWS_FIRST_SPIKE = os.environ.get("BUCKAROO_ROWS_FIRST_SPIKE", "").lower() in ("1", "true")
# Delay between phase 1 (skeleton + deferred stats) and phase 2 (stats
# compute + push). The goal is for any AG-Grid-fired ``infinite_request``
# triggered by phase 1 to land on the server *before* the stats compute
# starts, so its parquet response arrives between the two messages.
#
# 10ms is a localhost-tuned starting point: loopback WS RTT is well under
# 1ms, so 10ms is comfortable. For non-local deployments (RTT 30-100ms+)
# this needs to be measured and raised — at typical remote RTTs, 10ms is
# *shorter* than the round-trip and the spike's premise breaks.
_ROWS_FIRST_SPIKE_PHASE2_DELAY_S = float(
os.environ.get("BUCKAROO_ROWS_FIRST_SPIKE_DELAY_S", "0.01"))

# Fields in buckaroo_state that drive dataflow changes; others are ignored.
_DATAFLOW_FIELDS = ("post_processing", "cleaning_method", "quick_command_args")

Expand Down Expand Up @@ -77,33 +96,33 @@ def _handle_buckaroo_state_change(self, new_state):
if old_state.get("quick_command_args") != new_state.get("quick_command_args"):
dataflow.quick_command_args = new_state.get("quick_command_args", {})

# Re-extract state from the dataflow — same helper works for both
# ServerDataflow and XorqServerDataflow (verified by probe).
buckaroo_state = get_buckaroo_display_state(dataflow)
session.df_display_args = buckaroo_state["df_display_args"]
session.df_data_dict = buckaroo_state["df_data_dict"]
session.df_meta = buckaroo_state["df_meta"]
session.buckaroo_state = new_state
session.buckaroo_options = buckaroo_state["buckaroo_options"]
session.command_config = buckaroo_state["command_config"]

# Re-apply component_config so theme settings survive state changes
if session.component_config and session.df_display_args:
for key in session.df_display_args:
dvc = session.df_display_args[key].get("df_viewer_config")
if dvc is not None:
dvc["component_config"] = {
**dvc.get("component_config", {}),
**session.component_config,
}

# Broadcast updated state to all connected clients
update_payload = json.dumps(build_state_message(session))
for client in list(session.ws_clients):
# Spike-gated "rows first" two-message response: when on, the
# state-change handler returns ``initial_state`` with deferred
# stats first, then a second ``initial_state`` once the
# analysis-pipeline pass has finished. Default (flag off) is
# today's single-message rebroadcast.
if _ROWS_FIRST_SPIKE:
# Phase 1: drive the dataflow far enough to know the new
# ``df_meta`` + ``df_display_args``, but skip the
# analysis-pipeline pass that produces ``summary_sd``.
dataflow._defer_summary_sd = True
try:
client.write_message(update_payload)
except Exception:
session.ws_clients.discard(client)
buckaroo_state = get_buckaroo_display_state(dataflow)
finally:
dataflow._defer_summary_sd = False
else:
buckaroo_state = get_buckaroo_display_state(dataflow)

session.buckaroo_state = new_state
self._apply_state_and_broadcast(session, buckaroo_state)

# Phase 2 (spike only): schedule the stats compute with a
# short delay so any ``infinite_request`` the client fires
# in response to phase 1 has a real time window to arrive
# and get serviced before the (potentially expensive) stats
# compute starts.
if _ROWS_FIRST_SPIKE:
IOLoop.current().call_later(_ROWS_FIRST_SPIKE_PHASE2_DELAY_S, self._send_stats_update, session)
except Exception:
tb = traceback.format_exc()
log.error("buckaroo_state_change error session=%s: %s", self.session_id, tb)
Expand All @@ -112,6 +131,62 @@ def _handle_buckaroo_state_change(self, new_state):
err["details"] = tb
self.write_message(json.dumps(err))

def _apply_state_and_broadcast(self, session, buckaroo_state):
"""Push a freshly-extracted ``buckaroo_state`` dict into the session,
re-apply ``component_config`` so theme settings survive, and broadcast
the resulting state message to every WS client of this session.

Note: callers update ``session.buckaroo_state`` themselves when needed
(phase 1 of the spike does; phase 2 inherits it unchanged).
"""
session.df_display_args = buckaroo_state["df_display_args"]
session.df_data_dict = buckaroo_state["df_data_dict"]
session.df_meta = buckaroo_state["df_meta"]
session.buckaroo_options = buckaroo_state["buckaroo_options"]
session.command_config = buckaroo_state["command_config"]

if session.component_config and session.df_display_args:
for key in session.df_display_args:
dvc = session.df_display_args[key].get("df_viewer_config")
if dvc is not None:
dvc["component_config"] = {
**dvc.get("component_config", {}),
**session.component_config,
}

payload = json.dumps(build_state_message(session))
for client in list(session.ws_clients):
try:
client.write_message(payload)
except Exception:
session.ws_clients.discard(client)

def _send_stats_update(self, session):
"""Phase 2 of the spike rows-first protocol: run the stats compute
that the state-change handler deliberately deferred, then push the
updated state to all clients of this session.

Runs synchronously on the IOLoop thread (so does block while
compute is in flight) — the point of the ``call_later`` is just
to yield between phase 1 and this, so any pending
``infinite_request`` from the client gets served between them.
"""
try:
dataflow = session.xorq_dataflow if session.backend == "xorq" else session.dataflow
if dataflow is None:
return
# Lift the spike's ``_defer_summary_sd`` short-circuit and run
# the analysis pipeline that phase 1 skipped.
dataflow.recompute_summary_sd()

# Re-extract the full state (now with fresh ``summary_sd``
# / ``merged_sd`` / cache + scope pointers).
buckaroo_state = get_buckaroo_display_state(dataflow)
self._apply_state_and_broadcast(session, buckaroo_state)
except Exception:
tb = traceback.format_exc()
log.error("stats_update error session=%s: %s", self.session_id, tb)

def _handle_infinite_request(self, payload_args):
sessions = self.application.settings["sessions"]
session = sessions.get(self.session_id)
Expand Down
181 changes: 181 additions & 0 deletions tests/unit/server/test_rows_first_spike.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
"""SPIKE: rows-first WS state-change protocol.

When ``BUCKAROO_ROWS_FIRST_SPIKE`` is set, ``_handle_buckaroo_state_change``
emits the ``initial_state`` rebroadcast *before* running the analysis
pipeline that produces ``summary_sd``. A second ``initial_state``
follows once stats are ready, sent via an ``IOLoop.add_callback`` that
yields between the two so any ``infinite_request`` from the client
gets a turn ahead of the stats compute.

This test pins the wire-shape (two ``initial_state`` messages back to
back), proves an ``infinite_request`` fired between them is serviced
in order, and verifies the second message carries fresh
``filtered_*`` keys (i.e. the deferred stats compute actually ran).

Not pulling this into the default suite — gated on the env flag — so
it's safe to ship the spike alongside today's single-message
rebroadcast path while we evaluate the perf characteristics.
"""
import io
import json
import shutil
import sys
import tempfile

import pyarrow.parquet as pq
import pytest
import tornado.httpclient
import tornado.testing
import tornado.websocket

xo = pytest.importorskip("xorq.api")

from buckaroo.server.app import make_app as _make_app # noqa: E402

pytestmark = pytest.mark.skipif(
sys.platform == "win32",
reason="Temp file locking prevents cleanup on Windows")


def make_app():
return _make_app(open_browser=False)


def _build_expr_dir(builds_root):
expr = xo.memtable({
'idx': list(range(10)),
'name': ['alpha', 'beta', 'gamma', 'alpha', 'delta',
'epsilon', 'alpha', 'zeta', 'eta', 'alpha'],
}, name='t')
return str(xo.build_expr(expr, builds_dir=builds_root))


async def _post(port, path, body):
client = tornado.httpclient.AsyncHTTPClient()
return await client.fetch(
f"http://localhost:{port}{path}",
method="POST", body=json.dumps(body),
headers={"Content-Type": "application/json"},
raise_error=False)


@pytest.fixture(autouse=True)
def enable_spike(monkeypatch):
"""Flip the spike gate on for every test in this module. The
``_ROWS_FIRST_SPIKE`` constant is resolved from the env at module
import, so patching the env at test time has no effect — patch the
resolved module attribute directly."""
from buckaroo.server import websocket_handler as wh
monkeypatch.setattr(wh, "_ROWS_FIRST_SPIKE", True)


class TestRowsFirstSpike(tornado.testing.AsyncHTTPTestCase):
def get_app(self):
return make_app()

@tornado.testing.gen_test
async def test_state_change_emits_two_initial_state_messages(self):
"""With the spike on, a single state_change produces two
``initial_state`` messages: phase 1 (deferred stats) followed by
phase 2 (computed stats)."""
builds_root = tempfile.mkdtemp()
try:
build_path = _build_expr_dir(builds_root)
await _post(self.get_http_port(), "/load_expr",
{"session": "spike-1", "build_dir": build_path})

ws = await tornado.websocket.websocket_connect(
f"ws://localhost:{self.get_http_port()}/ws/spike-1")
await ws.read_message() # initial connection state

ws.write_message(json.dumps({
"type": "buckaroo_state_change",
"new_state": {
"post_processing": "",
"cleaning_method": "",
"quick_command_args": {"search": ["alpha"]},
"df_display": "main",
"show_commands": False,
"sampled": False,
"search_string": "alpha",
}}))

# Phase 1: meta/display updated, stats may still be the
# previous state's value (the spike's ``_defer_summary_sd``
# short-circuit left ``summary_sd`` untouched).
phase1 = json.loads(await ws.read_message())
self.assertEqual(phase1["type"], "initial_state")
self.assertEqual(phase1["buckaroo_state"]["quick_command_args"], {"search": ["alpha"]})

# Phase 2: stats compute ran; the message arrives after the
# ``call_later`` delay fires.
phase2 = json.loads(await ws.read_message())
self.assertEqual(phase2["type"], "initial_state")
self.assertEqual(phase2["buckaroo_state"]["quick_command_args"], {"search": ["alpha"]})

# Phase 1 vs phase 2 stats-payload divergence is exercised
# downstream — the spike's value here is just "the wire
# carries two ``initial_state`` frames per state change."

ws.close()
finally:
shutil.rmtree(builds_root, ignore_errors=True)

@tornado.testing.gen_test
async def test_infinite_request_between_phases_returns_rows(self):
"""The key win of the spike: rows can be served *between* phase
1 (skeleton) and phase 2 (stats). The client fires
``infinite_request`` after phase 1 arrives, and the parquet
comes back before the stats message."""
builds_root = tempfile.mkdtemp()
try:
build_path = _build_expr_dir(builds_root)
await _post(self.get_http_port(), "/load_expr",
{"session": "spike-2", "build_dir": build_path})

ws = await tornado.websocket.websocket_connect(
f"ws://localhost:{self.get_http_port()}/ws/spike-2")
await ws.read_message() # initial connection state

ws.write_message(json.dumps({
"type": "buckaroo_state_change",
"new_state": {
"post_processing": "",
"cleaning_method": "",
"quick_command_args": {"search": ["alpha"]},
"df_display": "main",
"show_commands": False,
"sampled": False,
"search_string": "alpha",
}}))

# Phase 1 arrives.
phase1 = json.loads(await ws.read_message())
self.assertEqual(phase1["type"], "initial_state")

# Simulate AG-Grid firing infinite_request immediately after
# seeing phase 1.
ws.write_message(json.dumps({
"type": "infinite_request",
"payload_args": {"start": 0, "end": 10,
"sourceName": "default", "origEnd": 10}}))

# The infinite_resp + parquet frame must arrive *before*
# the phase 2 stats message — the whole point of the
# add_callback yield.
json_frame = json.loads(await ws.read_message())
self.assertEqual(json_frame["type"], "infinite_resp")
self.assertEqual(json_frame["length"], 4)

binary_frame = await ws.read_message()
self.assertIsInstance(binary_frame, bytes)
table = pq.read_table(io.BytesIO(binary_frame))
self.assertEqual(table.num_rows, 4)

# Phase 2 follows.
phase2 = json.loads(await ws.read_message())
self.assertEqual(phase2["type"], "initial_state")

ws.close()
finally:
shutil.rmtree(builds_root, ignore_errors=True)