# TorchOpt for Zero-Order Differentiation

[<img align="left" src="https://colab.research.google.com/assets/colab-badge.svg">](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/6_Zero_Order_Differentiation.ipynb)

When the inner-loop process is non-differentiable or one wants to eliminate the heavy computation burdens in the previous two modes (brought by Hessian), one can choose ZD. ZD typically gets gradients based on zero-order estimation, such as finite-difference, or Evolutionary Strategy.

TorchOpt offers API for ES-based differentiation. Instead of optimizing the objective $F$, ES optimizes a Gaussion smoothing objective defined as $\tilde{f}_{\sigma} (\theta) = \mathbb{E}_{{z} \sim \mathcal{N}( {0}, {I}_d )} [ f ({\theta} + \sigma \, z) ]$, where $\sigma$ denotes precision. The gradient of such objective is $\nabla_\theta \tilde{f}_{\sigma} (\theta) = \frac{1}{\sigma} \mathbb{E}_{{z} \sim \mathcal{N}( {0}, {I}_d )} [ f({\theta} + \sigma \, z) \cdot z ]$. Refer to [ES-MAML](https://arxiv.org/pdf/1910.01215.pdf) for more details.

In this tutorial, we will introduce how TorchOpt can be used to ES-based differentiation.

In [1]:
import functorch
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchopt

## 1. Functional API

The basic functional API is `torchopt.diff.zero_order.zero_order`, which is used as the decorator for the forward process zero-order gradient procedures. Users are required to implement the noise sampling function, which will be used as the input of zero_order decorator. Here we show the specific meaning for each parameter used in the decorator.

- `distribution` for noise sampling distribution
- `method` for different kind of algorithms, we support `'naive'` ([ES-RL](https://arxiv.org/abs/1703.03864)), `'forward'` ([Forward-FD](http://proceedings.mlr.press/v80/choromanski18a/choromanski18a.pdf)), and `'antithetic'` ([antithetic](https://d1wqtxts1xzle7.cloudfront.net/75609515/coredp2011_1web-with-cover-page-v2.pdf?Expires=1670215467&Signature=RfP~mQhhhI7aGknwXbRBgSggFrKuNTPYdyUSdMmfTxOa62QoOJAm-Xhr3F1PLyjUQc2JVxmKIKGGuyYvyfCTpB31dfmMtuVQxZMWVF-SfErTN05SliC93yjA1x1g2kjhn8bkBFdQqGl~1RQSKnhj88BakgSeDNzyCxwbD5VgR89BXRs4YIK5RBIKYtgLhoyz5jar7wHS3TJhRzs3WNeTIAjAmLqJ068oGFZ0Jr7maGquTe3w~8LEEIprJ6cyCMc6b1UUJkmwjNq0RLTVbxgFjfi4Z9kyxyJB9IOS1J25OOON4jfwh5JlXS7MVskuONUyHJim1TQ8OwCraKlBsQLPQw__&Key-Pair-Id=APKAJLOHF5GGSLRBV4ZA)).
- `argnums` specifies which parameter we want to trace the meta-gradient.
- `sigma` is for precision.
- `num_samples` specifies how many times we want to conduct the sampling.

We show the pseudo code in the following part.

```python
# Functional API for zero-order differentiation
# 1. Customize the noise distribution via a distribution class
class Distribution:
    def sample(self, sample_shape = torch.Size()):
        # sampling function for noise
        return noise_batch

distribution = Distribution()

# 2. Customize the noise distribution via a sampling function
def distribution(sample_shape = torch.Size()):
    # sampling function for noise
    return noise_batch

# 3. Distribution can also be an instance of `torch.distributions.Distribution`, e.g., `torch.distributions.Normal(...)`
distribution = torch.distributions.Normal(loc=0, scale=1)

# Decorator that wraps the function
@torchopt.diff.zero_order(distribution=distribution, method='naive', argnums=0, sigma=0.01, num_samples=100)
def forward(params, data):
    # Forward optimization process for params
    return output

# Define params and get data
params, data = ..., ...
loss = forward(params, data)

meta_grads = torch.autograd.grad(loss, params)
```

Here we use the example of a linear layer as an example, note that this is just an example to show linear layer can work with ES.

In [2]:
torch.random.manual_seed(0)

fmodel, params = functorch.make_functional(torch.nn.Linear(32, 1))
x = torch.randn(64, 32) * 0.1
y = torch.randn(64) * 0.1
distribution = torch.distributions.Normal(loc=0, scale=1)


@torchopt.diff.zero_order.zero_order(
    distribution=distribution, method='forward', argnums=0, sigma=0.01, num_samples=1000
)
def forward_process(params, fn, x, y):
    y_pred = fn(params, x)
    loss = torch.mean((y - y_pred) ** 2)
    return loss


optimizer = torchopt.adam(lr=0.01)
opt_state = optimizer.init(params)

for i in range(25):
    opt_state = optimizer.init(params)  # init optimizer
    loss = forward_process(params, fmodel, x, y)  # compute loss

    grads = torch.autograd.grad(loss, params)  # compute gradients
    updates, opt_state = optimizer.update(grads, opt_state)  # get updates
    params = torchopt.apply_updates(params, updates)  # update network parameters

    print(f'{i + 1:03d}: {loss!r}')

001: tensor(0.0269, grad_fn=<ZeroOrderBackward>)
002: tensor(0.0246, grad_fn=<ZeroOrderBackward>)
003: tensor(0.0225, grad_fn=<ZeroOrderBackward>)
004: tensor(0.0205, grad_fn=<ZeroOrderBackward>)
005: tensor(0.0187, grad_fn=<ZeroOrderBackward>)
006: tensor(0.0171, grad_fn=<ZeroOrderBackward>)
007: tensor(0.0156, grad_fn=<ZeroOrderBackward>)
008: tensor(0.0144, grad_fn=<ZeroOrderBackward>)
009: tensor(0.0134, grad_fn=<ZeroOrderBackward>)
010: tensor(0.0128, grad_fn=<ZeroOrderBackward>)
011: tensor(0.0122, grad_fn=<ZeroOrderBackward>)
012: tensor(0.0118, grad_fn=<ZeroOrderBackward>)
013: tensor(0.0120, grad_fn=<ZeroOrderBackward>)
014: tensor(0.0117, grad_fn=<ZeroOrderBackward>)
015: tensor(0.0117, grad_fn=<ZeroOrderBackward>)
016: tensor(0.0118, grad_fn=<ZeroOrderBackward>)
017: tensor(0.0121, grad_fn=<ZeroOrderBackward>)
018: tensor(0.0117, grad_fn=<ZeroOrderBackward>)
019: tensor(0.0118, grad_fn=<ZeroOrderBackward>)
020: tensor(0.0118, grad_fn=<ZeroOrderBackward>)
021: tensor(0.0115, 