Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions monarch_hyperactor/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ pub fn reload_config_from_env() -> PyResult<()> {
Ok(())
}

#[pyfunction()]
pub fn reset_config_to_defaults() -> PyResult<()> {
// Set all config values to defaults, ignoring even environment variables.
hyperactor::config::global::reset_to_defaults();
Ok(())
}

/// Map from the kwarg name passed to `monarch.configure(...)` to the
/// `Key<T>` associated with that kwarg. This contains all attribute
/// keys whose `@meta(CONFIG = ConfigAttr { py_name: Some(...), .. })`
Expand Down Expand Up @@ -238,6 +245,13 @@ pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
)?;
module.add_function(reload)?;

let reset = wrap_pyfunction!(reset_config_to_defaults, module)?;
reset.setattr(
"__module__",
"monarch._rust_bindings.monarch_hyperactor.config",
)?;
module.add_function(reset)?;

let configure = wrap_pyfunction!(configure, module)?;
configure.setattr(
"__module__",
Expand Down
23 changes: 18 additions & 5 deletions python/monarch/_rust_bindings/monarch_hyperactor/config.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,26 @@ def reload_config_from_env() -> None:

This reads all HYPERACTOR_* environment variables and updates
the global configuration.
For any configuration setting not present in environment variables,
this function will not change its value.
"""
...

def reset_config_to_defaults() -> None:
"""Reset all configuration to default values, ignoring environment variables.
Call reload_config_from_env() to reload the environment variables.
"""
...

def configure(
default_transport: ChannelTransport = ChannelTransport.Unix,
enable_log_forwarding: bool = False,
enable_file_capture: bool = False,
tail_log_lines: int = 0,
) -> None: ...
default_transport: ChannelTransport = ...,
enable_log_forwarding: bool = ...,
enable_file_capture: bool = ...,
tail_log_lines: int = ...,
**kwargs: object,
) -> None:
"""Change a configuration value in the global configuration. If called with
no arguments, makes no changes. Does not reset any configuration"""
...

def get_configuration() -> Dict[str, Any]: ...
65 changes: 47 additions & 18 deletions python/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,76 @@

# pyre-unsafe

from contextlib import contextmanager

import pytest
from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
from monarch._rust_bindings.monarch_hyperactor.config import (
configure,
get_configuration,
reload_config_from_env,
reset_config_to_defaults,
)


@contextmanager
def configure_temporary(*args, **kwargs):
"""Call configure, and then reset the configuration to the default values after
exiting. Always use this when testing so that other tests are not affected by any
changes made."""

try:
configure(*args, **kwargs)
yield
finally:
reset_config_to_defaults()
# Re-apply any environment variables that were set for this test.
reload_config_from_env()


def test_get_set_transport() -> None:
for transport in (
ChannelTransport.Unix,
ChannelTransport.TcpWithLocalhost,
ChannelTransport.TcpWithHostname,
ChannelTransport.MetaTlsWithHostname,
):
configure(default_transport=transport)
assert get_configuration()["default_transport"] == transport
# Succeed even if we don't specify the transport
configure()
assert (
get_configuration()["default_transport"] == ChannelTransport.MetaTlsWithHostname
)
with configure_temporary(default_transport=transport):
assert get_configuration()["default_transport"] == transport
# Succeed even if we don't specify the transport, but does not change the
# previous value.
with configure_temporary():
assert get_configuration()["default_transport"] == ChannelTransport.Unix
with pytest.raises(TypeError):
configure(default_transport="unix") # type: ignore
with configure_temporary(default_transport="unix"): # type: ignore
pass
with pytest.raises(TypeError):
configure(default_transport=42) # type: ignore
with configure_temporary(default_transport=42): # type: ignore
pass
with pytest.raises(TypeError):
configure(default_transport={}) # type: ignore
with configure_temporary(default_transport={}): # type: ignore
pass


def test_nonexistent_config_key() -> None:
with pytest.raises(ValueError):
configure(does_not_exist=42) # type: ignore
with configure_temporary(does_not_exist=42): # type: ignore
pass


def test_get_set_multiple() -> None:
configure(default_transport=ChannelTransport.TcpWithLocalhost)
configure(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100)
with configure_temporary(default_transport=ChannelTransport.TcpWithLocalhost):
with configure_temporary(
enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100
):
config = get_configuration()
assert config["enable_log_forwarding"]
assert config["enable_file_capture"]
assert config["tail_log_lines"] == 100
assert config["default_transport"] == ChannelTransport.TcpWithLocalhost
# Make sure the previous values are restored.
config = get_configuration()

assert config["enable_log_forwarding"]
assert config["enable_file_capture"]
assert config["tail_log_lines"] == 100
assert config["default_transport"] == ChannelTransport.TcpWithLocalhost
assert not config["enable_log_forwarding"]
assert not config["enable_file_capture"]
assert config["tail_log_lines"] == 0
assert config["default_transport"] == ChannelTransport.Unix
2 changes: 2 additions & 0 deletions scripts/common-setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ run_test_groups() {
# sustainable/most robust solution.
export CONDA_LIBSTDCPP="${CONDA_PREFIX}/lib/libstdc++.so.6"
export LD_PRELOAD="${CONDA_LIBSTDCPP}${LD_PRELOAD:+:$LD_PRELOAD}"
# Backtraces help with debugging remotely.
export RUST_BACKTRACE=1
local FAILED_GROUPS=()
for GROUP in $(seq 1 10); do
echo "Running test group $GROUP of 10..."
Expand Down
Loading