Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3244,7 +3244,7 @@ def share_scope(module: Module, other: Module, /):
>>> list(params['DenseLoRA_0'].keys())
['A', 'B', 'kernel', 'bias']
"""
if module.scope is None:
if module.scope is None or other.scope is None:
raise errors.CallShareScopeOnUnboundModuleError()

def _is_child_scope(scope: Scope, other: Scope) -> bool:
Expand All @@ -3256,10 +3256,27 @@ def _is_child_scope(scope: Scope, other: Scope) -> bool:
target = target.parent
return False

if other.scope is not None and _is_child_scope(module.scope, other.scope):
if _is_child_scope(module.scope, other.scope):
# Child is a true child, overwrite its scope
object.__setattr__(other, 'scope', module.scope)
module_to_update = other
new_scope = module.scope
else:
# Child has its own independent scope, overwrite
# parent scope, so that we preserve the sharing
object.__setattr__(module, 'scope', other.scope)
module_to_update = module
new_scope = other.scope

old_scope = module_to_update.scope
object.__setattr__(module_to_update, 'scope', new_scope)

# Reattach all the children to the new scope as well.
for m in module_to_update._state.children.values():
if not isinstance(m, Module):
continue
# Should we go recursively to check if any of the ancestors point to the old
# scope?
if m.scope and m.scope.parent == old_scope:
# Reserve the scope, so that if there is a conflict we can raise an error.
if isinstance(m.scope.name, str):
new_scope.reserve(m.scope.name)
m.scope.parent = new_scope
28 changes: 28 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3307,6 +3307,34 @@ def __call__(self, x: jax.Array):
self.assertIn('A', params['dense_lora'])
self.assertIn('B', params['dense_lora'])

def test_external_grandchild_scope_correct(self):
class GrandChild(nn.Module):
@nn.compact
def __call__(self):
return nn.Dense(50)(jnp.zeros(10))

class Child(nn.Module):
child: GrandChild

@nn.compact
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.child(*args, **kwargs)

class Parent(nn.Module):
main_child: Child

def setup(self):
nn.share_scope(self, self.main_child)

@nn.compact
def __call__(self, *args: Any, **kwargs: Any) -> Any:
nn.Dense(10)(jnp.zeros(10))
r = self.main_child(*args, **kwargs)
return r

params = Parent(Child(GrandChild())).init(jax.random.key(0))
self.assertNotIn('main_child', params['params'])
self.assertIn('child', params['params'])

if __name__ == '__main__':
absltest.main()
Loading