Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
3 contributors

Users who have contributed to this file

@seba-1511 @praateekmahajan @debajyotidatta
155 lines (115 sloc) 4.9 KB
#!/usr/bin/env python3
from torch.autograd import grad
from learn2learn.algorithms.base_learner import BaseLearner
from learn2learn.utils import clone_module
def maml_update(model, lr, grads=None):
"""
**Description**
Performs a MAML update on model using grads and lr.
The function re-routes the Python object, thus avoiding in-place
operations.
NOTE: The model itself is updated in-place (no deepcopy), but the
parameters' tensors are not.
**Arguments**
* **model** (Module) - The model to update.
* **lr** (float) - The learning rate used to update the model.
* **grads** (list, *optional*, default=None) - A list of gradients for each parameter
of the model. If None, will use the gradients in .grad attributes.
**Example**
~~~python
maml = l2l.algorithms.MAML(Model(), lr=0.1)
model = maml.clone() # The next two lines essentially implement model.adapt(loss)
grads = autograd.grad(loss, model.parameters(), create_graph=True)
maml_update(model, lr=0.1, grads)
~~~
"""
if grads is not None:
params = list(model.parameters())
if not len(grads) == len(list(params)):
msg = 'WARNING:maml_update(): Parameters and gradients have different length. ('
msg += str(len(params)) + ' vs ' + str(len(grads)) + ')'
print(msg)
for p, g in zip(params, grads):
p.grad = g
# Update the params
for param_key in model._parameters:
p = model._parameters[param_key]
if p is not None and p.grad is not None:
model._parameters[param_key] = p - lr * p.grad
# Second, handle the buffers if necessary
for buffer_key in model._buffers:
buff = model._buffers[buffer_key]
if buff is not None and buff.grad is not None:
model._buffers[buffer_key] = buff - lr * buff.grad
# Then, recurse for each submodule
for module_key in model._modules:
model._modules[module_key] = maml_update(model._modules[module_key],
lr=lr,
grads=None)
return model
class MAML(BaseLearner):
"""
[[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/maml.py)
**Description**
High-level implementation of *Model-Agnostic Meta-Learning*.
This class wraps an arbitrary nn.Module and augments it with `clone()` and `adapt`
methods.
For the first-order version of MAML (i.e. FOMAML), set the `first_order` flag to `True`
upon initialization.
**Arguments**
* **model** (Module) - Module to be wrapped.
* **lr** (float) - Fast adaptation learning rate.
* **first_order** (bool, *optional*, default=False) - Whether to use the
**References**
1. Finn et al. 2017. “Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.”
**Example**
~~~python
linear = l2l.algorithms.MAML(nn.Linear(20, 10), lr=0.01)
clone = linear.clone()
error = loss(clone(X), y)
clone.adapt(error)
error = loss(clone(X), y)
error.backward()
~~~
"""
def __init__(self, model, lr, first_order=False):
super(MAML, self).__init__()
self.module = model
self.lr = lr
self.first_order = first_order
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
def adapt(self, loss, first_order=None):
"""
**Description**
Updates the clone parameters in place using the MAML update.
**Arguments**
* **loss** (Tensor) - Loss to minimize upon update.
* **first_order** (bool, *optional*, default=None) - Whether to use first- or
second-order updates. Defaults to self.first_order.
"""
if first_order is None:
first_order = self.first_order
second_order = not first_order
gradients = grad(loss,
self.module.parameters(),
retain_graph=second_order,
create_graph=second_order)
self.module = maml_update(self.module, self.lr, gradients)
def clone(self, first_order=None):
"""
**Description**
Returns a `MAML`-wrapped copy of the module whose parameters and buffers
are `torch.clone`d from the original module.
This implies that back-propagating losses on the cloned module will
populate the buffers of the original module.
For more information, refer to learn2learn.clone_module().
**Arguments**
* **first_order** (bool, *optional*, default=None) - Whether the clone uses first-
or second-order updates. Defaults to self.first_order.
"""
if first_order is None:
first_order = self.first_order
return MAML(clone_module(self.module),
lr=self.lr,
first_order=first_order)
You can’t perform that action at this time.