Skip to content

Commit

Permalink
[ADD] Simplify DirectionalDerivatives API (#17)
Browse files Browse the repository at this point in the history
Exotic features, like using different GGNs to compute directions and
directional curvatures, as well as full control of which intermediate buffers
to keep, have been deprecated in favor of a simpler API.

- Remove Newton step computation for now as it was internally relying on
 `DirectionalDerivatives`
- Remove many utilities and associated tests from the exotic features
- Forbid duplicate indices in `subsampling`
- Always delete intermediate buffers other than the target quantities

---

* [REF] Rename `GramComputations` → `DirectionalDerivativesComputation`

* [DEL] Remove `compute_gammas` and `keep_gammas` arguments

Always compute and keep first-order directional derivatives.

* [DEL] Remove `compute_lambdas` and `keep_lambdas` arguments

* [DEL] Remove Newton step and Newton optimizer

Currently, the Newton step depends on the directional derivatives.
I will be easier to clean up the latter's API before reactivating the
Newton step feature.

* [DEL] Remove `keep_gram_mat` argument

Always delete the Gram matrix.

* [DEL] Remove `keep_gram_evecs` argument

Always remove the Gram matrix eigenvectors.

* [DEL] Remove `keep_gram_evals` argument

Always delete the Gram matrix eigenvalues.

* [DEL] Remove `keep_batch_size` argument

Always delete batch size.

* [DEL] Remove `keep_backpack_buffers` argument

Always delete BackPACK buffers.

* [DEL] Remove `param_groups` from `get_extensions`

* [ADD] Simplify `DirectionalDerivativesComputation` API and tests

* [ADD] Forbid `subsampling` with repeated indices (#16)

* [REF] Rename file containing argument checks

* [ADD] Forbid sub-sampling with repeated indices

Co-authored-by: Felix Dangel <fdangel@tue.mpg.de>

* [ADD] Simplify internals of `DirectionalDerivativesComputation`

Forbid duplicates in subsampling, share subsampling for directions and
directional curvatures.

Co-authored-by: Felix Dangel <fdangel@tue.mpg.de>
  • Loading branch information
f-dangel and f-dangel committed Feb 21, 2022
1 parent fff88a0 commit 5310b0c
Show file tree
Hide file tree
Showing 22 changed files with 713 additions and 3,307 deletions.
93 changes: 11 additions & 82 deletions test/implementation/optim_autograd.py
Expand Up @@ -2,95 +2,24 @@

from test.implementation.autograd import AutogradExtensions

import torch
from backpack.utils.convert_parameters import vector_to_parameter_list


class AutogradOptimExtensions(AutogradExtensions):
"""Autograd implementation of optimizer functionality with similar API."""

def gammas_ggn(self, top_k, subsampling_directions=None, subsampling_first=None):
"""First-order directional derivatives along the top-k GGN eigenvectors.
Args:
top_k (int): Number of leading eigenvectors used as directions. Will be
clipped to ``[1, max]`` with ``max`` the maximum number of nontrivial
eigenvalues.
subsampling_directions ([int] or None): Indices of samples used to compute
Newton directions. If ``None``, all samples in the batch will be used.
subsampling_first ([int], optional): Sample indices used for individual
gradients.
"""
return super().gammas_ggn(
top_k,
ggn_subsampling=subsampling_directions,
grad_subsampling=subsampling_first,
)

def lambdas_ggn(self, top_k, subsampling_directions=None, subsampling_second=None):
"""Second-order directional derivatives along the top-k GGN eigenvectors.
Args:
top_k (int): Number of leading eigenvectors used as directions. Will be
clipped to ``[1, max]`` with ``max`` the maximum number of nontrivial
eigenvalues.
subsampling_directions ([int] or None): Indices of samples used to compute
Newton directions. If ``None``, all samples in the batch will be used.
subsampling_second ([int], optional): Sample indices used for individual
curvature matrices.
"""
return super().lambdas_ggn(
top_k,
ggn_subsampling=subsampling_directions,
lambda_subsampling=subsampling_second,
)

def newton_step(
self,
param_groups,
damping,
subsampling_directions=None,
subsampling_first=None,
subsampling_second=None,
def directional_derivatives(
self, param_groups, subsampling_grad=None, subsampling_ggn=None
):
"""Directionally-damped Newton step along the top-k GGN eigenvectors.
Args:
param_groups ([dict]): Parameter groups like for ``torch.nn.Optimizer``s.
damping (vivit.optim.damping._Damping): Policy for selecting
dampings along a direction from first- and second- order directional
derivatives.
subsampling_directions ([int] or None): Indices of samples used to compute
Newton directions. If ``None``, all samples in the batch will be used.
subsampling_first ([int], optional): Sample indices used for individual
gradients.
subsampling_second ([int], optional): Sample indices used for individual
curvature matrices.
"""
group_gammas, group_evecs = super().gammas_ggn(
gammas = self.gammas_ggn(
param_groups,
ggn_subsampling=subsampling_directions,
grad_subsampling=subsampling_first,
directions=True,
grad_subsampling=subsampling_grad,
ggn_subsampling=subsampling_ggn,
directions=False,
)
group_lambdas = super().lambdas_ggn(

lambdas = self.lambdas_ggn(
param_groups,
ggn_subsampling=subsampling_directions,
lambda_subsampling=subsampling_second,
ggn_subsampling=subsampling_ggn,
lambda_subsampling=subsampling_ggn,
)

newton_steps = []

for group, gammas, evecs, lambdas in zip(
param_groups, group_gammas, group_evecs, group_lambdas
):
deltas = damping(gammas, lambdas)

batch_axis = 0
scale = -gammas.mean(batch_axis) / (lambdas.mean(batch_axis) + deltas)

step = torch.einsum("d,id->i", scale, evecs)

newton_steps.append(vector_to_parameter_list(step, group["params"]))

return newton_steps
return gammas, lambdas
218 changes: 14 additions & 204 deletions test/implementation/optim_backpack.py
@@ -1,225 +1,35 @@
"""BackPACK implementation of operations used in ``vivit.optim``."""

from test.implementation.backpack import BackpackExtensions
from typing import Any, Dict, List, Optional

from backpack import backpack
from torch import Tensor

from vivit.optim import GramComputations
from vivit.optim.computations import BaseComputations
from vivit.optim.damped_newton import DampedNewton
from vivit.optim.damping import _DirectionalCoefficients
from vivit.optim import DirectionalDerivativesComputation


class BackpackOptimExtensions(BackpackExtensions):
def gammas_ggn(
self, param_groups, subsampling_directions=None, subsampling_first=None
):
"""First-order directional derivatives along leading GGN eigenvectors via
``vivit.optim.computations``.
Args:
param_groups ([dict]): Parameter groups like for ``torch.nn.Optimizer``s.
subsampling_directions ([int] or None): Indices of samples used to compute
Newton directions. If ``None``, all samples in the batch will be used.
subsampling_first ([int], optional): Sample indices used for individual
gradients.
"""
computations = GramComputations(
subsampling_directions=subsampling_directions,
subsampling_first=subsampling_first,
)

_, _, loss = self.problem.forward_pass()

with backpack(
*computations.get_extensions(param_groups),
extension_hook=computations.get_extension_hook(
param_groups,
keep_backpack_buffers=False,
keep_gram_mat=False,
keep_gram_evecs=False,
keep_batch_size=False,
keep_gammas=True,
keep_lambdas=False,
keep_gram_evals=False,
),
):
loss.backward()

return [computations._gammas[id(group)] for group in param_groups]

def lambdas_ggn(
self, param_groups, subsampling_directions=None, subsampling_second=None
):
"""Second-order directional derivatives along leading GGN eigenvectors via
``vivit.optim.computations``.
Args:
param_groups ([dict]): Parameter groups like for ``torch.nn.Optimizer``s.
subsampling_directions ([int] or None): Indices of samples used to compute
Newton directions. If ``None``, all samples in the batch will be used.
subsampling_second ([int], optional): Sample indices used for individual
curvature matrices.
"""
computations = GramComputations(
subsampling_directions=subsampling_directions,
subsampling_second=subsampling_second,
)

_, _, loss = self.problem.forward_pass()

with backpack(
*computations.get_extensions(param_groups),
extension_hook=computations.get_extension_hook(
param_groups,
keep_backpack_buffers=False,
keep_gram_mat=False,
keep_gram_evecs=False,
keep_batch_size=False,
keep_gammas=False,
keep_lambdas=True,
keep_gram_evals=False,
),
):
loss.backward()

return [computations._lambdas[id(group)] for group in param_groups]

def newton_step(
def directional_derivatives(
self,
param_groups,
damping,
subsampling_directions=None,
subsampling_first=None,
subsampling_second=None,
subsampling_grad=None,
subsampling_ggn=None,
mc_samples_ggn=0,
):
"""Directionally-damped Newton step along the top-k GGN eigenvectors.
Args:
param_groups ([dict]): Parameter groups like for ``torch.nn.Optimizer``s.
damping (vivit.optim.damping._Damping): Policy for selecting
dampings along a direction from first- and second- order directional
derivatives.
subsampling_directions ([int] or None): Indices of samples used to compute
Newton directions. If ``None``, all samples in the batch will be used.
subsampling_first ([int], optional): Sample indices used for individual
gradients.
subsampling_second ([int], optional): Sample indices used for individual
curvature matrices.
"""
computations = BaseComputations(
subsampling_directions=subsampling_directions,
subsampling_first=subsampling_first,
subsampling_second=subsampling_second,
"""Compute 1st and 2nd-order directional derivatives along GGN eigenvectors."""
computations = DirectionalDerivativesComputation(
subsampling_grad=subsampling_grad,
subsampling_ggn=subsampling_ggn,
mc_samples_ggn=mc_samples_ggn,
)

_, _, loss = self.problem.forward_pass()

savefield = "test_newton_step"

with backpack(
*computations.get_extensions(param_groups),
extension_hook=computations.get_extension_hook(
param_groups,
damping,
savefield,
keep_gram_mat=False,
keep_gram_evals=False,
keep_gram_evecs=False,
keep_gammas=False,
keep_lambdas=False,
keep_batch_size=False,
keep_coefficients=False,
keep_newton_step=False,
keep_backpack_buffers=False,
),
*computations.get_extensions(),
extension_hook=computations.get_extension_hook(param_groups),
):
loss.backward()

newton_step = [
[getattr(param, savefield) for param in group["params"]]
for group in param_groups
return [computations._gammas[id(group)] for group in param_groups], [
computations._lambdas[id(group)] for group in param_groups
]

return newton_step

def optim_newton_step(
self,
param_groups: List[Dict[str, Any]],
damping: _DirectionalCoefficients,
subsampling_directions: Optional[List[int]] = None,
subsampling_first: Optional[List[int]] = None,
subsampling_second: Optional[List[int]] = None,
use_closure: bool = False,
):
"""Directionally-damped Newton step along the top-k GGN eigenvectors.
Uses the ``DampedNewton`` optimizer to compute Newton steps.
Args:
param_groups: Parameter groups like for ``torch.nn.Optimizer``s.
damping: Computes Newton coefficients from first- and second- order
directional derivatives.
subsampling_directions: Indices of samples used to compute
Newton directions. If ``None``, all samples in the batch will be used.
subsampling_first: Sample indices used for individual gradients.
subsampling_second: Sample indices used for individual curvature matrices.
use_closure: Whether to use a closure in the optimizer. Default: ``False``.
"""
computations = BaseComputations(
subsampling_directions=subsampling_directions,
subsampling_first=subsampling_first,
subsampling_second=subsampling_second,
)

opt = DampedNewton(
param_groups,
coefficients=damping,
computations=computations,
criterion=None,
)

savefield = "test_newton_step"
DampedNewton.SAVEFIELD = savefield

if use_closure:

def closure() -> Tensor:
"""Evaluate the loss on a fixed mini-batch.
Returns:
Mini-batch loss.
"""
_, _, loss = self.problem.forward_pass()
return loss

opt.step(closure=closure)

else:
_, _, loss = self.problem.forward_pass()

with backpack(
*opt.get_extensions(),
extension_hook=opt.get_extension_hook(
keep_gram_mat=False,
keep_gram_evals=False,
keep_gram_evecs=False,
keep_gammas=False,
keep_lambdas=False,
keep_batch_size=False,
keep_coefficients=False,
keep_newton_step=False,
keep_backpack_buffers=False,
),
):
loss.backward()
opt.step()

newton_step = [
[getattr(param, savefield) for param in group["params"]]
for group in param_groups
]

return newton_step
8 changes: 2 additions & 6 deletions test/linalg/test_eigh.py
Expand Up @@ -40,9 +40,7 @@ def test_ggn_eigh_eigenvalues(
"""
problem.set_up()

param_groups = param_groups_fn(problem.model.named_parameters())
for group in param_groups:
group["criterion"] = keep_all
param_groups = param_groups_fn(problem.model.named_parameters(), keep_all)

backpack_eigh = BackpackLinalgExtensions(problem).eigh_ggn(
param_groups, subsampling
Expand Down Expand Up @@ -97,9 +95,7 @@ def test_ggn_eigh_eigenvectors(
"""
problem.set_up()

param_groups = param_groups_fn(problem.model.named_parameters())
for group in param_groups:
group["criterion"] = keep_nonzero
param_groups = param_groups_fn(problem.model.named_parameters(), keep_nonzero)

backpack_eigh = BackpackLinalgExtensions(problem).eigh_ggn(
param_groups, subsampling
Expand Down
5 changes: 2 additions & 3 deletions test/linalg/test_eigvalsh.py
Expand Up @@ -39,9 +39,8 @@ def test_ggn_eigvalsh(
"""
problem.set_up()

param_groups = param_groups_fn(problem.model.named_parameters())
for group in param_groups:
group["criterion"] = keep_all
# TODO Remove dependency on 'criterion' from autograd implementation
param_groups = param_groups_fn(problem.model.named_parameters(), keep_all)

backpack_result = BackpackLinalgExtensions(problem).eigvalsh_ggn(
param_groups, subsampling
Expand Down

0 comments on commit 5310b0c

Please sign in to comment.