From 05353e6417def5ab1942e0fcbdff0fcc44cd8009 Mon Sep 17 00:00:00 2001 From: Mark Sandler Date: Fri, 6 Sep 2024 12:17:03 -0700 Subject: [PATCH] Fixes a small bug in flax.linen.share_scope, where the scopes of children of the module being merged that were created before setup(),were not being updated to point to the new scope, and so they would end up staying under the original tree. PiperOrigin-RevId: 671852299 --- flax/linen/module.py | 25 +++++++++++++++++++++---- tests/linen/linen_module_test.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/flax/linen/module.py b/flax/linen/module.py index 5e8985c51..406912d22 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -3244,7 +3244,7 @@ def share_scope(module: Module, other: Module, /): >>> list(params['DenseLoRA_0'].keys()) ['A', 'B', 'kernel', 'bias'] """ - if module.scope is None: + if module.scope is None or other.scope is None: raise errors.CallShareScopeOnUnboundModuleError() def _is_child_scope(scope: Scope, other: Scope) -> bool: @@ -3256,10 +3256,27 @@ def _is_child_scope(scope: Scope, other: Scope) -> bool: target = target.parent return False - if other.scope is not None and _is_child_scope(module.scope, other.scope): + if _is_child_scope(module.scope, other.scope): # Child is a true child, overwrite its scope - object.__setattr__(other, 'scope', module.scope) + module_to_update = other + new_scope = module.scope else: # Child has its own independent scope, overwrite # parent scope, so that we preserve the sharing - object.__setattr__(module, 'scope', other.scope) + module_to_update = module + new_scope = other.scope + + old_scope = module_to_update.scope + object.__setattr__(module_to_update, 'scope', new_scope) + + # Reattach all the children to the new scope as well. + for m in module_to_update._state.children.values(): + if not isinstance(m, Module): + continue + # Should we go recursively to check if any of the ancestors point to the old + # scope? + if m.scope and m.scope.parent == old_scope: + # Reserve the scope, so that if there is a conflict we can raise an error. + if isinstance(m.scope.name, str): + new_scope.reserve(m.scope.name) + m.scope.parent = new_scope diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index 30bfe993c..fdee1443b 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -3307,6 +3307,34 @@ def __call__(self, x: jax.Array): self.assertIn('A', params['dense_lora']) self.assertIn('B', params['dense_lora']) + def test_external_grandchild_scope_correct(self): + class GrandChild(nn.Module): + @nn.compact + def __call__(self): + return nn.Dense(50)(jnp.zeros(10)) + + class Child(nn.Module): + child: GrandChild + + @nn.compact + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self.child(*args, **kwargs) + + class Parent(nn.Module): + main_child: Child + + def setup(self): + nn.share_scope(self, self.main_child) + + @nn.compact + def __call__(self, *args: Any, **kwargs: Any) -> Any: + nn.Dense(10)(jnp.zeros(10)) + r = self.main_child(*args, **kwargs) + return r + + params = Parent(Child(GrandChild())).init(jax.random.key(0)) + self.assertNotIn('main_child', params['params']) + self.assertIn('child', params['params']) if __name__ == '__main__': absltest.main()