From f732834aeeebcac5196d7f88713fc3c57682d540 Mon Sep 17 00:00:00 2001 From: Shayne Fletcher Date: Wed, 19 Nov 2025 14:01:06 -0800 Subject: [PATCH 1/2] : python/tests: logging: use configure api for log env (#1932) Summary: replace env-var overrides with `configure`. relax `test_flush_logs_ipython` which before was over-constrained and consequently, highly flaky. Reviewed By: mariusae Differential Revision: D87449483 --- python/tests/test_python_actors.py | 342 +++++++++-------------------- 1 file changed, 106 insertions(+), 236 deletions(-) diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index d6527d88a..759263704 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -33,6 +33,10 @@ PythonMessageKind, ) from monarch._rust_bindings.monarch_hyperactor.alloc import Alloc, AllocSpec +from monarch._rust_bindings.monarch_hyperactor.config import ( + configure, + get_configuration, +) from monarch._rust_bindings.monarch_hyperactor.mailbox import ( PortId, PortRef, @@ -424,98 +428,6 @@ async def awaitit(f): return await f -# def test_actor_future() -> None: -# v = 0 - -# async def incr(): -# nonlocal v -# v += 1 -# return v - -# # can use async implementation from sync -# # if no non-blocking is provided -# f = Future(impl=incr, requires_loop=False) -# assert f.get() == 1 -# assert v == 1 -# assert f.get() == 1 -# assert asyncio.run(awaitit(f)) == 1 - -# f = Future(impl=incr, requires_loop=False) -# assert asyncio.run(awaitit(f)) == 2 -# assert f.get() == 2 - -# async def incr2(): -# nonlocal v -# v += 2 -# return v - -# # Use non-blocking optimization if provided -# f = Future(impl=incr2) -# assert f.get() == 4 - -# async def nope(): -# nonlocal v -# v += 1 -# raise ValueError("nope") - -# f = Future(impl=nope, requires_loop=False) - -# with pytest.raises(ValueError): -# f.get() - -# assert v == 5 - -# with pytest.raises(ValueError): -# f.get() - -# assert v == 5 - -# with pytest.raises(ValueError): -# asyncio.run(awaitit(f)) - -# assert v == 5 - -# async def nope2(): -# nonlocal v -# v += 1 -# raise ValueError("nope") - -# f = Future(impl=nope2) - -# with pytest.raises(ValueError): -# f.get() - -# assert v == 6 - -# with pytest.raises(ValueError): -# f.result() - -# assert f.exception() is not None - -# assert v == 6 - -# with pytest.raises(ValueError): -# asyncio.run(awaitit(f)) - -# assert v == 6 - -# async def seven(): -# return 7 - -# f = Future(impl=seven, requires_loop=False) - -# assert 7 == f.get(timeout=0.001) - -# async def neverfinish(): -# f = asyncio.Future() -# await f - -# f = Future(impl=neverfinish, requires_loop=True) - -# with pytest.raises(asyncio.exceptions.TimeoutError): -# f.get(timeout=0.1) - - class Printer(Actor): def __init__(self) -> None: self._logger: logging.Logger = logging.getLogger() @@ -548,15 +460,11 @@ def _handle_undeliverable_message( # oss_skip: pytest keeps complaining about mocking get_ipython module @pytest.mark.oss_skip async def test_actor_log_streaming() -> None: - old_env = {} - env_vars = { - "HYPERACTOR_MESH_ENABLE_LOG_FORWARDING": "true", - "HYPERACTOR_MESH_ENABLE_FILE_CAPTURE": "true", - "HYPERACTOR_MESH_TAIL_LOG_LINES": "100", - } - for key, value in env_vars.items(): - old_env[key] = os.environ.get(key) - os.environ[key] = value + config = get_configuration() + enable_log_forwarding = config["enable_log_forwarding"] + enable_file_capture = config["enable_file_capture"] + tail_log_lines = config["tail_log_lines"] + configure(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100) # Save original file descriptors original_stdout_fd = os.dup(1) # stdout @@ -695,11 +603,12 @@ async def test_actor_log_streaming() -> None: ), stderr_content finally: - for key, value in old_env.items(): - if value is None: - os.environ.pop(key, None) - else: - os.environ[key] = value + # Restore config to defaults + configure( + enable_log_forwarding=enable_log_forwarding, + enable_file_capture=enable_file_capture, + tail_log_lines=tail_log_lines, + ) # Ensure file descriptors are restored even if something goes wrong try: @@ -718,15 +627,13 @@ async def test_alloc_based_log_streaming() -> None: """Test both AllocHandle.stream_logs = False and True cases.""" async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None: - old_env = {} - env_vars = { - "HYPERACTOR_MESH_ENABLE_LOG_FORWARDING": "true", - "HYPERACTOR_MESH_ENABLE_FILE_CAPTURE": "true", - "HYPERACTOR_MESH_TAIL_LOG_LINES": "100", - } - for key, value in env_vars.items(): - old_env[key] = os.environ.get(key) - os.environ[key] = value + config = get_configuration() + enable_log_forwarding = config["enable_log_forwarding"] + enable_file_capture = config["enable_file_capture"] + tail_log_lines = config["tail_log_lines"] + configure( + enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100 + ) # Save original file descriptors original_stdout_fd = os.dup(1) # stdout @@ -807,11 +714,11 @@ def _stream_logs(self) -> bool: ), f"stream_logs=True case: {stdout_content}" finally: - for key, value in old_env.items(): - if value is None: - os.environ.pop(key, None) - else: - os.environ[key] = value + configure( + enable_log_forwarding=enable_log_forwarding, + enable_file_capture=enable_file_capture, + tail_log_lines=tail_log_lines, + ) # Ensure file descriptors are restored even if something goes wrong try: os.dup2(original_stdout_fd, 1) @@ -827,15 +734,11 @@ def _stream_logs(self) -> bool: # oss_skip: (SF) broken in GitHub by D86994420. Passes internally. @pytest.mark.oss_skip async def test_logging_option_defaults() -> None: - old_env = {} - env_vars = { - "HYPERACTOR_MESH_ENABLE_LOG_FORWARDING": "true", - "HYPERACTOR_MESH_ENABLE_FILE_CAPTURE": "true", - "HYPERACTOR_MESH_TAIL_LOG_LINES": "100", - } - for key, value in env_vars.items(): - old_env[key] = os.environ.get(key) - os.environ[key] = value + config = get_configuration() + enable_log_forwarding = config["enable_log_forwarding"] + enable_file_capture = config["enable_file_capture"] + tail_log_lines = config["tail_log_lines"] + configure(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100) # Save original file descriptors original_stdout_fd = os.dup(1) # stdout @@ -915,11 +818,11 @@ async def test_logging_option_defaults() -> None: ), stderr_content finally: - for key, value in old_env.items(): - if value is None: - os.environ.pop(key, None) - else: - os.environ[key] = value + configure( + enable_log_forwarding=enable_log_forwarding, + enable_file_capture=enable_file_capture, + tail_log_lines=tail_log_lines, + ) # Ensure file descriptors are restored even if something goes wrong try: @@ -957,15 +860,11 @@ def __init__(self): @pytest.mark.oss_skip async def test_flush_called_only_once() -> None: """Test that flush is called only once when ending an ipython cell""" - old_env = {} - env_vars = { - "HYPERACTOR_MESH_ENABLE_LOG_FORWARDING": "true", - "HYPERACTOR_MESH_ENABLE_FILE_CAPTURE": "true", - "HYPERACTOR_MESH_TAIL_LOG_LINES": "100", - } - for key, value in env_vars.items(): - old_env[key] = os.environ.get(key) - os.environ[key] = value + config = get_configuration() + enable_log_forwarding = config["enable_log_forwarding"] + enable_file_capture = config["enable_file_capture"] + tail_log_lines = config["tail_log_lines"] + configure(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100) mock_ipython = MockIPython() with unittest.mock.patch( "monarch._src.actor.logging.get_ipython", @@ -988,11 +887,11 @@ async def test_flush_called_only_once() -> None: mock_ipython.events.trigger("post_run_cell", unittest.mock.MagicMock()) assert mock_flush.call_count == 1 - for key, value in old_env.items(): - if value is None: - os.environ.pop(key, None) - else: - os.environ[key] = value + configure( + enable_log_forwarding=enable_log_forwarding, + enable_file_capture=enable_file_capture, + tail_log_lines=tail_log_lines, + ) # oss_skip: pytest keeps complaining about mocking get_ipython module @@ -1000,15 +899,12 @@ async def test_flush_called_only_once() -> None: @pytest.mark.timeout(180) async def test_flush_logs_ipython() -> None: """Test that logs are flushed when get_ipython is available and post_run_cell event is triggered.""" - old_env = {} - env_vars = { - "HYPERACTOR_MESH_ENABLE_LOG_FORWARDING": "true", - "HYPERACTOR_MESH_ENABLE_FILE_CAPTURE": "true", - "HYPERACTOR_MESH_TAIL_LOG_LINES": "100", - } - for key, value in env_vars.items(): - old_env[key] = os.environ.get(key) - os.environ[key] = value + config = get_configuration() + enable_log_forwarding = config["enable_log_forwarding"] + enable_file_capture = config["enable_file_capture"] + tail_log_lines = config["tail_log_lines"] + configure(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100) + # Save original file descriptors original_stdout_fd = os.dup(1) # stdout @@ -1079,32 +975,22 @@ async def test_flush_logs_ipython() -> None: # Clean up temp files os.unlink(stdout_path) - # Verify that logs were flushed when the post_run_cell event was triggered - # We should see the aggregated logs in the output - assert ( - len( - re.findall( - r"\[10 similar log lines\].*ipython1 test log", stdout_content - ) - ) - == 3 - ), stdout_content + # We triggered post_run_cell three times; in the current + # implementation that yields three aggregated groups per + # message type (though the counts may be 10, 10, 8 rather than + # all 10). + pattern1 = r"\[\d+ similar log lines\].*ipython1 test log" + pattern2 = r"\[\d+ similar log lines\].*ipython2 test log" - assert ( - len( - re.findall( - r"\[10 similar log lines\].*ipython2 test log", stdout_content - ) - ) - == 3 - ), stdout_content + assert len(re.findall(pattern1, stdout_content)) >= 3, stdout_content + assert len(re.findall(pattern2, stdout_content)) >= 3, stdout_content finally: - for key, value in old_env.items(): - if value is None: - os.environ.pop(key, None) - else: - os.environ[key] = value + configure( + enable_log_forwarding=enable_log_forwarding, + enable_file_capture=enable_file_capture, + tail_log_lines=tail_log_lines, + ) # Ensure file descriptors are restored even if something goes wrong try: os.dup2(original_stdout_fd, 1) @@ -1116,15 +1002,11 @@ async def test_flush_logs_ipython() -> None: # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited @pytest.mark.oss_skip async def test_flush_logs_fast_exit() -> None: - old_env = {} - env_vars = { - "HYPERACTOR_MESH_ENABLE_LOG_FORWARDING": "true", - "HYPERACTOR_MESH_ENABLE_FILE_CAPTURE": "true", - "HYPERACTOR_MESH_TAIL_LOG_LINES": "100", - } - for key, value in env_vars.items(): - old_env[key] = os.environ.get(key) - os.environ[key] = value + config = get_configuration() + enable_log_forwarding = config["enable_log_forwarding"] + enable_file_capture = config["enable_file_capture"] + tail_log_lines = config["tail_log_lines"] + configure(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100) # We use a subprocess to run the test so we can handle the flushed logs at the end. # Otherwise, it is hard to restore the original stdout/stderr. @@ -1151,11 +1033,11 @@ async def test_flush_logs_fast_exit() -> None: == 1 ), process.stdout - for key, value in old_env.items(): - if value is None: - os.environ.pop(key, None) - else: - os.environ[key] = value + configure( + enable_log_forwarding=enable_log_forwarding, + enable_file_capture=enable_file_capture, + tail_log_lines=tail_log_lines, + ) # oss_skip: (SF) broken in GitHub by D86994420. Passes internally. @@ -1165,15 +1047,11 @@ async def test_flush_on_disable_aggregation() -> None: This tests the corner case: "Make sure we flush whatever in the aggregators before disabling aggregation." """ - old_env = {} - env_vars = { - "HYPERACTOR_MESH_ENABLE_LOG_FORWARDING": "true", - "HYPERACTOR_MESH_ENABLE_FILE_CAPTURE": "true", - "HYPERACTOR_MESH_TAIL_LOG_LINES": "100", - } - for key, value in env_vars.items(): - old_env[key] = os.environ.get(key) - os.environ[key] = value + config = get_configuration() + enable_log_forwarding = config["enable_log_forwarding"] + enable_file_capture = config["enable_file_capture"] + tail_log_lines = config["tail_log_lines"] + configure(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100) # Save original file descriptors original_stdout_fd = os.dup(1) # stdout @@ -1254,11 +1132,11 @@ async def test_flush_on_disable_aggregation() -> None: ), f"Expected 10 single log lines, got {total_single} from {stdout_content}" finally: - for key, value in old_env.items(): - if value is None: - os.environ.pop(key, None) - else: - os.environ[key] = value + configure( + enable_log_forwarding=enable_log_forwarding, + enable_file_capture=enable_file_capture, + tail_log_lines=tail_log_lines, + ) # Ensure file descriptors are restored even if something goes wrong try: @@ -1275,15 +1153,11 @@ async def test_multiple_ongoing_flushes_no_deadlock() -> None: Because now a flush call is purely sync, it is very easy to get into a deadlock. So we assert the last flush call will not get into such a state. """ - old_env = {} - env_vars = { - "HYPERACTOR_MESH_ENABLE_LOG_FORWARDING": "true", - "HYPERACTOR_MESH_ENABLE_FILE_CAPTURE": "true", - "HYPERACTOR_MESH_TAIL_LOG_LINES": "100", - } - for key, value in env_vars.items(): - old_env[key] = os.environ.get(key) - os.environ[key] = value + config = get_configuration() + enable_log_forwarding = config["enable_log_forwarding"] + enable_file_capture = config["enable_file_capture"] + tail_log_lines = config["tail_log_lines"] + configure(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100) pm = this_host().spawn_procs(per_host={"gpus": 4}) am = pm.spawn("printer", Printer) @@ -1306,11 +1180,11 @@ async def test_multiple_ongoing_flushes_no_deadlock() -> None: # The last flush should not block futures[-1].get() - for key, value in old_env.items(): - if value is None: - os.environ.pop(key, None) - else: - os.environ[key] = value + configure( + enable_log_forwarding=enable_log_forwarding, + enable_file_capture=enable_file_capture, + tail_log_lines=tail_log_lines, + ) # oss_skip: (SF) broken in GitHub by D86994420. Passes internally. @@ -1320,15 +1194,11 @@ async def test_adjust_aggregation_window() -> None: This tests the corner case: "This can happen if the user has adjusted the aggregation window." """ - old_env = {} - env_vars = { - "HYPERACTOR_MESH_ENABLE_LOG_FORWARDING": "true", - "HYPERACTOR_MESH_ENABLE_FILE_CAPTURE": "true", - "HYPERACTOR_MESH_TAIL_LOG_LINES": "100", - } - for key, value in env_vars.items(): - old_env[key] = os.environ.get(key) - os.environ[key] = value + config = get_configuration() + enable_log_forwarding = config["enable_log_forwarding"] + enable_file_capture = config["enable_file_capture"] + tail_log_lines = config["tail_log_lines"] + configure(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100) # Save original file descriptors original_stdout_fd = os.dup(1) # stdout @@ -1396,11 +1266,11 @@ async def test_adjust_aggregation_window() -> None: ), stdout_content finally: - for key, value in old_env.items(): - if value is None: - os.environ.pop(key, None) - else: - os.environ[key] = value + configure( + enable_log_forwarding=enable_log_forwarding, + enable_file_capture=enable_file_capture, + tail_log_lines=tail_log_lines, + ) # Ensure file descriptors are restored even if something goes wrong try: From a1f60d3dd2a33429fbd7ac9841385effc18f90a8 Mon Sep 17 00:00:00 2001 From: Shayne Fletcher Date: Wed, 19 Nov 2025 14:01:06 -0800 Subject: [PATCH 2/2] : python/tests/test_actors.py: use context mgr for config Summary: remove manual config save/restore boilerplate: ``` config = get_configuration() enable_log_forwarding = config["enable_log_forwarding"] # ... etc configure(enable_log_forwarding=True, ...) try: # test code finally: configure(enable_log_forwarding=enable_log_forwarding, ...) ``` replace with context manager phrasing: ``` with configured(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100): # test code ``` i intend to follow this diff up with one for the FD + sys.stdout/sys.stderr redirection Differential Revision: D87478665 --- python/tests/test_python_actors.py | 1211 +++++++++++++--------------- 1 file changed, 583 insertions(+), 628 deletions(-) diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 759263704..c9e90b3f2 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -7,6 +7,7 @@ # pyre-unsafe import asyncio +import contextlib import ctypes import importlib.resources import io @@ -457,159 +458,165 @@ def _handle_undeliverable_message( return True +@contextlib.contextmanager +def configured(**overrides): + prev = get_configuration().copy() + configure(**overrides) + try: + yield get_configuration().copy() + finally: + configure(**prev) + + # oss_skip: pytest keeps complaining about mocking get_ipython module @pytest.mark.oss_skip async def test_actor_log_streaming() -> None: - config = get_configuration() - enable_log_forwarding = config["enable_log_forwarding"] - enable_file_capture = config["enable_file_capture"] - tail_log_lines = config["tail_log_lines"] - configure(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100) - # Save original file descriptors original_stdout_fd = os.dup(1) # stdout original_stderr_fd = os.dup(2) # stderr try: - # Create temporary files to capture output - with tempfile.NamedTemporaryFile( - mode="w+", delete=False - ) as stdout_file, tempfile.NamedTemporaryFile( - mode="w+", delete=False - ) as stderr_file: - stdout_path = stdout_file.name - stderr_path = stderr_file.name - - # Redirect file descriptors to our temp files - # This will capture both Python and Rust output - os.dup2(stdout_file.fileno(), 1) - os.dup2(stderr_file.fileno(), 2) - - # Also redirect Python's sys.stdout/stderr for completeness - original_sys_stdout = sys.stdout - original_sys_stderr = sys.stderr - sys.stdout = stdout_file - sys.stderr = stderr_file + with configured( + enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100 + ): + # Create temporary files to capture output + with tempfile.NamedTemporaryFile( + mode="w+", delete=False + ) as stdout_file, tempfile.NamedTemporaryFile( + mode="w+", delete=False + ) as stderr_file: + stdout_path = stdout_file.name + stderr_path = stderr_file.name - try: - pm = this_host().spawn_procs(per_host={"gpus": 2}) - am = pm.spawn("printer", Printer) + # Redirect file descriptors to our temp files + # This will capture both Python and Rust output + os.dup2(stdout_file.fileno(), 1) + os.dup2(stderr_file.fileno(), 2) - # Disable streaming logs to client - await pm.logging_option( - stream_to_client=False, aggregate_window_sec=None - ) - await asyncio.sleep(1) + # Also redirect Python's sys.stdout/stderr for completeness + original_sys_stdout = sys.stdout + original_sys_stderr = sys.stderr + sys.stdout = stdout_file + sys.stderr = stderr_file - # These should not be streamed to client initially - for _ in range(5): - await am.print.call("no print streaming") - await am.log.call("no log streaming") - await asyncio.sleep(1) + try: + pm = this_host().spawn_procs(per_host={"gpus": 2}) + am = pm.spawn("printer", Printer) - # Enable streaming logs to client - await pm.logging_option( - stream_to_client=True, aggregate_window_sec=1, level=logging.FATAL - ) - # Give it some time to reflect - await asyncio.sleep(1) - - # These should be streamed to client - for _ in range(5): - await am.print.call("has print streaming") - await am.log.call("no log streaming due to level mismatch") - await asyncio.sleep(1) - - # Enable streaming logs to client - await pm.logging_option( - stream_to_client=True, aggregate_window_sec=1, level=logging.ERROR - ) - # Give it some time to reflect - await asyncio.sleep(1) - - # These should be streamed to client - for _ in range(5): - await am.print.call("has print streaming too") - await am.log.call("has log streaming as level matched") - - await asyncio.sleep(1) - - # Flush all outputs - stdout_file.flush() - stderr_file.flush() - os.fsync(stdout_file.fileno()) - os.fsync(stderr_file.fileno()) - - finally: - # Restore Python's sys.stdout/stderr - sys.stdout = original_sys_stdout - sys.stderr = original_sys_stderr - - # Restore original file descriptors - os.dup2(original_stdout_fd, 1) - os.dup2(original_stderr_fd, 2) - - # Read the captured output - with open(stdout_path, "r") as f: - stdout_content = f.read() - - with open(stderr_path, "r") as f: - stderr_content = f.read() - - # Clean up temp files - os.unlink(stdout_path) - os.unlink(stderr_path) - - # Assertions on the captured output - # Has a leading context so we can distinguish between streamed log and - # the log directly printed by the child processes as they share the same stdout/stderr - assert not re.search( - r"similar log lines.*no print streaming", stdout_content - ), stdout_content - assert not re.search( - r"similar log lines.*no print streaming", stderr_content - ), stderr_content - assert not re.search( - r"similar log lines.*no log streaming", stdout_content - ), stdout_content - assert not re.search( - r"similar log lines.*no log streaming", stderr_content - ), stderr_content - assert not re.search( - r"similar log lines.*no log streaming due to level mismatch", stdout_content - ), stdout_content - assert not re.search( - r"similar log lines.*no log streaming due to level mismatch", stderr_content - ), stderr_content - - assert re.search( - r"similar log lines.*has print streaming", stdout_content - ), stdout_content - assert not re.search( - r"similar log lines.*has print streaming", stderr_content - ), stderr_content - assert re.search( - r"similar log lines.*has print streaming too", stdout_content - ), stdout_content - assert not re.search( - r"similar log lines.*has print streaming too", stderr_content - ), stderr_content - assert not re.search( - r"similar log lines.*log streaming as level matched", stdout_content - ), stdout_content - assert re.search( - r"similar log lines.*log streaming as level matched", - stderr_content, - ), stderr_content + # Disable streaming logs to client + await pm.logging_option( + stream_to_client=False, aggregate_window_sec=None + ) + await asyncio.sleep(1) - finally: - # Restore config to defaults - configure( - enable_log_forwarding=enable_log_forwarding, - enable_file_capture=enable_file_capture, - tail_log_lines=tail_log_lines, - ) + # These should not be streamed to client initially + for _ in range(5): + await am.print.call("no print streaming") + await am.log.call("no log streaming") + await asyncio.sleep(1) + + # Enable streaming logs to client + await pm.logging_option( + stream_to_client=True, + aggregate_window_sec=1, + level=logging.FATAL, + ) + # Give it some time to reflect + await asyncio.sleep(1) + + # These should be streamed to client + for _ in range(5): + await am.print.call("has print streaming") + await am.log.call("no log streaming due to level mismatch") + await asyncio.sleep(1) + + # Enable streaming logs to client + await pm.logging_option( + stream_to_client=True, + aggregate_window_sec=1, + level=logging.ERROR, + ) + # Give it some time to reflect + await asyncio.sleep(1) + + # These should be streamed to client + for _ in range(5): + await am.print.call("has print streaming too") + await am.log.call("has log streaming as level matched") + + await asyncio.sleep(1) + + # Flush all outputs + stdout_file.flush() + stderr_file.flush() + os.fsync(stdout_file.fileno()) + os.fsync(stderr_file.fileno()) + + finally: + # Restore Python's sys.stdout/stderr + sys.stdout = original_sys_stdout + sys.stderr = original_sys_stderr + + # Restore original file descriptors + os.dup2(original_stdout_fd, 1) + os.dup2(original_stderr_fd, 2) + + # Read the captured output + with open(stdout_path, "r") as f: + stdout_content = f.read() + + with open(stderr_path, "r") as f: + stderr_content = f.read() + + # Clean up temp files + os.unlink(stdout_path) + os.unlink(stderr_path) + + # Assertions on the captured output + # Has a leading context so we can distinguish between streamed log and + # the log directly printed by the child processes as they share the same stdout/stderr + assert not re.search( + r"similar log lines.*no print streaming", stdout_content + ), stdout_content + assert not re.search( + r"similar log lines.*no print streaming", stderr_content + ), stderr_content + assert not re.search( + r"similar log lines.*no log streaming", stdout_content + ), stdout_content + assert not re.search( + r"similar log lines.*no log streaming", stderr_content + ), stderr_content + assert not re.search( + r"similar log lines.*no log streaming due to level mismatch", + stdout_content, + ), stdout_content + assert not re.search( + r"similar log lines.*no log streaming due to level mismatch", + stderr_content, + ), stderr_content + + assert re.search( + r"similar log lines.*has print streaming", stdout_content + ), stdout_content + assert not re.search( + r"similar log lines.*has print streaming", stderr_content + ), stderr_content + assert re.search( + r"similar log lines.*has print streaming too", stdout_content + ), stdout_content + assert not re.search( + r"similar log lines.*has print streaming too", stderr_content + ), stderr_content + assert not re.search( + r"similar log lines.*log streaming as level matched", stdout_content + ), stdout_content + assert re.search( + r"similar log lines.*log streaming as level matched", + stderr_content, + ), stderr_content + finally: # Ensure file descriptors are restored even if something goes wrong try: os.dup2(original_stdout_fd, 1) @@ -627,98 +634,92 @@ async def test_alloc_based_log_streaming() -> None: """Test both AllocHandle.stream_logs = False and True cases.""" async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None: - config = get_configuration() - enable_log_forwarding = config["enable_log_forwarding"] - enable_file_capture = config["enable_file_capture"] - tail_log_lines = config["tail_log_lines"] - configure( - enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100 - ) - # Save original file descriptors original_stdout_fd = os.dup(1) # stdout try: - # Create temporary files to capture output - with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file: - stdout_path = stdout_file.name - os.dup2(stdout_file.fileno(), 1) - original_sys_stdout = sys.stdout - sys.stdout = stdout_file - - try: - # Create proc mesh with custom stream_logs setting - class ProcessAllocatorStreamLogs(ProcessAllocator): - def allocate_nonblocking( - self, spec: AllocSpec - ) -> PythonTask[Alloc]: - return super().allocate_nonblocking(spec) - - def _stream_logs(self) -> bool: - return stream_logs - - alloc = ProcessAllocatorStreamLogs(*_get_bootstrap_args()) - - host_mesh = HostMesh.allocate_nonblocking( - "host", - Extent(["hosts"], [1]), - alloc, - bootstrap_cmd=_bootstrap_cmd(), - ) - - pm = host_mesh.spawn_procs(name="proc", per_host={"gpus": 2}) - - am = pm.spawn("printer", Printer) + with configured( + enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100 + ): + # Create temporary files to capture output + with tempfile.NamedTemporaryFile( + mode="w+", delete=False + ) as stdout_file: + stdout_path = stdout_file.name + os.dup2(stdout_file.fileno(), 1) + original_sys_stdout = sys.stdout + sys.stdout = stdout_file + + try: + # Create proc mesh with custom stream_logs setting + class ProcessAllocatorStreamLogs(ProcessAllocator): + def allocate_nonblocking( + self, spec: AllocSpec + ) -> PythonTask[Alloc]: + return super().allocate_nonblocking(spec) + + def _stream_logs(self) -> bool: + return stream_logs + + alloc = ProcessAllocatorStreamLogs(*_get_bootstrap_args()) + + host_mesh = HostMesh.allocate_nonblocking( + "host", + Extent(["hosts"], [1]), + alloc, + bootstrap_cmd=_bootstrap_cmd(), + ) - await pm.initialized + pm = host_mesh.spawn_procs(name="proc", per_host={"gpus": 2}) - for _ in range(5): - await am.print.call(f"{test_name} print streaming") + am = pm.spawn("printer", Printer) - # Wait for at least the aggregation window (3 seconds) - await asyncio.sleep(5) + await pm.initialized - # Flush all outputs - stdout_file.flush() - os.fsync(stdout_file.fileno()) + for _ in range(5): + await am.print.call(f"{test_name} print streaming") - finally: - # Restore Python's sys.stdout - sys.stdout = original_sys_stdout + # Wait for at least the aggregation window (3 seconds) + await asyncio.sleep(5) - # Restore original file descriptors - os.dup2(original_stdout_fd, 1) + # Flush all outputs + stdout_file.flush() + os.fsync(stdout_file.fileno()) - # Read the captured output - with open(stdout_path, "r") as f: - stdout_content = f.read() + finally: + # Restore Python's sys.stdout + sys.stdout = original_sys_stdout - # Clean up temp files - os.unlink(stdout_path) + # Restore original file descriptors + os.dup2(original_stdout_fd, 1) - if not stream_logs: - # When stream_logs=False, logs should not be streamed to client - assert not re.search( - rf"similar log lines.*{test_name} print streaming", stdout_content - ), f"stream_logs=False case: {stdout_content}" - assert re.search( - rf"{test_name} print streaming", stdout_content - ), f"stream_logs=False case: {stdout_content}" - else: - # When stream_logs=True, logs should be streamed to client (no aggregation by default) - assert re.search( - rf"similar log lines.*{test_name} print streaming", stdout_content - ), f"stream_logs=True case: {stdout_content}" - assert not re.search( - rf"\[[0-9]\]{test_name} print streaming", stdout_content - ), f"stream_logs=True case: {stdout_content}" + # Read the captured output + with open(stdout_path, "r") as f: + stdout_content = f.read() + + # Clean up temp files + os.unlink(stdout_path) + + if not stream_logs: + # When stream_logs=False, logs should not be streamed to client + assert not re.search( + rf"similar log lines.*{test_name} print streaming", + stdout_content, + ), f"stream_logs=False case: {stdout_content}" + assert re.search( + rf"{test_name} print streaming", stdout_content + ), f"stream_logs=False case: {stdout_content}" + else: + # When stream_logs=True, logs should be streamed to client (no aggregation by default) + assert re.search( + rf"similar log lines.*{test_name} print streaming", + stdout_content, + ), f"stream_logs=True case: {stdout_content}" + assert not re.search( + rf"\[[0-9]\]{test_name} print streaming", stdout_content + ), f"stream_logs=True case: {stdout_content}" finally: - configure( - enable_log_forwarding=enable_log_forwarding, - enable_file_capture=enable_file_capture, - tail_log_lines=tail_log_lines, - ) # Ensure file descriptors are restored even if something goes wrong try: os.dup2(original_stdout_fd, 1) @@ -734,96 +735,87 @@ def _stream_logs(self) -> bool: # oss_skip: (SF) broken in GitHub by D86994420. Passes internally. @pytest.mark.oss_skip async def test_logging_option_defaults() -> None: - config = get_configuration() - enable_log_forwarding = config["enable_log_forwarding"] - enable_file_capture = config["enable_file_capture"] - tail_log_lines = config["tail_log_lines"] - configure(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100) - # Save original file descriptors original_stdout_fd = os.dup(1) # stdout original_stderr_fd = os.dup(2) # stderr try: - # Create temporary files to capture output - with tempfile.NamedTemporaryFile( - mode="w+", delete=False - ) as stdout_file, tempfile.NamedTemporaryFile( - mode="w+", delete=False - ) as stderr_file: - stdout_path = stdout_file.name - stderr_path = stderr_file.name - - # Redirect file descriptors to our temp files - # This will capture both Python and Rust output - os.dup2(stdout_file.fileno(), 1) - os.dup2(stderr_file.fileno(), 2) - - # Also redirect Python's sys.stdout/stderr for completeness - original_sys_stdout = sys.stdout - original_sys_stderr = sys.stderr - sys.stdout = stdout_file - sys.stderr = stderr_file + with configured( + enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100 + ): + # Create temporary files to capture output + with tempfile.NamedTemporaryFile( + mode="w+", delete=False + ) as stdout_file, tempfile.NamedTemporaryFile( + mode="w+", delete=False + ) as stderr_file: + stdout_path = stdout_file.name + stderr_path = stderr_file.name - try: - pm = this_host().spawn_procs(per_host={"gpus": 2}) - am = pm.spawn("printer", Printer) - - for _ in range(5): - await am.print.call("print streaming") - await am.log.call("log streaming") - - # Wait for > default aggregation window (3 seconds) - await asyncio.sleep(5) - - # Flush all outputs - stdout_file.flush() - stderr_file.flush() - os.fsync(stdout_file.fileno()) - os.fsync(stderr_file.fileno()) - - finally: - # Restore Python's sys.stdout/stderr - sys.stdout = original_sys_stdout - sys.stderr = original_sys_stderr - - # Restore original file descriptors - os.dup2(original_stdout_fd, 1) - os.dup2(original_stderr_fd, 2) - - # Read the captured output - with open(stdout_path, "r") as f: - stdout_content = f.read() - - with open(stderr_path, "r") as f: - stderr_content = f.read() - - # Clean up temp files - os.unlink(stdout_path) - os.unlink(stderr_path) - - # Assertions on the captured output - assert not re.search( - r"similar log lines.*print streaming", stdout_content - ), stdout_content - assert re.search(r"print streaming", stdout_content), stdout_content - assert not re.search( - r"similar log lines.*print streaming", stderr_content - ), stderr_content - assert not re.search( - r"similar log lines.*log streaming", stdout_content - ), stdout_content - assert not re.search( - r"similar log lines.*log streaming", stderr_content - ), stderr_content + # Redirect file descriptors to our temp files + # This will capture both Python and Rust output + os.dup2(stdout_file.fileno(), 1) + os.dup2(stderr_file.fileno(), 2) - finally: - configure( - enable_log_forwarding=enable_log_forwarding, - enable_file_capture=enable_file_capture, - tail_log_lines=tail_log_lines, - ) + # Also redirect Python's sys.stdout/stderr for completeness + original_sys_stdout = sys.stdout + original_sys_stderr = sys.stderr + sys.stdout = stdout_file + sys.stderr = stderr_file + + try: + pm = this_host().spawn_procs(per_host={"gpus": 2}) + am = pm.spawn("printer", Printer) + + for _ in range(5): + await am.print.call("print streaming") + await am.log.call("log streaming") + + # Wait for > default aggregation window (3 seconds) + await asyncio.sleep(5) + + # Flush all outputs + stdout_file.flush() + stderr_file.flush() + os.fsync(stdout_file.fileno()) + os.fsync(stderr_file.fileno()) + + finally: + # Restore Python's sys.stdout/stderr + sys.stdout = original_sys_stdout + sys.stderr = original_sys_stderr + + # Restore original file descriptors + os.dup2(original_stdout_fd, 1) + os.dup2(original_stderr_fd, 2) + + # Read the captured output + with open(stdout_path, "r") as f: + stdout_content = f.read() + + with open(stderr_path, "r") as f: + stderr_content = f.read() + + # Clean up temp files + os.unlink(stdout_path) + os.unlink(stderr_path) + + # Assertions on the captured output + assert not re.search( + r"similar log lines.*print streaming", stdout_content + ), stdout_content + assert re.search(r"print streaming", stdout_content), stdout_content + assert not re.search( + r"similar log lines.*print streaming", stderr_content + ), stderr_content + assert not re.search( + r"similar log lines.*log streaming", stdout_content + ), stdout_content + assert not re.search( + r"similar log lines.*log streaming", stderr_content + ), stderr_content + finally: # Ensure file descriptors are restored even if something goes wrong try: os.dup2(original_stdout_fd, 1) @@ -860,38 +852,31 @@ def __init__(self): @pytest.mark.oss_skip async def test_flush_called_only_once() -> None: """Test that flush is called only once when ending an ipython cell""" - config = get_configuration() - enable_log_forwarding = config["enable_log_forwarding"] - enable_file_capture = config["enable_file_capture"] - tail_log_lines = config["tail_log_lines"] - configure(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100) - mock_ipython = MockIPython() - with unittest.mock.patch( - "monarch._src.actor.logging.get_ipython", - lambda: mock_ipython, - ), unittest.mock.patch( - "monarch._src.actor.logging.IN_IPYTHON", True - ), unittest.mock.patch( - "monarch._src.actor.logging.flush_all_proc_mesh_logs" - ) as mock_flush: - # Create 2 proc meshes with a large aggregation window - pm1 = this_host().spawn_procs(per_host={"gpus": 2}) - _ = this_host().spawn_procs(per_host={"gpus": 2}) - # flush not yet called unless post_run_cell - assert mock_flush.call_count == 0 - assert mock_ipython.events.registers == 0 - await pm1.logging_option(stream_to_client=True, aggregate_window_sec=600) - assert mock_ipython.events.registers == 1 - - # now, flush should be called only once - mock_ipython.events.trigger("post_run_cell", unittest.mock.MagicMock()) - - assert mock_flush.call_count == 1 - configure( - enable_log_forwarding=enable_log_forwarding, - enable_file_capture=enable_file_capture, - tail_log_lines=tail_log_lines, - ) + with configured( + enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100 + ): + mock_ipython = MockIPython() + with unittest.mock.patch( + "monarch._src.actor.logging.get_ipython", + lambda: mock_ipython, + ), unittest.mock.patch( + "monarch._src.actor.logging.IN_IPYTHON", True + ), unittest.mock.patch( + "monarch._src.actor.logging.flush_all_proc_mesh_logs" + ) as mock_flush: + # Create 2 proc meshes with a large aggregation window + pm1 = this_host().spawn_procs(per_host={"gpus": 2}) + _ = this_host().spawn_procs(per_host={"gpus": 2}) + # flush not yet called unless post_run_cell + assert mock_flush.call_count == 0 + assert mock_ipython.events.registers == 0 + await pm1.logging_option(stream_to_client=True, aggregate_window_sec=600) + assert mock_ipython.events.registers == 1 + + # now, flush should be called only once + mock_ipython.events.trigger("post_run_cell", unittest.mock.MagicMock()) + + assert mock_flush.call_count == 1 # oss_skip: pytest keeps complaining about mocking get_ipython module @@ -899,145 +884,131 @@ async def test_flush_called_only_once() -> None: @pytest.mark.timeout(180) async def test_flush_logs_ipython() -> None: """Test that logs are flushed when get_ipython is available and post_run_cell event is triggered.""" - config = get_configuration() - enable_log_forwarding = config["enable_log_forwarding"] - enable_file_capture = config["enable_file_capture"] - tail_log_lines = config["tail_log_lines"] - configure(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100) - - # Save original file descriptors - original_stdout_fd = os.dup(1) # stdout + with configured( + enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100 + ): + # Save original file descriptors + original_stdout_fd = os.dup(1) # stdout - try: - # Create temporary files to capture output - with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file: - stdout_path = stdout_file.name + try: + # Create temporary files to capture output + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file: + stdout_path = stdout_file.name - # Redirect file descriptors to our temp files - os.dup2(stdout_file.fileno(), 1) + # Redirect file descriptors to our temp files + os.dup2(stdout_file.fileno(), 1) - # Also redirect Python's sys.stdout - original_sys_stdout = sys.stdout - sys.stdout = stdout_file + # Also redirect Python's sys.stdout + original_sys_stdout = sys.stdout + sys.stdout = stdout_file - try: - mock_ipython = MockIPython() + try: + mock_ipython = MockIPython() + + with unittest.mock.patch( + "monarch._src.actor.logging.get_ipython", + lambda: mock_ipython, + ), unittest.mock.patch( + "monarch._src.actor.logging.IN_IPYTHON", True + ): + # Make sure we can register and unregister callbacks + for _ in range(3): + pm1 = this_host().spawn_procs(per_host={"gpus": 2}) + pm2 = this_host().spawn_procs(per_host={"gpus": 2}) + am1 = pm1.spawn("printer", Printer) + am2 = pm2.spawn("printer", Printer) + + # Set aggregation window to ensure logs are buffered + await pm1.logging_option( + stream_to_client=True, aggregate_window_sec=600 + ) + await pm2.logging_option( + stream_to_client=True, aggregate_window_sec=600 + ) + + # Generate some logs that will be aggregated + for _ in range(5): + await am1.print.call("ipython1 test log") + await am2.print.call("ipython2 test log") + + # Trigger the post_run_cell event which should flush logs + mock_ipython.events.trigger( + "post_run_cell", unittest.mock.MagicMock() + ) + + # Flush all outputs + stdout_file.flush() + os.fsync(stdout_file.fileno()) + + # We expect to register post_run_cell hook only once per notebook/ipython session + assert mock_ipython.events.registers == 1 + assert len(mock_ipython.events.callbacks["post_run_cell"]) == 1 + finally: + # Restore Python's sys.stdout + sys.stdout = original_sys_stdout - with unittest.mock.patch( - "monarch._src.actor.logging.get_ipython", - lambda: mock_ipython, - ), unittest.mock.patch("monarch._src.actor.logging.IN_IPYTHON", True): - # Make sure we can register and unregister callbacks - for _ in range(3): - pm1 = this_host().spawn_procs(per_host={"gpus": 2}) - pm2 = this_host().spawn_procs(per_host={"gpus": 2}) - am1 = pm1.spawn("printer", Printer) - am2 = pm2.spawn("printer", Printer) - - # Set aggregation window to ensure logs are buffered - await pm1.logging_option( - stream_to_client=True, aggregate_window_sec=600 - ) - await pm2.logging_option( - stream_to_client=True, aggregate_window_sec=600 - ) + # Restore original file descriptors + os.dup2(original_stdout_fd, 1) - # Generate some logs that will be aggregated - for _ in range(5): - await am1.print.call("ipython1 test log") - await am2.print.call("ipython2 test log") + # Read the captured output + with open(stdout_path, "r") as f: + stdout_content = f.read() - # Trigger the post_run_cell event which should flush logs - mock_ipython.events.trigger( - "post_run_cell", unittest.mock.MagicMock() - ) + # TODO: there are quite a lot of code dups and boilerplate; make them contextmanager utils - # Flush all outputs - stdout_file.flush() - os.fsync(stdout_file.fileno()) + # Clean up temp files + os.unlink(stdout_path) - # We expect to register post_run_cell hook only once per notebook/ipython session - assert mock_ipython.events.registers == 1 - assert len(mock_ipython.events.callbacks["post_run_cell"]) == 1 - finally: - # Restore Python's sys.stdout - sys.stdout = original_sys_stdout + # We triggered post_run_cell three times; in the current + # implementation that yields three aggregated groups per + # message type (though the counts may be 10, 10, 8 rather than + # all 10). + pattern1 = r"\[\d+ similar log lines\].*ipython1 test log" + pattern2 = r"\[\d+ similar log lines\].*ipython2 test log" - # Restore original file descriptors - os.dup2(original_stdout_fd, 1) + assert len(re.findall(pattern1, stdout_content)) >= 3, stdout_content + assert len(re.findall(pattern2, stdout_content)) >= 3, stdout_content - # Read the captured output - with open(stdout_path, "r") as f: - stdout_content = f.read() + finally: + # Ensure file descriptors are restored even if something goes wrong + try: + os.dup2(original_stdout_fd, 1) + os.close(original_stdout_fd) + except OSError: + pass - # TODO: there are quite a lot of code dups and boilerplate; make them contextmanager utils - # Clean up temp files - os.unlink(stdout_path) +# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited +@pytest.mark.oss_skip +async def test_flush_logs_fast_exit() -> None: + with configured( + enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100 + ): + # We use a subprocess to run the test so we can handle the flushed logs at the end. + # Otherwise, it is hard to restore the original stdout/stderr. - # We triggered post_run_cell three times; in the current - # implementation that yields three aggregated groups per - # message type (though the counts may be 10, 10, 8 rather than - # all 10). - pattern1 = r"\[\d+ similar log lines\].*ipython1 test log" - pattern2 = r"\[\d+ similar log lines\].*ipython2 test log" + test_bin = importlib.resources.files(str(__package__)).joinpath("test_bin") - assert len(re.findall(pattern1, stdout_content)) >= 3, stdout_content - assert len(re.findall(pattern2, stdout_content)) >= 3, stdout_content + # Run the binary in a separate process and capture stdout and stderr + cmd = [str(test_bin), "flush-logs"] - finally: - configure( - enable_log_forwarding=enable_log_forwarding, - enable_file_capture=enable_file_capture, - tail_log_lines=tail_log_lines, - ) - # Ensure file descriptors are restored even if something goes wrong - try: - os.dup2(original_stdout_fd, 1) - os.close(original_stdout_fd) - except OSError: - pass + process = subprocess.run(cmd, capture_output=True, timeout=60, text=True) + # Check if the process ended without error + if process.returncode != 0: + raise RuntimeError(f"{cmd} ended with error code {process.returncode}. ") -# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited -@pytest.mark.oss_skip -async def test_flush_logs_fast_exit() -> None: - config = get_configuration() - enable_log_forwarding = config["enable_log_forwarding"] - enable_file_capture = config["enable_file_capture"] - tail_log_lines = config["tail_log_lines"] - configure(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100) - # We use a subprocess to run the test so we can handle the flushed logs at the end. - # Otherwise, it is hard to restore the original stdout/stderr. - - test_bin = importlib.resources.files(str(__package__)).joinpath("test_bin") - - # Run the binary in a separate process and capture stdout and stderr - cmd = [str(test_bin), "flush-logs"] - - process = subprocess.run(cmd, capture_output=True, timeout=60, text=True) - - # Check if the process ended without error - if process.returncode != 0: - raise RuntimeError(f"{cmd} ended with error code {process.returncode}. ") - - # Assertions on the captured output, 160 = 32 procs * 5 logs per proc - # 32 and 5 are specified in the test_bin flush-logs. - assert ( - len( - re.findall( - r"160 similar log lines.*has print streaming", - process.stdout, + # Assertions on the captured output, 160 = 32 procs * 5 logs per proc + # 32 and 5 are specified in the test_bin flush-logs. + assert ( + len( + re.findall( + r"160 similar log lines.*has print streaming", + process.stdout, + ) ) - ) - == 1 - ), process.stdout - - configure( - enable_log_forwarding=enable_log_forwarding, - enable_file_capture=enable_file_capture, - tail_log_lines=tail_log_lines, - ) + == 1 + ), process.stdout # oss_skip: (SF) broken in GitHub by D86994420. Passes internally. @@ -1047,103 +1018,98 @@ async def test_flush_on_disable_aggregation() -> None: This tests the corner case: "Make sure we flush whatever in the aggregators before disabling aggregation." """ - config = get_configuration() - enable_log_forwarding = config["enable_log_forwarding"] - enable_file_capture = config["enable_file_capture"] - tail_log_lines = config["tail_log_lines"] - configure(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100) - - # Save original file descriptors - original_stdout_fd = os.dup(1) # stdout - - try: - # Create temporary files to capture output - with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file: - stdout_path = stdout_file.name - - # Redirect file descriptors to our temp files - os.dup2(stdout_file.fileno(), 1) - - # Also redirect Python's sys.stdout - original_sys_stdout = sys.stdout - sys.stdout = stdout_file + with configured( + enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100 + ): + # Save original file descriptors + original_stdout_fd = os.dup(1) # stdout - try: - pm = this_host().spawn_procs(per_host={"gpus": 2}) - am = pm.spawn("printer", Printer) + try: + # Create temporary files to capture output + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file: + stdout_path = stdout_file.name - # Set a long aggregation window to ensure logs aren't flushed immediately - await pm.logging_option(stream_to_client=True, aggregate_window_sec=60) + # Redirect file descriptors to our temp files + os.dup2(stdout_file.fileno(), 1) - # Generate some logs that will be aggregated but not flushed immediately - for _ in range(5): - await am.print.call("aggregated log line") - await asyncio.sleep(1) + # Also redirect Python's sys.stdout + original_sys_stdout = sys.stdout + sys.stdout = stdout_file - # Now disable aggregation - this should trigger an immediate flush - await pm.logging_option( - stream_to_client=True, aggregate_window_sec=None - ) + try: + pm = this_host().spawn_procs(per_host={"gpus": 2}) + am = pm.spawn("printer", Printer) - # Wait a bit to ensure logs are collected - await asyncio.sleep(1) - for _ in range(5): - await am.print.call("single log line") + # Set a long aggregation window to ensure logs aren't flushed immediately + await pm.logging_option( + stream_to_client=True, aggregate_window_sec=60 + ) - # Wait for > default aggregation window (3 secs) - await asyncio.sleep(5) + # Generate some logs that will be aggregated but not flushed immediately + for _ in range(5): + await am.print.call("aggregated log line") + await asyncio.sleep(1) - # Flush all outputs - stdout_file.flush() - os.fsync(stdout_file.fileno()) + # Now disable aggregation - this should trigger an immediate flush + await pm.logging_option( + stream_to_client=True, aggregate_window_sec=None + ) - finally: - # Restore Python's sys.stdout - sys.stdout = original_sys_stdout + # Wait a bit to ensure logs are collected + await asyncio.sleep(1) + for _ in range(5): + await am.print.call("single log line") - # Restore original file descriptors - os.dup2(original_stdout_fd, 1) + # Wait for > default aggregation window (3 secs) + await asyncio.sleep(5) - # Read the captured output - with open(stdout_path, "r") as f: - stdout_content = f.read() + # Flush all outputs + stdout_file.flush() + os.fsync(stdout_file.fileno()) - # Clean up temp files - os.unlink(stdout_path) + finally: + # Restore Python's sys.stdout + sys.stdout = original_sys_stdout - # Verify that logs were flushed when aggregation was disabled - # We should see the aggregated logs in the output - # 10 = 5 log lines * 2 procs - assert re.search( - r"\[10 similar log lines\].*aggregated log line", stdout_content - ), stdout_content + # Restore original file descriptors + os.dup2(original_stdout_fd, 1) - # No aggregated single log lines - assert not re.search( - r"similar log lines.*single log line", stdout_content - ), stdout_content + # Read the captured output + with open(stdout_path, "r") as f: + stdout_content = f.read() - # 10 = 5 log lines * 2 procs - total_single = len( - re.findall(r"\[.* [0-9]+\](?: \[[0-9]+\])? single log line", stdout_content) - ) - assert ( - total_single == 10 - ), f"Expected 10 single log lines, got {total_single} from {stdout_content}" + # Clean up temp files + os.unlink(stdout_path) - finally: - configure( - enable_log_forwarding=enable_log_forwarding, - enable_file_capture=enable_file_capture, - tail_log_lines=tail_log_lines, - ) + # Verify that logs were flushed when aggregation was disabled + # We should see the aggregated logs in the output + # 10 = 5 log lines * 2 procs + assert re.search( + r"\[10 similar log lines\].*aggregated log line", stdout_content + ), stdout_content + + # No aggregated single log lines + assert not re.search( + r"similar log lines.*single log line", stdout_content + ), stdout_content + + # 10 = 5 log lines * 2 procs + total_single = len( + re.findall( + r"\[.* [0-9]+\](?: \[[0-9]+\])? single log line", stdout_content + ) + ) + assert ( + total_single == 10 + ), f"Expected 10 single log lines, got {total_single} from {stdout_content}" - # Ensure file descriptors are restored even if something goes wrong - try: - os.dup2(original_stdout_fd, 1) - os.close(original_stdout_fd) - except OSError: - pass + finally: + # Ensure file descriptors are restored even if something goes wrong + try: + os.dup2(original_stdout_fd, 1) + os.close(original_stdout_fd) + except OSError: + pass @pytest.mark.timeout(120) @@ -1153,38 +1119,32 @@ async def test_multiple_ongoing_flushes_no_deadlock() -> None: Because now a flush call is purely sync, it is very easy to get into a deadlock. So we assert the last flush call will not get into such a state. """ - config = get_configuration() - enable_log_forwarding = config["enable_log_forwarding"] - enable_file_capture = config["enable_file_capture"] - tail_log_lines = config["tail_log_lines"] - configure(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100) - pm = this_host().spawn_procs(per_host={"gpus": 4}) - am = pm.spawn("printer", Printer) - - # Generate some logs that will be aggregated but not flushed immediately - for _ in range(10): - await am.print.call("aggregated log line") - - log_mesh = pm._logging_manager._logging_mesh_client - assert log_mesh is not None - futures = [] - for _ in range(5): - # FIXME: the order of futures doesn't necessarily mean the order of flushes due to the async nature. - await asyncio.sleep(0.1) - futures.append( - Future( - coro=log_mesh.flush(context().actor_instance._as_rust()).spawn().task() + with configured( + enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100 + ): + pm = this_host().spawn_procs(per_host={"gpus": 4}) + am = pm.spawn("printer", Printer) + + # Generate some logs that will be aggregated but not flushed immediately + for _ in range(10): + await am.print.call("aggregated log line") + + log_mesh = pm._logging_manager._logging_mesh_client + assert log_mesh is not None + futures = [] + for _ in range(5): + # FIXME: the order of futures doesn't necessarily mean the order of flushes due to the async nature. + await asyncio.sleep(0.1) + futures.append( + Future( + coro=log_mesh.flush(context().actor_instance._as_rust()) + .spawn() + .task() + ) ) - ) - # The last flush should not block - futures[-1].get() - - configure( - enable_log_forwarding=enable_log_forwarding, - enable_file_capture=enable_file_capture, - tail_log_lines=tail_log_lines, - ) + # The last flush should not block + futures[-1].get() # oss_skip: (SF) broken in GitHub by D86994420. Passes internally. @@ -1194,90 +1154,85 @@ async def test_adjust_aggregation_window() -> None: This tests the corner case: "This can happen if the user has adjusted the aggregation window." """ - config = get_configuration() - enable_log_forwarding = config["enable_log_forwarding"] - enable_file_capture = config["enable_file_capture"] - tail_log_lines = config["tail_log_lines"] - configure(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100) - - # Save original file descriptors - original_stdout_fd = os.dup(1) # stdout - - try: - # Create temporary files to capture output - with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file: - stdout_path = stdout_file.name + with configured( + enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100 + ): + # Save original file descriptors + original_stdout_fd = os.dup(1) # stdout - # Redirect file descriptors to our temp files - os.dup2(stdout_file.fileno(), 1) + try: + # Create temporary files to capture output + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as stdout_file: + stdout_path = stdout_file.name - # Also redirect Python's sys.stdout - original_sys_stdout = sys.stdout - sys.stdout = stdout_file + # Redirect file descriptors to our temp files + os.dup2(stdout_file.fileno(), 1) - try: - pm = this_host().spawn_procs(per_host={"gpus": 2}) - am = pm.spawn("printer", Printer) + # Also redirect Python's sys.stdout + original_sys_stdout = sys.stdout + sys.stdout = stdout_file - # Set a long aggregation window initially - await pm.logging_option(stream_to_client=True, aggregate_window_sec=100) + try: + pm = this_host().spawn_procs(per_host={"gpus": 2}) + am = pm.spawn("printer", Printer) - # Generate some logs that will be aggregated - for _ in range(3): - await am.print.call("first batch of logs") - await asyncio.sleep(1) + # Set a long aggregation window initially + await pm.logging_option( + stream_to_client=True, aggregate_window_sec=100 + ) - # Now adjust to a shorter window - this should update the flush deadline - await pm.logging_option(stream_to_client=True, aggregate_window_sec=2) + # Generate some logs that will be aggregated + for _ in range(3): + await am.print.call("first batch of logs") + await asyncio.sleep(1) - # Generate more logs - for _ in range(3): - await am.print.call("second batch of logs") + # Now adjust to a shorter window - this should update the flush deadline + await pm.logging_option( + stream_to_client=True, aggregate_window_sec=2 + ) - # Wait for > aggregation window (2 secs) - await asyncio.sleep(4) + # Generate more logs + for _ in range(3): + await am.print.call("second batch of logs") - # Flush all outputs - stdout_file.flush() - os.fsync(stdout_file.fileno()) + # Wait for > aggregation window (2 secs) + await asyncio.sleep(4) - finally: - # Restore Python's sys.stdout/stderr - sys.stdout = original_sys_stdout + # Flush all outputs + stdout_file.flush() + os.fsync(stdout_file.fileno()) - # Restore original file descriptors - os.dup2(original_stdout_fd, 1) + finally: + # Restore Python's sys.stdout/stderr + sys.stdout = original_sys_stdout - # Read the captured output - with open(stdout_path, "r") as f: - stdout_content = f.read() + # Restore original file descriptors + os.dup2(original_stdout_fd, 1) - # Clean up temp files - os.unlink(stdout_path) + # Read the captured output + with open(stdout_path, "r") as f: + stdout_content = f.read() - # Verify that logs were flushed when the aggregation window was adjusted - # We should see both batches of logs in the output - assert re.search( - r"\[6 similar log lines\].*first batch of logs", stdout_content - ), stdout_content + # Clean up temp files + os.unlink(stdout_path) - assert re.search( - r"similar log lines.*second batch of logs", stdout_content - ), stdout_content + # Verify that logs were flushed when the aggregation window was adjusted + # We should see both batches of logs in the output + assert re.search( + r"\[6 similar log lines\].*first batch of logs", stdout_content + ), stdout_content - finally: - configure( - enable_log_forwarding=enable_log_forwarding, - enable_file_capture=enable_file_capture, - tail_log_lines=tail_log_lines, - ) + assert re.search( + r"similar log lines.*second batch of logs", stdout_content + ), stdout_content - # Ensure file descriptors are restored even if something goes wrong - try: - os.dup2(original_stdout_fd, 1) - os.close(original_stdout_fd) - except OSError: - pass + finally: + # Ensure file descriptors are restored even if something goes wrong + try: + os.dup2(original_stdout_fd, 1) + os.close(original_stdout_fd) + except OSError: + pass class SendAlot(Actor):