diff --git a/hyperactor/src/config/global.rs b/hyperactor/src/config/global.rs index 59f14ad63..8e352a75b 100644 --- a/hyperactor/src/config/global.rs +++ b/hyperactor/src/config/global.rs @@ -524,8 +524,7 @@ pub fn create_or_merge(source: Source, attrs: Attrs) { /// contribute to resolution in [`get`], [`get_cloned`], or /// [`attrs`]. Defaults and any remaining layers continue to apply /// in their normal priority order. -#[allow(dead_code)] -pub(crate) fn clear(source: Source) { +pub fn clear(source: Source) { let mut g = LAYERS.write().unwrap(); g.ordered.retain(|l| layer_source(l) != source); } @@ -586,6 +585,34 @@ pub fn attrs() -> Attrs { merged } +/// Return a snapshot of the attributes for a specific configuration +/// source. +/// +/// If a layer with the given [`Source`] exists, this clones and +/// returns its [`Attrs`]. Otherwise an empty [`Attrs`] is returned. +/// The returned map is detached from the global store – mutating it +/// does **not** affect the underlying layer; use [`set`] or +/// [`create_or_merge`] to modify layers. +fn layer_attrs_for(source: Source) -> Attrs { + let layers = LAYERS.read().unwrap(); + if let Some(layer) = layers.ordered.iter().find(|l| layer_source(l) == source) { + layer_attrs(layer).clone() + } else { + Attrs::new() + } +} + +/// Snapshot the current attributes in the **Runtime** configuration +/// layer. +/// +/// This returns a cloned [`Attrs`] containing only values explicitly +/// set in the [`Source::Runtime`] layer (no merging with +/// Env/File/Defaults). If no Runtime layer is present, an empty +/// [`Attrs`] is returned. +pub fn runtime_attrs() -> Attrs { + layer_attrs_for(Source::Runtime) +} + /// Reset the global configuration to only Defaults (for testing). /// /// This clears all explicit layers (`File`, `Env`, `Runtime`, and diff --git a/monarch_hyperactor/src/config.rs b/monarch_hyperactor/src/config.rs index 62cdb4aaa..347d8927c 100644 --- a/monarch_hyperactor/src/config.rs +++ b/monarch_hyperactor/src/config.rs @@ -99,7 +99,34 @@ where val.map(|v| v.into_py_any(py)).transpose() } -fn set_global_config(key: &'static dyn ErasedKey, value: T) -> PyResult<()> { +/// Fetch a config value from the **Runtime** layer only and convert +/// it to Python. +/// +/// This mirrors [`get_global_config`] but restricts the lookup to the +/// `Source::Runtime` layer (ignoring TestOverride/Env/File/defaults). +/// If the key has a runtime override, it is cloned as `T`, converted +/// to `P`, then to a `PyObject`; otherwise `Ok(None)` is returned. +fn get_runtime_config<'py, P, T>( + py: Python<'py>, + key: &'static dyn ErasedKey, +) -> PyResult> +where + T: AttrValue + TryInto

, + P: IntoPyObjectExt<'py>, + PyErr: From<>::Error>, +{ + let key = key.downcast_ref::().expect("cannot fail"); + let runtime = hyperactor::config::global::runtime_attrs(); + let val: Option

= runtime + .get(key.clone()) + .cloned() + .map(|v| v.try_into()) + .transpose()?; + val.map(|v| v.into_py_any(py)).transpose() +} + +/// Note that this function writes strictly into the `Runtime` layer. +fn set_runtime_config(key: &'static dyn ErasedKey, value: T) -> PyResult<()> { // Again, can't fail unless there's a bug in the code in this file. let key = key.downcast_ref().expect("cannot fail"); let mut attrs = Attrs::new(); @@ -108,7 +135,7 @@ fn set_global_config(key: &'static dyn ErasedKey, value: T Ok(()) } -fn set_global_config_from_py_obj(py: Python<'_>, name: &str, val: PyObject) -> PyResult<()> { +fn set_runtime_config_from_py_obj(py: Python<'_>, name: &str, val: PyObject) -> PyResult<()> { // Get the `ErasedKey` from the kwarg `name` passed to `monarch.configure(...)`. let key = match KEY_BY_NAME.get(name) { None => { @@ -128,7 +155,7 @@ fn set_global_config_from_py_obj(py: Python<'_>, name: &str, val: PyObject) -> P name, key.typename() ))), - Some(info) => (info.set_global_config)(py, key, val), + Some(info) => (info.set_runtime_config)(py, key, val), } } @@ -137,10 +164,15 @@ fn set_global_config_from_py_obj(py: Python<'_>, name: &str, val: PyObject) -> P /// `T::typehash() == PythonConfigTypeInfo::typehash()`. struct PythonConfigTypeInfo { typehash: fn() -> u64, - set_global_config: - fn(py: Python<'_>, key: &'static dyn ErasedKey, val: PyObject) -> PyResult<()>, + get_global_config: fn(py: Python<'_>, key: &'static dyn ErasedKey) -> PyResult>, + + set_runtime_config: + fn(py: Python<'_>, key: &'static dyn ErasedKey, val: PyObject) -> PyResult<()>, + + get_runtime_config: + fn(py: Python<'_>, key: &'static dyn ErasedKey) -> PyResult>, } inventory::collect!(PythonConfigTypeInfo); @@ -160,15 +192,18 @@ macro_rules! declare_py_config_type { hyperactor::submit! { PythonConfigTypeInfo { typehash: $ty::typehash, - set_global_config: |py, key, val| { + set_runtime_config: |py, key, val| { let val: $ty = val.extract::<$ty>(py).map_err(|err| PyTypeError::new_err(format!( "invalid value `{}` for configuration key `{}` ({})", val, key.name(), err )))?; - set_global_config(key, val) + set_runtime_config(key, val) }, get_global_config: |py, key| { get_global_config::<$ty, $ty>(py, key) + }, + get_runtime_config: |py, key| { + get_runtime_config::<$ty, $ty>(py, key) } } } @@ -180,15 +215,18 @@ macro_rules! declare_py_config_type { hyperactor::submit! { PythonConfigTypeInfo { typehash: $ty::typehash, - set_global_config: |py, key, val| { + set_runtime_config: |py, key, val| { let val: $ty = val.extract::<$py_ty>(py).map_err(|err| PyTypeError::new_err(format!( "invalid value `{}` for configuration key `{}` ({})", val, key.name(), err )))?.into(); - set_global_config(key, val) + set_runtime_config(key, val) }, get_global_config: |py, key| { get_global_config::<$py_ty, $ty>(py, key) + }, + get_runtime_config: |py, key| { + get_runtime_config::<$py_ty, $ty>(py, key) } } } @@ -212,7 +250,7 @@ fn configure(py: Python<'_>, kwargs: Option>) -> PyRes .map(|kwargs| { kwargs .into_iter() - .try_for_each(|(key, val)| set_global_config_from_py_obj(py, &key, val)) + .try_for_each(|(key, val)| set_runtime_config_from_py_obj(py, &key, val)) }) .transpose()?; Ok(()) @@ -236,6 +274,62 @@ fn get_configuration(py: Python<'_>) -> PyResult> { .collect() } +/// Get only the Runtime layer configuration (Python-exposed keys). +/// +/// The Runtime layer is effectively the "Python configuration layer", +/// populated exclusively via `configure(**kwargs)` from Python. This +/// function returns only the Python-exposed keys (those with +/// `@meta(CONFIG = ConfigAttr { py_name: Some(...), .. })`) that are +/// currently set in the Runtime layer. +/// +/// This is used by Python's `configured()` context manager to +/// snapshot and restore the Runtime layer for composable, nested +/// configuration overrides: +/// +/// ```python +/// prev = get_runtime_configuration() +/// try: +/// configure(**overrides) +/// yield get_configuration() +/// finally: +/// clear_runtime_configuration() +/// configure(**prev) +/// ``` +/// +/// Unlike `get_configuration()`, which returns the merged view across +/// all layers (File, Env, Runtime, TestOverride), this returns only +/// what's explicitly set in the Runtime layer. +#[pyfunction] +fn get_runtime_configuration(py: Python<'_>) -> PyResult> { + KEY_BY_NAME + .iter() + .filter_map(|(name, key)| match TYPEHASH_TO_INFO.get(&key.typehash()) { + None => None, + Some(info) => match (info.get_runtime_config)(py, *key) { + Err(err) => Some(Err(err)), + Ok(val) => val.map(|val| Ok(((*name).into(), val))), + }, + }) + .collect() +} + +/// Clear runtime configuration overrides. +/// +/// This removes all entries from the Runtime config layer for this +/// process. The Runtime layer is exclusively populated via Python's +/// `configure(**kwargs)`, so clearing it is SAFE — it will not +/// destroy configuration from other sources (environment variables, +/// config files, or built-in defaults). +/// +/// This is primarily used by Python's `configured()` context manager +/// to restore configuration state after applying temporary overrides. +/// Other layers (Env, File, TestOverride, defaults) are unaffected. +#[pyfunction] +fn clear_runtime_configuration(_py: Python<'_>) -> PyResult<()> { + hyperactor::config::global::clear(Source::Runtime); + Ok(()) +} + /// Register Python bindings for the config module pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> { let reload = wrap_pyfunction!(reload_config_from_env, module)?; @@ -266,5 +360,19 @@ pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> { )?; module.add_function(get_configuration)?; + let get_runtime_configuration = wrap_pyfunction!(get_runtime_configuration, module)?; + get_runtime_configuration.setattr( + "__module__", + "monarch._rust_bindings.monarch_hyperactor.config", + )?; + module.add_function(get_runtime_configuration)?; + + let clear_runtime_configuration = wrap_pyfunction!(clear_runtime_configuration, module)?; + clear_runtime_configuration.setattr( + "__module__", + "monarch._rust_bindings.monarch_hyperactor.config", + )?; + module.add_function(clear_runtime_configuration)?; + Ok(()) } diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi index 25bc42c9e..c5b1c5c11 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi @@ -38,8 +38,42 @@ def configure( tail_log_lines: int = ..., **kwargs: object, ) -> None: - """Change a configuration value in the global configuration. If called with - no arguments, makes no changes. Does not reset any configuration""" + """Configure Hyperactor runtime defaults for this process. + + This updates the **runtime** configuration layer from Python, + setting the default channel transport and optional logging + behaviour (forwarding, file capture, and how many lines to tail). + """ + ... + +def get_configuration() -> Dict[str, Any]: + """Return a snapshot of the current Hyperactor configuration. + + The result is a plain dictionary view of the merged configuration + (defaults plus any overrides from environment or Python), useful + for debugging and tests. + """ ... -def get_configuration() -> Dict[str, Any]: ... +def get_runtime_configuration() -> Dict[str, Any]: + """Return a snapshot of the Runtime layer configuration. + + The Runtime layer contains only configuration values set from + Python via configure(). This returns only those Python-exposed + keys currently in the Runtime layer (not merged across all layers + like get_configuration). + + This can be used to snapshot/restore Runtime state. + """ + ... + +def clear_runtime_configuration() -> None: + """Clear all Runtime layer configuration overrides. + + Safely removes all entries from the Runtime config layer. Since + the Runtime layer is exclusively populated via Python's + configure(), this will not affect configuration from environment + variables, config files, or built-in defaults. + """ + + ... diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 47cafffa5..c6a15cb6b 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -24,7 +24,7 @@ import unittest.mock from tempfile import TemporaryDirectory from types import ModuleType -from typing import Any, cast, Iterator, NamedTuple, Tuple +from typing import Any, cast, Dict, Iterator, NamedTuple, Tuple import monarch.actor import pytest @@ -35,8 +35,10 @@ ) from monarch._rust_bindings.monarch_hyperactor.alloc import Alloc, AllocSpec from monarch._rust_bindings.monarch_hyperactor.config import ( + clear_runtime_configuration, configure, get_configuration, + get_runtime_configuration, ) from monarch._rust_bindings.monarch_hyperactor.mailbox import ( PortId, @@ -459,23 +461,55 @@ def _handle_undeliverable_message( return True -class RedirectedPaths(NamedTuple): - stdout: str - stderr: str | None - - @contextlib.contextmanager -def configured(**overrides): - prev = get_configuration().copy() - configure(**overrides) +def configured(**overrides) -> Iterator[Dict[str, Any]]: + # Retrieve runtime + prev = get_runtime_configuration() try: - yield get_configuration().copy() + # Merge overrides into runtime + configure(**overrides) + + # Snapshot of merged config (global - all layers) + yield get_configuration() finally: + # Restore previous runtime + clear_runtime_configuration() configure(**prev) +class RedirectedPaths(NamedTuple): + """Filesystem paths for captured stdio from redirected_stdio(). + + `stdout` is the path to the temporary file holding captured + stdout. + + `stderr` is the path to the temporary file holding captured + stderr, or None if stderr capture was disabled. + + """ + + stdout: str + stderr: str | None + + @contextlib.contextmanager def redirected_stdio(capture_stderr: bool = True) -> Iterator[RedirectedPaths]: + """Temporarily redirect process stdout (and optionally stderr) to + temp files. + + This is a context manager / generator that: + * dup()s the original OS-level stdout/stderr FDs, + * points fd 1 (and optionally fd 2) at temporary files, + * swaps sys.stdout/sys.stderr to those files, + * yields the on-disk paths of the capture files, and + * on exit, flushes/fsyncs, restores the original FDs and + sys.std*, and closes the temp files. + + It is primarily intended for tests that need to assert on what a + subprocess or library printed to stdout/stderr at the OS FD level. + + """ + # Save original OS-level FDs original_stdout_fd = os.dup(1) original_stderr_fd = os.dup(2) if capture_stderr else None @@ -554,112 +588,127 @@ def redirected_stdio(capture_stderr: bool = True) -> Iterator[RedirectedPaths]: pass +@contextlib.contextmanager +def configured_with_redirected_stdio( + capture_stderr: bool = True, + **config_overrides: Any, +) -> Iterator[tuple[Dict[str, Any], RedirectedPaths]]: + """Apply config overrides and capture stdio for the duration of + the block.""" + with configured(**config_overrides) as config, redirected_stdio( + capture_stderr + ) as paths: + yield (config, paths) + + # oss_skip: pytest keeps complaining about mocking get_ipython module @pytest.mark.oss_skip async def test_actor_log_streaming() -> None: - with configured( - enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100 - ): - with redirected_stdio(capture_stderr=True) as paths: - pm = this_host().spawn_procs(per_host={"gpus": 2}) - am = pm.spawn("printer", Printer) - - # Disable streaming logs to client - await pm.logging_option(stream_to_client=False, aggregate_window_sec=None) - await asyncio.sleep(1) + with configured_with_redirected_stdio( + capture_stderr=True, + enable_log_forwarding=True, + enable_file_capture=True, + tail_log_lines=100, + ) as (_, paths): + pm = this_host().spawn_procs(per_host={"gpus": 2}) + am = pm.spawn("printer", Printer) - # These should not be streamed to client initially - for _ in range(5): - await am.print.call("no print streaming") - await am.log.call("no log streaming") - await asyncio.sleep(1) + # Disable streaming logs to client + await pm.logging_option(stream_to_client=False, aggregate_window_sec=None) + await asyncio.sleep(1) - # Enable streaming logs to client - await pm.logging_option( - stream_to_client=True, - aggregate_window_sec=1, - level=logging.FATAL, - ) - # Give it some time to reflect + # These should not be streamed to client initially + for _ in range(5): + await am.print.call("no print streaming") + await am.log.call("no log streaming") await asyncio.sleep(1) - # These should be streamed to client - for _ in range(5): - await am.print.call("has print streaming") - await am.log.call("no log streaming due to level mismatch") - await asyncio.sleep(1) + # Enable streaming logs to client + await pm.logging_option( + stream_to_client=True, + aggregate_window_sec=1, + level=logging.FATAL, + ) + # Give it some time to reflect + await asyncio.sleep(1) - # Enable streaming logs to client - await pm.logging_option( - stream_to_client=True, - aggregate_window_sec=1, - level=logging.ERROR, - ) - # Give it some time to reflect + # These should be streamed to client + for _ in range(5): + await am.print.call("has print streaming") + await am.log.call("no log streaming due to level mismatch") await asyncio.sleep(1) - # These should be streamed to client - for _ in range(5): - await am.print.call("has print streaming too") - await am.log.call("has log streaming as level matched") + # Enable streaming logs to client + await pm.logging_option( + stream_to_client=True, + aggregate_window_sec=1, + level=logging.ERROR, + ) + # Give it some time to reflect + await asyncio.sleep(1) - await asyncio.sleep(1) + # These should be streamed to client + for _ in range(5): + await am.print.call("has print streaming too") + await am.log.call("has log streaming as level matched") - # Flush to ensure all output is written before reading - sys.stdout.flush() - sys.stderr.flush() + await asyncio.sleep(1) - # Read the captured output - with open(paths.stdout, "r") as f: - stdout_content = f.read() + # Flush to ensure all output is written before reading + sys.stdout.flush() + sys.stderr.flush() - assert paths.stderr is not None - with open(paths.stderr, "r") as f: - stderr_content = f.read() - - # Assertions on the captured output - # Has a leading context so we can distinguish between streamed log and - # the log directly printed by the child processes as they share the same stdout/stderr - assert not re.search( - r"similar log lines.*no print streaming", stdout_content - ), stdout_content - assert not re.search( - r"similar log lines.*no print streaming", stderr_content - ), stderr_content - assert not re.search( - r"similar log lines.*no log streaming", stdout_content - ), stdout_content - assert not re.search( - r"similar log lines.*no log streaming", stderr_content - ), stderr_content - assert not re.search( - r"similar log lines.*no log streaming due to level mismatch", - stdout_content, - ), stdout_content - assert not re.search( - r"similar log lines.*no log streaming due to level mismatch", - stderr_content, - ), stderr_content - - assert re.search( - r"similar log lines.*has print streaming", stdout_content - ), stdout_content - assert not re.search( - r"similar log lines.*has print streaming", stderr_content - ), stderr_content - assert re.search( - r"similar log lines.*has print streaming too", stdout_content - ), stdout_content - assert not re.search( - r"similar log lines.*has print streaming too", stderr_content - ), stderr_content - assert not re.search( - r"similar log lines.*log streaming as level matched", stdout_content - ), stdout_content - assert re.search( - r"similar log lines.*log streaming as level matched", - stderr_content, - ), stderr_content + # Read the captured output + with open(paths.stdout, "r") as f: + stdout_content = f.read() + + assert paths.stderr is not None + with open(paths.stderr, "r") as f: + stderr_content = f.read() + + # Assertions on the captured output + # Has a leading context so we can distinguish between streamed log and + # the log directly printed by the child processes as they share the same stdout/stderr + assert not re.search( + r"similar log lines.*no print streaming", stdout_content + ), stdout_content + assert not re.search( + r"similar log lines.*no print streaming", stderr_content + ), stderr_content + assert not re.search( + r"similar log lines.*no log streaming", stdout_content + ), stdout_content + assert not re.search( + r"similar log lines.*no log streaming", stderr_content + ), stderr_content + assert not re.search( + r"similar log lines.*no log streaming due to level mismatch", + stdout_content, + ), stdout_content + assert not re.search( + r"similar log lines.*no log streaming due to level mismatch", + stderr_content, + ), stderr_content + + assert re.search( + r"similar log lines.*has print streaming", stdout_content + ), stdout_content + assert not re.search( + r"similar log lines.*has print streaming", stderr_content + ), stderr_content + assert re.search( + r"similar log lines.*has print streaming too", stdout_content + ), stdout_content + assert not re.search( + r"similar log lines.*has print streaming too", stderr_content + ), stderr_content + assert not re.search( + r"similar log lines.*log streaming as level matched", stdout_content + ), stdout_content + assert re.search( + r"similar log lines.*log streaming as level matched", + stderr_content, + ), stderr_content # oss_skip: pytest keeps complaining about mocking get_ipython module @@ -669,66 +718,66 @@ async def test_alloc_based_log_streaming() -> None: """Test both AllocHandle.stream_logs = False and True cases.""" async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None: - with configured( - enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100 - ): - with redirected_stdio(capture_stderr=False) as paths: - # Create proc mesh with custom stream_logs setting - class ProcessAllocatorStreamLogs(ProcessAllocator): - def allocate_nonblocking( - self, spec: AllocSpec - ) -> PythonTask[Alloc]: - return super().allocate_nonblocking(spec) - - def _stream_logs(self) -> bool: - return stream_logs - - alloc = ProcessAllocatorStreamLogs(*_get_bootstrap_args()) - - host_mesh = HostMesh.allocate_nonblocking( - "host", - Extent(["hosts"], [1]), - alloc, - bootstrap_cmd=_bootstrap_cmd(), - ) + with configured_with_redirected_stdio( + capture_stderr=False, + enable_log_forwarding=True, + enable_file_capture=True, + tail_log_lines=100, + ) as (_, paths): + # Create proc mesh with custom stream_logs setting + class ProcessAllocatorStreamLogs(ProcessAllocator): + def allocate_nonblocking(self, spec: AllocSpec) -> PythonTask[Alloc]: + return super().allocate_nonblocking(spec) + + def _stream_logs(self) -> bool: + return stream_logs + + alloc = ProcessAllocatorStreamLogs(*_get_bootstrap_args()) + + host_mesh = HostMesh.allocate_nonblocking( + "host", + Extent(["hosts"], [1]), + alloc, + bootstrap_cmd=_bootstrap_cmd(), + ) - pm = host_mesh.spawn_procs(name="proc", per_host={"gpus": 2}) + pm = host_mesh.spawn_procs(name="proc", per_host={"gpus": 2}) - am = pm.spawn("printer", Printer) + am = pm.spawn("printer", Printer) - await pm.initialized + await pm.initialized - for _ in range(5): - await am.print.call(f"{test_name} print streaming") - - # Wait for at least the aggregation window (3 seconds) - await asyncio.sleep(5) - - # Flush to ensure all output is written before reading - sys.stdout.flush() - - # Read the captured output - with open(paths.stdout, "r") as f: - stdout_content = f.read() - - if not stream_logs: - # When stream_logs=False, logs should not be streamed to client - assert not re.search( - rf"similar log lines.*{test_name} print streaming", - stdout_content, - ), f"stream_logs=False case: {stdout_content}" - assert re.search( - rf"{test_name} print streaming", stdout_content - ), f"stream_logs=False case: {stdout_content}" - else: - # When stream_logs=True, logs should be streamed to client (no aggregation by default) - assert re.search( - rf"similar log lines.*{test_name} print streaming", - stdout_content, - ), f"stream_logs=True case: {stdout_content}" - assert not re.search( - rf"\[[0-9]\]{test_name} print streaming", stdout_content - ), f"stream_logs=True case: {stdout_content}" + for _ in range(5): + await am.print.call(f"{test_name} print streaming") + + # Wait for at least the aggregation window (3 seconds) + await asyncio.sleep(5) + + # Flush to ensure all output is written before reading + sys.stdout.flush() + + # Read the captured output + with open(paths.stdout, "r") as f: + stdout_content = f.read() + + if not stream_logs: + # When stream_logs=False, logs should not be streamed to client + assert not re.search( + rf"similar log lines.*{test_name} print streaming", + stdout_content, + ), f"stream_logs=False case: {stdout_content}" + assert re.search( + rf"{test_name} print streaming", stdout_content + ), f"stream_logs=False case: {stdout_content}" + else: + # When stream_logs=True, logs should be streamed to client (no aggregation by default) + assert re.search( + rf"similar log lines.*{test_name} print streaming", + stdout_content, + ), f"stream_logs=True case: {stdout_content}" + assert not re.search( + rf"\[[0-9]\]{test_name} print streaming", stdout_content + ), f"stream_logs=True case: {stdout_content}" # Test both cases await test_stream_logs_case(False, "stream_logs_false") @@ -738,46 +787,48 @@ def _stream_logs(self) -> bool: # oss_skip: (SF) broken in GitHub by D86994420. Passes internally. @pytest.mark.oss_skip async def test_logging_option_defaults() -> None: - with configured( - enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100 - ): - with redirected_stdio(capture_stderr=True) as paths: - pm = this_host().spawn_procs(per_host={"gpus": 2}) - am = pm.spawn("printer", Printer) - - for _ in range(5): - await am.print.call("print streaming") - await am.log.call("log streaming") + with configured_with_redirected_stdio( + capture_stderr=True, + enable_log_forwarding=True, + enable_file_capture=True, + tail_log_lines=100, + ) as (_, paths): + pm = this_host().spawn_procs(per_host={"gpus": 2}) + am = pm.spawn("printer", Printer) - # Wait for > default aggregation window (3 seconds) - await asyncio.sleep(5) + for _ in range(5): + await am.print.call("print streaming") + await am.log.call("log streaming") - # Flush to ensure all output is written before reading - sys.stdout.flush() - sys.stderr.flush() + # Wait for > default aggregation window (3 seconds) + await asyncio.sleep(5) - # Read the captured output - with open(paths.stdout, "r") as f: - stdout_content = f.read() + # Flush to ensure all output is written before reading + sys.stdout.flush() + sys.stderr.flush() - assert paths.stderr is not None - with open(paths.stderr, "r") as f: - stderr_content = f.read() - - # Assertions on the captured output - assert not re.search( - r"similar log lines.*print streaming", stdout_content - ), stdout_content - assert re.search(r"print streaming", stdout_content), stdout_content - assert not re.search( - r"similar log lines.*print streaming", stderr_content - ), stderr_content - assert not re.search( - r"similar log lines.*log streaming", stdout_content - ), stdout_content - assert not re.search( - r"similar log lines.*log streaming", stderr_content - ), stderr_content + # Read the captured output + with open(paths.stdout, "r") as f: + stdout_content = f.read() + + assert paths.stderr is not None + with open(paths.stderr, "r") as f: + stderr_content = f.read() + + # Assertions on the captured output + assert not re.search( + r"similar log lines.*print streaming", stdout_content + ), stdout_content + assert re.search(r"print streaming", stdout_content), stdout_content + assert not re.search( + r"similar log lines.*print streaming", stderr_content + ), stderr_content + assert not re.search( + r"similar log lines.*log streaming", stdout_content + ), stdout_content + assert not re.search( + r"similar log lines.*log streaming", stderr_content + ), stderr_content class MockEvents: @@ -838,61 +889,61 @@ async def test_flush_called_only_once() -> None: @pytest.mark.timeout(180) async def test_flush_logs_ipython() -> None: """Test that logs are flushed when get_ipython is available and post_run_cell event is triggered.""" - with configured( - enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100 - ): - with redirected_stdio(capture_stderr=False) as paths: - mock_ipython = MockIPython() - - with unittest.mock.patch( - "monarch._src.actor.logging.get_ipython", - lambda: mock_ipython, - ), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True): - # Make sure we can register and unregister callbacks - for _ in range(3): - pm1 = this_host().spawn_procs(per_host={"gpus": 2}) - pm2 = this_host().spawn_procs(per_host={"gpus": 2}) - am1 = pm1.spawn("printer", Printer) - am2 = pm2.spawn("printer", Printer) - - # Set aggregation window to ensure logs are buffered - await pm1.logging_option( - stream_to_client=True, aggregate_window_sec=600 - ) - await pm2.logging_option( - stream_to_client=True, aggregate_window_sec=600 - ) - - # Generate some logs that will be aggregated - for _ in range(5): - await am1.print.call("ipython1 test log") - await am2.print.call("ipython2 test log") - - # Trigger the post_run_cell event which should flush logs - mock_ipython.events.trigger( - "post_run_cell", unittest.mock.MagicMock() - ) - - # We expect to register post_run_cell hook only once per notebook/ipython session - assert mock_ipython.events.registers == 1 - assert len(mock_ipython.events.callbacks["post_run_cell"]) == 1 + with configured_with_redirected_stdio( + capture_stderr=False, + enable_log_forwarding=True, + enable_file_capture=True, + tail_log_lines=100, + ) as (_, paths): + mock_ipython = MockIPython() - # Flush to ensure all output is written before reading - sys.stdout.flush() + with unittest.mock.patch( + "monarch._src.actor.logging.get_ipython", + lambda: mock_ipython, + ), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True): + # Make sure we can register and unregister callbacks + for _ in range(3): + pm1 = this_host().spawn_procs(per_host={"gpus": 2}) + pm2 = this_host().spawn_procs(per_host={"gpus": 2}) + am1 = pm1.spawn("printer", Printer) + am2 = pm2.spawn("printer", Printer) + + # Set aggregation window to ensure logs are buffered + await pm1.logging_option( + stream_to_client=True, aggregate_window_sec=600 + ) + await pm2.logging_option( + stream_to_client=True, aggregate_window_sec=600 + ) - # Read the captured output - with open(paths.stdout, "r") as f: - stdout_content = f.read() + # Generate some logs that will be aggregated + for _ in range(5): + await am1.print.call("ipython1 test log") + await am2.print.call("ipython2 test log") + + # Trigger the post_run_cell event which should flush logs + mock_ipython.events.trigger("post_run_cell", unittest.mock.MagicMock()) + + # We expect to register post_run_cell hook only once per notebook/ipython session + assert mock_ipython.events.registers == 1 + assert len(mock_ipython.events.callbacks["post_run_cell"]) == 1 - # We triggered post_run_cell three times; in the current - # implementation that yields three aggregated groups per - # message type (though the counts may be 10, 10, 8 rather than - # all 10). - pattern1 = r"\[\d+ similar log lines\].*ipython1 test log" - pattern2 = r"\[\d+ similar log lines\].*ipython2 test log" + # Flush to ensure all output is written before reading + sys.stdout.flush() + + # Read the captured output + with open(paths.stdout, "r") as f: + stdout_content = f.read() + + # We triggered post_run_cell three times; in the current + # implementation that yields three aggregated groups per + # message type (though the counts may be 10, 10, 8 rather than + # all 10). + pattern1 = r"\[\d+ similar log lines\].*ipython1 test log" + pattern2 = r"\[\d+ similar log lines\].*ipython2 test log" - assert len(re.findall(pattern1, stdout_content)) >= 3, stdout_content - assert len(re.findall(pattern2, stdout_content)) >= 3, stdout_content + assert len(re.findall(pattern1, stdout_content)) >= 3, stdout_content + assert len(re.findall(pattern2, stdout_content)) >= 3, stdout_content # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited @@ -935,60 +986,60 @@ async def test_flush_on_disable_aggregation() -> None: This tests the corner case: "Make sure we flush whatever in the aggregators before disabling aggregation." """ - with configured( - enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100 - ): - with redirected_stdio(capture_stderr=False) as paths: - pm = this_host().spawn_procs(per_host={"gpus": 2}) - am = pm.spawn("printer", Printer) - - # Set a long aggregation window to ensure logs aren't flushed immediately - await pm.logging_option(stream_to_client=True, aggregate_window_sec=60) - - # Generate some logs that will be aggregated but not flushed immediately - for _ in range(5): - await am.print.call("aggregated log line") - await asyncio.sleep(1) + with configured_with_redirected_stdio( + capture_stderr=False, + enable_log_forwarding=True, + enable_file_capture=True, + tail_log_lines=100, + ) as (_, paths): + pm = this_host().spawn_procs(per_host={"gpus": 2}) + am = pm.spawn("printer", Printer) - # Now disable aggregation - this should trigger an immediate flush - await pm.logging_option(stream_to_client=True, aggregate_window_sec=None) + # Set a long aggregation window to ensure logs aren't flushed immediately + await pm.logging_option(stream_to_client=True, aggregate_window_sec=60) - # Wait a bit to ensure logs are collected - await asyncio.sleep(1) - for _ in range(5): - await am.print.call("single log line") - - # Wait for > default aggregation window (3 secs) - await asyncio.sleep(5) + # Generate some logs that will be aggregated but not flushed immediately + for _ in range(5): + await am.print.call("aggregated log line") + await asyncio.sleep(1) - # Flush to ensure all output is written before reading - sys.stdout.flush() + # Now disable aggregation - this should trigger an immediate flush + await pm.logging_option(stream_to_client=True, aggregate_window_sec=None) - # Read the captured output - with open(paths.stdout, "r") as f: - stdout_content = f.read() + # Wait a bit to ensure logs are collected + await asyncio.sleep(1) + for _ in range(5): + await am.print.call("single log line") - # Verify that logs were flushed when aggregation was disabled - # We should see the aggregated logs in the output - # 10 = 5 log lines * 2 procs - assert re.search( - r"\[10 similar log lines\].*aggregated log line", stdout_content - ), stdout_content + # Wait for > default aggregation window (3 secs) + await asyncio.sleep(5) - # No aggregated single log lines - assert not re.search( - r"similar log lines.*single log line", stdout_content - ), stdout_content + # Flush to ensure all output is written before reading + sys.stdout.flush() - # 10 = 5 log lines * 2 procs - total_single = len( - re.findall( - r"\[.* [0-9]+\](?: \[[0-9]+\])? single log line", stdout_content - ) - ) - assert ( - total_single == 10 - ), f"Expected 10 single log lines, got {total_single} from {stdout_content}" + # Read the captured output + with open(paths.stdout, "r") as f: + stdout_content = f.read() + + # Verify that logs were flushed when aggregation was disabled + # We should see the aggregated logs in the output + # 10 = 5 log lines * 2 procs + assert re.search( + r"\[10 similar log lines\].*aggregated log line", stdout_content + ), stdout_content + + # No aggregated single log lines + assert not re.search( + r"similar log lines.*single log line", stdout_content + ), stdout_content + + # 10 = 5 log lines * 2 procs + total_single = len( + re.findall(r"\[.* [0-9]+\](?: \[[0-9]+\])? single log line", stdout_content) + ) + assert ( + total_single == 10 + ), f"Expected 10 single log lines, got {total_single} from {stdout_content}" @pytest.mark.timeout(120) @@ -1033,47 +1084,49 @@ async def test_adjust_aggregation_window() -> None: This tests the corner case: "This can happen if the user has adjusted the aggregation window." """ - with configured( - enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100 - ): - with redirected_stdio(capture_stderr=False) as paths: - pm = this_host().spawn_procs(per_host={"gpus": 2}) - am = pm.spawn("printer", Printer) + with configured_with_redirected_stdio( + capture_stderr=False, + enable_log_forwarding=True, + enable_file_capture=True, + tail_log_lines=100, + ) as (_, paths): + pm = this_host().spawn_procs(per_host={"gpus": 2}) + am = pm.spawn("printer", Printer) - # Set a long aggregation window initially - await pm.logging_option(stream_to_client=True, aggregate_window_sec=100) + # Set a long aggregation window initially + await pm.logging_option(stream_to_client=True, aggregate_window_sec=100) - # Generate some logs that will be aggregated - for _ in range(3): - await am.print.call("first batch of logs") - await asyncio.sleep(1) + # Generate some logs that will be aggregated + for _ in range(3): + await am.print.call("first batch of logs") + await asyncio.sleep(1) - # Now adjust to a shorter window - this should update the flush deadline - await pm.logging_option(stream_to_client=True, aggregate_window_sec=2) + # Now adjust to a shorter window - this should update the flush deadline + await pm.logging_option(stream_to_client=True, aggregate_window_sec=2) - # Generate more logs - for _ in range(3): - await am.print.call("second batch of logs") + # Generate more logs + for _ in range(3): + await am.print.call("second batch of logs") - # Wait for > aggregation window (2 secs) - await asyncio.sleep(4) + # Wait for > aggregation window (2 secs) + await asyncio.sleep(4) - # Flush to ensure all output is written before reading - sys.stdout.flush() + # Flush to ensure all output is written before reading + sys.stdout.flush() - # Read the captured output - with open(paths.stdout, "r") as f: - stdout_content = f.read() + # Read the captured output + with open(paths.stdout, "r") as f: + stdout_content = f.read() - # Verify that logs were flushed when the aggregation window was adjusted - # We should see both batches of logs in the output - assert re.search( - r"\[6 similar log lines\].*first batch of logs", stdout_content - ), stdout_content + # Verify that logs were flushed when the aggregation window was adjusted + # We should see both batches of logs in the output + assert re.search( + r"\[6 similar log lines\].*first batch of logs", stdout_content + ), stdout_content - assert re.search( - r"similar log lines.*second batch of logs", stdout_content - ), stdout_content + assert re.search( + r"similar log lines.*second batch of logs", stdout_content + ), stdout_content class SendAlot(Actor):