# TorchOpt for implicit differentiation

[<img align="left" src="https://colab.research.google.com/assets/colab-badge.svg">](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/5_Implicit_Differentiation.ipynb)

In this tutorial, we will introduce how TorchOpt can be used to conduct implicit differentiation. Here we use the example of [IMAML](https://arxiv.org/abs/1909.04630) as the illustrative example.

In [1]:
import functorch
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchopt

## 1. Basic API

The basic API is `torchopt.diff.implicit.custom_root`, which is used as the decorator for the forward process implicit gradient procedures.

## 1.1 Forward Process, Backward Process with Optimality Conditions
For IMAML, the inner-loop objective is described by the following equation.

$$
{\mathcal{Alg}}^{\star} \left( \boldsymbol{\theta}, \mathcal{D}_{i}^{\text{tr}} \right) = \underset{\phi' \in \Phi}{\operatorname{\arg \min}} ~ G \left( \boldsymbol{\phi}', \boldsymbol{\theta} \right) \triangleq \mathcal{L} \left( \boldsymbol{\phi}', \mathcal{D}_{i}^{\text{tr}} \right) + \frac{\lambda}{2} {\left\| \boldsymbol{\phi}' - \boldsymbol{\theta} \right\|}^{2}
$$

According to this function, we can define the forward function `inner_solver`, where we solve this equation based on sufficient gradient descents. For such inner-loop process, the optimality condition is that the gradient w.r.t inner-loop parameter is $0$.

$$
{\left. \nabla_{\boldsymbol{\phi}'} G \left( \boldsymbol{\phi}', \boldsymbol{\theta} \right) \right|}_{\boldsymbol{\phi}' = \boldsymbol{\phi}^{\star}} = 0
$$

Thus we can define the optimality function by defining `imaml_objective` and make it first-order gradient w.r.t the inner-loop parameter as $0$. We achieve so by calling out `functorch.grad(imaml_objective, argnums=0)`. Finally, the forward function is decorated by the `@torchopt.diff.implicit.custom_root` decorator and the optimality condition we define.

In [2]:
# Optimality function
def imaml_objective(params, meta_params, data):
    x, y, fmodel = data
    y_pred = fmodel(params, x)
    regularization_loss = 0.0
    for p1, p2 in zip(params, meta_params):
        regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))
    loss = F.mse_loss(y_pred, y) + regularization_loss
    return loss


# Optimality Condition is: the gradient w.r.t inner-loop optimal params is 0 (we achieve so by
# specifying argnums=0 in functorch.grad) the argnums=1 specify which meta-parameter we want to
# backpropogate, in this case we want to backpropogate to the initial parameters so we set it as 1.
# You can also set argnums as (1, 2) if you want to backpropogate through multiple meta parameters

# Here we pass argnums=1 to the custom_root. That means we want to compute the gradient of
# optimal_params w.r.t. the 1-indexed argument in inner_solver, i.e., params.
@torchopt.diff.implicit.custom_root(functorch.grad(imaml_objective, argnums=0), argnums=1)
def inner_solver(params, meta_params, data):
    """Solve ridge regression by conjugate gradient."""
    # Initial functional optimizer based on TorchOpt
    x, y, fmodel = data
    optimizer = torchopt.sgd(lr=2e-2)
    opt_state = optimizer.init(params)
    with torch.enable_grad():
        # Temporarily enable gradient computation for conducting the optimization
        for i in range(100):
            pred = fmodel(params, x)
            loss = F.mse_loss(pred, y)  # compute loss

            # Compute regularization loss
            regularization_loss = 0.0
            for p1, p2 in zip(params, meta_params):
                regularization_loss += 0.5 * torch.sum(torch.square(p1.view(-1) - p2.view(-1)))
            final_loss = loss + regularization_loss

            grads = torch.autograd.grad(final_loss, params)  # compute gradients
            updates, opt_state = optimizer.update(grads, opt_state, inplace=True)  # get updates
            params = torchopt.apply_updates(params, updates, inplace=True)

    optimal_params = params
    return optimal_params

In the next step, we consider a specific case for one layer neural network to fit the linear data.

In [3]:
torch.manual_seed(0)
x = torch.randn(20, 4)
w = torch.randn(4, 1)
b = torch.randn(1)
y = x @ w + b + 0.5 * torch.randn(20, 1)

We instantiate an one layer neural network, where the weights and bias are initialised with constant.

In [4]:
class Net(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(dim, 1, bias=True)
        nn.init.ones_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)

    def forward(self, x):
        return self.fc(x)


model = Net(4)
fmodel, meta_params = functorch.make_functional(model)
data = (x, y, fmodel)

# clone function for parameters
def clone(params):
    cloned = []
    for item in params:
        if isinstance(item, torch.Tensor):
            cloned.append(item.clone().detach_().requires_grad_(True))
        else:
            cloned.append(item)
    return tuple(cloned)

We take the forward process by calling out the forward function, then we pass the optimal params into the outer-loop loss function.

In [5]:
optimal_params = inner_solver(clone(meta_params), meta_params, data)

outer_loss = fmodel(optimal_params, x).mean()

Finally, we can get the meta gradient as shown below.

In [6]:
torch.autograd.grad(outer_loss, meta_params)

(tensor([[-0.0369,  0.0248,  0.0347,  0.0067]]), tensor([0.3156]))
