# TorchOpt for implicit differentiation

[<img align="left" src="https://colab.research.google.com/assets/colab-badge.svg">](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/5_Implicit_Differentiation.ipynb)

By treating the solution $\phi^{\star}$ as an implicit function of $\theta$, the idea of implicit differentiation is to directly get analytical best-response derivatives $\partial \phi^{\star}(\theta)/ \partial \theta$ by implicit function theorem. This is suitable for algorithms when the inner-level optimal solution is achieved ${\left. \frac{\partial F (\phi, \theta)}{\partial \phi} \right\rvert}_{\phi = \phi^{\star}} = 0$ or reaches some stationary conditions $F (\phi^{\star}, \theta) = 0$, such as [IMAML](https://arxiv.org/abs/1909.04630) and [DEQ](https://arxiv.org/abs/1909.01377).

In this tutorial, we will introduce how TorchOpt can be used to conduct implicit differentiation.

In [1]:
import functorch
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchopt

## 1. Functional API

The basic functional API is `torchopt.diff.implicit.custom_root`, which is used as the decorator for the forward process implicit gradient procedures. Users are required to implement the stationary conditions for the inner-loop process, which will be used as the input of custom_root decorator. We show the pseudo code in the following part.

```python
# Functional API for implicit gradient
def stationary(params, meta_params, data):
    # stationary condition construction
    return stationary condition

# Decorator that wraps the function
# Optionally specify the linear solver (conjugate gradient or Neumann series)
@torchopt.diff.implicit.custom_root(stationary, solve=linear_solver)
def solve(params, meta_params, data):
    # Forward optimization process for params
    return optimal_params

# Define params, meta params and get data
params, meta_prams, data = ..., ..., ...
optimal_params = solve(params, meta_params, data)
loss = outer_loss(optimal_params)

meta_grads = torch.autograd.grad(loss, meta_params)
```

Here we use the example of [IMAML](https://arxiv.org/abs/1909.04630) as a real example. For IMAML, the inner-loop objective is described by the following equation.

$$
{\mathcal{Alg}}^{\star} \left( \boldsymbol{\theta}, \mathcal{D}_{i}^{\text{tr}} \right) = \underset{\phi'}{\operatorname{\arg \min}} ~ G \left( \boldsymbol{\phi}', \boldsymbol{\theta} \right) \triangleq \mathcal{L} \left( \boldsymbol{\phi}', \mathcal{D}_{i}^{\text{tr}} \right) + \frac{\lambda}{2} {\left\| \boldsymbol{\phi}' - \boldsymbol{\theta} \right\|}^{2}
$$

According to this function, we can define the forward function `inner_solver`, where we solve this equation based on sufficient gradient descents. For such inner-loop process, the optimality condition is that the gradient w.r.t inner-loop parameter is $0$.

$$
{\left. \nabla_{\boldsymbol{\phi}'} G \left( \boldsymbol{\phi}', \boldsymbol{\theta} \right) \right\rvert}_{\boldsymbol{\phi}' = \boldsymbol{\phi}^{\star}} = 0
$$

Thus we can define the optimality function by defining `imaml_objective` and make it first-order gradient w.r.t the inner-loop parameter as $0$. We achieve so by calling out `functorch.grad(imaml_objective, argnums=0)`. Finally, the forward function is decorated by the `@torchopt.diff.implicit.custom_root` decorator and the optimality condition we define.

In [2]:
# Inner-loop objective function
# The optimality function: grad(imaml_objective)
def imaml_objective(params, meta_params, data):
    x, y, fmodel = data
    y_pred = fmodel(params, x)
    regularization_loss = 0.0
    for p1, p2 in zip(params, meta_params):
        regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))
    loss = F.mse_loss(y_pred, y) + regularization_loss
    return loss


# Optimality Condition is: the gradient w.r.t inner-loop optimal params is 0 (we achieve so by
# specifying argnums=0 in functorch.grad) the argnums=1 specify which meta-parameter we want to
# backpropogate, in this case we want to backpropogate to the initial parameters so we set it as 1.
# You can also set argnums as (1, 2) if you want to backpropogate through multiple meta parameters

# Here we pass argnums=1 to the custom_root. That means we want to compute the gradient of
# optimal_params w.r.t. the 1-indexed argument in inner_solver, i.e., params.
# torchopt.linear_solve.solve_normal_cg specify that we use the conjugate gradient based linear solver
@torchopt.diff.implicit.custom_root(
    functorch.grad(imaml_objective, argnums=0),  # optimality function
    argnums=1,
    solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),
)
def inner_solver(params, meta_params, data):
    # Initial functional optimizer based on TorchOpt
    x, y, fmodel = data
    optimizer = torchopt.sgd(lr=2e-2)
    opt_state = optimizer.init(params)
    with torch.enable_grad():
        # Temporarily enable gradient computation for conducting the optimization
        for i in range(100):
            pred = fmodel(params, x)
            loss = F.mse_loss(pred, y)  # compute loss

            # Compute regularization loss
            regularization_loss = 0.0
            for p1, p2 in zip(params, meta_params):
                regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))
            final_loss = loss + regularization_loss

            grads = torch.autograd.grad(final_loss, params)  # compute gradients
            updates, opt_state = optimizer.update(grads, opt_state, inplace=True)  # get updates
            params = torchopt.apply_updates(params, updates, inplace=True)

    optimal_params = params
    return optimal_params


# torchopt.linear_solve.solve_inv specify that we use the Neumann Series inversion linear solver
@torchopt.diff.implicit.custom_root(
    functorch.grad(imaml_objective, argnums=0),  # optimality function
    argnums=1,
    solve=torchopt.linear_solve.solve_inv(ns=True, maxiter=100, alpha=0.1),
)
def inner_solver_inv_ns(params, meta_params, data):
    # Initial functional optimizer based on TorchOpt
    x, y, fmodel = data
    optimizer = torchopt.sgd(lr=2e-2)
    opt_state = optimizer.init(params)
    with torch.enable_grad():
        # Temporarily enable gradient computation for conducting the optimization
        for i in range(100):
            pred = fmodel(params, x)
            loss = F.mse_loss(pred, y)  # compute loss

            # Compute regularization loss
            regularization_loss = 0.0
            for p1, p2 in zip(params, meta_params):
                regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))
            final_loss = loss + regularization_loss

            grads = torch.autograd.grad(final_loss, params)  # compute gradients
            updates, opt_state = optimizer.update(grads, opt_state, inplace=True)  # get updates
            params = torchopt.apply_updates(params, updates, inplace=True)

    optimal_params = params
    return optimal_params

In the next step, we consider a specific case for one layer neural network to fit the linear data.

In [3]:
torch.manual_seed(0)
x = torch.randn(20, 4)
w = torch.randn(4, 1)
b = torch.randn(1)
y = x @ w + b + 0.5 * torch.randn(20, 1)

We instantiate an one layer neural network, where the weights and bias are initialized with constant.

In [4]:
class Net(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(dim, 1, bias=True)
        nn.init.ones_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)

    def forward(self, x):
        return self.fc(x)


model = Net(4)
fmodel, meta_params = functorch.make_functional(model)
data = (x, y, fmodel)

# Clone function for parameters
def clone(params):
    cloned = []
    for item in params:
        if isinstance(item, torch.Tensor):
            cloned.append(item.clone().detach_().requires_grad_(True))
        else:
            cloned.append(item)
    return tuple(cloned)

We take the forward process by calling out the forward function, then we pass the optimal params into the outer-loop loss function.

In [5]:
optimal_params = inner_solver(clone(meta_params), meta_params, data)

outer_loss = fmodel(optimal_params, x).mean()

Finally, we can get the meta gradient as shown below.

In [6]:
torch.autograd.grad(outer_loss, meta_params)

(tensor([[-0.0369,  0.0248,  0.0347,  0.0067]]), tensor([0.3156]))


Also we can switch to the Neumann Series inversion linear solver.

In [7]:
optimal_params = inner_solver_inv_ns(clone(meta_params), meta_params, data)
outer_loss = fmodel(optimal_params, x).mean()
torch.autograd.grad(outer_loss, meta_params)

(tensor([[-0.0369,  0.0248,  0.0347,  0.0067]]), tensor([0.3156]))


## 2. OOP API

The basic OOP class is the class `ImplicitMetaGradientModule`. We make the network as an `nn.Module` following a classical PyTorch style. Users need to define the stationary condition/objective function and the inner-loop solve function to enable implicit gradient computation. We show the pseudo code in the following part.

```python
from torchopt.diff.implicit import ImplicitMetaGradientModule

# Inherited from the class ImplicitMetaGradientModule
# Optionally specify the linear solver (conjugate gradient or Neumann series)
class InnerNet(ImplicitMetaGradientModule, linear_solve=linear_solver):
    def __init__(self, meta_module):
        ...

    def forward(self, batch):
        # Forward process
        ...

    def optimality(self, batch, labels):
        # Stationary condition construction for calculating implicit gradient
        # NOTE: If this method is not implemented, it will be automatically derived from the
        # gradient of the `objective` function.
        ...

    def objective(self, batch, labels):
        # Define the inner-loop optimization objective
        # NOTE: This method is optional if method `optimality` is implemented.
        ...

    def solve(self, batch, labels):
        # Conduct the inner-loop optimization
        ...
        return self  # optimized module

# Get meta_params and data
meta_params, data = ..., ...
inner_net = InnerNet()

# Solve for inner-loop process related with the meta parameters
optimal_inner_net = inner_net.solve(meta_params, *data)

# Get outer-loss and solve for meta-gradient
loss = outer_loss(optimal_inner_net)
meta_grad = torch.autograd.grad(loss, meta_params)
```

For a custom network, users are required to define the meta-parameter (typically meta module's parameters) and inner-parameter (typically `self.parameters`) before calling the `solve` function. By default, `ImplicitMetaGradientModule` treats all `nn.Module` or `torch.Tensor` in the input as meta-parameter, while treats the rest as inner-parameter (including `nn.Module` defined in `__init__` or copied network).

After calling the solve function, `ImplicitMetaGradientModule` will automatically calculate the implicit gradient defined in optimality function, and connect the gradient flow between meta-parameter and inner-parameter. Here we offer the example of implicit-maml.

In [8]:
from torchopt.diff.implicit import ImplicitMetaGradientModule


class InnerNet(
    ImplicitMetaGradientModule,
    linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),
):
    def __init__(self, meta_net, n_inner_iter, reg_param):
        super().__init__()
        # treated as meta-parameter
        self.meta_net = meta_net
        # Get a deepcopy, treated as inner-parameter
        self.net = torchopt.module_clone(meta_net, by='deepcopy', detach_buffers=True)
        self.n_inner_iter = n_inner_iter
        self.reg_param = reg_param

    def forward(self, x):
        return self.net(x)

    def objective(self, x, y):
        # We do not implement the optimality conditions, so it will be automatically derived from
        # the gradient of the `objective` function.
        y_pred = self(x)
        loss = F.mse_loss(y_pred, y)
        regularization_loss = 0
        for p1, p2 in zip(
            self.parameters(),  # parameters of `self.net`
            self.meta_parameters(),  # parameters of `self.meta_net`
        ):
            regularization_loss += (
                0.5 * self.reg_param * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))
            )
        return loss + regularization_loss

    def solve(self, x, y):
        params = tuple(self.parameters())
        inner_optim = torchopt.SGD(params, lr=2e-2)
        with torch.enable_grad():
            # Temporarily enable gradient computation for conducting the optimization
            for _ in range(self.n_inner_iter):
                loss = self.objective(x, y)
                inner_optim.zero_grad()
                # NOTE: The parameter inputs should be explicitly specified in `backward` function
                # as argument `inputs`. Otherwise, if not provided, the gradient is accumulated into
                # all the leaf Tensors (including the meta-parameters) that were used to compute the
                # objective output. Alternatively, please use `torch.autograd.grad` instead.
                loss.backward(inputs=params)  # backward pass in inner-loop
                inner_optim.step()  # update inner parameters
        return self
        # torchopt get the implicit gradient from self.parameters() to meta_net.parameters()


# Initialize the meta network
meta_net = Net(4)
inner_net = InnerNet(meta_net, 100, reg_param=1)

# Solve for inner-loop
optimal_inner_net = inner_net.solve(x, y)
outer_loss = optimal_inner_net(x).mean()

# Derive the meta gradient
torch.autograd.grad(outer_loss, meta_net.parameters())

(tensor([[-0.0369,  0.0248,  0.0347,  0.0067]]), tensor([0.3156]))


Another example for DEQ (fixed-point iteration in the inner-loop) can also be implemented using the OOP API. DEQ has the following fixed-point optimality conditions (inner-loop objective).
$$
F \left( \boldsymbol{\theta}, x^{\star} \right) = x^{\star}
$$
We are going to derive the implicit gradient from inner-parameter $x^{\star}$ to meta-parameter $\boldsymbol{\theta}$.

In [6]:
from torchopt.diff.implicit import ImplicitMetaGradientModule


class Net(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Sequential(nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, dim))
        # nn.init.ones_(self.fc.weight)
        # nn.init.zeros_(self.fc.bias)

    def forward(self, x):
        return self.fc(x)


class InnerNet(
    ImplicitMetaGradientModule,
    linear_solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),
):
    def __init__(self, meta_net):
        super().__init__()
        self.meta_net = meta_net

    def forward(self, x):
        return self.meta_net(x)

    def optimality(self):
        return tuple(self.x - self.forward(self.x))

    def register_x(self, x):
        # inner parameter should be the parameter of self so we set it here
        # so self.x belongs to self.parameters()
        self.x = nn.Parameter(x)

    def solve(
        self,
    ):
        # conduct iterative fixed point process
        for _ in range(10):
            self.x = nn.Parameter(self(self.x))
        return self
        # torchopt get the implicit gradient from self.x to meta_net.parameters()


# initialize meta network
meta_net = Net(4)
x = torch.randn(1, 4)
inner_net = InnerNet(
    meta_net,
)
# register inner-parameter before calling solve
# Must define the self.x (self.parameter) before calling the solve function
# do not set x as the input for solve function
inner_net.register_x(x)
# solve for inner-loop
optimal_inner_net = inner_net.solve()
outer_loss = optimal_inner_net.x.mean()

# derive the meta gradient
torch.autograd.grad(outer_loss, meta_net.parameters())

RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([1, 4]) and output[0] has a shape of torch.Size([4]).