Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
185 lines (129 sloc) 5.6 KB
#!/usr/bin/env python3
import copy
import torch
def magic_box(x):
"""
[[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/utils.py)
**Description**
The magic box operator, which evaluates to 1 but whose gradient is \\(dx\\):
$$\\boxdot (x) = \\exp(x - \\bot(x))$$
where \\(\\bot\\) is the stop-gradient (or detach) operator.
This operator is useful when computing higher-order derivatives of stochastic graphs.
For more informations, please refer to the DiCE paper. (Reference 1)
**References**
1. Foerster et al. 2018. “DiCE: The Infinitely Differentiable Monte-Carlo Estimator.” arXiv.
**Arguments**
* **x** (Variable) - Variable to transform.
**Return**
* (Variable) - Tensor of 1, but it's gradient is the gradient of x.
**Example**
~~~python
loss = (magic_box(cum_log_probs) * advantages).mean() # loss is the mean advantage
loss.backward()
~~~
"""
if isinstance(x, torch.Tensor):
return torch.exp(x - x.detach())
return x
def clone_parameters(param_list):
return [p.clone() for p in param_list]
def clone_module(module):
"""
[[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/utils.py)
**Description**
Creates a copy of a module, whose parameters/buffers/submodules
are created using PyTorch's torch.clone().
This implies that the computational graph is kept, and you can compute
the derivatives of the new modules' parameters w.r.t the original
parameters.
**Arguments**
* **module** (Module) - Module to be cloned.
**Return**
* (Module) - The cloned module.
**Example**
~~~python
net = nn.Sequential(Linear(20, 10), nn.ReLU(), nn.Linear(10, 2))
clone = clone_module(net)
error = loss(clone(X), y)
error.backward() # Gradients are back-propagate all the way to net.
~~~
"""
# NOTE: This function might break in future versions of PyTorch.
# TODO: This function might require that module.forward()
# was called in order to work properly, if forward() instanciates
# new variables.
# TODO: deepcopy is expensive. We can probably get away with a shallowcopy.
# However, since shallow copy does not recurse, we need to write a
# recursive version of shallow copy.
# NOTE: This can probably be implemented more cleanly with
# clone = recursive_shallow_copy(model)
# clone._apply(lambda t: t.clone())
clone = copy.deepcopy(module)
# First, re-write all parameters
for param_key in module._parameters:
if module._parameters[param_key] is not None:
cloned = module._parameters[param_key].clone()
clone._parameters[param_key] = cloned
# Second, handle the buffers if necessary
for buffer_key in module._buffers:
if clone._buffers[buffer_key] is not None and \
clone._buffers[buffer_key].requires_grad:
clone._buffers[buffer_key] = module._buffers[buffer_key].clone()
# Then, recurse for each submodule
for module_key in clone._modules:
clone._modules[module_key] = clone_module(module._modules[module_key])
return clone
def detach_module(module):
"""
[[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/utils.py)
**Description**
Detaches all parameters/buffers of a previously cloned module from its computational graph.
Note: detach works in-place, so it does not return a copy.
**Arguments**
* **module** (Module) - Module to be detached.
**Example**
~~~python
net = nn.Sequential(Linear(20, 10), nn.ReLU(), nn.Linear(10, 2))
clone = clone_module(net)
detach_module(clone)
error = loss(clone(X), y)
error.backward() # Gradients are back-propagate on clone, not net.
~~~
"""
# First, re-write all parameters
for param_key in module._parameters:
if module._parameters[param_key] is not None:
detached = module._parameters[param_key].detach_()
# Second, handle the buffers if necessary
for buffer_key in module._buffers:
if module._buffers[buffer_key] is not None and \
module._buffers[buffer_key].requires_grad:
module._buffers[buffer_key] = module._buffers[buffer_key].detach_()
# Then, recurse for each submodule
for module_key in module._modules:
detach_module(module._modules[module_key])
def clone_distribution(dist):
# TODO: This function was never tested.
clone = copy.deepcopy(dist)
for param_key in clone.__dict__:
item = clone.__dict__[param_key]
if isinstance(item, th.Tensor):
if item.requires_grad:
clone.__dict__[param_key] = dist.__dict__[param_key].clone()
elif isinstance(item, th.nn.Module):
clone.__dict__[param_key] = clone_module(dist.__dict__[param_key])
elif isinstance(item, th.Distribution):
clone.__dict__[param_key] = clone_distribution(dist.__dict__[param_key])
return clone
def detach_distribution(dist):
# TODO: This function was never tested.
for param_key in dist.__dict__:
item = dist.__dict__[param_key]
if isinstance(item, th.Tensor):
if item.requires_grad:
dist.__dict__[param_key] = dist.__dict__[param_key].detach()
elif isinstance(item, th.nn.Module):
dist.__dict__[param_key] = detach_module(dist.__dict__[param_key])
elif isinstance(item, th.Distribution):
dist.__dict__[param_key] = detach_distribution(dist.__dict__[param_key])
return dist
You can’t perform that action at this time.