# TorchOpt as Functional Optimizer

[<img align="left" src="https://colab.research.google.com/assets/colab-badge.svg">](https://colab.research.google.com/drive/1yfi-ETyIptlIM7WFYWF_IFhX4WF3LldP?usp=sharing)

In this tutorial, we will introduce how TorchOpt can be treated as functional optimizer to conduct normal optimization with functional programing style. We will also illustrate how to conduct differentiable optimization with functional programing in PyTorch.

## 1. Basic API

In this first part, we will illustrate how TorchOpt can be used as a functional optimizer. We compare it with different API in [JAX](https://github.com/google/jax) and [PyTorch](https://pytorch.org) to help understand the similarity and dissimilarity. We use simple network, Adam optimizer and MSE loss objective.

In [1]:
from collections import OrderedDict

import functorch
import jax
import jax.numpy as jnp
import optax
import torch
import torch.autograd
import torch.nn as nn

import torchopt


class Net(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(dim, 1, bias=True)
        nn.init.ones_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)

    def forward(self, x):
        return self.fc(x)


def mse(inputs, targets):
    return ((inputs - targets) ** 2).mean()

### 1.1 Original JAX implementation

The first example is JAX implementation coupled with [Optax](https://github.com/deepmind/optax), which belongs to functional programing style.

In [2]:
def origin_jax():
    batch_size = 1
    dim = 1
    params = OrderedDict([('weight', jnp.ones((dim, 1))), ('bias', jnp.zeros((1,)))])

    def model(params, x):
        return jnp.matmul(x, params['weight']) + params['bias']

    # Obtain the `opt_state` that contains statistics for the optimizer
    learning_rate = 1.
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(params)

    def compute_loss(params, x, y):
        pred = model(params, x)
        return mse(pred, y)

    xs = 2 * jnp.ones((batch_size, dim))
    ys = jnp.ones((batch_size, 1))

    grads = jax.grad(compute_loss)(params, xs, ys)
    updates, opt_state = optimizer.update(grads, opt_state)

    print('Parameters before update:', params)
    params = optax.apply_updates(params, updates)
    print('Parameters after update:', params)

In [3]:
origin_jax()

Parameters before update: {
    'weight': DeviceArray([[1.]], dtype=float32)),
    'bias': DeviceArray([0.], dtype=float32)
}
Parameters after update: {
    'weight': DeviceArray([[6.735325e-06]], dtype=float32),
    'bias': DeviceArray([-0.99999326], dtype=float32)
}

### 1.2 `functorch` with TorchOpt

The second example is [`functorch`](https://pytorch.org/functorch) coupled with TorchOpt. It basically follows the same structure with the JAX example.

In [4]:
def interact_with_functorch():
    batch_size = 1
    dim = 1
    net = Net(dim)
    model, params = functorch.make_functional(net)  # get the functional version of the model

    # Obtain the `opt_state` that contains statistics for the optimizer
    learning_rate = 1.
    optimizer = torchopt.adam(learning_rate)
    opt_state = optimizer.init(params)

    xs = 2 * torch.ones((batch_size, dim))
    ys = torch.ones((batch_size, 1))

    pred = model(params, xs)
    loss = mse(pred, ys)

    grads = torch.autograd.grad(loss, params)
    updates, opt_state = optimizer.update(grads, opt_state)
    
    print('Parameters before update:', params)
    params = torchopt.apply_updates(params, updates)
    print('Parameters after update:', params)

In [5]:
interact_with_functorch()

Parameters before update: (
    Parameter containing: tensor([[1.]], requires_grad=True),
    Parameter containing: tensor([0.], requires_grad=True)
)
Parameters after update: (
    Parameter containing: tensor([[0.]], requires_grad=True),
    Parameter containing: tensor([-1.], requires_grad=True)
)

### 1.3 Full TorchOpt

The third example is to illustrate that TorchOpt can also directly replace `torch.optim` with exactly the same usage. Note the API difference happens between `torchopt.adam()` and `torchopt.Adam()`.

In [6]:
def full_torchopt():
    batch_size = 1
    dim = 1
    net = Net(dim)

    learning_rate = 1.
    optim = torchopt.Adam(net.parameters(), lr=learning_rate)

    xs = 2 * torch.ones((batch_size, dim))
    ys = torch.ones((batch_size, 1))

    pred = net(xs)
    loss = mse(pred, ys)

    print('Parameters before update:', dict(net.named_parameters()))
    optim.zero_grad()
    loss.backward()
    optim.step()
    print('Parameters after update:', dict(net.named_parameters()))

In [7]:
full_torchopt()

Parameters before update: {
    'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),
    'fc.bias': Parameter containing: tensor([0.], requires_grad=True)
}
Parameters after update: {
    'fc.weight': Parameter containing: tensor([[0.]], requires_grad=True),
    'fc.bias': Parameter containing: tensor([-1.], requires_grad=True)
}

### 1.4 Original PyTorch

The final example is to original PyTorch example with `torch.optim`.

In [8]:
def origin_torch():
    batch_size = 1
    dim = 1
    net = Net(dim)

    learning_rate = 1.
    optim = torch.optim.Adam(net.parameters(), lr=learning_rate)

    xs = 2 * torch.ones((batch_size, dim))
    ys = torch.ones((batch_size, 1))

    pred = net(xs)
    loss = mse(pred, ys)

    print('Parameters before update:', dict(net.named_parameters()))
    optim.zero_grad()
    loss.backward()
    optim.step()
    print('Parameters after update:', dict(net.named_parameters()))

In [9]:
origin_torch()

Parameters before update: {
    'fc.weight': Parameter containing: tensor([[1.]], requires_grad=True),
    'fc.bias': Parameter containing: tensor([0.], requires_grad=True)
}
Parameters after update: {
    'fc.weight': Parameter containing: tensor([[1.1921e-07]], requires_grad=True),
    'fc.bias': Parameter containing: tensor([-1.0000], requires_grad=True)
}

## 2. Differentiable Optimization with Functional Optimizer

Coupled with functional optimizer, you can conduct differentiable optimization by setting the `inplace` flag as `False` in update and `apply_updates` function. (which might be helpful for meta-learning algorithm implementation with functional programing style). 

Note that `torchopt.SGD` and `torchopt.Adam` do not support differentiable optimization. Refer to the Meta-Optimizer notebook for PyTorch-like differentiable optimizers.

In [10]:
def differentiable():
    batch_size = 1
    dim = 1
    net = Net(dim)
    model, params = functorch.make_functional(net)  # get the functional version of the model

    # Meta-parameter
    meta_param = nn.Parameter(torch.ones(1))

    # SGD example
    learning_rate = 1.
    optimizer = torchopt.sgd(learning_rate)
    opt_state = optimizer.init(params)

    xs = torch.ones((batch_size, dim))
    ys = torch.ones((batch_size, 1))

    pred = model(params, xs)
    # Where meta_param is used
    pred = pred + meta_param
    loss = mse(pred, ys)

    grads = torch.autograd.grad(loss, params, create_graph=True)
    updates, opt_state = optimizer.update(grads, opt_state, inplace=False)
    params = torchopt.apply_updates(params, updates, inplace=False)  # update parameters with single step SGD update

    pred = model(params, xs)
    loss = mse(pred, ys)
    loss.backward()

    print('Gradient for the meta-parameter:', meta_param.grad)

In [11]:
differentiable()

Gradient for the meta-parameter: tensor([32.])


### 2.1 Track the Gradient of Momentum

Note that most modern optimizers involve momentum term in the gradient update (basically only SGD with `momentum = 0` does not involve). We provide an option for user to choose whether to also track the meta-gradient through momentum term. The default option is `moment_requires_grad=True`.

In [12]:
optim = torchopt.adam(lr=1., moment_requires_grad=False)

In [13]:
optim = torchopt.adam(lr=1., moment_requires_grad=True)

In [14]:
optim = torchopt.sgd(lr=1., momentum=0.8, moment_requires_grad=True)

## 3. Accelerated Optimizer

Users can use accelerated optimizer by setting the `use_accelerated_op` as `True`. Currently we only support the Adam optimizer.

Check whether the `accelerated_op` is available:

In [15]:
torchopt.accelerated_op_available(torch.device('cpu'))

True


In [16]:
torchopt.accelerated_op_available(torch.device('cuda'))

True


In [17]:
net = Net(1).cuda()
optim = torchopt.Adam(net.parameters(), lr=1., use_accelerated_op=True)

In [18]:
optim = torchopt.adam(lr=1., use_accelerated_op=True)