# TorchOpt for implicit differentiation

[<img align="left" src="https://colab.research.google.com/assets/colab-badge.svg">](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/5_Implicit_Differentiation.ipynb)

In this tutorial, we will introduce how TorchOpt can be used to conduct implicit differentiation. Here we use the example of [IMAML](https://arxiv.org/abs/1909.04630) as the illustrative example.

In [5]:
import functorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import jax

import torchopt
from torchopt import implicit_diff, sgd

## 1. Basic API

The basic API is **implicit_diff**, which is used as the decorator for the forward process implicit gradient procedures.

In [6]:
from torchopt import implicit_diff, sgd

## 1.1 Forward Process, Backward Process with Optimality Conditions
For IMAML, the inner-loop objective is described by the following equation.

$$
\mathcal{A} l g^{\star}\left(\boldsymbol{\theta}, \mathcal{D}_{i}^{\operatorname{tr}}\right)=\underset{\phi^{\prime} \in \Phi}{\operatorname{argmin}} \mathcal{L}\left(\boldsymbol{\phi}^{\prime}, \mathcal{D}_{i}^{\operatorname{tr}}\right)+\frac{\lambda}{2}\left\|\boldsymbol{\phi}^{\prime}-\boldsymbol{\theta}\right\|^{2}
$$

According to this function, we can define the forward function **inner_solver**, where we solve this equation based on sufficient gradient descents. For such inner-loop process, the optimality condition is that the gradient w.r.t inner-loop parameter is 0.

$$
\left.\nabla_{\boldsymbol{\phi}^{\prime}} G\left(\boldsymbol{\phi}^{\prime}, \boldsymbol{\theta}\right)\right|_{\phi^{\prime}=\boldsymbol{\phi}}=0
$$

Thus we can define the optimality function by defining **imaml_objective** and make it first-order gradient w.r.t the inner-loop parameter as 0. We achieve so by calling out **functorch.grad(imaml_objective, argnums=0)**. Finally, the forward function is decorated by the **@implicit_diff** and the optimalit condition we define.

In [10]:
# Optimality function
def imaml_objective(optimal_params, init_params, data):
    x, y, f = data
    y_pred = f(optimal_params, x)
    regularisation_loss = 0
    for p1, p2 in zip(optimal_params, init_params):
        regularisation_loss += 0.5 * torch.sum((p1.view(-1) - p2.view(-1))**2)
    loss = F.mse_loss(y_pred, y) + regularisation_loss
    return loss 

# Optimality Condition is: the gradient w.r.t inner-loop optimal params is 0 (we achieve so by specifying argnums=0 in functorch.grad)
# the argnums=1 specify which meta-parameter we want to backpropogate, in this case we want to backpropogate to the initial parameters
# so we set it as 1.
# You can also set argnums as (1,2) if you want to backpropogate through multiple meta parameters

@implicit_diff.custom_root(functorch.grad(imaml_objective, argnums=0), argnums=1)
def inner_solver(init_params_copy, init_params, data):
    """Solve ridge regression by conjugate gradient."""
    # inital functional optimizer based on torchopt
    x, y, f = data
    params = init_params_copy
    optimizer = sgd(lr=2e-2)
    opt_state = optimizer.init(params)
    with torch.enable_grad():
        # temporarily enable gradient computation for conducting the optimization
        for i in range(100):
            pred = f(params, x)   
            loss = F.mse_loss(pred, y)                         # compute loss
            regularisation_loss = 0
            # compute regularisation loss
            for p1, p2 in zip(params, init_params):
                regularisation_loss += 0.5 * torch.sum((p1.view(-1) - p2.view(-1))**2)
            final_loss = loss + regularisation_loss
            grads = torch.autograd.grad(final_loss, params)                # compute gradients
            updates, opt_state = optimizer.update(grads, opt_state)  # get updates
            params = TorchOpt.apply_updates(params, updates)       
    return params

In the next step, we consider a specific case for one layer neural network to fit the linear data.

In [12]:
x = torch.randn(20, 4)
w = torch.randn(4, 1)
b = torch.randn(1)
y = x @ w + b + torch.randn(20, 1) * 0.5

We instantiate an one layer neural network, where the weights and bias are initialised with constant.

In [13]:
class Net(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(dim, 1, bias=True)
        nn.init.ones_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)

    def forward(self, x):
        return self.fc(x)
    
model = Net(4)
f, p = functorch.make_functional(model)
data = (x, y, f)

# clone function for 
def clone(p):
    p_out = []
    for item in p:
        if isinstance(item, torch.Tensor):
            p_out.append(item.clone().detach_().requires_grad_(True))
        else:
            p_out.append(item)
    return tuple(p_out)

We take the forward process by calling out the forward function, then we pass the optimal params into the outer-loop loss function.

In [14]:
optimal_params = inner_solver(clone(p), p, data)

outer_loss = f(optimal_params, x).mean()

Finally, we can get the meta gradient as shown below.

In [15]:
torch.autograd.grad(outer_loss, p)

(tensor([[-0.0582, -0.0163,  0.0379, -0.0265]]), tensor([0.2984]))