From 9c8a330376aa14115f3eb0a4129a4fcf677b9f76 Mon Sep 17 00:00:00 2001 From: seba-1511 Date: Sun, 30 Aug 2020 15:37:07 -0400 Subject: [PATCH 1/3] Fix clone_module with shared parameters. --- learn2learn/utils.py | 32 +++++++++++++--- tests/unit/algorithms/maml_test.py | 61 ++++++++++++++++++++++++------ tests/unit/utils_test.py | 45 +++++++++++++++++++++- 3 files changed, 121 insertions(+), 17 deletions(-) diff --git a/learn2learn/utils.py b/learn2learn/utils.py index b5ab8dbf..f6a5e0a8 100644 --- a/learn2learn/utils.py +++ b/learn2learn/utils.py @@ -48,7 +48,7 @@ def clone_parameters(param_list): return [p.clone() for p in param_list] -def clone_module(module): +def clone_module(module, memo=None): """ [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/utils.py) @@ -91,6 +91,12 @@ def clone_module(module): # clone = recursive_shallow_copy(model) # clone._apply(lambda t: t.clone()) + if memo is None: + # Maps original data_ptr to the cloned tensor. + # Useful when a Module uses parameters from another Module; see: + # https://github.com/learnables/learn2learn/issues/174 + memo = {} + # First, create a copy of the module. # Adapted from: # https://github.com/pytorch/pytorch/blob/65bad41cbec096aa767b3752843eddebf845726f/torch/nn/modules/module.py#L1171 @@ -106,20 +112,36 @@ def clone_module(module): if hasattr(clone, '_parameters'): for param_key in module._parameters: if module._parameters[param_key] is not None: - cloned = module._parameters[param_key].clone() - clone._parameters[param_key] = cloned + param = module._parameters[param_key] + param_ptr = param.data_ptr + if param_ptr in memo: + clone._parameters[param_key] = memo[param_ptr] + else: + cloned = param.clone() + clone._parameters[param_key] = cloned + memo[param_ptr] = cloned # Third, handle the buffers if necessary if hasattr(clone, '_buffers'): for buffer_key in module._buffers: if clone._buffers[buffer_key] is not None and \ clone._buffers[buffer_key].requires_grad: - clone._buffers[buffer_key] = module._buffers[buffer_key].clone() + buff = module._buffers[buffer_key] + buff_ptr = buff.data_ptr + if buff_ptr in memo: + clone._buffers[buffer_key] = memo[buff_ptr] + else: + cloned = buff.clone() + clone._buffers[buffer_key] = cloned + memo[param_ptr] = cloned # Then, recurse for each submodule if hasattr(clone, '_modules'): for module_key in clone._modules: - clone._modules[module_key] = clone_module(module._modules[module_key]) + clone._modules[module_key] = clone_module( + module._modules[module_key], + memo=memo, + ) # Finally, rebuild the flattened parameters for RNNs # See this issue for more details: diff --git a/tests/unit/algorithms/maml_test.py b/tests/unit/algorithms/maml_test.py index 2601e75d..48e4d329 100644 --- a/tests/unit/algorithms/maml_test.py +++ b/tests/unit/algorithms/maml_test.py @@ -18,12 +18,14 @@ def close(x, y): class TestMAMLAlgorithm(unittest.TestCase): def setUp(self): - self.model = torch.nn.Sequential(torch.nn.Linear(INPUT_SIZE, HIDDEN_SIZE), - torch.nn.ReLU(), - torch.nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE), - torch.nn.Sigmoid(), - torch.nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE), - torch.nn.Softmax()) + self.model = torch.nn.Sequential( + torch.nn.Linear(INPUT_SIZE, HIDDEN_SIZE), + torch.nn.ReLU(), + torch.nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE), + torch.nn.Sigmoid(), + torch.nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE), + torch.nn.Softmax(), + ) self.model.register_buffer('dummy_buf', torch.zeros(1, 2, 3, 4)) @@ -101,7 +103,11 @@ def test_allow_nograd(self): try: # Check that without allow_nograd, adaptation fails clone.adapt(loss) - self.assertTrue(False, 'adaptation successful despite requires_grad=False') # Check that execution never gets here + # Check that execution never gets here + self.assertTrue( + False, + 'adaptation successful despite requires_grad=False', + ) except: # Check that with allow_nograd, adaptation succeeds clone.adapt(loss, allow_nograd=True) @@ -112,10 +118,12 @@ def test_allow_nograd(self): if p.requires_grad: self.assertTrue(p.grad is not None) - maml = l2l.algorithms.MAML(self.model, - lr=INNER_LR, - first_order=False, - allow_nograd=True) + maml = l2l.algorithms.MAML( + self.model, + lr=INNER_LR, + first_order=False, + allow_nograd=True, + ) clone = maml.clone() loss = sum([p.norm(p=2) for p in clone.parameters()]) # Check that without allow_nograd, adaptation succeeds thanks to init. @@ -123,6 +131,37 @@ def test_allow_nograd(self): clone.adapt(loss) self.assertTrue(close(orig_weight, self.model[2].weight)) + def test_module_shared_params(self): + + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + cnn = [ + torch.nn.Conv2d(3, 32, 3, 2, 1), + torch.nn.ReLU(), + torch.nn.Conv2d(32, 32, 3, 2, 1), + torch.nn.ReLU(), + torch.nn.Conv2d(32, 32, 3, 2, 1), + torch.nn.ReLU(), + ] + self.seq = torch.nn.Sequential(*cnn) + self.head = torch.nn.Sequential(*[ + torch.nn.Conv2d(32, 32, 3, 2, 1), + torch.nn.ReLU(), + torch.nn.Conv2d(32, 100, 3, 2, 1)] + ) + self.net = torch.nn.Sequential(self.seq, self.head) + + def forward(self, x): + return self.net(x) + + module = TestModule() + maml = l2l.algorithms.MAML(module, lr=0.1) + clone = maml.clone() + loss = sum(p.norm(p=2) for p in clone.parameters()) + clone.adapt(loss) + loss = sum(p.norm(p=2) for p in clone.parameters()) + loss.backward() if __name__ == '__main__': diff --git a/tests/unit/utils_test.py b/tests/unit/utils_test.py index b65f9a3d..e3946a53 100644 --- a/tests/unit/utils_test.py +++ b/tests/unit/utils_test.py @@ -13,6 +13,11 @@ def ref_clone_module(module): each forward call. See this issue for more details: https://github.com/learnables/learn2learn/issues/139 + + Note: This implementation also does not work for Modules that re-use + parameters from another Module. + See this issue for more details: + https://github.com/learnables/learn2learn/issues/174 """ # First, create a copy of the module. clone = copy.deepcopy(module) @@ -191,10 +196,48 @@ def test_rnn_clone(self): # Ensure we did better self.assertTrue(first_loss > second_loss) + def test_module_clone_shared_params(self): + # Tests proper use of memo parameter + + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + cnn = [ + torch.nn.Conv2d(3, 32, 3, 2, 1), + torch.nn.ReLU(), + torch.nn.Conv2d(32, 32, 3, 2, 1), + torch.nn.ReLU(), + torch.nn.Conv2d(32, 32, 3, 2, 1), + torch.nn.ReLU(), + ] + self.seq = torch.nn.Sequential(*cnn) + self.head = torch.nn.Sequential(*[ + torch.nn.Conv2d(32, 32, 3, 2, 1), + torch.nn.ReLU(), + torch.nn.Conv2d(32, 100, 3, 2, 1)] + ) + self.net = torch.nn.Sequential(self.seq, self.head) + + def forward(self, x): + return self.net(x) + + original = TestModule() + clone = l2l.clone_module(original) + self.assertTrue( + len(list(clone.parameters())) == len(list(original.parameters())), + 'clone and original do not have same number of parameters.', + ) + + orig_params = [p.data_ptr() for p in original.parameters()] + duplicates = [p.data_ptr() in orig_params for p in clone.parameters()] + self.assertTrue(not any(duplicates), 'clone() forgot some parameters.') def test_module_detach(self): original_output = self.model(self.input) - original_loss = self.loss_func(original_output, torch.tensor([[0., 0.]])) + original_loss = self.loss_func( + original_output, + torch.tensor([[0., 0.]]) + ) original_gradients = torch.autograd.grad(original_loss, self.model.parameters(), From 5a8c49aed277a03d46e3e7dc9d4e6ebf0767acad Mon Sep 17 00:00:00 2001 From: seba-1511 Date: Sun, 30 Aug 2020 12:42:13 -0700 Subject: [PATCH 2/3] Add _notravis for benchmarks too. --- .../vision/{benchmarks_test.py => benchmarks_test_notravis.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/unit/vision/{benchmarks_test.py => benchmarks_test_notravis.py} (100%) diff --git a/tests/unit/vision/benchmarks_test.py b/tests/unit/vision/benchmarks_test_notravis.py similarity index 100% rename from tests/unit/vision/benchmarks_test.py rename to tests/unit/vision/benchmarks_test_notravis.py From ef3bb27e69daacab59a3e270e33d63f4aa473e88 Mon Sep 17 00:00:00 2001 From: seba-1511 Date: Sun, 30 Aug 2020 15:47:49 -0400 Subject: [PATCH 3/3] Update CHANGELOG. --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cc051ee..8a518aec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +* Fix `clone_module` for Modules whose submodules share parameters. + ## v0.1.2