Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Update] Support SVD method to calculate M^{-1/p} #103

Merged
merged 47 commits into from
Feb 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
c6c1c72
update: grafting
kozistr Feb 4, 2023
85e0cd2
refactor: is_precondition_step
kozistr Feb 4, 2023
ca9de76
feature: support SVD method to calculate M^{-1/p}
kozistr Feb 4, 2023
607f6ae
update: compute_power_svd
kozistr Feb 4, 2023
385be0a
refactor: compute_power_schur_newton
kozistr Feb 4, 2023
fc020c1
docs: compute_power_svd() docstring
kozistr Feb 4, 2023
e6595cd
feature: use_svd
kozistr Feb 4, 2023
e6f6e1a
update: test_compute_power
kozistr Feb 4, 2023
4f5829a
update: test_compute_power
kozistr Feb 4, 2023
b56b794
update: test_shampoo_pre_conditioner
kozistr Feb 4, 2023
010f8d1
docs: compute_power_svd docstring
kozistr Feb 4, 2023
49de081
update: compute_power_svd
kozistr Feb 4, 2023
c85eb1e
docs: compute_power_schur_newton, _compute_power_svd
kozistr Feb 4, 2023
31fa001
update: compute_pre_conditioners
kozistr Feb 4, 2023
b1b1dc3
fix: typo
kozistr Feb 4, 2023
466f6e7
update: Shampoo recipes
kozistr Feb 4, 2023
bfa2cf4
update: Shampoo recipes
kozistr Feb 4, 2023
06ae7b6
update: recipes
kozistr Feb 4, 2023
60ae998
update: recipes
kozistr Feb 4, 2023
d2c5316
update: recipes
kozistr Feb 4, 2023
ebf1240
feature: perform batch svd when the shapes of the pre-conditioners ar…
kozistr Feb 5, 2023
c70fed0
update: block_size to 512
kozistr Feb 5, 2023
a05ffd2
docs: Shampoo docstring
kozistr Feb 5, 2023
898a8ef
update: block size to 256
kozistr Feb 5, 2023
347bffb
update: optimizers
kozistr Feb 5, 2023
b2fca31
docs: Shampoo optimizer
kozistr Feb 5, 2023
5004def
update: shampoo recipe
kozistr Feb 5, 2023
42ebe2b
update: test_get_supported_optimizers
kozistr Feb 5, 2023
21765ee
feature: Shampoo optimizer
kozistr Feb 5, 2023
995bdd6
update: test_scalable_shampoo_optimizer
kozistr Feb 5, 2023
dc0b41e
update: test_update_frequency
kozistr Feb 5, 2023
cb16c18
update: test_bf16_gradient
kozistr Feb 5, 2023
c94e7ac
update: Shampoo recipe
kozistr Feb 5, 2023
001a077
update: NO_SPARSE_OPTIMIZERS
kozistr Feb 5, 2023
5dbc7b5
style: fix ERA001
kozistr Feb 5, 2023
c15e503
update: __name__ to __str__
kozistr Feb 5, 2023
0e94d1f
update: cases
kozistr Feb 5, 2023
02696a7
update: compute_pre_conditioners
kozistr Feb 5, 2023
a8a0c4b
update: default value of matrix_eps to 1e-6
kozistr Feb 5, 2023
0e879a4
update: copy to inv_pre_cond
kozistr Feb 5, 2023
ca2d5ae
update: Shampoo optimizer
kozistr Feb 5, 2023
10ce92e
docs: compute_pre_conditioners docstring
kozistr Feb 5, 2023
ceef415
update: power_iter
kozistr Feb 5, 2023
2b9221e
update: power_iter
kozistr Feb 5, 2023
24282a9
update: compute_power_schur_newton
kozistr Feb 5, 2023
e650bae
update: use_svd to False
kozistr Feb 5, 2023
01b5c5a
update: test_scalable_shampoo_pre_conditioner
kozistr Feb 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/optimizer_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,14 @@ Shampoo
.. autoclass:: pytorch_optimizer.Shampoo
:members:

.. _ScalableShampoo:

ScalableShampoo
---------------

.. autoclass:: pytorch_optimizer.ScalableShampoo
:members:

.. _GSAM:

GSAM
Expand Down
16 changes: 12 additions & 4 deletions docs/util_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,20 @@ matrix_power
.. autoclass:: pytorch_optimizer.matrix_power
:members:

.. _compute_power:
.. _compute_power_schur_newton:

compute_power
-------------
compute_power_schur_newton
--------------------------

.. autoclass:: pytorch_optimizer.compute_power
.. autoclass:: pytorch_optimizer.compute_power_schur_newton
:members:

.. _compute_power_svd:

compute_power_svd
-----------------

.. autoclass:: pytorch_optimizer.compute_power_svd
:members:

.. _merge_small_dims:
Expand Down
6 changes: 4 additions & 2 deletions pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from pytorch_optimizer.optimizer.ranger21 import Ranger21
from pytorch_optimizer.optimizer.sam import SAM
from pytorch_optimizer.optimizer.sgdp import SGDP
from pytorch_optimizer.optimizer.shampoo import Shampoo
from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo
from pytorch_optimizer.optimizer.shampoo_utils import (
AdaGradGraft,
BlockPartitioner,
Expand All @@ -52,7 +52,8 @@
RMSPropGraft,
SGDGraft,
SQRTNGraft,
compute_power,
compute_power_schur_newton,
compute_power_svd,
matrix_power,
merge_small_dims,
power_iter,
Expand Down Expand Up @@ -86,6 +87,7 @@
Ranger21,
SGDP,
Shampoo,
ScalableShampoo,
DAdaptAdaGrad,
DAdaptAdam,
DAdaptSGD,
Expand Down
153 changes: 142 additions & 11 deletions pytorch_optimizer/optimizer/shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,134 @@
RMSPropGraft,
SGDGraft,
SQRTNGraft,
compute_power_svd,
)


class Shampoo(Optimizer, BaseOptimizer):
r"""Preconditioned Stochastic Tensor Optimization.

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
:param momentum: float. momentum.
:param weight_decay: float. weight decay (L2 penalty).
:param preconditioning_compute_steps: int. performance tuning params for controlling memory and compute
requirements. How often to compute pre-conditioner.
:param matrix_eps: float. term added to the denominator to improve numerical stability.
"""

def __init__(
self,
params: PARAMETERS,
lr: float = 1e-3,
momentum: float = 0.0,
weight_decay: float = 0.0,
preconditioning_compute_steps: int = 1,
matrix_eps: float = 1e-6,
):
self.lr = lr
self.momentum = momentum
self.weight_decay = weight_decay
self.preconditioning_compute_steps = preconditioning_compute_steps
self.matrix_eps = matrix_eps

self.validate_parameters()

defaults: DEFAULTS = {
'lr': lr,
'momentum': momentum,
'weight_decay': weight_decay,
}
super().__init__(params, defaults)

def validate_parameters(self):
self.validate_learning_rate(self.lr)
self.validate_momentum(self.momentum)
self.validate_weight_decay(self.weight_decay)
self.validate_update_frequency(self.preconditioning_compute_steps)
self.validate_epsilon(self.matrix_eps)

@property
def __str__(self) -> str:
return 'Shampoo'

@torch.no_grad()
def reset(self):
for group in self.param_groups:
for p in group['params']:
state = self.state[p]

state['step'] = 0

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
momentum = group['momentum']
for p in group['params']:
if p.grad is None:
continue

grad = p.grad
if grad.is_sparse:
raise NoSparseGradientError(self.__str__)

state = self.state[p]
if len(state) == 0:
state['step'] = 0

if momentum > 0.0:
state['momentum_buffer'] = grad.clone()

for dim_id, dim in enumerate(grad.size()):
state[f'pre_cond_{dim_id}'] = self.matrix_eps * torch.eye(dim, out=grad.new(dim, dim))
state[f'inv_pre_cond_{dim_id}'] = grad.new(dim, dim).zero_()

state['step'] += 1

if momentum > 0.0:
grad.mul_(1.0 - momentum).add_(state['momentum_buffer'], alpha=momentum)

if group['weight_decay'] > 0.0:
grad.add_(p, alpha=group['weight_decay'])

order: int = grad.ndimension()
original_size: int = grad.size()
for dim_id, dim in enumerate(grad.size()):
pre_cond = state[f'pre_cond_{dim_id}']
inv_pre_cond = state[f'inv_pre_cond_{dim_id}']

grad = grad.transpose_(0, dim_id).contiguous()
transposed_size = grad.size()

grad = grad.view(dim, -1)

grad_t = grad.t()
pre_cond.add_(grad @ grad_t)
if state['step'] % self.preconditioning_compute_steps == 0:
inv_pre_cond.copy_(compute_power_svd(pre_cond, -1.0 / order))

if dim_id == order - 1:
grad = grad_t @ inv_pre_cond
grad = grad.view(original_size)
else:
grad = inv_pre_cond @ grad
grad = grad.view(transposed_size)

state['momentum_buffer'] = grad

p.add_(grad, alpha=-group['lr'])

return loss


class ScalableShampoo(Optimizer, BaseOptimizer):
r"""Scalable Preconditioned Stochastic Tensor Optimization.

Reference : https://github.com/google-research/google-research/blob/master/scalable_shampoo/pytorch/shampoo.py.

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
Expand All @@ -45,6 +167,10 @@ class Shampoo(Optimizer, BaseOptimizer):
:param nesterov: bool. Nesterov momentum.
:param diagonal_eps: float. term added to the denominator to improve numerical stability.
:param matrix_eps: float. term added to the denominator to improve numerical stability.
:param use_svd: bool. use SVD instead of Schur-Newton method to calculate M^{-1/p}.
Theoretically, Schur-Newton method is faster than SVD method to calculate M^{-1/p}.
However, the inefficiency of the loop code, SVD is much faster than that.
see https://github.com/kozistr/pytorch_optimizer/pull/103
"""

def __init__(
Expand All @@ -60,14 +186,15 @@ def __init__(
start_preconditioning_step: int = 5,
preconditioning_compute_steps: int = 1,
statistics_compute_steps: int = 1,
block_size: int = 128,
block_size: int = 256,
no_preconditioning_for_layers_with_dim_gt: int = 8192,
shape_interpretation: bool = True,
graft_type: int = LayerWiseGrafting.SGD,
pre_conditioner_type: int = PreConditionerType.ALL,
nesterov: bool = True,
diagonal_eps: float = 1e-10,
matrix_eps: float = 1e-6,
use_svd: bool = False,
):
self.lr = lr
self.betas = betas
Expand All @@ -87,6 +214,7 @@ def __init__(
self.nesterov = nesterov
self.diagonal_eps = diagonal_eps
self.matrix_eps = matrix_eps
self.use_svd = use_svd

self.validate_parameters()

Expand All @@ -109,7 +237,7 @@ def validate_parameters(self):

@property
def __str__(self) -> str:
return 'Shampoo'
return 'ScalableShampoo'

@torch.no_grad()
def reset(self):
Expand All @@ -128,6 +256,7 @@ def reset(self):
self.shape_interpretation,
self.matrix_eps,
self.pre_conditioner_type,
self.use_svd,
)
if self.graft_type == LayerWiseGrafting.ADAGRAD:
state['graft'] = AdaGradGraft(p, self.diagonal_eps)
Expand All @@ -140,6 +269,9 @@ def reset(self):
else:
state['graft'] = Graft(p)

def is_precondition_step(self, step: int) -> bool:
return step >= self.start_preconditioning_step

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
Expand Down Expand Up @@ -170,6 +302,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
self.shape_interpretation,
self.matrix_eps,
self.pre_conditioner_type,
self.use_svd,
)
if self.graft_type == LayerWiseGrafting.ADAGRAD:
state['graft'] = AdaGradGraft(p, self.diagonal_eps)
Expand All @@ -185,27 +318,26 @@ def step(self, closure: CLOSURE = None) -> LOSS:
state['step'] += 1
pre_conditioner, graft = state['pre_conditioner'], state['graft']

# gather statistics, compute pre-conditioners
is_precondition_step: bool = self.is_precondition_step(state['step'])

graft.add_statistics(grad, beta2)
if state['step'] % self.statistics_compute_steps == 0:
pre_conditioner.add_statistics(grad)
if state['step'] % self.preconditioning_compute_steps == 0:
pre_conditioner.compute_pre_conditioners()

# pre-condition gradients
pre_conditioner_multiplier: float = group['lr'] if not self.decoupled_learning_rate else 1.0
graft_grad: torch.Tensor = graft.precondition_gradient(grad * pre_conditioner_multiplier)
shampoo_grad: torch.Tensor = grad
if state['step'] >= self.start_preconditioning_step:
if is_precondition_step:
shampoo_grad = pre_conditioner.preconditioned_grad(grad)

# grafting
graft_norm = torch.norm(graft_grad)
shampoo_norm = torch.norm(shampoo_grad)
if self.graft_type != LayerWiseGrafting.NONE:
graft_norm = torch.norm(graft_grad)
shampoo_norm = torch.norm(shampoo_grad)

shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16))

# apply weight decay (adam style)
if group['weight_decay'] > 0.0:
if not self.decoupled_weight_decay:
shampoo_grad.add_(p, alpha=group['weight_decay'])
Expand All @@ -214,11 +346,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
shampoo_grad.mul_(1.0 - group['lr'] * group['weight_decay'])
graft_grad.mul_(1.0 - group['lr'] * group['weight_decay'])

# Momentum and Nesterov momentum, if needed
state['momentum'].mul_(beta1).add_(shampoo_grad)
graft_momentum = graft.update_momentum(grad, beta1)

if state['step'] >= self.start_preconditioning_step:
if is_precondition_step:
momentum_update = state['momentum']
wd_update = shampoo_grad
else:
Expand Down