diff --git a/python/simpler/task_interface.py b/python/simpler/task_interface.py index fedb98cc6..43adf1f16 100644 --- a/python/simpler/task_interface.py +++ b/python/simpler/task_interface.py @@ -352,7 +352,7 @@ def comm_destroy(self, comm_handle: int) -> None: """Destroy the communicator and release its resources.""" self._impl.comm_destroy(int(comm_handle)) - def bootstrap_context( + def bootstrap_context( # noqa: PLR0912 -- config validation + comm setup + window carving + H2D staging in one linear flow; splitting would obscure the ordered failure semantics self, device_id: int, cfg: ChipBootstrapConfig, @@ -374,6 +374,29 @@ def bootstrap_context( ordering is the caller's responsibility. """ try: + # Validate host-staging symmetry up-front — before any device or + # communicator state is touched — so a missing staging entry + # surfaces as a clean ValueError on the channel rather than a + # KeyError from deep inside the flush/H2D loop (which would leave + # the parent waiting on a silent chip child). + for spec in cfg.buffers: + if spec.load_from_host: + try: + cfg.input_staging(spec.name) + except KeyError: + raise ValueError( + f"ChipBufferSpec(name={spec.name!r}, load_from_host=True) requires a " + f"matching HostBufferStaging in host_inputs; none found" + ) from None + if spec.store_to_host: + try: + cfg.output_staging(spec.name) + except KeyError: + raise ValueError( + f"ChipBufferSpec(name={spec.name!r}, store_to_host=True) requires a " + f"matching HostBufferStaging in host_outputs; none found" + ) from None + self.set_device(device_id) device_ctx = 0 diff --git a/python/simpler/worker.py b/python/simpler/worker.py index 9f2fa9cfa..43376b097 100644 --- a/python/simpler/worker.py +++ b/python/simpler/worker.py @@ -397,19 +397,38 @@ def _chip_process_loop_with_bootstrap( # noqa: PLR0912 except Exception as e: # noqa: BLE001 code = 1 msg = _format_exc(f"chip_process dev={device_id}", e) + + # Flush store_to_host buffers *before* publishing TASK_DONE so + # the parent cannot observe the mailbox transition (and start + # reading the output SharedMemory) while the D2H DMA is still + # in flight. Only flush on a successful kernel run: on + # failure the device output region is undefined and stamping + # garbage into the parent's SharedMemory would mask the real + # error in any post-mortem. + if code == 0: + for dev_ptr, staging in _store_to_host: + # Skip zero-byte stagings up-front — mirrors the + # load_from_host H2D path in task_interface.py and + # avoids a spurious ValueError from + # ``ctypes.c_char.from_buffer`` on an empty buffer. + if staging.size == 0: + continue + try: + shm = SharedMemory(name=staging.shm_name) + try: + shm_buf = shm.buf + assert shm_buf is not None + host_ptr = ctypes.addressof(ctypes.c_char.from_buffer(shm_buf)) + cw.copy_from(host_ptr, dev_ptr, staging.size) + finally: + shm.close() + except Exception as e: # noqa: BLE001 + code = 1 + msg = _format_exc(f"chip_process dev={device_id} store_to_host={staging.name!r}", e) + break + _write_error(buf, code, msg) _mailbox_store_i32(state_addr, _TASK_DONE) - - # Post-task: flush store_to_host buffers to SharedMemory. - for dev_ptr, staging in _store_to_host: - shm = SharedMemory(name=staging.shm_name) - try: - shm_buf = shm.buf - assert shm_buf is not None - host_ptr = ctypes.addressof(ctypes.c_char.from_buffer(shm_buf)) - cw._impl.copy_from(host_ptr, dev_ptr, staging.size) - finally: - shm.close() elif state == _CONTROL_REQUEST: sub_cmd = struct.unpack_from("Q", buf, _OFF_CALLABLE)[0] code = 0 diff --git a/tests/ut/py/test_worker/test_bootstrap_context_sim.py b/tests/ut/py/test_worker/test_bootstrap_context_sim.py index c76377ffe..04008dbf5 100644 --- a/tests/ut/py/test_worker/test_bootstrap_context_sim.py +++ b/tests/ut/py/test_worker/test_bootstrap_context_sim.py @@ -328,6 +328,185 @@ def test_load_from_host_round_trip(self): assert results[0].get("readback") == payload, "round-trip payload mismatch" +# --------------------------------------------------------------------------- +# 2b. store_to_host — payload written by the child ends up in host_outputs shm. +# --------------------------------------------------------------------------- + + +def _store_rank_entry( # noqa: PLR0913 + rank: int, + nranks: int, + rootinfo_path: str, + window_size: int, + host_lib: str, + aicpu_path: str, + aicore_path: str, + sim_context_path: str, + buffer_specs: list[dict], + host_output_specs: list[dict], + payload: bytes | None, + result_queue: mp.Queue, # type: ignore[type-arg] +) -> None: + """Forked rank that exercises the store_to_host flush path. + + Mirrors the ``store_to_host=True`` flush that ``_chip_process_loop_with_bootstrap`` + runs after a successful task: write a known payload into the device buffer + via ``copy_to``, then D2H-copy it into the parent's ``host_outputs`` + SharedMemory. Leaves ``bootstrap_context`` to validate the + store_to_host ↔ host_outputs pairing before any comm work runs. + """ + result: dict[str, object] = {"rank": rank, "ok": False} + try: + from simpler.task_interface import ( + ChipBootstrapConfig, + ChipBufferSpec, + ChipCommBootstrapConfig, + ChipWorker, + HostBufferStaging, + ) + + worker = ChipWorker() + worker.init(host_lib, aicpu_path, aicore_path, sim_context_path) + + cfg = ChipBootstrapConfig( + comm=ChipCommBootstrapConfig( + rank=rank, + nranks=nranks, + rootinfo_path=rootinfo_path, + window_size=window_size, + ), + buffers=[ChipBufferSpec(**s) for s in buffer_specs], + host_outputs=[HostBufferStaging(**s) for s in host_output_specs], + ) + + res = worker.bootstrap_context(device_id=rank, cfg=cfg) + + if payload is not None and res.buffer_ptrs: + src = (ctypes.c_char * len(payload)).from_buffer_copy(payload) + worker.copy_to(res.buffer_ptrs[0], ctypes.addressof(src), len(payload)) + + # Manually run the same flush logic worker.py uses on TASK_DONE, + # so this test covers the exact D2H handshake without needing a + # full dispatch loop. + for spec, ptr in zip(cfg.buffers, res.buffer_ptrs): + if not spec.store_to_host or spec.nbytes == 0: + continue + staging = cfg.output_staging(spec.name) + shm = SharedMemory(name=staging.shm_name) + try: + shm_buf = shm.buf + assert shm_buf is not None + host_ptr = ctypes.addressof(ctypes.c_char.from_buffer(shm_buf)) + worker.copy_from(host_ptr, ptr, staging.size) + finally: + shm.close() + + worker.shutdown_bootstrap() + worker.finalize() + result["ok"] = True + except Exception: # noqa: BLE001 + result["error"] = traceback.format_exc() + finally: + result_queue.put(result) + + +class TestBootstrapContextStoreToHost: + def test_store_to_host_round_trip(self): + """Round-trip a payload via the store_to_host + host_outputs pairing. + + Rank 0 writes a known pattern into its window buffer and flushes it to + a parent-owned SharedMemory. Rank 1 participates only so + ``comm_alloc_windows`` can clear its internal barrier. The parent + reads the output shm after both children exit and asserts the payload + round-tripped unchanged. + """ + nbytes = 64 + payload = bytes(range(nbytes)) + + shm = SharedMemory(create=True, size=nbytes) + try: + buf = shm.buf + assert buf is not None + buf[:nbytes] = b"\x00" * nbytes + + buffer_specs_r0 = [ + { + "name": "y", + "dtype": "float32", + "count": 16, + "placement": "window", + "nbytes": nbytes, + "store_to_host": True, + }, + ] + buffer_specs_r1 = [ + { + "name": "y", + "dtype": "float32", + "count": 16, + "placement": "window", + "nbytes": nbytes, + "store_to_host": False, + }, + ] + host_outputs_r0 = [{"name": "y", "shm_name": shm.name, "size": nbytes}] + + bins = _sim_binaries() + host_lib = str(bins.host_path) + aicpu_path = str(bins.aicpu_path) + aicore_path = str(bins.aicore_path) + sim_context_path = str(bins.sim_context_path) if bins.sim_context_path else "" + + rootinfo_path = f"/tmp/pto_bootstrap_sim_{os.getpid()}_store.bin" + ctx = mp.get_context("fork") + result_queue: mp.Queue = ctx.Queue() # type: ignore[type-arg] + procs = [] + for rank, specs, outputs, pay in ( + (0, buffer_specs_r0, host_outputs_r0, payload), + (1, buffer_specs_r1, [], None), + ): + p = ctx.Process( + target=_store_rank_entry, + args=( + rank, + 2, + rootinfo_path, + 4096, + host_lib, + aicpu_path, + aicore_path, + sim_context_path, + specs, + outputs, + pay, + result_queue, + ), + daemon=False, + ) + p.start() + procs.append(p) + + results: dict[int, dict] = {} + for _ in range(2): + r = result_queue.get(timeout=180) + results[int(r["rank"])] = r + for p in procs: + p.join(timeout=60) + try: + os.unlink(rootinfo_path) + except FileNotFoundError: + pass + + assert results[0].get("ok"), f"rank 0 failed: {results[0].get('error')}" + assert results[1].get("ok"), f"rank 1 failed: {results[1].get('error')}" + readback = bytes(shm.buf[:nbytes]) # type: ignore[index] + finally: + shm.close() + shm.unlink() + + assert readback == payload, f"store_to_host round-trip mismatch: got {readback!r}" + + # --------------------------------------------------------------------------- # 3. Channel integration — parent reads SUCCESS fields from the mailbox. # --------------------------------------------------------------------------- @@ -485,3 +664,116 @@ def test_invalid_placement_publishes_error(self): finally: shm.close() shm.unlink() + + +# --------------------------------------------------------------------------- +# 4b. Error path — store_to_host=True without a matching host_outputs entry. +# --------------------------------------------------------------------------- + + +def _missing_output_staging_rank_entry( + host_lib: str, + aicpu_path: str, + aicore_path: str, + sim_context_path: str, + channel_shm_name: str, + result_queue: mp.Queue, # type: ignore[type-arg] +) -> None: + """Trip the store_to_host ↔ host_outputs symmetry check in bootstrap_context. + + Runs single-process: the new validation fires before any communicator + work, so no peer rank is required. Verifies both the child-side + exception and the channel payload the parent will see. + """ + result: dict[str, object] = {"raised": False, "state": None, "message": None} + try: + from simpler.task_interface import ( + ChipBootstrapChannel, + ChipBootstrapConfig, + ChipBufferSpec, + ChipWorker, + ) + + worker = ChipWorker() + worker.init(host_lib, aicpu_path, aicore_path, sim_context_path) + + shm = SharedMemory(name=channel_shm_name) + try: + channel = ChipBootstrapChannel(_shm_addr(shm), max_buffer_count=376) + + cfg = ChipBootstrapConfig( + comm=None, + buffers=[ + ChipBufferSpec( + name="y", + dtype="float32", + count=1, + placement="window", + nbytes=4, + store_to_host=True, + ) + ], + host_outputs=[], + ) + try: + worker.bootstrap_context(device_id=0, cfg=cfg, channel=channel) + except ValueError as e: + result["raised"] = True + result["exc_msg"] = str(e) + + result["state"] = int(channel.state) + result["message"] = channel.error_message + finally: + shm.close() + worker.shutdown_bootstrap() + worker.finalize() + except Exception: # noqa: BLE001 + result["error"] = traceback.format_exc() + finally: + result_queue.put(result) + + +class TestBootstrapContextMissingOutputStaging: + def test_store_to_host_without_host_outputs_raises(self): + from _task_interface import ( # pyright: ignore[reportMissingImports] + CHIP_BOOTSTRAP_MAILBOX_SIZE, + ChipBootstrapChannel, + ChipBootstrapMailboxState, + ) + + bins = _sim_binaries() + host_lib = str(bins.host_path) + aicpu_path = str(bins.aicpu_path) + aicore_path = str(bins.aicore_path) + sim_context_path = str(bins.sim_context_path) if bins.sim_context_path else "" + + shm = SharedMemory(create=True, size=CHIP_BOOTSTRAP_MAILBOX_SIZE) + buf = shm.buf + assert buf is not None + for off in range(0, CHIP_BOOTSTRAP_MAILBOX_SIZE, 8): + struct.pack_into("Q", buf, off, 0) + try: + ctx = mp.get_context("fork") + result_queue: mp.Queue = ctx.Queue() # type: ignore[type-arg] + p = ctx.Process( + target=_missing_output_staging_rank_entry, + args=(host_lib, aicpu_path, aicore_path, sim_context_path, shm.name, result_queue), + daemon=False, + ) + p.start() + r = result_queue.get(timeout=60) + p.join(timeout=30) + + assert r.get("raised"), f"expected ValueError; got {r}" + exc_msg = str(r.get("exc_msg", "")) + assert "store_to_host=True" in exc_msg, f"exc_msg missing sentinel: {exc_msg!r}" + assert "host_outputs" in exc_msg, f"exc_msg missing 'host_outputs': {exc_msg!r}" + + channel = ChipBootstrapChannel(_shm_addr(shm), max_buffer_count=376) + assert channel.state == ChipBootstrapMailboxState.ERROR + assert channel.error_code == 1 + assert channel.error_message.startswith("ValueError: ") + assert "store_to_host=True" in channel.error_message + finally: + shm.close() + shm.unlink()