Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion python/simpler/task_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def comm_destroy(self, comm_handle: int) -> None:
"""Destroy the communicator and release its resources."""
self._impl.comm_destroy(int(comm_handle))

def bootstrap_context(
def bootstrap_context( # noqa: PLR0912 -- config validation + comm setup + window carving + H2D staging in one linear flow; splitting would obscure the ordered failure semantics
self,
device_id: int,
cfg: ChipBootstrapConfig,
Expand All @@ -374,6 +374,29 @@ def bootstrap_context(
ordering is the caller's responsibility.
"""
try:
# Validate host-staging symmetry up-front — before any device or
# communicator state is touched — so a missing staging entry
# surfaces as a clean ValueError on the channel rather than a
# KeyError from deep inside the flush/H2D loop (which would leave
# the parent waiting on a silent chip child).
for spec in cfg.buffers:
if spec.load_from_host:
try:
cfg.input_staging(spec.name)
except KeyError:
raise ValueError(
f"ChipBufferSpec(name={spec.name!r}, load_from_host=True) requires a "
f"matching HostBufferStaging in host_inputs; none found"
) from None
if spec.store_to_host:
try:
cfg.output_staging(spec.name)
except KeyError:
raise ValueError(
f"ChipBufferSpec(name={spec.name!r}, store_to_host=True) requires a "
f"matching HostBufferStaging in host_outputs; none found"
) from None

self.set_device(device_id)

device_ctx = 0
Expand Down
41 changes: 30 additions & 11 deletions python/simpler/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,19 +397,38 @@ def _chip_process_loop_with_bootstrap( # noqa: PLR0912
except Exception as e: # noqa: BLE001
code = 1
msg = _format_exc(f"chip_process dev={device_id}", e)

# Flush store_to_host buffers *before* publishing TASK_DONE so
# the parent cannot observe the mailbox transition (and start
# reading the output SharedMemory) while the D2H DMA is still
# in flight. Only flush on a successful kernel run: on
# failure the device output region is undefined and stamping
# garbage into the parent's SharedMemory would mask the real
# error in any post-mortem.
if code == 0:
for dev_ptr, staging in _store_to_host:
# Skip zero-byte stagings up-front — mirrors the
# load_from_host H2D path in task_interface.py and
# avoids a spurious ValueError from
# ``ctypes.c_char.from_buffer`` on an empty buffer.
if staging.size == 0:
continue
try:
shm = SharedMemory(name=staging.shm_name)
try:
shm_buf = shm.buf
assert shm_buf is not None
host_ptr = ctypes.addressof(ctypes.c_char.from_buffer(shm_buf))
cw.copy_from(host_ptr, dev_ptr, staging.size)
finally:
shm.close()
except Exception as e: # noqa: BLE001
code = 1
msg = _format_exc(f"chip_process dev={device_id} store_to_host={staging.name!r}", e)
break
Comment thread
ChaoWao marked this conversation as resolved.

_write_error(buf, code, msg)
_mailbox_store_i32(state_addr, _TASK_DONE)

# Post-task: flush store_to_host buffers to SharedMemory.
for dev_ptr, staging in _store_to_host:
shm = SharedMemory(name=staging.shm_name)
try:
shm_buf = shm.buf
assert shm_buf is not None
host_ptr = ctypes.addressof(ctypes.c_char.from_buffer(shm_buf))
cw._impl.copy_from(host_ptr, dev_ptr, staging.size)
finally:
shm.close()
elif state == _CONTROL_REQUEST:
sub_cmd = struct.unpack_from("Q", buf, _OFF_CALLABLE)[0]
code = 0
Expand Down
292 changes: 292 additions & 0 deletions tests/ut/py/test_worker/test_bootstrap_context_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,185 @@ def test_load_from_host_round_trip(self):
assert results[0].get("readback") == payload, "round-trip payload mismatch"


# ---------------------------------------------------------------------------
# 2b. store_to_host — payload written by the child ends up in host_outputs shm.
# ---------------------------------------------------------------------------


def _store_rank_entry( # noqa: PLR0913
rank: int,
nranks: int,
rootinfo_path: str,
window_size: int,
host_lib: str,
aicpu_path: str,
aicore_path: str,
sim_context_path: str,
buffer_specs: list[dict],
host_output_specs: list[dict],
payload: bytes | None,
result_queue: mp.Queue, # type: ignore[type-arg]
) -> None:
"""Forked rank that exercises the store_to_host flush path.

Mirrors the ``store_to_host=True`` flush that ``_chip_process_loop_with_bootstrap``
runs after a successful task: write a known payload into the device buffer
via ``copy_to``, then D2H-copy it into the parent's ``host_outputs``
SharedMemory. Leaves ``bootstrap_context`` to validate the
store_to_host ↔ host_outputs pairing before any comm work runs.
"""
result: dict[str, object] = {"rank": rank, "ok": False}
try:
from simpler.task_interface import (
ChipBootstrapConfig,
ChipBufferSpec,
ChipCommBootstrapConfig,
ChipWorker,
HostBufferStaging,
)

worker = ChipWorker()
worker.init(host_lib, aicpu_path, aicore_path, sim_context_path)

cfg = ChipBootstrapConfig(
comm=ChipCommBootstrapConfig(
rank=rank,
nranks=nranks,
rootinfo_path=rootinfo_path,
window_size=window_size,
),
buffers=[ChipBufferSpec(**s) for s in buffer_specs],
host_outputs=[HostBufferStaging(**s) for s in host_output_specs],
)

res = worker.bootstrap_context(device_id=rank, cfg=cfg)

if payload is not None and res.buffer_ptrs:
src = (ctypes.c_char * len(payload)).from_buffer_copy(payload)
worker.copy_to(res.buffer_ptrs[0], ctypes.addressof(src), len(payload))

# Manually run the same flush logic worker.py uses on TASK_DONE,
# so this test covers the exact D2H handshake without needing a
# full dispatch loop.
for spec, ptr in zip(cfg.buffers, res.buffer_ptrs):
if not spec.store_to_host or spec.nbytes == 0:
continue
staging = cfg.output_staging(spec.name)
shm = SharedMemory(name=staging.shm_name)
try:
shm_buf = shm.buf
assert shm_buf is not None
host_ptr = ctypes.addressof(ctypes.c_char.from_buffer(shm_buf))
worker.copy_from(host_ptr, ptr, staging.size)
finally:
shm.close()
Comment thread
ChaoWao marked this conversation as resolved.

worker.shutdown_bootstrap()
worker.finalize()
result["ok"] = True
except Exception: # noqa: BLE001
result["error"] = traceback.format_exc()
finally:
result_queue.put(result)


class TestBootstrapContextStoreToHost:
def test_store_to_host_round_trip(self):
"""Round-trip a payload via the store_to_host + host_outputs pairing.

Rank 0 writes a known pattern into its window buffer and flushes it to
a parent-owned SharedMemory. Rank 1 participates only so
``comm_alloc_windows`` can clear its internal barrier. The parent
reads the output shm after both children exit and asserts the payload
round-tripped unchanged.
"""
nbytes = 64
payload = bytes(range(nbytes))

shm = SharedMemory(create=True, size=nbytes)
try:
buf = shm.buf
assert buf is not None
buf[:nbytes] = b"\x00" * nbytes

buffer_specs_r0 = [
{
"name": "y",
"dtype": "float32",
"count": 16,
"placement": "window",
"nbytes": nbytes,
"store_to_host": True,
},
]
buffer_specs_r1 = [
{
"name": "y",
"dtype": "float32",
"count": 16,
"placement": "window",
"nbytes": nbytes,
"store_to_host": False,
},
]
host_outputs_r0 = [{"name": "y", "shm_name": shm.name, "size": nbytes}]

bins = _sim_binaries()
host_lib = str(bins.host_path)
aicpu_path = str(bins.aicpu_path)
aicore_path = str(bins.aicore_path)
sim_context_path = str(bins.sim_context_path) if bins.sim_context_path else ""

rootinfo_path = f"/tmp/pto_bootstrap_sim_{os.getpid()}_store.bin"
ctx = mp.get_context("fork")
result_queue: mp.Queue = ctx.Queue() # type: ignore[type-arg]
procs = []
for rank, specs, outputs, pay in (
(0, buffer_specs_r0, host_outputs_r0, payload),
(1, buffer_specs_r1, [], None),
):
p = ctx.Process(
target=_store_rank_entry,
args=(
rank,
2,
rootinfo_path,
4096,
host_lib,
aicpu_path,
aicore_path,
sim_context_path,
specs,
outputs,
pay,
result_queue,
),
daemon=False,
)
p.start()
procs.append(p)

results: dict[int, dict] = {}
for _ in range(2):
r = result_queue.get(timeout=180)
results[int(r["rank"])] = r
for p in procs:
p.join(timeout=60)
try:
os.unlink(rootinfo_path)
except FileNotFoundError:
pass

assert results[0].get("ok"), f"rank 0 failed: {results[0].get('error')}"
assert results[1].get("ok"), f"rank 1 failed: {results[1].get('error')}"
readback = bytes(shm.buf[:nbytes]) # type: ignore[index]
finally:
shm.close()
shm.unlink()

assert readback == payload, f"store_to_host round-trip mismatch: got {readback!r}"


# ---------------------------------------------------------------------------
# 3. Channel integration — parent reads SUCCESS fields from the mailbox.
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -485,3 +664,116 @@ def test_invalid_placement_publishes_error(self):
finally:
shm.close()
shm.unlink()


# ---------------------------------------------------------------------------
# 4b. Error path — store_to_host=True without a matching host_outputs entry.
# ---------------------------------------------------------------------------


def _missing_output_staging_rank_entry(
host_lib: str,
aicpu_path: str,
aicore_path: str,
sim_context_path: str,
channel_shm_name: str,
result_queue: mp.Queue, # type: ignore[type-arg]
) -> None:
"""Trip the store_to_host ↔ host_outputs symmetry check in bootstrap_context.

Runs single-process: the new validation fires before any communicator
work, so no peer rank is required. Verifies both the child-side
exception and the channel payload the parent will see.
"""
result: dict[str, object] = {"raised": False, "state": None, "message": None}
try:
from simpler.task_interface import (
ChipBootstrapChannel,
ChipBootstrapConfig,
ChipBufferSpec,
ChipWorker,
)

worker = ChipWorker()
worker.init(host_lib, aicpu_path, aicore_path, sim_context_path)

shm = SharedMemory(name=channel_shm_name)
try:
channel = ChipBootstrapChannel(_shm_addr(shm), max_buffer_count=376)

cfg = ChipBootstrapConfig(
comm=None,
buffers=[
ChipBufferSpec(
name="y",
dtype="float32",
count=1,
placement="window",
nbytes=4,
store_to_host=True,
)
],
host_outputs=[],
)
try:
worker.bootstrap_context(device_id=0, cfg=cfg, channel=channel)
except ValueError as e:
result["raised"] = True
result["exc_msg"] = str(e)

result["state"] = int(channel.state)
result["message"] = channel.error_message
finally:
shm.close()
worker.shutdown_bootstrap()
worker.finalize()
except Exception: # noqa: BLE001
result["error"] = traceback.format_exc()
finally:
result_queue.put(result)


class TestBootstrapContextMissingOutputStaging:
def test_store_to_host_without_host_outputs_raises(self):
from _task_interface import ( # pyright: ignore[reportMissingImports]
CHIP_BOOTSTRAP_MAILBOX_SIZE,
ChipBootstrapChannel,
ChipBootstrapMailboxState,
)

bins = _sim_binaries()
host_lib = str(bins.host_path)
aicpu_path = str(bins.aicpu_path)
aicore_path = str(bins.aicore_path)
sim_context_path = str(bins.sim_context_path) if bins.sim_context_path else ""

shm = SharedMemory(create=True, size=CHIP_BOOTSTRAP_MAILBOX_SIZE)
buf = shm.buf
assert buf is not None
for off in range(0, CHIP_BOOTSTRAP_MAILBOX_SIZE, 8):
struct.pack_into("Q", buf, off, 0)
try:
ctx = mp.get_context("fork")
result_queue: mp.Queue = ctx.Queue() # type: ignore[type-arg]
p = ctx.Process(
target=_missing_output_staging_rank_entry,
args=(host_lib, aicpu_path, aicore_path, sim_context_path, shm.name, result_queue),
daemon=False,
)
p.start()
r = result_queue.get(timeout=60)
p.join(timeout=30)

assert r.get("raised"), f"expected ValueError; got {r}"
exc_msg = str(r.get("exc_msg", ""))
assert "store_to_host=True" in exc_msg, f"exc_msg missing sentinel: {exc_msg!r}"
assert "host_outputs" in exc_msg, f"exc_msg missing 'host_outputs': {exc_msg!r}"

channel = ChipBootstrapChannel(_shm_addr(shm), max_buffer_count=376)
assert channel.state == ChipBootstrapMailboxState.ERROR
assert channel.error_code == 1
assert channel.error_message.startswith("ValueError: ")
assert "store_to_host=True" in channel.error_message
finally:
shm.close()
shm.unlink()
Loading