Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Simplify DirectionalDerivatives API #17

Merged
merged 15 commits into from Feb 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
93 changes: 11 additions & 82 deletions test/implementation/optim_autograd.py
Expand Up @@ -2,95 +2,24 @@

from test.implementation.autograd import AutogradExtensions

import torch
from backpack.utils.convert_parameters import vector_to_parameter_list


class AutogradOptimExtensions(AutogradExtensions):
"""Autograd implementation of optimizer functionality with similar API."""

def gammas_ggn(self, top_k, subsampling_directions=None, subsampling_first=None):
"""First-order directional derivatives along the top-k GGN eigenvectors.

Args:
top_k (int): Number of leading eigenvectors used as directions. Will be
clipped to ``[1, max]`` with ``max`` the maximum number of nontrivial
eigenvalues.
subsampling_directions ([int] or None): Indices of samples used to compute
Newton directions. If ``None``, all samples in the batch will be used.
subsampling_first ([int], optional): Sample indices used for individual
gradients.
"""
return super().gammas_ggn(
top_k,
ggn_subsampling=subsampling_directions,
grad_subsampling=subsampling_first,
)

def lambdas_ggn(self, top_k, subsampling_directions=None, subsampling_second=None):
"""Second-order directional derivatives along the top-k GGN eigenvectors.

Args:
top_k (int): Number of leading eigenvectors used as directions. Will be
clipped to ``[1, max]`` with ``max`` the maximum number of nontrivial
eigenvalues.
subsampling_directions ([int] or None): Indices of samples used to compute
Newton directions. If ``None``, all samples in the batch will be used.
subsampling_second ([int], optional): Sample indices used for individual
curvature matrices.
"""
return super().lambdas_ggn(
top_k,
ggn_subsampling=subsampling_directions,
lambda_subsampling=subsampling_second,
)

def newton_step(
self,
param_groups,
damping,
subsampling_directions=None,
subsampling_first=None,
subsampling_second=None,
def directional_derivatives(
self, param_groups, subsampling_grad=None, subsampling_ggn=None
):
"""Directionally-damped Newton step along the top-k GGN eigenvectors.

Args:
param_groups ([dict]): Parameter groups like for ``torch.nn.Optimizer``s.
damping (vivit.optim.damping._Damping): Policy for selecting
dampings along a direction from first- and second- order directional
derivatives.
subsampling_directions ([int] or None): Indices of samples used to compute
Newton directions. If ``None``, all samples in the batch will be used.
subsampling_first ([int], optional): Sample indices used for individual
gradients.
subsampling_second ([int], optional): Sample indices used for individual
curvature matrices.
"""
group_gammas, group_evecs = super().gammas_ggn(
gammas = self.gammas_ggn(
param_groups,
ggn_subsampling=subsampling_directions,
grad_subsampling=subsampling_first,
directions=True,
grad_subsampling=subsampling_grad,
ggn_subsampling=subsampling_ggn,
directions=False,
)
group_lambdas = super().lambdas_ggn(

lambdas = self.lambdas_ggn(
param_groups,
ggn_subsampling=subsampling_directions,
lambda_subsampling=subsampling_second,
ggn_subsampling=subsampling_ggn,
lambda_subsampling=subsampling_ggn,
)

newton_steps = []

for group, gammas, evecs, lambdas in zip(
param_groups, group_gammas, group_evecs, group_lambdas
):
deltas = damping(gammas, lambdas)

batch_axis = 0
scale = -gammas.mean(batch_axis) / (lambdas.mean(batch_axis) + deltas)

step = torch.einsum("d,id->i", scale, evecs)

newton_steps.append(vector_to_parameter_list(step, group["params"]))

return newton_steps
return gammas, lambdas
218 changes: 14 additions & 204 deletions test/implementation/optim_backpack.py
@@ -1,225 +1,35 @@
"""BackPACK implementation of operations used in ``vivit.optim``."""

from test.implementation.backpack import BackpackExtensions
from typing import Any, Dict, List, Optional

from backpack import backpack
from torch import Tensor

from vivit.optim import GramComputations
from vivit.optim.computations import BaseComputations
from vivit.optim.damped_newton import DampedNewton
from vivit.optim.damping import _DirectionalCoefficients
from vivit.optim import DirectionalDerivativesComputation


class BackpackOptimExtensions(BackpackExtensions):
def gammas_ggn(
self, param_groups, subsampling_directions=None, subsampling_first=None
):
"""First-order directional derivatives along leading GGN eigenvectors via
``vivit.optim.computations``.

Args:
param_groups ([dict]): Parameter groups like for ``torch.nn.Optimizer``s.
subsampling_directions ([int] or None): Indices of samples used to compute
Newton directions. If ``None``, all samples in the batch will be used.
subsampling_first ([int], optional): Sample indices used for individual
gradients.
"""
computations = GramComputations(
subsampling_directions=subsampling_directions,
subsampling_first=subsampling_first,
)

_, _, loss = self.problem.forward_pass()

with backpack(
*computations.get_extensions(param_groups),
extension_hook=computations.get_extension_hook(
param_groups,
keep_backpack_buffers=False,
keep_gram_mat=False,
keep_gram_evecs=False,
keep_batch_size=False,
keep_gammas=True,
keep_lambdas=False,
keep_gram_evals=False,
),
):
loss.backward()

return [computations._gammas[id(group)] for group in param_groups]

def lambdas_ggn(
self, param_groups, subsampling_directions=None, subsampling_second=None
):
"""Second-order directional derivatives along leading GGN eigenvectors via
``vivit.optim.computations``.

Args:
param_groups ([dict]): Parameter groups like for ``torch.nn.Optimizer``s.
subsampling_directions ([int] or None): Indices of samples used to compute
Newton directions. If ``None``, all samples in the batch will be used.
subsampling_second ([int], optional): Sample indices used for individual
curvature matrices.
"""
computations = GramComputations(
subsampling_directions=subsampling_directions,
subsampling_second=subsampling_second,
)

_, _, loss = self.problem.forward_pass()

with backpack(
*computations.get_extensions(param_groups),
extension_hook=computations.get_extension_hook(
param_groups,
keep_backpack_buffers=False,
keep_gram_mat=False,
keep_gram_evecs=False,
keep_batch_size=False,
keep_gammas=False,
keep_lambdas=True,
keep_gram_evals=False,
),
):
loss.backward()

return [computations._lambdas[id(group)] for group in param_groups]

def newton_step(
def directional_derivatives(
self,
param_groups,
damping,
subsampling_directions=None,
subsampling_first=None,
subsampling_second=None,
subsampling_grad=None,
subsampling_ggn=None,
mc_samples_ggn=0,
):
"""Directionally-damped Newton step along the top-k GGN eigenvectors.

Args:
param_groups ([dict]): Parameter groups like for ``torch.nn.Optimizer``s.
damping (vivit.optim.damping._Damping): Policy for selecting
dampings along a direction from first- and second- order directional
derivatives.
subsampling_directions ([int] or None): Indices of samples used to compute
Newton directions. If ``None``, all samples in the batch will be used.
subsampling_first ([int], optional): Sample indices used for individual
gradients.
subsampling_second ([int], optional): Sample indices used for individual
curvature matrices.
"""
computations = BaseComputations(
subsampling_directions=subsampling_directions,
subsampling_first=subsampling_first,
subsampling_second=subsampling_second,
"""Compute 1st and 2nd-order directional derivatives along GGN eigenvectors."""
computations = DirectionalDerivativesComputation(
subsampling_grad=subsampling_grad,
subsampling_ggn=subsampling_ggn,
mc_samples_ggn=mc_samples_ggn,
)

_, _, loss = self.problem.forward_pass()

savefield = "test_newton_step"

with backpack(
*computations.get_extensions(param_groups),
extension_hook=computations.get_extension_hook(
param_groups,
damping,
savefield,
keep_gram_mat=False,
keep_gram_evals=False,
keep_gram_evecs=False,
keep_gammas=False,
keep_lambdas=False,
keep_batch_size=False,
keep_coefficients=False,
keep_newton_step=False,
keep_backpack_buffers=False,
),
*computations.get_extensions(),
extension_hook=computations.get_extension_hook(param_groups),
):
loss.backward()

newton_step = [
[getattr(param, savefield) for param in group["params"]]
for group in param_groups
return [computations._gammas[id(group)] for group in param_groups], [
computations._lambdas[id(group)] for group in param_groups
]

return newton_step

def optim_newton_step(
self,
param_groups: List[Dict[str, Any]],
damping: _DirectionalCoefficients,
subsampling_directions: Optional[List[int]] = None,
subsampling_first: Optional[List[int]] = None,
subsampling_second: Optional[List[int]] = None,
use_closure: bool = False,
):
"""Directionally-damped Newton step along the top-k GGN eigenvectors.

Uses the ``DampedNewton`` optimizer to compute Newton steps.

Args:
param_groups: Parameter groups like for ``torch.nn.Optimizer``s.
damping: Computes Newton coefficients from first- and second- order
directional derivatives.
subsampling_directions: Indices of samples used to compute
Newton directions. If ``None``, all samples in the batch will be used.
subsampling_first: Sample indices used for individual gradients.
subsampling_second: Sample indices used for individual curvature matrices.
use_closure: Whether to use a closure in the optimizer. Default: ``False``.
"""
computations = BaseComputations(
subsampling_directions=subsampling_directions,
subsampling_first=subsampling_first,
subsampling_second=subsampling_second,
)

opt = DampedNewton(
param_groups,
coefficients=damping,
computations=computations,
criterion=None,
)

savefield = "test_newton_step"
DampedNewton.SAVEFIELD = savefield

if use_closure:

def closure() -> Tensor:
"""Evaluate the loss on a fixed mini-batch.

Returns:
Mini-batch loss.
"""
_, _, loss = self.problem.forward_pass()
return loss

opt.step(closure=closure)

else:
_, _, loss = self.problem.forward_pass()

with backpack(
*opt.get_extensions(),
extension_hook=opt.get_extension_hook(
keep_gram_mat=False,
keep_gram_evals=False,
keep_gram_evecs=False,
keep_gammas=False,
keep_lambdas=False,
keep_batch_size=False,
keep_coefficients=False,
keep_newton_step=False,
keep_backpack_buffers=False,
),
):
loss.backward()
opt.step()

newton_step = [
[getattr(param, savefield) for param in group["params"]]
for group in param_groups
]

return newton_step
8 changes: 2 additions & 6 deletions test/linalg/test_eigh.py
Expand Up @@ -40,9 +40,7 @@ def test_ggn_eigh_eigenvalues(
"""
problem.set_up()

param_groups = param_groups_fn(problem.model.named_parameters())
for group in param_groups:
group["criterion"] = keep_all
param_groups = param_groups_fn(problem.model.named_parameters(), keep_all)

backpack_eigh = BackpackLinalgExtensions(problem).eigh_ggn(
param_groups, subsampling
Expand Down Expand Up @@ -97,9 +95,7 @@ def test_ggn_eigh_eigenvectors(
"""
problem.set_up()

param_groups = param_groups_fn(problem.model.named_parameters())
for group in param_groups:
group["criterion"] = keep_nonzero
param_groups = param_groups_fn(problem.model.named_parameters(), keep_nonzero)

backpack_eigh = BackpackLinalgExtensions(problem).eigh_ggn(
param_groups, subsampling
Expand Down
5 changes: 2 additions & 3 deletions test/linalg/test_eigvalsh.py
Expand Up @@ -39,9 +39,8 @@ def test_ggn_eigvalsh(
"""
problem.set_up()

param_groups = param_groups_fn(problem.model.named_parameters())
for group in param_groups:
group["criterion"] = keep_all
# TODO Remove dependency on 'criterion' from autograd implementation
param_groups = param_groups_fn(problem.model.named_parameters(), keep_all)

backpack_result = BackpackLinalgExtensions(problem).eigvalsh_ggn(
param_groups, subsampling
Expand Down