diff --git a/hyperactor/src/attrs.rs b/hyperactor/src/attrs.rs index bc95585b3..2ba806a1e 100644 --- a/hyperactor/src/attrs.rs +++ b/hyperactor/src/attrs.rs @@ -556,6 +556,15 @@ impl Attrs { pub(crate) fn get_value_by_name(&self, name: &'static str) -> Option<&dyn SerializableValue> { self.values.get(name).map(|b| b.as_ref()) } + + /// Merge all attributes from `other` into this set, consuming + /// `other`. + /// + /// For each key in `other`, moves its value into `self`, + /// overwriting any existing value for the same key. + pub(crate) fn merge(&mut self, other: Attrs) { + self.values.extend(other.values); + } } impl Clone for Attrs { diff --git a/hyperactor/src/config/global.rs b/hyperactor/src/config/global.rs index 908a757ba..d4389a96e 100644 --- a/hyperactor/src/config/global.rs +++ b/hyperactor/src/config/global.rs @@ -440,6 +440,24 @@ pub fn try_get_cloned(key: Key) -> Option { key.default().cloned() } +/// Construct a [`Layer`] for the given [`Source`] using the provided +/// `attrs`. +/// +/// Used by [`set`] and [`create_or_merge`] when installing a new +/// configuration layer. +fn make_layer(source: Source, attrs: Attrs) -> Layer { + match source { + Source::File => Layer::File(attrs), + Source::Env => Layer::Env(attrs), + Source::Runtime => Layer::Runtime(attrs), + Source::TestOverride => Layer::TestOverride { + attrs, + stacks: HashMap::new(), + }, + Source::ClientOverride => Layer::ClientOverride(attrs), + } +} + /// Insert or replace a configuration layer for the given source. /// /// If a layer with the same [`Source`] already exists, its @@ -457,16 +475,34 @@ pub fn set(source: Source, attrs: Attrs) { if let Some(l) = g.ordered.iter_mut().find(|l| layer_source(l) == source) { *layer_attrs_mut(l) = attrs; } else { - g.ordered.push(match source { - Source::File => Layer::File(attrs), - Source::Env => Layer::Env(attrs), - Source::Runtime => Layer::Runtime(attrs), - Source::TestOverride => Layer::TestOverride { - attrs, - stacks: HashMap::new(), - }, - Source::ClientOverride => Layer::ClientOverride(attrs), - }); + g.ordered.push(make_layer(source, attrs)); + } + g.ordered.sort_by_key(|l| priority(layer_source(l))); // TestOverride < Runtime < Env < File < ClientOverride +} + +/// Insert or update a configuration layer for the given [`Source`]. +/// +/// If a layer with the same [`Source`] already exists, its attributes +/// are **updated in place**: all keys present in `attrs` are absorbed +/// into the existing layer, overwriting any previous values for those +/// keys while leaving all other keys in that layer unchanged. +/// +/// If no layer for `source` exists yet, this behaves like [`set`]: a +/// new layer is created with the provided `attrs`. +/// +/// This is useful for incremental / additive updates (for example, +/// runtime configuration driven by a Python API), where callers want +/// to change a subset of keys without discarding previously installed +/// values in the same layer. +/// +/// By contrast, [`set`] replaces the entire layer for `source` with +/// `attrs`, discarding any existing values in that layer. +pub fn create_or_merge(source: Source, attrs: Attrs) { + let mut g = LAYERS.write().unwrap(); + if let Some(layer) = g.ordered.iter_mut().find(|l| layer_source(l) == source) { + layer_attrs_mut(layer).merge(attrs); + } else { + g.ordered.push(make_layer(source, attrs)); } g.ordered.sort_by_key(|l| priority(layer_source(l))); // TestOverride < Runtime < Env < File < ClientOverride } @@ -1206,4 +1242,37 @@ mod tests { assert!(priority(Env) < priority(File)); assert!(priority(File) < priority(ClientOverride)); } + + #[test] + fn test_create_or_merge_runtime_merges_keys() { + let _lock = lock(); + reset_to_defaults(); + + // Seed Runtime with one key. + let mut rt = Attrs::new(); + rt[MESSAGE_TTL_DEFAULT] = 10; + set(Source::Runtime, rt); + + // Now update Runtime with a different key via + // `create_or_merge`. + let mut update = Attrs::new(); + update[MESSAGE_ACK_EVERY_N_MESSAGES] = 123; + create_or_merge(Source::Runtime, update); + + // Both keys should now be visible from Runtime. + assert_eq!(get(MESSAGE_TTL_DEFAULT), 10); + assert_eq!(get(MESSAGE_ACK_EVERY_N_MESSAGES), 123); + } + + #[test] + fn test_create_or_merge_runtime_creates_layer_if_missing() { + let _lock = lock(); + reset_to_defaults(); + + let mut rt = Attrs::new(); + rt[MESSAGE_TTL_DEFAULT] = 42; + create_or_merge(Source::Runtime, rt); + + assert_eq!(get(MESSAGE_TTL_DEFAULT), 42); + } } diff --git a/hyperactor_mesh/src/bootstrap.rs b/hyperactor_mesh/src/bootstrap.rs index df8f45383..87f1f07de 100644 --- a/hyperactor_mesh/src/bootstrap.rs +++ b/hyperactor_mesh/src/bootstrap.rs @@ -96,7 +96,7 @@ declare_attrs! { /// piping) or via [`StreamFwder`] when piping is active. @meta(CONFIG = ConfigAttr { env_name: Some("HYPERACTOR_MESH_ENABLE_LOG_FORWARDING".to_string()), - py_name: None, + py_name: Some("enable_log_forwarding".to_string()), }) pub attr MESH_ENABLE_LOG_FORWARDING: bool = false; @@ -121,7 +121,7 @@ declare_attrs! { /// buffer used for peeking—independent of file capture. @meta(CONFIG = ConfigAttr { env_name: Some("HYPERACTOR_MESH_ENABLE_FILE_CAPTURE".to_string()), - py_name: None, + py_name: Some("enable_file_capture".to_string()), }) pub attr MESH_ENABLE_FILE_CAPTURE: bool = false; @@ -130,7 +130,7 @@ declare_attrs! { /// pipes. Default: 100 @meta(CONFIG = ConfigAttr { env_name: Some("HYPERACTOR_MESH_TAIL_LOG_LINES".to_string()), - py_name: None, + py_name: Some("tail_log_lines".to_string()), }) pub attr MESH_TAIL_LOG_LINES: usize = 0; diff --git a/monarch_hyperactor/src/config.rs b/monarch_hyperactor/src/config.rs index f3215c35f..bea80671b 100644 --- a/monarch_hyperactor/src/config.rs +++ b/monarch_hyperactor/src/config.rs @@ -97,7 +97,7 @@ fn set_global_config(key: &'static dyn ErasedKey, value: T let key = key.downcast_ref().expect("cannot fail"); let mut attrs = Attrs::new(); attrs.set(key.clone(), value); - hyperactor::config::global::set(Source::Runtime, attrs); + hyperactor::config::global::create_or_merge(Source::Runtime, attrs); Ok(()) } diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi index ff2802d87..c9dba4c57 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi @@ -25,5 +25,8 @@ def reload_config_from_env() -> None: def configure( default_transport: ChannelTransport = ChannelTransport.Unix, + enable_log_forwarding: bool = False, + enable_file_capture: bool = False, + tail_log_lines: int = 0, ) -> None: ... def get_configuration() -> Dict[str, Any]: ... diff --git a/python/tests/test_config.py b/python/tests/test_config.py index 8d80c5de6..deee1065c 100644 --- a/python/tests/test_config.py +++ b/python/tests/test_config.py @@ -39,3 +39,14 @@ def test_get_set_transport() -> None: def test_nonexistent_config_key() -> None: with pytest.raises(ValueError): configure(does_not_exist=42) # type: ignore + + +def test_get_set_multiple() -> None: + configure(default_transport=ChannelTransport.TcpWithLocalhost) + configure(enable_log_forwarding=True, enable_file_capture=True, tail_log_lines=100) + config = get_configuration() + + assert config["enable_log_forwarding"] + assert config["enable_file_capture"] + assert config["tail_log_lines"] == 100 + assert config["default_transport"] == ChannelTransport.TcpWithLocalhost