Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Support calling submodules directly (#47)
Browse files Browse the repository at this point in the history
* Clean up param_box after calling patched forward.

* Add unit test for directly calling submodule.

* Add support for directly calling submodules.
  • Loading branch information
egrefen committed Apr 24, 2020
1 parent 0c2ed8d commit 41e0eee
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 11 deletions.
51 changes: 40 additions & 11 deletions higher/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,21 @@ def __init__(self, original_params, root) -> None:
)
self._modules: _typing.Dict[str, _MonkeyPatchBase] = _OrderedDict()

@property
def direct_submodule_call(self):
return params_box[0] is None

@property
def is_root(self):
return self._root_ref is None

@property
def root(self):
if self.is_root:
return self
else:
return self._root_ref()

def __setattr__(self, name, value):
def remove_from(*dicts):
for d in dicts:
Expand All @@ -223,13 +238,8 @@ def remove_from(*dicts):
if not self._being_modifed_internally:
# Additional behaviour for when fast weights are being
# directly modified goes here:
if not self._root_ref:
root = self
else:
root = self._root_ref()

old_value = self._parameters[name]
fast_params = root.fast_params[:]
fast_params = self.root.fast_params[:]
if not fast_params:
raise Exception(
"Cannot assign parameters to patched module which "
Expand Down Expand Up @@ -316,7 +326,15 @@ def remove_from(*dicts):

true_forward = type(module).forward

def patched_forward(self, *args, **kwargs):
def patched_forward(self, *args, params=None, **kwargs):
if self.direct_submodule_call:
# If submodule was called directly, run intialisation that happens
# at top level call. If *full set of params* is provided here, it
# will use those. If not, it will fall back on fast weights.
# In the future, we should be able to support passing only the
# submodule (+ children) weights here, but that's not simple.
self.root._refill_params_box(params)

with _modify_internally(self):
for name, param in zip(
self._param_names,
Expand Down Expand Up @@ -387,18 +405,28 @@ def make_functional(
param_mapping = _utils._get_param_mapping(module, [], [])
setattr(fmodule, "_param_mapping", param_mapping)

def _patched_forward(self, *args, **kwargs):
if "params" in kwargs:
params = kwargs.pop('params')
def _refill_params_box(self, params):
if params is not None:
self.fast_params = params # update view on latest fast params
elif self.fast_params is None:
raise ValueError(
"params keyword must be provided if patched module not "
"tracking its own fast parameters"
)

# Copy fast parameters into params_box for use in boxed_forward
params_box[0] = self._expand_params(self.fast_params)
return self.boxed_forward(*args, **kwargs)


def _patched_forward(self, *args, params=None, **kwargs):
self._refill_params_box(params)

output = self.boxed_forward(*args, **kwargs)

# Clean up
params_box[0] = None

return output

def _update_params(self, params):
self.fast_params = params
Expand All @@ -408,6 +436,7 @@ def _update_params(self, params):
setattr(MonkeyPatched, "forward", _patched_forward)
setattr(MonkeyPatched, "parameters", _patched_parameters)
setattr(MonkeyPatched, "update_params", _update_params)
setattr(MonkeyPatched, "_refill_params_box", _refill_params_box)

if encapsulator is not None:
encapsulator(fmodule, module)
Expand Down
17 changes: 17 additions & 0 deletions tests/test_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,23 @@ def testMiniMAMLTestTime(self, _, model_builder):
self.assertIsNone(p.grad)
self.assertIsNone(g)

def testSubModuleDirectCall(self):
"""Check that patched submodules can be called directly."""
class Module(nn.Module):
def __init__(self):
super().__init__()
self.submodule = nn.Linear(3, 4)

def forward(self, inputs):
return self.submodule(inputs)

module = _NestedEnc(nn.Linear(3, 4))
fmodule = higher.monkeypatch(module)

xs = torch.randn(2, 3)
fsubmodule = fmodule.f

self.assertTrue(torch.equal(fmodule(xs), fsubmodule(xs)))

if __name__ == '__main__':
unittest.main()

0 comments on commit 41e0eee

Please sign in to comment.