torchopt
The core design of TorchOpt follows the philosophy of functional programming. Aligned with functorch
_, users can conduct functional-style programming with models, optimizers, and training in PyTorch. We first introduce our functional optimizers, which treat the optimization process as a functional transformation.
Currently, TorchOpt supports 4 functional optimizers: sgd
, adam
, rmsprop
, and adamw
.
torchopt.FuncOptimizer torchopt.adadelta torchopt.adagrad torchopt.adam torchopt.adamw torchopt.adamax torchopt.radam torchopt.rmsprop torchopt.sgd
TorchOpt offers functional API by passing gradients and optimizer states to the optimizer function to apply updates.
torchopt.apply_updates
Here is an example of functional optimization coupled with functorch
_:
class Net(nn.Module): ...
class Loader(DataLoader): ...
net = Net() # init
loader = Loader()
optimizer = torchopt.adam(lr)
model, params = functorch.make_functional(net) # use functorch extract network parameters
opt_state = optimizer.init(params) # init optimizer
xs, ys = next(loader) # get data
pred = model(params, xs) # forward
loss = F.cross_entropy(pred, ys) # compute loss
grads = torch.autograd.grad(loss, params) # compute gradients
updates, opt_state = optimizer.update(grads, opt_state) # get updates
params = torchopt.apply_updates(params, updates) # update network parameters
We also provide a wrapper torchopt.FuncOptimizer
to make maintaining the optimizer state easier:
net = Net() # init
loader = Loader()
optimizer = torchopt.FuncOptimizer(torchopt.adam()) # wrap with `torchopt.FuncOptimizer`
model, params = functorch.make_functional(net) # use functorch extract network parameters
for xs, ys in loader: # get data
pred = model(params, xs) # forward
loss = F.cross_entropy(pred, ys) # compute loss
params = optimizer.step(loss, params) # update network parameters
Combined with the functional optimizers above, we can define our classic OOP optimizers. We designed a base class torchopt.Optimizer
that has the same interface as torch.optim.Optimizer
_. We offer original PyTorch APIs (e.g., zero_grad()
or step()
) for traditional PyTorch-like (OOP) parameter update.
torchopt.Optimizer torchopt.AdaDelta torchopt.Adadelta torchopt.AdaGrad torchopt.Adagrad torchopt.Adam torchopt.AdamW torchopt.AdaMax torchopt.Adamax torchopt.RAdam torchopt.RMSProp torchopt.SGD
By combining low-level API torchopt.Optimizer
with the previous functional optimizer, we can achieve high-level API:
learning_rate = 1.0
# High-level API
optim = torchopt.Adam(net.parameters(), lr=learning_rate)
# which can be achieved by low-level API:
optim = torchopt.Optimizer(net.parameters(), torchopt.adam(lr=learning_rate))
Here is an example of PyTorch-like APIs:
net = Net() # init
loader = Loader()
optimizer = torchopt.Adam(net.parameters())
xs, ys = next(loader) # get data
pred = net(xs) # forward
loss = F.cross_entropy(pred, ys) # compute loss
optimizer.zero_grad() # zero gradients
loss.backward() # backward
optimizer.step() # step updates
Users always need to conduct multiple gradient transformations (functions) before the final update. In the designing of TorchOpt, we treat these functions as derivations of torchopt.chain
. So we can build our own chain like torchopt.chain(torchopt.clip_grad_norm(max_norm=1.), torchopt.sgd(lr=1., moment_requires_grad=True))
to clip the gradient and update parameters using sgd
.
torchopt.chain
Note
torchopt.chain
will sequentially conduct transformations, so the order matters. For example, we need to first conduct gradient normalization and then conduct the optimizer step. The order should be (clip, sgd) in torchopt.chain
function.
Here is an example of chaining torchopt.clip_grad_norm
and torchopt.adam
for functional optimizer and OOP optimizer.
func_optimizer = torchopt.chain(torchopt.clip_grad_norm(max_norm=2.0), torchopt.adam(1e-1))
oop_optimizer = torchopt.Optimizer(net.parameters() func_optimizer)
Users can also add optimizer hook to control the gradient flow.
torchopt.hook.register_hook torchopt.hook.zero_nan_hook torchopt.hook.nan_to_num_hook
For example, torchopt.hook.zero_nan_hook
registers hook to the first-order gradients. During the backpropagation, the NaN gradients will be set to 0. Here is an example of such operation coupled with torchopt.chain
.
impl = torchopt.chain(torchopt.hook.register_hook(torchopt.hook.zero_nan_hook), torchopt.adam(1e-1))
TorchOpt also provides implementations of learning rate schedulers, which can be used to control the learning rate during the training process. TorchOpt mainly offers the linear learning rate scheduler and the polynomial learning rate scheduler.
torchopt.schedule.linear_schedule torchopt.schedule.polynomial_schedule
Here is an example of combining optimizer with learning rate scheduler.
functional_adam = torchopt.adam(
lr=torchopt.schedule.linear_schedule(
init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000
)
)
adam = torchopt.Adam(
net.parameters(),
lr=torchopt.schedule.linear_schedule(
init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000
),
)
Check the notebook tutorial at Functional Optimizer.