Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

* Fix `clone_module` for Modules whose submodules share parameters.


## v0.1.2

Expand Down
32 changes: 27 additions & 5 deletions learn2learn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def clone_parameters(param_list):
return [p.clone() for p in param_list]


def clone_module(module):
def clone_module(module, memo=None):
"""

[[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/utils.py)
Expand Down Expand Up @@ -91,6 +91,12 @@ def clone_module(module):
# clone = recursive_shallow_copy(model)
# clone._apply(lambda t: t.clone())

if memo is None:
# Maps original data_ptr to the cloned tensor.
# Useful when a Module uses parameters from another Module; see:
# https://github.com/learnables/learn2learn/issues/174
memo = {}

# First, create a copy of the module.
# Adapted from:
# https://github.com/pytorch/pytorch/blob/65bad41cbec096aa767b3752843eddebf845726f/torch/nn/modules/module.py#L1171
Expand All @@ -106,20 +112,36 @@ def clone_module(module):
if hasattr(clone, '_parameters'):
for param_key in module._parameters:
if module._parameters[param_key] is not None:
cloned = module._parameters[param_key].clone()
clone._parameters[param_key] = cloned
param = module._parameters[param_key]
param_ptr = param.data_ptr
if param_ptr in memo:
clone._parameters[param_key] = memo[param_ptr]
else:
cloned = param.clone()
clone._parameters[param_key] = cloned
memo[param_ptr] = cloned

# Third, handle the buffers if necessary
if hasattr(clone, '_buffers'):
for buffer_key in module._buffers:
if clone._buffers[buffer_key] is not None and \
clone._buffers[buffer_key].requires_grad:
clone._buffers[buffer_key] = module._buffers[buffer_key].clone()
buff = module._buffers[buffer_key]
buff_ptr = buff.data_ptr
if buff_ptr in memo:
clone._buffers[buffer_key] = memo[buff_ptr]
else:
cloned = buff.clone()
clone._buffers[buffer_key] = cloned
memo[param_ptr] = cloned

# Then, recurse for each submodule
if hasattr(clone, '_modules'):
for module_key in clone._modules:
clone._modules[module_key] = clone_module(module._modules[module_key])
clone._modules[module_key] = clone_module(
module._modules[module_key],
memo=memo,
)

# Finally, rebuild the flattened parameters for RNNs
# See this issue for more details:
Expand Down
61 changes: 50 additions & 11 deletions tests/unit/algorithms/maml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ def close(x, y):
class TestMAMLAlgorithm(unittest.TestCase):

def setUp(self):
self.model = torch.nn.Sequential(torch.nn.Linear(INPUT_SIZE, HIDDEN_SIZE),
torch.nn.ReLU(),
torch.nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
torch.nn.Sigmoid(),
torch.nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
torch.nn.Softmax())
self.model = torch.nn.Sequential(
torch.nn.Linear(INPUT_SIZE, HIDDEN_SIZE),
torch.nn.ReLU(),
torch.nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
torch.nn.Sigmoid(),
torch.nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
torch.nn.Softmax(),
)

self.model.register_buffer('dummy_buf', torch.zeros(1, 2, 3, 4))

Expand Down Expand Up @@ -101,7 +103,11 @@ def test_allow_nograd(self):
try:
# Check that without allow_nograd, adaptation fails
clone.adapt(loss)
self.assertTrue(False, 'adaptation successful despite requires_grad=False') # Check that execution never gets here
# Check that execution never gets here
self.assertTrue(
False,
'adaptation successful despite requires_grad=False',
)
except:
# Check that with allow_nograd, adaptation succeeds
clone.adapt(loss, allow_nograd=True)
Expand All @@ -112,17 +118,50 @@ def test_allow_nograd(self):
if p.requires_grad:
self.assertTrue(p.grad is not None)

maml = l2l.algorithms.MAML(self.model,
lr=INNER_LR,
first_order=False,
allow_nograd=True)
maml = l2l.algorithms.MAML(
self.model,
lr=INNER_LR,
first_order=False,
allow_nograd=True,
)
clone = maml.clone()
loss = sum([p.norm(p=2) for p in clone.parameters()])
# Check that without allow_nograd, adaptation succeeds thanks to init.
orig_weight = self.model[2].weight.clone().detach()
clone.adapt(loss)
self.assertTrue(close(orig_weight, self.model[2].weight))

def test_module_shared_params(self):

class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
cnn = [
torch.nn.Conv2d(3, 32, 3, 2, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(32, 32, 3, 2, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(32, 32, 3, 2, 1),
torch.nn.ReLU(),
]
self.seq = torch.nn.Sequential(*cnn)
self.head = torch.nn.Sequential(*[
torch.nn.Conv2d(32, 32, 3, 2, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(32, 100, 3, 2, 1)]
)
self.net = torch.nn.Sequential(self.seq, self.head)

def forward(self, x):
return self.net(x)

module = TestModule()
maml = l2l.algorithms.MAML(module, lr=0.1)
clone = maml.clone()
loss = sum(p.norm(p=2) for p in clone.parameters())
clone.adapt(loss)
loss = sum(p.norm(p=2) for p in clone.parameters())
loss.backward()


if __name__ == '__main__':
Expand Down
45 changes: 44 additions & 1 deletion tests/unit/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ def ref_clone_module(module):
each forward call.
See this issue for more details:
https://github.com/learnables/learn2learn/issues/139

Note: This implementation also does not work for Modules that re-use
parameters from another Module.
See this issue for more details:
https://github.com/learnables/learn2learn/issues/174
"""
# First, create a copy of the module.
clone = copy.deepcopy(module)
Expand Down Expand Up @@ -191,10 +196,48 @@ def test_rnn_clone(self):
# Ensure we did better
self.assertTrue(first_loss > second_loss)

def test_module_clone_shared_params(self):
# Tests proper use of memo parameter

class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
cnn = [
torch.nn.Conv2d(3, 32, 3, 2, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(32, 32, 3, 2, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(32, 32, 3, 2, 1),
torch.nn.ReLU(),
]
self.seq = torch.nn.Sequential(*cnn)
self.head = torch.nn.Sequential(*[
torch.nn.Conv2d(32, 32, 3, 2, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(32, 100, 3, 2, 1)]
)
self.net = torch.nn.Sequential(self.seq, self.head)

def forward(self, x):
return self.net(x)

original = TestModule()
clone = l2l.clone_module(original)
self.assertTrue(
len(list(clone.parameters())) == len(list(original.parameters())),
'clone and original do not have same number of parameters.',
)

orig_params = [p.data_ptr() for p in original.parameters()]
duplicates = [p.data_ptr() in orig_params for p in clone.parameters()]
self.assertTrue(not any(duplicates), 'clone() forgot some parameters.')

def test_module_detach(self):
original_output = self.model(self.input)
original_loss = self.loss_func(original_output, torch.tensor([[0., 0.]]))
original_loss = self.loss_func(
original_output,
torch.tensor([[0., 0.]])
)

original_gradients = torch.autograd.grad(original_loss,
self.model.parameters(),
Expand Down