# TorchOpt as Functional Optimizer

[<img align="left" src="https://colab.research.google.com/assets/colab-badge.svg">](https://colab.research.google.com/drive/1h005zH00arR5IgeSUjNETnP_r4A_-oMr?usp=sharing)

In this tutorial, we will introduce how TorchOpt can be treated as functional optimizer to conduct normal optimization with functional programing style. We will also illustrate how to conduct differentiable optimization with functional programing in PyTorch.

## 1. Basic API

In this first part, we will illustrate how TorchOpt can be used as a functional optimizer. We compare it with different api in Jax and PyTorch to help understand the similarity and dissimilarity. We use simple network, adam optimizer and  MSE loss objective.

In [2]:
import torch
import functorch
import torch.autograd
import torch.nn as nn
import optax
import jax
from jax import numpy as jnp

import torchopt


class Net(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fc = nn.Linear(dim, 1, bias=False)
        self.fc.weight.data = torch.ones_like(self.fc.weight.data)

    def forward(self, x):
        return self.fc(x)

- Original JAX implementation

The first example is jax implementation coupled with optax, which belongs to functional programing style.

In [2]:
def origin_jax():
    learning_rate = 1.
    batch_size = 1
    dim = 1
    optimizer = optax.adam(learning_rate)
    # Obtain the `opt_state` that contains statistics for the optimizer.
    params = {'w': jnp.ones((dim, 1))}
    opt_state = optimizer.init(params)

    def compute_loss(params, x, y): return (
        (jnp.matmul(x, params['w']) - y) ** 2).sum()

    xs = 2 * jnp.ones((batch_size, dim))
    ys = jnp.ones((batch_size, ))
    grads = jax.grad(compute_loss)(params, xs, ys)
    updates, opt_state = optimizer.update(grads, opt_state)
    print(params)
    params = optax.apply_updates(params, updates)
    print(params)

In [3]:
origin_jax()



{'w': DeviceArray([[1.]], dtype=float32)}
{'w': DeviceArray([[6.67572e-06]], dtype=float32)}


- Functorch with TorchOpt

The Second example is functorch coupled with TorchOpt. It basically follows the same structure with the jax example.

In [4]:
def interact_with_functorch():
    batch_size = 1
    dim = 1
    net = Net(dim)
    func, params = functorch.make_functional(net)

    lr = 1.
    optimizer = torchopt.adam(lr)

    opt_state = optimizer.init(params)

    xs = 2 * torch.ones(batch_size, dim)
    ys = torch.ones(batch_size)

    pred = func(params, xs)
    loss = ((pred - ys) ** 2).sum()
    grad = torch.autograd.grad(loss, params)
    updates, opt_state = optimizer.update(grad, opt_state)
    print(params)
    params = torchopt.apply_updates(params, updates)
    print(params)

In [5]:
interact_with_functorch()

(Parameter containing:
tensor([[1.]], requires_grad=True),)
(Parameter containing:
tensor([[0.]], requires_grad=True),)


- Full TorchOpt

The Third example is to illustrate that TorchOpt can also directly replace torch.optim with exactly the same usage. Note the API 
difference happens between torchopt.adam() and torchopt.Adam(). 

In [6]:
def full_torchopt():
    batch_size = 1
    dim = 1
    net = Net(dim)

    lr = 1.
    optim = torchopt.Adam(net.parameters(), lr=lr)

    xs = 2 * torch.ones(batch_size, dim)
    ys = torch.ones(batch_size)

    pred = net(xs)
    loss = ((pred - ys) ** 2).sum()

    print(net.fc.weight)
    optim.zero_grad()
    loss.backward()
    optim.step()
    print(net.fc.weight)

In [7]:
full_torchopt()

Parameter containing:
tensor([[1.]], requires_grad=True)
Parameter containing:
tensor([[0.]], requires_grad=True)


- Original PyTorch

The final example is to original PyTorch example with torch.optim.

In [8]:
def origin_torch():
    batch_size = 1
    dim = 1
    net = Net(dim)

    lr = 1.
    optim = torch.optim.Adam(net.parameters(), lr=lr)

    xs = 2 * torch.ones(batch_size, dim)
    ys = torch.ones(batch_size)

    pred = net(xs)
    loss = ((pred - ys) ** 2).sum()

    print(net.fc.weight)
    optim.zero_grad()
    loss.backward()
    optim.step()
    print(net.fc.weight)

In [9]:
origin_torch()

Parameter containing:
tensor([[1.]], requires_grad=True)
Parameter containing:
tensor([[1.1921e-07]], requires_grad=True)


## 2. Differentiable Optimization with functional optimizor
Coupled with functional optimizer, you can conduct differentiable optimization by setting the inplce flag as False in update and apply_updates function. (which might be helpful for meta-learning algorithm implementation with functional programing style). 

Note that torchopt.SGD, torchopt.Adam do not support differentiable optimization. Refer to the Meta Optimizer notebook for pytorch-like differentiable optimizers.

In [28]:
def differentiable():
    batch_size = 1
    dim = 1
    net = Net(dim)
    func, params = functorch.make_functional(net)

    lr = 1.
    # sgd example
    optimizer = torchopt.sgd(lr)
    meta_param = torch.tensor(1., requires_grad=True)

    opt_state = optimizer.init(params)

    xs = torch.ones(batch_size, dim)
    ys = torch.ones(batch_size)

    pred = func(params, xs)
    # where meta_param is used
    pred = pred + meta_param
    loss = ((pred - ys) ** 2).sum()
    grad = torch.autograd.grad(loss, params, create_graph=True)
    updates, opt_state = optimizer.update(grad, opt_state, inplace=False)
    params = torchopt.apply_updates(params, updates, inplace=False)

    pred = func(params, xs)
    loss = ((pred - ys) ** 2).sum()
    loss.backward()

    print(meta_param.grad)

In [29]:
differentiable()

tensor(8.)


## 2.1. Track the gradient of moment
Note that most modern optimizers involve moment term in the gradient update (basically only SGD with momentum = 0 does not involve). We provide an option for user to choose whether to also  track the meta-gradient through moment term. The default option is `moment_requires_grad=True`.

In [22]:
optim = torchopt.adam(lr=1., moment_requires_grad=False)

In [23]:
optim = torchopt.adam(lr=1., moment_requires_grad=True)

In [27]:
optim = torchopt.sgd(lr=1., momentum=0.8, moment_requires_grad=True)

## 3. Accletated Optimizer
Users can use acclerated optimizer by seeting the `use_accelerated_op` as True. Currently we only support the Adam optimizer.

Check whether the accelerated_op is avariable:

In [3]:
torchopt.accelerated_op_available(torch.device("cpu"))

True

In [4]:
torchopt.accelerated_op_available(torch.device("cuda"))

True

In [24]:
net = Net(1).cuda()
optim = torchopt.Adam(net.parameters(), lr=1., use_accelerated_op=True)

In [25]:
optim = torchopt.adam(lr=1., use_accelerated_op=True)