# TorchOpt for Zero-Order Differentiation

[<img align="left" src="https://colab.research.google.com/assets/colab-badge.svg">](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/6_Zero_Order_Differentiation.ipynb)

When the inner-loop process is non-differentiable or one wants to eliminate the heavy computation burdens in the previous two modes (brought by Hessian), one can choose ZD. ZD typically gets gradients based on zero-order estimation, such as finite-difference, or Evolutionary Strategy.

TorchOpt offers API for ES-based differentiation. Instead of optimizing the objective $F$, ES optimizes a Gaussion smoothing objective defined as $\tilde{f}_{\sigma} (\theta) = \mathbb{E}_{{z} \sim \mathcal{N}( {0}, {I}_d )} [ f ({\theta} + \sigma \, z) ]$, where $\sigma$ denotes precision. The gradient of such objective is $\nabla_\theta \tilde{f}_{\sigma} (\theta) = \frac{1}{\sigma} \mathbb{E}_{{z} \sim \mathcal{N}( {0}, {I}_d )} [ f({\theta} + \sigma \, z) \cdot z ]$. Refer to [ES-MAML](https://arxiv.org/pdf/1910.01215.pdf) for more details.

In this tutorial, we will introduce how TorchOpt can be used to ES-based differentiation.

In [1]:
import functorch
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchopt

## 1. Functional API

The basic functional API is `torchopt.diff.zero_order.zero_order`, which is used as the decorator for the forward process zero-order gradient procedures. Users are required to implement the noise sampling function, which will be used as the input of zero_order decorator. Here we show the specific meaning for each parameter used in the decorator.

- `distribution` for noise sampling distribution. The distribution $\lambda$ should be spherical symmetric and with a constant variance of $1$ for each element. I.e.:
    - Spherical symmetric: $\mathbb{E}_{\boldsymbol{z} \sim \lambda} [ \boldsymbol{z} ] = \boldsymbol{0}$.
    - Constant variance of $1$ for each element: $\mathbb{E}_{\boldsymbol{z} \sim \lambda} [ {\lvert \boldsymbol{z}_i \rvert}^2 ] = 1$.
- `method` for different kind of algorithms, we support `'naive'` ([ES-RL](https://arxiv.org/abs/1703.03864)), `'forward'` ([Forward-FD](http://proceedings.mlr.press/v80/choromanski18a/choromanski18a.pdf)), and `'antithetic'` ([antithetic](https://d1wqtxts1xzle7.cloudfront.net/75609515/coredp2011_1web-with-cover-page-v2.pdf?Expires=1670215467&Signature=RfP~mQhhhI7aGknwXbRBgSggFrKuNTPYdyUSdMmfTxOa62QoOJAm-Xhr3F1PLyjUQc2JVxmKIKGGuyYvyfCTpB31dfmMtuVQxZMWVF-SfErTN05SliC93yjA1x1g2kjhn8bkBFdQqGl~1RQSKnhj88BakgSeDNzyCxwbD5VgR89BXRs4YIK5RBIKYtgLhoyz5jar7wHS3TJhRzs3WNeTIAjAmLqJ068oGFZ0Jr7maGquTe3w~8LEEIprJ6cyCMc6b1UUJkmwjNq0RLTVbxgFjfi4Z9kyxyJB9IOS1J25OOON4jfwh5JlXS7MVskuONUyHJim1TQ8OwCraKlBsQLPQw__&Key-Pair-Id=APKAJLOHF5GGSLRBV4ZA)).
- `argnums` specifies which parameter we want to trace the meta-gradient.
- `num_samples` specifies how many times we want to conduct the sampling.
- `sigma` is for precision. This is the scaling factor for the sampling distribution.

We show the pseudo code in the following part.

```python
# Functional API for zero-order differentiation
# 1. Customize the noise distribution via a distribution class
class Distribution:
    def sample(self, sample_shape=torch.Size()):
        # Sampling function for noise
        # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.
        ...
        return noise_batch

distribution = Distribution()

# 2. Customize the noise distribution via a sampling function
def distribution(sample_shape=torch.Size()):
    # Sampling function for noise
    # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.
    ...
    return noise_batch

# 3. Distribution can also be an instance of `torch.distributions.Distribution`, e.g., `torch.distributions.Normal(...)`
distribution = torch.distributions.Normal(loc=0, scale=1)

# Decorator that wraps the function
@torchopt.diff.zero_order(distribution=distribution, method='naive', argnums=0, num_samples=100, sigma=0.01)
def forward(params, data):
    # Forward optimization process for params
    ...
    return objective  # the returned tensor should be a scalar tensor

# Define params and get data
params, data = ..., ...

# Forward pass
loss = forward(params, data)
# Backward pass using zero-order differentiation
grads = torch.autograd.grad(loss, params)
```

Here we use the example of a linear layer as an example, note that this is just an example to show linear layer can work with ES.

In [2]:
torch.random.manual_seed(0)

fmodel, params = functorch.make_functional(nn.Linear(32, 1))
x = torch.randn(64, 32) * 0.1
y = torch.randn(64, 1) * 0.1
distribution = torch.distributions.Normal(loc=0, scale=1)


@torchopt.diff.zero_order(
    distribution=distribution, method='forward', argnums=0, num_samples=100, sigma=0.01
)
def forward_process(params, fn, x, y):
    y_pred = fn(params, x)
    loss = F.mse_loss(y_pred, y)
    return loss


optimizer = torchopt.adam(lr=0.01)
opt_state = optimizer.init(params)  # init optimizer

for i in range(25):
    loss = forward_process(params, fmodel, x, y)  # compute loss

    grads = torch.autograd.grad(loss, params)  # compute gradients
    updates, opt_state = optimizer.update(grads, opt_state)  # get updates
    params = torchopt.apply_updates(params, updates)  # update network parameters

    print(f'{i + 1:03d}: {loss!r}')

001: tensor(0.0265, grad_fn=<ZeroOrderBackward>)
002: tensor(0.0243, grad_fn=<ZeroOrderBackward>)
003: tensor(0.0222, grad_fn=<ZeroOrderBackward>)
004: tensor(0.0202, grad_fn=<ZeroOrderBackward>)
005: tensor(0.0184, grad_fn=<ZeroOrderBackward>)
006: tensor(0.0170, grad_fn=<ZeroOrderBackward>)
007: tensor(0.0157, grad_fn=<ZeroOrderBackward>)
008: tensor(0.0146, grad_fn=<ZeroOrderBackward>)
009: tensor(0.0137, grad_fn=<ZeroOrderBackward>)
010: tensor(0.0130, grad_fn=<ZeroOrderBackward>)
011: tensor(0.0123, grad_fn=<ZeroOrderBackward>)
012: tensor(0.0118, grad_fn=<ZeroOrderBackward>)
013: tensor(0.0114, grad_fn=<ZeroOrderBackward>)
014: tensor(0.0111, grad_fn=<ZeroOrderBackward>)
015: tensor(0.0111, grad_fn=<ZeroOrderBackward>)
016: tensor(0.0111, grad_fn=<ZeroOrderBackward>)
017: tensor(0.0113, grad_fn=<ZeroOrderBackward>)
018: tensor(0.0115, grad_fn=<ZeroOrderBackward>)
019: tensor(0.0118, grad_fn=<ZeroOrderBackward>)
020: tensor(0.0120, grad_fn=<ZeroOrderBackward>)
021: tensor(0.0121, 

## 2. OOP API

The basic OOP API is the class `ZeroOrderGradientModule`. We make the network as an `nn.Module` following a classical PyTorch style. Users need to define the forward process zero-order gradient procedures `forward()` and a noise sampling function `sample()`. Here we show the specific meaning for each parameter used in the class.

- `method` for different kind of algorithms, we support `'naive'` ([ES-RL](https://arxiv.org/abs/1703.03864)), `'forward'` ([Forward-FD](http://proceedings.mlr.press/v80/choromanski18a/choromanski18a.pdf)), and `'antithetic'` ([antithetic](https://d1wqtxts1xzle7.cloudfront.net/75609515/coredp2011_1web-with-cover-page-v2.pdf?Expires=1670215467&Signature=RfP~mQhhhI7aGknwXbRBgSggFrKuNTPYdyUSdMmfTxOa62QoOJAm-Xhr3F1PLyjUQc2JVxmKIKGGuyYvyfCTpB31dfmMtuVQxZMWVF-SfErTN05SliC93yjA1x1g2kjhn8bkBFdQqGl~1RQSKnhj88BakgSeDNzyCxwbD5VgR89BXRs4YIK5RBIKYtgLhoyz5jar7wHS3TJhRzs3WNeTIAjAmLqJ068oGFZ0Jr7maGquTe3w~8LEEIprJ6cyCMc6b1UUJkmwjNq0RLTVbxgFjfi4Z9kyxyJB9IOS1J25OOON4jfwh5JlXS7MVskuONUyHJim1TQ8OwCraKlBsQLPQw__&Key-Pair-Id=APKAJLOHF5GGSLRBV4ZA)).
- `num_samples` specifies how many times we want to conduct the sampling.
- `sigma` is for precision. This is the scaling factor for the sampling distribution.

We show the pseudo code in the following part.

```python
from torchopt.nn import ZeroOrderGradientModule

# Inherited from the class ZeroOrderGradientModule
# Optionally specify the `method` and/or `num_samples` and/or `sigma` used for sampling
class Net(ZeroOrderGradientModule, method='naive', num_samples=100, sigma=0.01):
    def __init__(self, ...):
        ...

    def forward(self, batch):
        # Forward process
        ...
        return objective  # the returned tensor should be a scalar tensor

    def sample(self, sample_shape=torch.Size()):
        # Generate a batch of noise samples
        # NOTE: The distribution should be spherical symmetric and with a constant variance of 1.
        ...
        return noise_batch

# Get model and data
net = Net(...)
data = ...

# Forward pass
loss = Net(data)
# Backward pass using zero-order differentiation
grads = torch.autograd.grad(loss, net.parameters())
```

Here we reimplement the functional API example above with the OOP API.

In [3]:
torch.random.manual_seed(0)


class Net(torchopt.nn.ZeroOrderGradientModule, method='forward', num_samples=100, sigma=0.01):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(dim, 1)
        self.distribution = torch.distributions.Normal(loc=0, scale=1)

    def forward(self, x, y):
        y_pred = self.fc(x)
        loss = F.mse_loss(y_pred, y)
        return loss

    def sample(self, sample_shape=torch.Size()):
        return self.distribution.sample(sample_shape)


x = torch.randn(64, 32) * 0.1
y = torch.randn(64, 1) * 0.1
net = Net(dim=32)


optimizer = torchopt.Adam(net.parameters(), lr=0.01)

for i in range(25):
    loss = net(x, y)  # compute loss

    optimizer.zero_grad()
    loss.backward()  # backward pass
    optimizer.step()  # update network parameters

    print(f'{i + 1:03d}: {loss!r}')

001: tensor(0.0201, grad_fn=<ZeroOrderBackward>)
002: tensor(0.0181, grad_fn=<ZeroOrderBackward>)
003: tensor(0.0167, grad_fn=<ZeroOrderBackward>)
004: tensor(0.0153, grad_fn=<ZeroOrderBackward>)
005: tensor(0.0142, grad_fn=<ZeroOrderBackward>)
006: tensor(0.0133, grad_fn=<ZeroOrderBackward>)
007: tensor(0.0125, grad_fn=<ZeroOrderBackward>)
008: tensor(0.0119, grad_fn=<ZeroOrderBackward>)
009: tensor(0.0116, grad_fn=<ZeroOrderBackward>)
010: tensor(0.0114, grad_fn=<ZeroOrderBackward>)
011: tensor(0.0112, grad_fn=<ZeroOrderBackward>)
012: tensor(0.0112, grad_fn=<ZeroOrderBackward>)
013: tensor(0.0113, grad_fn=<ZeroOrderBackward>)
014: tensor(0.0116, grad_fn=<ZeroOrderBackward>)
015: tensor(0.0118, grad_fn=<ZeroOrderBackward>)
016: tensor(0.0121, grad_fn=<ZeroOrderBackward>)
017: tensor(0.0123, grad_fn=<ZeroOrderBackward>)
018: tensor(0.0125, grad_fn=<ZeroOrderBackward>)
019: tensor(0.0127, grad_fn=<ZeroOrderBackward>)
020: tensor(0.0127, grad_fn=<ZeroOrderBackward>)
021: tensor(0.0125, 