# TorchOpt for Zero-order Differentiation

[<img align="left" src="https://colab.research.google.com/assets/colab-badge.svg">](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/5_Implicit_Differentiation.ipynb)

By treating the solution $\phi^{\star}$ as an implicit function of $\theta$, the idea of implicit differentiation is to directly get analytical best-response derivatives $\partial \phi^{\star}(\theta)/ \partial \theta$ by implicit function theorem. This is suitable for algorithms when the inner-level optimal solution is achieved ${\left. \frac{\partial F (\phi, \theta)}{\partial \phi} \right\rvert}_{\phi = \phi^{\star}} = 0$ or reaches some stationary conditions $F (\phi^{\star}, \theta) = 0$, such as [iMAML](https://arxiv.org/abs/1909.04630) and [DEQ](https://arxiv.org/abs/1909.01377).

In this tutorial, we will introduce how TorchOpt can be used to conduct implicit differentiation.

In [1]:
import functorch
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchopt

## 1. Functional API

The basic functional API is `torchopt.diff.implicit.custom_root`, which is used as the decorator for the forward process implicit gradient procedures. Users are required to implement the stationary conditions for the inner-loop process, which will be used as the input of custom_root decorator. We show the pseudo code in the following part.

```python
# Functional API for implicit gradient
def stationary(params, meta_params, data):
    # stationary condition construction
    return stationary condition

# Decorator that wraps the function
# Optionally specify the linear solver (conjugate gradient or Neumann series)
@torchopt.diff.implicit.custom_root(stationary, solve=linear_solver)
def solve(params, meta_params, data):
    # Forward optimization process for params
    return optimal_params

# Define params, meta_params and get data
params, meta_prams, data = ..., ..., ...
optimal_params = solve(params, meta_params, data)
loss = outer_loss(optimal_params)

meta_grads = torch.autograd.grad(loss, meta_params)
```

Here we use the example of [IMAML](https://arxiv.org/abs/1909.04630) as a real example. For IMAML, the inner-loop objective is described by the following equation.

$$
{\mathcal{Alg}}^{\star} \left( \boldsymbol{\theta}, \mathcal{D}_{i}^{\text{tr}} \right) = \underset{\phi'}{\operatorname{\arg \min}} ~ G \left( \boldsymbol{\phi}', \boldsymbol{\theta} \right) \triangleq \mathcal{L} \left( \boldsymbol{\phi}', \mathcal{D}_{i}^{\text{tr}} \right) + \frac{\lambda}{2} {\left\| \boldsymbol{\phi}' - \boldsymbol{\theta} \right\|}^{2}
$$

According to this function, we can define the forward function `inner_solver`, where we solve this equation based on sufficient gradient descents. For such inner-loop process, the optimality condition is that the gradient w.r.t inner-loop parameter is $0$.

$$
{\left. \nabla_{\boldsymbol{\phi}'} G \left( \boldsymbol{\phi}', \boldsymbol{\theta} \right) \right\rvert}_{\boldsymbol{\phi}' = \boldsymbol{\phi}^{\star}} = 0
$$

Thus we can define the optimality function by defining `imaml_objective` and make it first-order gradient w.r.t the inner-loop parameter as $0$. We achieve so by calling out `functorch.grad(imaml_objective, argnums=0)`. Finally, the forward function is decorated by the `@torchopt.diff.implicit.custom_root` decorator and the optimality condition we define.

In [2]:
sigma = 0.01
fmodel, params = functorch.make_functional(torch.nn.Linear(32, 1))
x = torch.randn(64, 32)
y = torch.randn(64)
distribution = torch.distributions.normal.Normal(loc=0, scale=1)


@torchopt.diff.zero_order.zero_order(distribution, sigma=0.001)
def forward_process(params, f, x, y):
    y_pred = f(params, x)
    loss = torch.mean((y - y_pred) ** 2)
    return loss


out = forward_process(params, fmodel, x, y)

torch.random.manual_seed(0)
torch.autograd.grad(out, params)


(
│   tensor([[-1234.2908, -1263.3658,  -274.7166,  -475.6739,   930.4658,   758.6697,
│   │     -346.4539, -2318.9763,   353.3193, -1385.0305,   383.6967,   337.8162,
│   │      131.3857,  1356.8802,  1224.3552,  -271.0982, -1482.9535, -1859.2985,
│   │      621.2355,   869.9462,   656.5251, -1704.8958,  -374.2433,  2031.5043,
│   │      822.4545,  -641.8979,  -190.0999,   201.1522,  1523.2024,  1739.1442,
│   │     1037.4543,  -924.9473]]),
│   tensor([-620.8703])
)


In the next step, we consider a specific case for one layer neural network to fit the linear data.

In [3]:
def sample(shape):
    x = torch.distributions.normal.Normal(loc=0, scale=1)
    return x.sample(shape)


@torchopt.diff.zero_order.zero_order(sample, sigma=0.001)
def forward_process(params, f, x, y):
    y_pred = f(params, x)
    loss = torch.mean((y - y_pred) ** 2)
    return loss


out = forward_process(params, fmodel, x, y)

torch.random.manual_seed(0)
print(torch.autograd.grad(out, params))

In [4]:
torch.manual_seed(0)

x = torch.randn(20, 4)
w = torch.randn(4, 1)
b = torch.randn(1)
y = x @ w + b + 0.5 * torch.randn(20, 1)

We instantiate an one layer neural network, where the weights and bias are initialized with constant.

In [5]:
class Net(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(dim, 1, bias=True)
        nn.init.ones_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)

    def forward(self, x):
        return self.fc(x)


model = Net(4)
fmodel, meta_params = functorch.make_functional(model)
data = (x, y, fmodel)

# Clone function for parameters
def clone(params):
    cloned = []
    for item in params:
        if isinstance(item, torch.Tensor):
            cloned.append(item.clone().detach_().requires_grad_(True))
        else:
            cloned.append(item)
    return tuple(cloned)

We take the forward process by calling out the forward function, then we pass the optimal params into the outer-loop loss function.

In [6]:
optimal_params = inner_solver(clone(meta_params), meta_params, data)

outer_loss = fmodel(optimal_params, x).mean()

Finally, we can get the meta gradient as shown below.

In [7]:
torch.autograd.grad(outer_loss, meta_params)

Also we can switch to the Neumann Series inversion linear solver.

In [8]:
optimal_params = inner_solver_inv_ns(clone(meta_params), meta_params, data)
outer_loss = fmodel(optimal_params, x).mean()
torch.autograd.grad(outer_loss, meta_params)

## 2. OOP API

The basic OOP class is the class `ImplicitMetaGradientModule`. We make the network as an `nn.Module` following a classical PyTorch style. Users need to define the stationary condition/objective function and the inner-loop solve function to enable implicit gradient computation. We show the pseudo code in the following part.

```python
from torchopt.nn import ImplicitMetaGradientModule

# Inherited from the class ImplicitMetaGradientModule
# Optionally specify the linear solver (conjugate gradient or Neumann series)
class InnerNet(ImplicitMetaGradientModule, linear_solve=linear_solver):
    def __init__(self, meta_module):
        ...

    def forward(self, batch):
        # Forward process
        ...

    def optimality(self, batch, labels):
        # Stationary condition construction for calculating implicit gradient
        # NOTE: If this method is not implemented, it will be automatically derived from the
        # gradient of the `objective` function.
        ...

    def objective(self, batch, labels):
        # Define the inner-loop optimization objective
        # NOTE: This method is optional if method `optimality` is implemented.
        ...

    def solve(self, batch, labels):
        # Conduct the inner-loop optimization
        ...
        return self  # optimized module

# Get meta_params and data
meta_params, data = ..., ...
inner_net = InnerNet()

# Solve for inner-loop process related with the meta-parameters
optimal_inner_net = inner_net.solve(meta_params, *data)

# Get outer-loss and solve for meta-gradient
loss = outer_loss(optimal_inner_net)
meta_grad = torch.autograd.grad(loss, meta_params)
```

In [9]:
class InnerNet(
    torchopt.nn.ImplicitMetaGradientModule,
    linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),
):
    def __init__(self, meta_net, n_inner_iter, reg_param):
        super().__init__()
        self.meta_net = meta_net
        # Get a deepcopy
        self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True)
        self.n_inner_iter = n_inner_iter
        self.reg_param = reg_param

    def forward(self, x):
        return self.net(x)

    def objective(self, x, y):
        # We do not implement the optimality conditions, so it will be automatically derived from
        # the gradient of the `objective` function.
        y_pred = self(x)
        loss = F.mse_loss(y_pred, y)
        regularization_loss = 0
        for p1, p2 in zip(
            self.parameters(),  # parameters of `self.net`
            self.meta_parameters(),  # parameters of `self.meta_net`
        ):
            regularization_loss += (
                0.5 * self.reg_param * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))
            )
        return loss + regularization_loss

    def solve(self, x, y):
        params = tuple(self.parameters())
        inner_optim = torchopt.SGD(params, lr=2e-2)
        with torch.enable_grad():
            # Temporarily enable gradient computation for conducting the optimization
            for _ in range(self.n_inner_iter):
                loss = self.objective(x, y)
                inner_optim.zero_grad()
                # NOTE: The parameter inputs should be explicitly specified in `backward` function
                # as argument `inputs`. Otherwise, if not provided, the gradient is accumulated into
                # all the leaf Tensors (including the meta-parameters) that were used to compute the
                # objective output. Alternatively, please use `torch.autograd.grad` instead.
                loss.backward(inputs=params)  # backward pass in inner-loop
                inner_optim.step()  # update inner parameters
        return self


# Initialize the meta network
meta_net = Net(4)
inner_net = InnerNet(meta_net, 100, reg_param=1)

# Solve for inner-loop
optimal_inner_net = inner_net.solve(x, y)
outer_loss = optimal_inner_net(x).mean()

# Derive the meta-gradient
torch.autograd.grad(outer_loss, meta_net.parameters())

(tensor([[-0.0369,  0.0248,  0.0347,  0.0067]]), tensor([0.3156]))
