Skip to content

Commit

Permalink
[ADD] Clean DirectionalDampedNewtonComputation (#27)
Browse files Browse the repository at this point in the history
Adds directionally damped Newton step computation with cleaned up API.

Aso fixes a bug in the eigenvalue criterion in the tests.
It always picked one more eigenvalue than specified.

---

* [BUG] Fix top_k criterion

* [ADD] Re-introduce damped Newton step with tests

* [DOC] Add docstring, clean up unused variables/prints

* [REF] Remove unused import

* [REF] Move damping cases to settings

* [DOC] Improve test function documentation

* [DEL] Remove blank lines

* [DOC] Add documentation for private methods

Co-authored-by: Felix Dangel <fdangel@tue.mpg.de>
  • Loading branch information
f-dangel and f-dangel committed Jun 22, 2022
1 parent 1f6c00d commit d871592
Show file tree
Hide file tree
Showing 8 changed files with 590 additions and 5 deletions.
7 changes: 4 additions & 3 deletions docs/rtd/computations.rst
Expand Up @@ -20,7 +20,8 @@ GGN eigenpairs (eigenvalues + eigenvector)
.. autoclass:: vivit.DirectionalDerivativesComputation
:members: __init__, get_extensions, get_extension_hook, get_result

Newton steps
--------------
Directionally damped Newton steps
---------------------------------

TODO
.. autoclass:: vivit.DirectionalDampedNewtonComputation
:members: __init__, get_extensions, get_extension_hook, get_result
34 changes: 34 additions & 0 deletions test/implementation/optim_autograd.py
Expand Up @@ -2,6 +2,8 @@

from test.implementation.autograd import AutogradExtensions

from torch import einsum


class AutogradOptimExtensions(AutogradExtensions):
"""Autograd implementation of optimizer functionality with similar API."""
Expand All @@ -23,3 +25,35 @@ def directional_derivatives(
)

return gammas, lambdas

def directional_damped_newton(
self, param_groups, subsampling_grad=None, subsampling_ggn=None
):

group_gammas, group_evecs = self.gammas_ggn(
param_groups,
ggn_subsampling=subsampling_ggn,
grad_subsampling=subsampling_grad,
directions=True,
)
group_lambdas = self.lambdas_ggn(
param_groups,
ggn_subsampling=subsampling_ggn,
lambda_subsampling=subsampling_ggn,
)

newton_steps = []

for group, gammas, lambdas, evecs in zip(
param_groups, group_gammas, group_lambdas, group_evecs
):
dummy_gram_evecs = None
dummy_evals = None
deltas = group["damping"](dummy_gram_evecs, dummy_evals, gammas, lambdas)

coefficients = -gammas.mean(0) / (lambdas.mean(0) + deltas)
newton = einsum("id,d->i", evecs, coefficients)

newton_steps.append(newton)

return newton_steps
37 changes: 36 additions & 1 deletion test/implementation/optim_backpack.py
Expand Up @@ -3,8 +3,12 @@
from test.implementation.backpack import BackpackExtensions

from backpack import backpack
from torch import cat

from vivit.optim import DirectionalDerivativesComputation
from vivit.optim import (
DirectionalDampedNewtonComputation,
DirectionalDerivativesComputation,
)


class BackpackOptimExtensions(BackpackExtensions):
Expand Down Expand Up @@ -38,3 +42,34 @@ def directional_derivatives(
lambdas.append(group_lambdas)

return gammas, lambdas

def directional_damped_newton(
self,
param_groups,
subsampling_grad=None,
subsampling_ggn=None,
mc_samples_ggn=0,
):
computations = DirectionalDampedNewtonComputation(
subsampling_grad=subsampling_grad,
subsampling_ggn=subsampling_ggn,
mc_samples_ggn=mc_samples_ggn,
)

_, _, loss = self.problem.forward_pass()

with backpack(
*computations.get_extensions(),
extension_hook=computations.get_extension_hook(param_groups),
):
loss.backward()

newton_steps = []

for group in param_groups:
group_newton_step = computations.get_result(group)
# flatten and concatenate over parameters in group
group_newton_step = cat([n.flatten() for n in group_newton_step])
newton_steps.append(group_newton_step)

return newton_steps
43 changes: 42 additions & 1 deletion test/optim/settings.py
Expand Up @@ -2,6 +2,9 @@

from test.problem import make_test_problems
from test.settings import SETTINGS
from typing import Callable

from torch import Tensor, ones

PROBLEMS = make_test_problems(SETTINGS)
IDS = [problem.make_id() for problem in PROBLEMS]
Expand Down Expand Up @@ -35,7 +38,7 @@ def criterion(evals):
shift = 0
candidates = evals
else:
shift = num_evals - 1 - k
shift = num_evals - k
candidates = evals[shift:]

return [idx + shift for idx, ev in enumerate(candidates) if ev > must_exceed]
Expand Down Expand Up @@ -99,3 +102,41 @@ def is_bias(name, param):

PARAM_BLOCKS_FN.append(weights_and_biases)
PARAM_BLOCKS_FN_IDS.append("param_groups=weights_and_biases")


def create_constant_damping(
damping: float,
) -> Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]:
"""Create a damping function with constant damping along all directions.
Args:
damping: Scale of the constant damping.
Returns:
Function that can be used as ``'damping'`` entry in a parameter group to
specify the directional damping.
"""

def constant_damping(
evals: Tensor, gram_evecs: Tensor, gammas: Tensor, lambdas: Tensor
) -> Tensor:
"""Constant directional damping function.
Args:
evals: Eigenvalues along the directions. Shape ``[K]``.
gram_evecs: Directions in Gram space. Shape ``[NC, K]``
gammas: Directional gradients. Shape ``[N, K]``.
lambdas: Directional curvatures. Shape ``[N, K]``.
Returns:
Directional dampings of shape ``[K]``.
"""
K = gammas.shape[1]
return damping * ones(K, dtype=gammas.dtype, device=gammas.device)

return constant_damping


DAMPING_VALUES = [1.0]
DAMPINGS = [create_constant_damping(d) for d in DAMPING_VALUES]
DAMPING_IDS = [f"damping={d}" for d in DAMPING_VALUES]
74 changes: 74 additions & 0 deletions test/optim/test_directional_damped_newton.py
@@ -0,0 +1,74 @@
"""Test ``vivit.optim.directional_damped_newton``."""

from test.implementation.optim_autograd import AutogradOptimExtensions
from test.implementation.optim_backpack import BackpackOptimExtensions
from test.optim.settings import (
CRITERIA,
CRITERIA_IDS,
DAMPING_IDS,
DAMPINGS,
IDS_REDUCTION_MEAN,
PARAM_BLOCKS_FN,
PARAM_BLOCKS_FN_IDS,
PROBLEMS_REDUCTION_MEAN,
SUBSAMPLINGS_GGN,
SUBSAMPLINGS_GGN_IDS,
SUBSAMPLINGS_GRAD,
SUBSAMPLINGS_GRAD_IDS,
)
from test.problem import ExtensionsTestProblem
from test.utils import check_sizes_and_values
from typing import Callable, List, Union

from pytest import mark
from torch import Tensor


@mark.parametrize("param_groups_fn", PARAM_BLOCKS_FN, ids=PARAM_BLOCKS_FN_IDS)
@mark.parametrize("subsampling_ggn", SUBSAMPLINGS_GGN, ids=SUBSAMPLINGS_GGN_IDS)
@mark.parametrize("subsampling_grad", SUBSAMPLINGS_GRAD, ids=SUBSAMPLINGS_GRAD_IDS)
@mark.parametrize("criterion", CRITERIA, ids=CRITERIA_IDS)
@mark.parametrize("damping", DAMPINGS, ids=DAMPING_IDS)
@mark.parametrize("problem", PROBLEMS_REDUCTION_MEAN, ids=IDS_REDUCTION_MEAN)
def test_directional_derivatives(
problem: ExtensionsTestProblem,
criterion: Callable[[Tensor], List[int]],
subsampling_grad: Union[List[int], None],
subsampling_ggn: Union[List[int], None],
param_groups_fn: Callable,
damping: Callable[[Tensor, Tensor, Tensor, Tensor], Tensor],
):
"""Compare damped Newton steps.
Args:
problem: Test case.
criterion: Filter function to select directions from eigenvalues.
subsampling_grad: Indices of samples used for gradient sub-sampling.
``None`` (equivalent to ``list(range(batch_size))``) uses all mini-batch
samples to compute directional gradients . Defaults to ``None`` (no
gradient sub-sampling).
subsampling_ggn: Indices of samples used for GGN curvature sub-sampling.
``None`` (equivalent to ``list(range(batch_size))``) uses all mini-batch
samples to compute directions and directional curvatures. Defaults to
``None`` (no curvature sub-sampling).
param_groups_fn: Function that creates the `param_groups` from the model's
named parameters and ``criterion``.
damping: Function that generates the directional dampings.
"""
problem.set_up()

param_groups = param_groups_fn(problem.model.named_parameters(), criterion)
for group in param_groups:
group["damping"] = damping

ag_newton = AutogradOptimExtensions(problem).directional_damped_newton(
param_groups, subsampling_grad=subsampling_grad, subsampling_ggn=subsampling_ggn
)

bp_newton = BackpackOptimExtensions(problem).directional_damped_newton(
param_groups, subsampling_grad=subsampling_grad, subsampling_ggn=subsampling_ggn
)

check_sizes_and_values(ag_newton, bp_newton, rtol=1e-5, atol=1e-5)

problem.tear_down()
2 changes: 2 additions & 0 deletions vivit/__init__.py
Expand Up @@ -4,6 +4,7 @@
from vivit import extensions
from vivit.linalg.eigh import EighComputation
from vivit.linalg.eigvalsh import EigvalshComputation
from vivit.optim.directional_damped_newton import DirectionalDampedNewtonComputation
from vivit.optim.directional_derivatives import DirectionalDerivativesComputation

__all__ = [
Expand All @@ -12,4 +13,5 @@
"EigvalshComputation",
"EighComputation",
"DirectionalDerivativesComputation",
"DirectionalDampedNewtonComputation",
]
2 changes: 2 additions & 0 deletions vivit/optim/__init__.py
@@ -1,7 +1,9 @@
"""Optimization methods using low-rank representations of the GGN/Fisher."""

from vivit.optim.directional_damped_newton import DirectionalDampedNewtonComputation
from vivit.optim.directional_derivatives import DirectionalDerivativesComputation

__all__ = [
"DirectionalDerivativesComputation",
"DirectionalDampedNewtonComputation",
]

0 comments on commit d871592

Please sign in to comment.