# TorchOpt for zero-order differentiation

[<img align="left" src="https://colab.research.google.com/assets/colab-badge.svg">](https://colab.research.google.com/github/metaopt/torchopt/blob/main/tutorials/5_Implicit_Differentiation.ipynb)

When the inner-loop process is non-differentiable or one wants to eliminate the heavy computation burdens in the previous two modes (brought by Hessian), one can choose ZD. ZD typically gets gradients based on zero-order estimation, such as finite-difference, or Evolutionary Strategy.

TorchOpt offers API for ES-based differentiation. Instead of optimizing the objective $F$, ES optimizes a Gaussion smoothing objective defined as $\tilde{f}_\sigma({\theta}) = \mathbb{E}_{{z} \sim \mathcal{N}( {0}, {I}_d )} \sim[ f ({\theta} + \sigma {z}) ]$, where $\sigma$ denotes precision. The gradient of such objective is $\nabla_\theta \tilde{f}_\sigma(\theta) = \frac{1}{\sigma} \mathbb{E}_{{z} \sim \mathcal{N}( {0}, {I}_d )} \sim[ f({\theta} + \sigma {z}) {z} ]$. Refer to [ES-MAML](https://arxiv.org/pdf/1910.01215.pdf) for more details.

In this tutorial, we will introduce how TorchOpt can be used to ES-based differentiation.

In [2]:
# Copyright 2022 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functorch
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchopt

## 1. Functional API

The basic functional API is `torchopt.diff.zero_order.zero_order`, which is used as the decorator for the forward process zero-order gradient procedures. Users are required to implement the noise sampling function, which will be used as the input of zero_order decorator. Here we show the speicifc meaning for each parameter used in the decorator.
* distribution for noise sampling distribution
* sigma is for precision
* method for different kind of algorithms, we support 'naive'([ES-RL](https://arxiv.org/abs/1703.03864)), 'forward'([Forward-FD](http://proceedings.mlr.press/v80/choromanski18a/choromanski18a.pdf)) and 'antithetic'([antithetic](https://d1wqtxts1xzle7.cloudfront.net/75609515/coredp2011_1web-with-cover-page-v2.pdf?Expires=1670215467&Signature=RfP~mQhhhI7aGknwXbRBgSggFrKuNTPYdyUSdMmfTxOa62QoOJAm-Xhr3F1PLyjUQc2JVxmKIKGGuyYvyfCTpB31dfmMtuVQxZMWVF-SfErTN05SliC93yjA1x1g2kjhn8bkBFdQqGl~1RQSKnhj88BakgSeDNzyCxwbD5VgR89BXRs4YIK5RBIKYtgLhoyz5jar7wHS3TJhRzs3WNeTIAjAmLqJ068oGFZ0Jr7maGquTe3w~8LEEIprJ6cyCMc6b1UUJkmwjNq0RLTVbxgFjfi4Z9kyxyJB9IOS1J25OOON4jfwh5JlXS7MVskuONUyHJim1TQ8OwCraKlBsQLPQw__&Key-Pair-Id=APKAJLOHF5GGSLRBV4ZA)).
* argnums specifies which parameter we want to trace the meta-gradient.
* num_samples specifies how many times we want to conduct the sampling.

We show the pseudo code in the following part.

```python
# Functional API for implicit gradient
# customize 
class distributions():
    def sample(shape):
        # sampling function for noise
        return noise

distribution = distributions()

# distribution can also be torch.distributions, e.g., torch.distributions.normal.Normal()

# Decorator that wraps the function
@torchopt.diff.zero_order.zero_order(distribution=distribution, sigma=0.01, method='naive', argnums=1, num_samples=100)
def solve(params, meta_params, data):
    # Forward optimization process for params
    return optimal_params

# Define params, meta params and get data
params, meta_prams, data = ..., ..., ...
optimal_params = solve(params, meta_params, data)
loss = outer_loss(optimal_params)

meta_grads = torch.autograd.grad(loss, meta_params)
```

Here we use the example of a linear layer as an example, note that this is just an example to show linear layer can work with ES.

In [15]:
import functorch
import torch

import torchopt

torch.random.manual_seed(0)

a = torch.nn.Linear(32, 1)
fmodel, params = functorch.make_functional(a)
x = torch.randn(64, 32) * 0.1
y = torch.randn(64) * 0.1
distribution = torch.distributions.normal.Normal(loc=0, scale=1)


@torchopt.diff.zero_order.zero_order(distribution=distribution, sigma=0.01, method='forward', argnums=0, num_samples=1000)
def forward_process(params, f, x, y):
    y_pred = f(params, x)
    loss = torch.mean((y - y_pred) ** 2)
    return loss

#out = forward_process(params, fmodel, x, y)
optimizer = torchopt.adam(lr=0.01)
opt_state = optimizer.init(params)

for i in range(20):
    opt_state = optimizer.init(params)                       # init optimizer                             
    loss = forward_process(params, fmodel, x, y)             # compute loss

    grads = torch.autograd.grad(loss, params)                # compute gradients
    updates, opt_state = optimizer.update(grads, opt_state)  # get updates
    params = torchopt.apply_updates(params, updates)         # update network parameters
    
    print(loss)

tensor(0.0269, grad_fn=<ZeroOrderBackward>)
tensor(0.0246, grad_fn=<ZeroOrderBackward>)
tensor(0.0225, grad_fn=<ZeroOrderBackward>)
tensor(0.0205, grad_fn=<ZeroOrderBackward>)
tensor(0.0187, grad_fn=<ZeroOrderBackward>)
tensor(0.0171, grad_fn=<ZeroOrderBackward>)
tensor(0.0156, grad_fn=<ZeroOrderBackward>)
tensor(0.0144, grad_fn=<ZeroOrderBackward>)
tensor(0.0134, grad_fn=<ZeroOrderBackward>)
tensor(0.0128, grad_fn=<ZeroOrderBackward>)
tensor(0.0122, grad_fn=<ZeroOrderBackward>)
tensor(0.0118, grad_fn=<ZeroOrderBackward>)
tensor(0.0120, grad_fn=<ZeroOrderBackward>)
tensor(0.0117, grad_fn=<ZeroOrderBackward>)
tensor(0.0117, grad_fn=<ZeroOrderBackward>)
tensor(0.0118, grad_fn=<ZeroOrderBackward>)
tensor(0.0121, grad_fn=<ZeroOrderBackward>)
tensor(0.0117, grad_fn=<ZeroOrderBackward>)
tensor(0.0118, grad_fn=<ZeroOrderBackward>)
tensor(0.0118, grad_fn=<ZeroOrderBackward>)
