diff --git a/test/implementation/optim_autograd.py b/test/implementation/optim_autograd.py index 67a811f..7878153 100644 --- a/test/implementation/optim_autograd.py +++ b/test/implementation/optim_autograd.py @@ -2,95 +2,24 @@ from test.implementation.autograd import AutogradExtensions -import torch -from backpack.utils.convert_parameters import vector_to_parameter_list - class AutogradOptimExtensions(AutogradExtensions): """Autograd implementation of optimizer functionality with similar API.""" - def gammas_ggn(self, top_k, subsampling_directions=None, subsampling_first=None): - """First-order directional derivatives along the top-k GGN eigenvectors. - - Args: - top_k (int): Number of leading eigenvectors used as directions. Will be - clipped to ``[1, max]`` with ``max`` the maximum number of nontrivial - eigenvalues. - subsampling_directions ([int] or None): Indices of samples used to compute - Newton directions. If ``None``, all samples in the batch will be used. - subsampling_first ([int], optional): Sample indices used for individual - gradients. - """ - return super().gammas_ggn( - top_k, - ggn_subsampling=subsampling_directions, - grad_subsampling=subsampling_first, - ) - - def lambdas_ggn(self, top_k, subsampling_directions=None, subsampling_second=None): - """Second-order directional derivatives along the top-k GGN eigenvectors. - - Args: - top_k (int): Number of leading eigenvectors used as directions. Will be - clipped to ``[1, max]`` with ``max`` the maximum number of nontrivial - eigenvalues. - subsampling_directions ([int] or None): Indices of samples used to compute - Newton directions. If ``None``, all samples in the batch will be used. - subsampling_second ([int], optional): Sample indices used for individual - curvature matrices. - """ - return super().lambdas_ggn( - top_k, - ggn_subsampling=subsampling_directions, - lambda_subsampling=subsampling_second, - ) - - def newton_step( - self, - param_groups, - damping, - subsampling_directions=None, - subsampling_first=None, - subsampling_second=None, + def directional_derivatives( + self, param_groups, subsampling_grad=None, subsampling_ggn=None ): - """Directionally-damped Newton step along the top-k GGN eigenvectors. - - Args: - param_groups ([dict]): Parameter groups like for ``torch.nn.Optimizer``s. - damping (vivit.optim.damping._Damping): Policy for selecting - dampings along a direction from first- and second- order directional - derivatives. - subsampling_directions ([int] or None): Indices of samples used to compute - Newton directions. If ``None``, all samples in the batch will be used. - subsampling_first ([int], optional): Sample indices used for individual - gradients. - subsampling_second ([int], optional): Sample indices used for individual - curvature matrices. - """ - group_gammas, group_evecs = super().gammas_ggn( + gammas = self.gammas_ggn( param_groups, - ggn_subsampling=subsampling_directions, - grad_subsampling=subsampling_first, - directions=True, + grad_subsampling=subsampling_grad, + ggn_subsampling=subsampling_ggn, + directions=False, ) - group_lambdas = super().lambdas_ggn( + + lambdas = self.lambdas_ggn( param_groups, - ggn_subsampling=subsampling_directions, - lambda_subsampling=subsampling_second, + ggn_subsampling=subsampling_ggn, + lambda_subsampling=subsampling_ggn, ) - newton_steps = [] - - for group, gammas, evecs, lambdas in zip( - param_groups, group_gammas, group_evecs, group_lambdas - ): - deltas = damping(gammas, lambdas) - - batch_axis = 0 - scale = -gammas.mean(batch_axis) / (lambdas.mean(batch_axis) + deltas) - - step = torch.einsum("d,id->i", scale, evecs) - - newton_steps.append(vector_to_parameter_list(step, group["params"])) - - return newton_steps + return gammas, lambdas diff --git a/test/implementation/optim_backpack.py b/test/implementation/optim_backpack.py index fef4645..9c970c9 100644 --- a/test/implementation/optim_backpack.py +++ b/test/implementation/optim_backpack.py @@ -1,225 +1,35 @@ """BackPACK implementation of operations used in ``vivit.optim``.""" from test.implementation.backpack import BackpackExtensions -from typing import Any, Dict, List, Optional from backpack import backpack -from torch import Tensor -from vivit.optim import GramComputations -from vivit.optim.computations import BaseComputations -from vivit.optim.damped_newton import DampedNewton -from vivit.optim.damping import _DirectionalCoefficients +from vivit.optim import DirectionalDerivativesComputation class BackpackOptimExtensions(BackpackExtensions): - def gammas_ggn( - self, param_groups, subsampling_directions=None, subsampling_first=None - ): - """First-order directional derivatives along leading GGN eigenvectors via - ``vivit.optim.computations``. - - Args: - param_groups ([dict]): Parameter groups like for ``torch.nn.Optimizer``s. - subsampling_directions ([int] or None): Indices of samples used to compute - Newton directions. If ``None``, all samples in the batch will be used. - subsampling_first ([int], optional): Sample indices used for individual - gradients. - """ - computations = GramComputations( - subsampling_directions=subsampling_directions, - subsampling_first=subsampling_first, - ) - - _, _, loss = self.problem.forward_pass() - - with backpack( - *computations.get_extensions(param_groups), - extension_hook=computations.get_extension_hook( - param_groups, - keep_backpack_buffers=False, - keep_gram_mat=False, - keep_gram_evecs=False, - keep_batch_size=False, - keep_gammas=True, - keep_lambdas=False, - keep_gram_evals=False, - ), - ): - loss.backward() - - return [computations._gammas[id(group)] for group in param_groups] - - def lambdas_ggn( - self, param_groups, subsampling_directions=None, subsampling_second=None - ): - """Second-order directional derivatives along leading GGN eigenvectors via - ``vivit.optim.computations``. - - Args: - param_groups ([dict]): Parameter groups like for ``torch.nn.Optimizer``s. - subsampling_directions ([int] or None): Indices of samples used to compute - Newton directions. If ``None``, all samples in the batch will be used. - subsampling_second ([int], optional): Sample indices used for individual - curvature matrices. - """ - computations = GramComputations( - subsampling_directions=subsampling_directions, - subsampling_second=subsampling_second, - ) - - _, _, loss = self.problem.forward_pass() - - with backpack( - *computations.get_extensions(param_groups), - extension_hook=computations.get_extension_hook( - param_groups, - keep_backpack_buffers=False, - keep_gram_mat=False, - keep_gram_evecs=False, - keep_batch_size=False, - keep_gammas=False, - keep_lambdas=True, - keep_gram_evals=False, - ), - ): - loss.backward() - - return [computations._lambdas[id(group)] for group in param_groups] - - def newton_step( + def directional_derivatives( self, param_groups, - damping, - subsampling_directions=None, - subsampling_first=None, - subsampling_second=None, + subsampling_grad=None, + subsampling_ggn=None, + mc_samples_ggn=0, ): - """Directionally-damped Newton step along the top-k GGN eigenvectors. - - Args: - param_groups ([dict]): Parameter groups like for ``torch.nn.Optimizer``s. - damping (vivit.optim.damping._Damping): Policy for selecting - dampings along a direction from first- and second- order directional - derivatives. - subsampling_directions ([int] or None): Indices of samples used to compute - Newton directions. If ``None``, all samples in the batch will be used. - subsampling_first ([int], optional): Sample indices used for individual - gradients. - subsampling_second ([int], optional): Sample indices used for individual - curvature matrices. - """ - computations = BaseComputations( - subsampling_directions=subsampling_directions, - subsampling_first=subsampling_first, - subsampling_second=subsampling_second, + """Compute 1st and 2nd-order directional derivatives along GGN eigenvectors.""" + computations = DirectionalDerivativesComputation( + subsampling_grad=subsampling_grad, + subsampling_ggn=subsampling_ggn, + mc_samples_ggn=mc_samples_ggn, ) _, _, loss = self.problem.forward_pass() - savefield = "test_newton_step" - with backpack( - *computations.get_extensions(param_groups), - extension_hook=computations.get_extension_hook( - param_groups, - damping, - savefield, - keep_gram_mat=False, - keep_gram_evals=False, - keep_gram_evecs=False, - keep_gammas=False, - keep_lambdas=False, - keep_batch_size=False, - keep_coefficients=False, - keep_newton_step=False, - keep_backpack_buffers=False, - ), + *computations.get_extensions(), + extension_hook=computations.get_extension_hook(param_groups), ): loss.backward() - newton_step = [ - [getattr(param, savefield) for param in group["params"]] - for group in param_groups + return [computations._gammas[id(group)] for group in param_groups], [ + computations._lambdas[id(group)] for group in param_groups ] - - return newton_step - - def optim_newton_step( - self, - param_groups: List[Dict[str, Any]], - damping: _DirectionalCoefficients, - subsampling_directions: Optional[List[int]] = None, - subsampling_first: Optional[List[int]] = None, - subsampling_second: Optional[List[int]] = None, - use_closure: bool = False, - ): - """Directionally-damped Newton step along the top-k GGN eigenvectors. - - Uses the ``DampedNewton`` optimizer to compute Newton steps. - - Args: - param_groups: Parameter groups like for ``torch.nn.Optimizer``s. - damping: Computes Newton coefficients from first- and second- order - directional derivatives. - subsampling_directions: Indices of samples used to compute - Newton directions. If ``None``, all samples in the batch will be used. - subsampling_first: Sample indices used for individual gradients. - subsampling_second: Sample indices used for individual curvature matrices. - use_closure: Whether to use a closure in the optimizer. Default: ``False``. - """ - computations = BaseComputations( - subsampling_directions=subsampling_directions, - subsampling_first=subsampling_first, - subsampling_second=subsampling_second, - ) - - opt = DampedNewton( - param_groups, - coefficients=damping, - computations=computations, - criterion=None, - ) - - savefield = "test_newton_step" - DampedNewton.SAVEFIELD = savefield - - if use_closure: - - def closure() -> Tensor: - """Evaluate the loss on a fixed mini-batch. - - Returns: - Mini-batch loss. - """ - _, _, loss = self.problem.forward_pass() - return loss - - opt.step(closure=closure) - - else: - _, _, loss = self.problem.forward_pass() - - with backpack( - *opt.get_extensions(), - extension_hook=opt.get_extension_hook( - keep_gram_mat=False, - keep_gram_evals=False, - keep_gram_evecs=False, - keep_gammas=False, - keep_lambdas=False, - keep_batch_size=False, - keep_coefficients=False, - keep_newton_step=False, - keep_backpack_buffers=False, - ), - ): - loss.backward() - opt.step() - - newton_step = [ - [getattr(param, savefield) for param in group["params"]] - for group in param_groups - ] - - return newton_step diff --git a/test/linalg/test_eigh.py b/test/linalg/test_eigh.py index b61828b..3c450d0 100644 --- a/test/linalg/test_eigh.py +++ b/test/linalg/test_eigh.py @@ -40,9 +40,7 @@ def test_ggn_eigh_eigenvalues( """ problem.set_up() - param_groups = param_groups_fn(problem.model.named_parameters()) - for group in param_groups: - group["criterion"] = keep_all + param_groups = param_groups_fn(problem.model.named_parameters(), keep_all) backpack_eigh = BackpackLinalgExtensions(problem).eigh_ggn( param_groups, subsampling @@ -97,9 +95,7 @@ def test_ggn_eigh_eigenvectors( """ problem.set_up() - param_groups = param_groups_fn(problem.model.named_parameters()) - for group in param_groups: - group["criterion"] = keep_nonzero + param_groups = param_groups_fn(problem.model.named_parameters(), keep_nonzero) backpack_eigh = BackpackLinalgExtensions(problem).eigh_ggn( param_groups, subsampling diff --git a/test/linalg/test_eigvalsh.py b/test/linalg/test_eigvalsh.py index 168601a..e9404de 100644 --- a/test/linalg/test_eigvalsh.py +++ b/test/linalg/test_eigvalsh.py @@ -39,9 +39,8 @@ def test_ggn_eigvalsh( """ problem.set_up() - param_groups = param_groups_fn(problem.model.named_parameters()) - for group in param_groups: - group["criterion"] = keep_all + # TODO Remove dependency on 'criterion' from autograd implementation + param_groups = param_groups_fn(problem.model.named_parameters(), keep_all) backpack_result = BackpackLinalgExtensions(problem).eigvalsh_ggn( param_groups, subsampling diff --git a/test/optim/settings.py b/test/optim/settings.py index b23fd75..bdca357 100644 --- a/test/optim/settings.py +++ b/test/optim/settings.py @@ -44,59 +44,29 @@ def criterion(evals): TOP_K = [1, 10] -TOP_K_IDS = [f"top_k={k}" for k in TOP_K] -TOP_K = [make_criterion(k) for k in TOP_K] +CRITERIA_IDS = [f"criterion=top_{k}" for k in TOP_K] +CRITERIA = [make_criterion(k) for k in TOP_K] -SUBSAMPLINGS_DIRECTIONS = [None, [0, 0, 1, 0, 1]] -SUBSAMPLINGS_DIRECTIONS_IDS = [ - f"subsampling_directions={sub}" for sub in SUBSAMPLINGS_DIRECTIONS -] +SUBSAMPLINGS_GGN = [None, [0, 1]] +SUBSAMPLINGS_GGN_IDS = [f"subsampling_ggn={sub}" for sub in SUBSAMPLINGS_GGN] -SUBSAMPLINGS_FIRST = [None, [0, 0, 1, 0, 1]] -SUBSAMPLINGS_FIRST_IDS = [f"subsampling_first={sub}" for sub in SUBSAMPLINGS_FIRST] - -SUBSAMPLINGS_SECOND = [None, [0, 0, 1, 0, 1]] -SUBSAMPLINGS_SECOND_IDS = [f"subsampling_second={sub}" for sub in SUBSAMPLINGS_SECOND] +SUBSAMPLINGS_GRAD = [None, [0, 1]] +SUBSAMPLINGS_GRAD_IDS = [f"subsampling_grad={sub}" for sub in SUBSAMPLINGS_GRAD] PARAM_BLOCKS_FN = [] PARAM_BLOCKS_FN_IDS = [] -def one_group(named_parameters): +def one_group(named_parameters, criterion): """All parameters in all group.""" - return [{"params": [p for (_, p) in named_parameters]}] + return [{"params": [p for (_, p) in named_parameters], "criterion": criterion}] PARAM_BLOCKS_FN.append(one_group) PARAM_BLOCKS_FN_IDS.append("param_groups=one") -def per_param(named_parameters): - """One parameter group for each parameter. Only group last two parameters. - - Grouping the last two parameters is a fix to avoid degenerate eigenspaces - (which will then result in arbitrary directions and differing directional - derivatives). Consider for instance a last linear layer in a neural net - with ``MSELoss``. Then, the GGN w.r.t. only the last bias is proportional - to the identity matrix, hence its eigenspace is degenerate. - """ - parameters = list(named_parameters) - num_params = len(parameters) - - if num_params <= 2: - return one_group(parameters) - else: - parameters = [p for (_, p) in parameters] - return [{"params": list(parameters)[-2:]}] + [ - {"params": [p]} for p in list(parameters)[: num_params - 2] - ] - - -PARAM_BLOCKS_FN.append(per_param) -PARAM_BLOCKS_FN_IDS.append("param_groups=per_param") - - -def weights_and_biases(parameters): +def weights_and_biases(parameters, criterion): """Group weights in one, biases in other group.""" parameters = list(parameters) @@ -108,8 +78,14 @@ def is_bias(name, param): else: return param.dim() == 1 - weights = {"params": [p for (n, p) in parameters if is_bias(n, p)]} - biases = {"params": [p for (n, p) in parameters if not is_bias(n, p)]} + weights = { + "params": [p for (n, p) in parameters if is_bias(n, p)], + "criterion": criterion, + } + biases = { + "params": [p for (n, p) in parameters if not is_bias(n, p)], + "criterion": criterion, + } if len(biases["params"]) == 1: raise ValueError( @@ -123,10 +99,3 @@ def is_bias(name, param): PARAM_BLOCKS_FN.append(weights_and_biases) PARAM_BLOCKS_FN_IDS.append("param_groups=weights_and_biases") - - -def insert_criterion(param_groups, criterion): - """Add 'criterion' entry for each parameter group.""" - criterion_entry = {"criterion": criterion} - for group in param_groups: - group.update(criterion_entry) diff --git a/test/optim/test_computations.py b/test/optim/test_computations.py deleted file mode 100644 index ce95e47..0000000 --- a/test/optim/test_computations.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Test ``vivit.optim.computations``.""" - -from test.implementation.optim_autograd import AutogradOptimExtensions -from test.implementation.optim_backpack import BackpackOptimExtensions -from test.optim.settings import ( - IDS_REDUCTION_MEAN, - PARAM_BLOCKS_FN, - PARAM_BLOCKS_FN_IDS, - PROBLEMS_REDUCTION_MEAN, - SUBSAMPLINGS_DIRECTIONS, - SUBSAMPLINGS_DIRECTIONS_IDS, - SUBSAMPLINGS_FIRST, - SUBSAMPLINGS_FIRST_IDS, - SUBSAMPLINGS_SECOND, - SUBSAMPLINGS_SECOND_IDS, - TOP_K, - TOP_K_IDS, - insert_criterion, -) -from test.utils import check_sizes_and_values - -import pytest - -from vivit.optim.damping import ConstantDamping - -CONSTANT_DAMPING_VALUES = [1.0] -DAMPINGS = [ConstantDamping(const) for const in CONSTANT_DAMPING_VALUES] -DAMPINGS_IDS = [ - f"damping=ConstantDamping({const})" for const in CONSTANT_DAMPING_VALUES -] - - -@pytest.mark.parametrize("param_block_fn", PARAM_BLOCKS_FN, ids=PARAM_BLOCKS_FN_IDS) -@pytest.mark.parametrize("damping", DAMPINGS, ids=DAMPINGS_IDS) -@pytest.mark.parametrize( - "subsampling_directions", SUBSAMPLINGS_DIRECTIONS, ids=SUBSAMPLINGS_DIRECTIONS_IDS -) -@pytest.mark.parametrize( - "subsampling_first", SUBSAMPLINGS_FIRST, ids=SUBSAMPLINGS_FIRST_IDS -) -@pytest.mark.parametrize( - "subsampling_second", SUBSAMPLINGS_SECOND, ids=SUBSAMPLINGS_SECOND_IDS -) -@pytest.mark.parametrize("top_k", TOP_K, ids=TOP_K_IDS) -@pytest.mark.parametrize("problem", PROBLEMS_REDUCTION_MEAN, ids=IDS_REDUCTION_MEAN) -def test_computations_newton_step( - problem, - top_k, - damping, - subsampling_directions, - subsampling_first, - subsampling_second, - param_block_fn, -): - """Compare damped Newton step along leading GGN eigenvectors with autograd. - - Args: - top_k (function): Criterion to select Gram space directions. - problem (ExtensionsTestProblem): Test case. - damping (vivit.optim.damping._Damping): Policy for selecting dampings along - a direction from first- and second- order directional derivatives. - subsampling_directions ([int] or None): Indices of samples used to compute - Newton directions. If ``None``, all samples in the batch will be used. - subsampling_first ([int], optional): Sample indices used for individual - gradients. - subsampling_second ([int], optional): Sample indices used for individual - curvature matrices. - param_block_fn (function): Function to group model parameters. - """ - problem.set_up() - - param_groups = param_block_fn(problem.model.named_parameters()) - insert_criterion(param_groups, top_k) - - autograd_res = AutogradOptimExtensions(problem).newton_step( - param_groups, - damping, - subsampling_directions=subsampling_directions, - subsampling_first=subsampling_first, - subsampling_second=subsampling_second, - ) - backpack_res = BackpackOptimExtensions(problem).newton_step( - param_groups, - damping, - subsampling_directions=subsampling_directions, - subsampling_first=subsampling_first, - subsampling_second=subsampling_second, - ) - - atol = 5e-5 - rtol = 1e-4 - - assert len(autograd_res) == len(backpack_res) == len(param_groups) - for autograd_step, backpack_step in zip(autograd_res, backpack_res): - check_sizes_and_values(autograd_step, backpack_step, atol=atol, rtol=rtol) - - problem.tear_down() diff --git a/test/optim/test_damped_newton.py b/test/optim/test_damped_newton.py deleted file mode 100644 index 3d87d36..0000000 --- a/test/optim/test_damped_newton.py +++ /dev/null @@ -1,107 +0,0 @@ -"""Test ``vivit.optim.damped_newton``.""" - -from test.implementation.optim_autograd import AutogradOptimExtensions -from test.implementation.optim_backpack import BackpackOptimExtensions -from test.optim.settings import ( - IDS_REDUCTION_MEAN, - PARAM_BLOCKS_FN, - PARAM_BLOCKS_FN_IDS, - PROBLEMS_REDUCTION_MEAN, - SUBSAMPLINGS_DIRECTIONS, - SUBSAMPLINGS_DIRECTIONS_IDS, - SUBSAMPLINGS_FIRST, - SUBSAMPLINGS_FIRST_IDS, - SUBSAMPLINGS_SECOND, - SUBSAMPLINGS_SECOND_IDS, - TOP_K, - TOP_K_IDS, - insert_criterion, -) -from test.problem import ExtensionsTestProblem -from test.utils import check_sizes_and_values -from typing import Any, Callable, Dict, Iterator, List, Union - -import pytest -from torch import Tensor - -from vivit.optim.damping import ConstantDamping, _DirectionalCoefficients - -CONSTANT_DAMPING_VALUES = [1.0] -DAMPINGS = [ConstantDamping(const) for const in CONSTANT_DAMPING_VALUES] -DAMPINGS_IDS = [ - f"damping=ConstantDamping({const})" for const in CONSTANT_DAMPING_VALUES -] - -USE_CLOSURE = [False, True] -USE_CLOSURE_IDS = [f"use_closure={use}" for use in USE_CLOSURE] - - -@pytest.mark.parametrize("use_closure", USE_CLOSURE, ids=USE_CLOSURE_IDS) -@pytest.mark.parametrize("param_block_fn", PARAM_BLOCKS_FN, ids=PARAM_BLOCKS_FN_IDS) -@pytest.mark.parametrize("damping", DAMPINGS, ids=DAMPINGS_IDS) -@pytest.mark.parametrize( - "subsampling_directions", SUBSAMPLINGS_DIRECTIONS, ids=SUBSAMPLINGS_DIRECTIONS_IDS -) -@pytest.mark.parametrize( - "subsampling_first", SUBSAMPLINGS_FIRST, ids=SUBSAMPLINGS_FIRST_IDS -) -@pytest.mark.parametrize( - "subsampling_second", SUBSAMPLINGS_SECOND, ids=SUBSAMPLINGS_SECOND_IDS -) -@pytest.mark.parametrize("top_k", TOP_K, ids=TOP_K_IDS) -@pytest.mark.parametrize("problem", PROBLEMS_REDUCTION_MEAN, ids=IDS_REDUCTION_MEAN) -def test_optim_newton_step( - problem: ExtensionsTestProblem, - top_k: Callable[[Tensor], List[int]], - damping: _DirectionalCoefficients, - subsampling_directions: Union[List[int], None], - subsampling_first: Union[List[int], None], - subsampling_second: Union[List[int], None], - param_block_fn: Callable[[Iterator[Tensor]], List[Dict[str, Any]]], - use_closure: bool, -): - """Compare damped Newton step along leading GGN eigenvectors with autograd. - - Use ``DampedNewton`` optimizer to compute Newton steps. - - Args: - top_k: Criterion to select Gram space directions. - problem: Test case. - damping: Policy for selecting dampings along - a direction from first- and second- order directional derivatives. - subsampling_directions: Indices of samples used to compute - Newton directions. If ``None``, all samples in the batch will be used. - subsampling_first: Sample indices used for individual gradients. - subsampling_second: Sample indices used for individual curvature matrices. - param_block_fn: Function to group model parameters. - use_closure: Whether to use a closure for computing the Newton step. - """ - problem.set_up() - - param_groups = param_block_fn(problem.model.named_parameters()) - insert_criterion(param_groups, top_k) - - autograd_res = AutogradOptimExtensions(problem).newton_step( - param_groups, - damping, - subsampling_directions=subsampling_directions, - subsampling_first=subsampling_first, - subsampling_second=subsampling_second, - ) - backpack_res = BackpackOptimExtensions(problem).optim_newton_step( - param_groups, - damping, - subsampling_directions=subsampling_directions, - subsampling_first=subsampling_first, - subsampling_second=subsampling_second, - use_closure=use_closure, - ) - - atol = 5e-5 - rtol = 1e-4 - - assert len(autograd_res) == len(backpack_res) == len(param_groups) - for autograd_step, backpack_step in zip(autograd_res, backpack_res): - check_sizes_and_values(autograd_step, backpack_step, atol=atol, rtol=rtol) - - problem.tear_down() diff --git a/test/optim/test_dampings.py b/test/optim/test_dampings.py deleted file mode 100644 index 222c093..0000000 --- a/test/optim/test_dampings.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Integration tests for damping policies.""" - -from test.utils import get_available_devices - -import pytest -import torch - -from vivit.optim.damping import BootstrapDamping - -DAMPING_GRIDS = [torch.logspace(-3, 2, 150)] -DAMPING_GRIDS_IDS = ["damping_grid=torch.logspace(-3, 2, 150)"] - -PERCENTILES = [95] -PERCENTILES_IDS = [f"percentile={percentile}" for percentile in PERCENTILES] - -NUM_RESAMPLES = [100] -NUM_RESAMPLES_IDS = [f"num_resample={num_resample}" for num_resample in NUM_RESAMPLES] - -DEVICES = get_available_devices() -DEVICES_IDS = [f"device={device}" for device in DEVICES] - -SEED_VALS = [0, 1, 42] -SEED_VALS_IDS = [f"seed_val={seed_val}" for seed_val in SEED_VALS] - - -@pytest.mark.parametrize("damping_grid", DAMPING_GRIDS, ids=DAMPING_GRIDS_IDS) -@pytest.mark.parametrize("percentile", PERCENTILES, ids=PERCENTILES_IDS) -@pytest.mark.parametrize("num_resamples", NUM_RESAMPLES, ids=NUM_RESAMPLES_IDS) -@pytest.mark.parametrize("device", DEVICES, ids=DEVICES_IDS) -@pytest.mark.parametrize("seed_val", SEED_VALS, ids=SEED_VALS_IDS) -def test_bootstrap_damping(damping_grid, percentile, num_resamples, device, seed_val): - - # Make deterministic - torch.manual_seed(seed_val) - - # Define Setting - N_1 = 5 # Number of 1st derivative samples for for each direction - N_2 = 6 # Number of 2nd derivative samples for for each direction - D = 3 # Number of directions - - # Sample 1st and 2nd order derivatives for each direction - first_lower = -0.5 - first_upper = 2.0 - first_derivs = (first_upper - first_lower) * torch.rand(N_1, D) + first_lower - first_derivs = first_derivs.to(device) - - second_lower = 1.0 - second_upper = 1.5 - second_derivs = (second_upper - second_lower) * torch.rand(N_2, D) + second_lower - second_derivs = second_derivs.to(device) - - # Compute dampings - damping = BootstrapDamping(damping_grid, num_resamples, percentile) - _ = damping(first_derivs, second_derivs) diff --git a/test/optim/test_directional_derivatives.py b/test/optim/test_directional_derivatives.py new file mode 100644 index 0000000..a36cd9b --- /dev/null +++ b/test/optim/test_directional_derivatives.py @@ -0,0 +1,70 @@ +"""Test ``vivit.optim.directional_derivatives``.""" + +from test.implementation.optim_autograd import AutogradOptimExtensions +from test.implementation.optim_backpack import BackpackOptimExtensions +from test.optim.settings import ( + CRITERIA, + CRITERIA_IDS, + IDS_REDUCTION_MEAN, + PARAM_BLOCKS_FN, + PARAM_BLOCKS_FN_IDS, + PROBLEMS_REDUCTION_MEAN, + SUBSAMPLINGS_GGN, + SUBSAMPLINGS_GGN_IDS, + SUBSAMPLINGS_GRAD, + SUBSAMPLINGS_GRAD_IDS, +) +from test.problem import ExtensionsTestProblem +from test.utils import check_sizes_and_values +from typing import Callable, List, Union + +from pytest import mark +from torch import Tensor + + +@mark.parametrize("param_groups_fn", PARAM_BLOCKS_FN, ids=PARAM_BLOCKS_FN_IDS) +@mark.parametrize("subsampling_ggn", SUBSAMPLINGS_GGN, ids=SUBSAMPLINGS_GGN_IDS) +@mark.parametrize("subsampling_grad", SUBSAMPLINGS_GRAD, ids=SUBSAMPLINGS_GRAD_IDS) +@mark.parametrize("criterion", CRITERIA, ids=CRITERIA_IDS) +@mark.parametrize("problem", PROBLEMS_REDUCTION_MEAN, ids=IDS_REDUCTION_MEAN) +def test_directional_derivatives( + problem: ExtensionsTestProblem, + criterion: Callable[[Tensor], List[int]], + subsampling_grad: Union[List[int], None], + subsampling_ggn: Union[List[int], None], + param_groups_fn: Callable, +): + """Compare 1ˢᵗ- and 2ⁿᵈ-order directional derivatives along GGN eigenvectors. + + Args: + problem: Test case. + criterion: Filter function to select directions from eigenvalues. + subsampling_grad: Indices of samples used for gradient sub-sampling. + ``None`` (equivalent to ``list(range(batch_size))``) uses all mini-batch + samples to compute directional gradients . Defaults to ``None`` (no + gradient sub-sampling). + subsampling_ggn: Indices of samples used for GGN curvature sub-sampling. + ``None`` (equivalent to ``list(range(batch_size))``) uses all mini-batch + samples to compute directions and directional curvatures. Defaults to + ``None`` (no curvature sub-sampling). + param_groups_fn: Function that creates the `param_groups` from the model's + named parameters and ``criterion``. + """ + problem.set_up() + + param_groups = param_groups_fn(problem.model.named_parameters(), criterion) + + ag_gammas, ag_lambdas = AutogradOptimExtensions(problem).directional_derivatives( + param_groups, subsampling_grad=subsampling_grad, subsampling_ggn=subsampling_ggn + ) + bp_gammas, bp_lambdas = BackpackOptimExtensions(problem).directional_derivatives( + param_groups, subsampling_grad=subsampling_grad, subsampling_ggn=subsampling_ggn + ) + + # directions can vary in sign, leading to same magnitude but opposite sign + ag_abs_gammas = [g.abs() for g in ag_gammas] + bp_abs_gammas = [g.abs() for g in bp_gammas] + check_sizes_and_values(ag_abs_gammas, bp_abs_gammas, rtol=1e-5, atol=1e-4) + check_sizes_and_values(ag_lambdas, bp_lambdas, rtol=1e-5, atol=1e-5) + + problem.tear_down() diff --git a/test/optim/test_gram_computations.py b/test/optim/test_gram_computations.py deleted file mode 100644 index ecfbf6b..0000000 --- a/test/optim/test_gram_computations.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Test ``vivit.optim.gram_computations``.""" - -from test.implementation.optim_autograd import AutogradOptimExtensions -from test.implementation.optim_backpack import BackpackOptimExtensions -from test.optim.settings import ( - IDS_REDUCTION_MEAN, - PARAM_BLOCKS_FN, - PARAM_BLOCKS_FN_IDS, - PROBLEMS_REDUCTION_MEAN, - SUBSAMPLINGS_DIRECTIONS, - SUBSAMPLINGS_DIRECTIONS_IDS, - SUBSAMPLINGS_FIRST, - SUBSAMPLINGS_FIRST_IDS, - SUBSAMPLINGS_SECOND, - SUBSAMPLINGS_SECOND_IDS, - TOP_K, - TOP_K_IDS, - insert_criterion, -) -from test.utils import check_sizes_and_values - -import pytest - - -@pytest.mark.parametrize("param_block_fn", PARAM_BLOCKS_FN, ids=PARAM_BLOCKS_FN_IDS) -@pytest.mark.parametrize( - "subsampling_directions", SUBSAMPLINGS_DIRECTIONS, ids=SUBSAMPLINGS_DIRECTIONS_IDS -) -@pytest.mark.parametrize( - "subsampling_first", SUBSAMPLINGS_FIRST, ids=SUBSAMPLINGS_FIRST_IDS -) -@pytest.mark.parametrize("top_k", TOP_K, ids=TOP_K_IDS) -@pytest.mark.parametrize("problem", PROBLEMS_REDUCTION_MEAN, ids=IDS_REDUCTION_MEAN) -def test_computations_gammas_ggn( - problem, top_k, subsampling_directions, subsampling_first, param_block_fn -): - """Compare optimizer's 1st-order directional derivatives ``γ[n, d]`` along leading - GGN eigenvectors with autograd. - - Args: - problem (ExtensionsTestProblem): Test case. - top_k (function): Criterion to select Gram space directions. - subsampling_directions ([int] or None): Indices of samples used to compute - Newton directions. If ``None``, all samples in the batch will be used. - subsampling_first ([int], optional): Sample indices used for individual - gradients. - param_block_fn (function): Function to group model parameters. - """ - problem.set_up() - - param_groups = param_block_fn(problem.model.named_parameters()) - insert_criterion(param_groups, top_k) - - autograd_res = AutogradOptimExtensions(problem).gammas_ggn( - param_groups, - subsampling_directions=subsampling_directions, - subsampling_first=subsampling_first, - ) - backpack_res = BackpackOptimExtensions(problem).gammas_ggn( - param_groups, - subsampling_directions=subsampling_directions, - subsampling_first=subsampling_first, - ) - - # directions can vary in sign, leading to same magnitude but opposite sign. - autograd_res = [res.abs() for res in autograd_res] - backpack_res = [res.abs() for res in backpack_res] - - rtol = 5e-3 - atol = 1e-4 - - check_sizes_and_values(autograd_res, backpack_res, atol=atol, rtol=rtol) - problem.tear_down() - - -@pytest.mark.parametrize("param_block_fn", PARAM_BLOCKS_FN, ids=PARAM_BLOCKS_FN_IDS) -@pytest.mark.parametrize( - "subsampling_directions", SUBSAMPLINGS_DIRECTIONS, ids=SUBSAMPLINGS_DIRECTIONS_IDS -) -@pytest.mark.parametrize( - "subsampling_second", SUBSAMPLINGS_SECOND, ids=SUBSAMPLINGS_SECOND_IDS -) -@pytest.mark.parametrize("top_k", TOP_K, ids=TOP_K_IDS) -@pytest.mark.parametrize("problem", PROBLEMS_REDUCTION_MEAN, ids=IDS_REDUCTION_MEAN) -def test_computations_lambdas_ggn( - problem, top_k, subsampling_directions, subsampling_second, param_block_fn -): - """Compare optimizer's 2nd-order directional derivatives ``λ[n, d]`` along leading - GGN eigenvectors with autograd. - - Args: - problem (ExtensionsTestProblem): Test case. - top_k (function): Criterion to select Gram space directions. - subsampling_directions ([int] or None): Indices of samples used to compute - Newton directions. If ``None``, all samples in the batch will be used. - subsampling_second ([int], optional): Sample indices used for individual - curvature matrices. - param_block_fn (function): Function to group model parameters. - """ - problem.set_up() - - param_groups = param_block_fn(problem.model.named_parameters()) - insert_criterion(param_groups, top_k) - - autograd_res = AutogradOptimExtensions(problem).lambdas_ggn( - param_groups, - subsampling_directions=subsampling_directions, - subsampling_second=subsampling_second, - ) - backpack_res = BackpackOptimExtensions(problem).lambdas_ggn( - param_groups, - subsampling_directions=subsampling_directions, - subsampling_second=subsampling_second, - ) - - rtol, atol = 1e-5, 1e-5 - check_sizes_and_values(autograd_res, backpack_res, rtol=rtol, atol=atol) - problem.tear_down() diff --git a/test/optim/test_lambdas_gammas.py b/test/optim/test_lambdas_gammas.py deleted file mode 100644 index 14afedd..0000000 --- a/test/optim/test_lambdas_gammas.py +++ /dev/null @@ -1,402 +0,0 @@ -"""In the existing test for gamma and lamdba (test_gammas.py and test_lambdas.py), the -vivit-computations are compared to autograd. There might be the chance that there is -the "same" mistake in both versions. So, here is another apporach to test the gammas and -lambdas: We use a very simple linear network. In this case, we can give the loss, its -gradient and GGN in closed-form. We use these closed-form expressions to compute -reference lambdas and gammas, that we can compare the vivit-computations with. - -The following tests are performed: -- TEST 1 (Loss value): We compare the loss evaluated on the actual model with the loss - that we derived theoretically -- TEST 2 (Loss gradient): We compare the loss gradient computed by pytorch with the loss - gradient that we derived theoretically -- TEST 3 (Loss GGN): We compare the loss GGN computed by autograd (see section - "Auxiliary Functions (3)") with the loss GGN that we derived theoretically -- TEST 4, 5 (gammas and lambdas): We compute the lambdas and gammas with the vivit- - utilities. As a comparison, we also compute the theoretically derived GGN, its - eigenvectors and compute the lambdas and gammas "manually". -- TEST 6 (Newton step): Finally, we compare the Newton step computed by vivit with a - "manual" computation. -""" - -from test.optim.settings import make_criterion -from test.utils import check_sizes_and_values - -import pytest -import torch -from backpack import backpack, extend -from backpack.hessianfree.hvp import hessian_vector_product -from backpack.utils.convert_parameters import vector_to_parameter_list - -from vivit.optim.computations import BaseComputations -from vivit.optim.damping import ConstantDamping - -# ====================================================================================== -# Auxiliary Functions (1) -# Set weights and biases for linear layer and choose if these parameters are trainable -# ====================================================================================== - - -def set_weights(linear_layer, weights, req_grad): - """ - Set weights in linear layer and choose if these parameters are trainable. - """ - - # Check if weights has the right shape - w = linear_layer.weight - if weights.shape == w.data.shape: - - # Set weights and requires_grad - w.data = weights - w.requires_grad = req_grad - - else: - raise ValueError("weights dont have the right shape") - - -def set_biases(linear_layer, biases, req_grad): - """ - Set biases in linear layer and choose if these parameters are trainable. - """ - - # Check if biases has the right shape - b = linear_layer.bias - if biases.shape == b.data.shape: - - # Set biases and requires_grad - b.data = biases - b.requires_grad = req_grad - - else: - raise ValueError("biases dont have the right shape") - - -# ====================================================================================== -# Auxiliary Functions (2) -# The MSE-loss corresponds to Phi. Here, we define functions for evaluating Phi, its -# sample gadients and GGNs. -# ====================================================================================== - - -def Phi(x, theta, MSE_reduction, W_1, W_2): - """ - Computes MSE-loss at (x, theta) manually. - """ - - # Make sure N == 1 - assert x.shape[0] == 1, "N has to be one such that model output is a vector" - - # Compute output of model - theta_re = theta.reshape(1, OUT_1) - y2 = (x @ W_1.T + theta_re) @ W_2.T - - # Compute MSE loss manually - if MSE_reduction == "mean": - return (1 / OUT_2) * (y2 @ y2.T).item() - elif MSE_reduction == "sum": - return (y2 @ y2.T).item() - else: - raise ValueError("Unknown MSE_reduction") - - -def Phi_batch(X, theta, MSE_reduction, W_1, W_2): - """ - Computes MSE-loss for batch X containing N samples (rows) by averaging or summing - the individual sample losses. - """ - - N = X.shape[0] - - # Accumulate loss over all batches - loss_batch = 0.0 - for n in range(N): - x = X[n, :].reshape(1, IN_1) - loss_batch += Phi(x, theta, MSE_reduction, W_1, W_2) - - # Return accumulated loss or return average - if MSE_reduction == "mean": - return (1 / N) * loss_batch - elif MSE_reduction == "sum": - return loss_batch - else: - raise ValueError("Unknown MSE_reduction") - - -def Phi_grad(x, theta, MSE_reduction, W_1, W_2): - """ - Computes gradient of MSE-loss at (x, theta) manually. - """ - - # Make sure N == 1 - assert x.shape[0] == 1, "N has to be one such that model output is a vector" - - # Compute MSE loss gradient manually - theta_re = theta.reshape(1, OUT_1) - grad = 2 * (W_2.T @ W_2 @ (W_1 @ x.T + theta_re.T)).reshape(OUT_1) - if MSE_reduction == "mean": - return grad / OUT_2 - elif MSE_reduction == "sum": - return grad - else: - raise ValueError("Unknown MSE_reduction") - - -def Phi_grads_list(X, theta, MSE_reduction, W_1, W_2): - """ - Computes MSE-loss gradients for batch X containing N samples (rows) and retuns them - as a list - """ - - N = X.shape[0] - - grads_list = [] - for n in range(N): - x = X[n, :].reshape(1, IN_1) - grads_list.append(Phi_grad(x, theta, MSE_reduction, W_1, W_2)) - - return grads_list - - -def Phi_GGN(x, theta, MSE_reduction, W_1, W_2): - """ - Computes Hessian (= GGN) of MSE-loss at (x, theta) manually. - """ - - # Make sure N == 1 - assert x.shape[0] == 1, "N has to be one such that model output is a vector" - - # Compute MSE loss Hessian (= GGN) manually - GGN = 2 * W_2.T @ W_2 - if MSE_reduction == "mean": - return GGN / OUT_2 - elif MSE_reduction == "sum": - return GGN - else: - raise ValueError("Unknown MSE_reduction") - - -def Phi_GGNs_list(X, theta, MSE_reduction, W_1, W_2): - """ - Computes MSE-loss GGNs for batch X containing N samples (rows) and retuns them - as a list - """ - - N = X.shape[0] - - GGNs_list = [] - for n in range(N): - x = X[n, :].reshape(1, IN_1) - GGNs_list.append(Phi_GGN(x, theta, MSE_reduction, W_1, W_2)) - - return GGNs_list - - -def reduce_list(the_list, reduction): - """ - Auxiliary function that computes the sum or mean over all list entries. The list - entries are assumed to be torch.Tensors. - """ - - # Check that list entries are torch.Tensors - if not torch.is_tensor(the_list[0]): - raise ValueError("List entries have to be torch.Tensors") - - # Sum over list entries - sum_over_list_entries = torch.zeros_like(the_list[0]) - for i in range(len(the_list)): - sum_over_list_entries += the_list[i] - - if reduction == "mean": - return sum_over_list_entries / len(the_list) - elif reduction == "sum": - return sum_over_list_entries - else: - raise ValueError("Unknown reduction") - - -# ====================================================================================== -# Auxiliary Functions (3) -# Utilities for computing the Hessian for a given model. We will use this as a -# comparison to Phi_GGN -# ====================================================================================== - - -def autograd_hessian_columns(loss, params, concat=False): - """Return an iterator of the Hessian columns computed via ``torch.autograd``. - Args: - loss (torch.Tensor): Loss whose Hessian is investigated. - params ([torch.Tensor]): List of torch.Tensors holding the network's - parameters. - concat (bool): If ``True``, flatten and concatenate the columns over all - parameters. - """ - D = sum(p.numel() for p in params) - device = loss.device - for d in range(D): - e_d = torch.zeros(D, device=device) - e_d[d] = 1.0 - e_d_list = vector_to_parameter_list(e_d, params) - hessian_e_d = hessian_vector_product(loss, params, e_d_list) - if concat: - hessian_e_d = torch.cat([tensor.flatten() for tensor in hessian_e_d]) - yield hessian_e_d - - -def autograd_hessian(loss, params): - """Compute the full Hessian via ``torch.autograd``. - Flatten and concatenate the columns over all parameters, such that the result - is a ``[D, D]`` tensor, where ``D`` denotes the total number of parameters. - Args: - params ([torch.Tensor]): List of torch.Tensors holding the network's - parameters. - Returns: - torch.Tensor: 2d tensor containing the Hessian matrix - """ - return torch.stack(list(autograd_hessian_columns(loss, params, concat=True))) - - -# ====================================================================================== -# Define Test Parameters -# ====================================================================================== - -# Test tolerances -ATOL = 1e-5 -RTOL = 1e-4 - -# Choose dimensions and -N = 8 -IN_1 = 10 # Layer 1 -OUT_1 = 11 -IN_2 = OUT_1 # Layer 2 -OUT_2 = IN_2 -if OUT_2 < IN_2: - print("Warning: The GGN won't have full rank") - - -# ====================================================================================== -# Run Tests -# ====================================================================================== - -# MSE-reductions -MSE_REDUCTIONS = ["mean"] -IDS_MSE_REDUCTIONS = [ - f"MSE_reduction={MSE_reduction}" for MSE_reduction in MSE_REDUCTIONS -] - -# Dampings -DAMPINGS = [1.0, 2.5] -IDS_DAMPINGS = [f"Damping={delta}" for delta in DAMPINGS] - -# Seed values -SEED_VALS = [0, 1, 42] -IDS_SEED_VALS = [f"SeedVal={seed_val}" for seed_val in SEED_VALS] - - -@pytest.mark.parametrize("MSE_reduction", MSE_REDUCTIONS, ids=IDS_MSE_REDUCTIONS) -@pytest.mark.parametrize("delta", DAMPINGS, ids=IDS_DAMPINGS) -@pytest.mark.parametrize("seed_val", SEED_VALS, ids=IDS_SEED_VALS) -def test_lambda_gamma(MSE_reduction, delta, seed_val): - - # Set torch seed - torch.manual_seed(seed_val) - - # Initialize weight matrices, theta and X - W_1 = 2 * torch.rand(OUT_1, IN_1) - 1 - W_2 = 2 * torch.rand(OUT_2, IN_2) - 1 - theta = torch.rand(OUT_1) - X = torch.rand(N, IN_1) - - # Initialize layers, create model and loss function - L_1 = torch.nn.Linear(IN_1, OUT_1, bias=True) - L_2 = torch.nn.Linear(IN_2, OUT_2, bias=False) - set_weights(L_1, W_1, False) - set_biases(L_1, theta, True) - set_weights(L_2, W_2, False) - model = extend(torch.nn.Sequential(L_1, L_2)) - loss_func = extend(torch.nn.MSELoss(reduction=MSE_reduction)) - - # ========================== - # TEST 1: Loss value - # ========================== - phi = torch.Tensor([Phi_batch(X, theta, MSE_reduction, W_1, W_2)]).reshape(1, 1) - loss = loss_func(model(X), torch.zeros(N, OUT_2)).reshape(1, 1) - check_sizes_and_values(loss, phi, atol=ATOL, rtol=RTOL) - - # ========================== - # TEST 2: Loss gradient - # ========================== - phi_grads_list = Phi_grads_list(X, theta, MSE_reduction, W_1, W_2) - phi_batch_grad = reduce_list(phi_grads_list, MSE_reduction) - model.zero_grad() - loss.backward(retain_graph=True) # Retain graph for computing Hessian later - loss_grad = list(model.parameters())[1].grad - check_sizes_and_values(loss_grad, phi_batch_grad, atol=ATOL, rtol=RTOL) - - # ========================== - # TEST 3: Loss GGN - # ========================== - phi_GGNs_list = Phi_GGNs_list(X, theta, MSE_reduction, W_1, W_2) - phi_batch_GGN = reduce_list(phi_GGNs_list, MSE_reduction) - theta_params = list(model.parameters())[1] - loss_GGN = autograd_hessian(loss, [theta_params]) - check_sizes_and_values(loss_GGN, phi_batch_GGN, atol=ATOL, rtol=RTOL) - - # Go through all eigenvectors and compute lambdas and gammas - eigvals, eigvecs = torch.symeig(phi_batch_GGN, eigenvectors=True) - phi_lambdas = torch.zeros(N, OUT_1) - phi_gammas = torch.zeros(N, OUT_1) - for i in range(N): - phi_grad = phi_grads_list[i] - phi_GGN = phi_GGNs_list[i] - for j in range(OUT_1): - eigvec = eigvecs[:, j] - - # Compute gammas and lambdas - phi_gammas[i, j] = torch.dot(eigvec, phi_grad).item() - phi_lambdas[i, j] = torch.dot(eigvec @ phi_GGN, eigvec).item() - - # Now, compute lambdas and gammas with vivit-utilities - top_k = make_criterion(k=OUT_1) - param_groups = [ - { - "params": [p for p in model.parameters() if p.requires_grad], - "criterion": top_k, - } - ] - computations = BaseComputations() - savefield = "test_newton_step" - const_damping = ConstantDamping(delta) - - # Forward and backward pass - loss = loss_func(model(X), torch.zeros(N, OUT_2)) - with backpack( - *computations.get_extensions(param_groups), - extension_hook=computations.get_extension_hook( - param_groups, const_damping, savefield - ), - ): - loss.backward() - - # ========================== - # Test 4: gammas - # ========================== - gammas_abs = torch.abs(list(computations._gram_computation._gammas.values())[0]) - check_sizes_and_values(gammas_abs, torch.abs(phi_gammas), atol=ATOL, rtol=RTOL) - - # ========================== - # Test 5: lambdas - # ========================== - lambdas = list(computations._gram_computation._lambdas.values())[0] - check_sizes_and_values(lambdas, phi_lambdas, atol=ATOL, rtol=RTOL) - - # ========================== - # Test 6: Newton step - # ========================== - newton_step = [ - [getattr(param, savefield) for param in group["params"]] - for group in param_groups - ][0][0] - damped_GGN = phi_batch_GGN + delta * torch.eye(OUT_1) - phi_newton_step = torch.solve(-phi_batch_grad.reshape(OUT_1, 1), damped_GGN) - phi_newton_step = phi_newton_step.solution.reshape(-1) - check_sizes_and_values(newton_step, phi_newton_step, atol=ATOL, rtol=RTOL) diff --git a/test/utils.py b/test/utils.py index 386174a..3c8919e 100644 --- a/test/utils.py +++ b/test/utils.py @@ -62,7 +62,9 @@ def check_sizes(*plists): for params in zip(*plists): for i in range(len(params) - 1): - assert params[i].size() == params[i + 1].size() + assert ( + params[i].size() == params[i + 1].size() + ), f"{params[i].size()} vs. {params[i + 1].size()}" def check_values(list1, list2, atol=atol, rtol=rtol): diff --git a/test/utils/test_subsampling.py b/test/utils/test_subsampling.py deleted file mode 100644 index 9db0621..0000000 --- a/test/utils/test_subsampling.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Test subsampling utilities.""" -import pytest -from backpack.extensions import BatchGrad, SqrtGGNExact - -from vivit.utils.subsampling import ( - merge_extensions, - merge_multiple_subsamplings, - merge_subsamplings, - sample_output_mapping, -) - - -def test_sample_output_mapping(): - """Test mapping from samples to output indices.""" - assert sample_output_mapping(None, None) is None - - assert sample_output_mapping([0, 1], None) == [0, 1] - assert sample_output_mapping([0, 1], [2, 1, 0]) == [2, 1] - - assert sample_output_mapping([0, 0], [2, 1, 0]) == [2, 2] - assert sample_output_mapping([2, 0], [2, 1, 0]) == [0, 2] - - with pytest.raises(ValueError): - sample_output_mapping([2, 1, 0], [0, 1]) - - with pytest.raises(ValueError): - sample_output_mapping(None, [0, 1]) - - -def test_merge_subsamplings(): - """Test merging of sub-samplings.""" - assert merge_subsamplings(None, None) is None - assert merge_subsamplings(None, [0, 1]) is None - - assert merge_subsamplings([0, 1], [2, 3]) == [0, 1, 2, 3] - - assert merge_subsamplings([0, 1], [0, 1]) == [0, 1] - - assert merge_subsamplings([0, 3, 1], [7, 0]) == [0, 1, 3, 7] - - -def test_multiple_subsamplings(): - """Test merging of multiple sub-samplings.""" - assert merge_multiple_subsamplings([1, 0]) == [0, 1] - - assert merge_multiple_subsamplings([0, 0, 4, 2], [0, 0, 8], [4, 2]) == [0, 2, 4, 8] - - assert merge_multiple_subsamplings([0, 0, 4, 2], None, [4, 2]) is None - - assert merge_multiple_subsamplings(None, [0, 1], None) is None - - with pytest.raises(ValueError): - merge_multiple_subsamplings() - - -def test_merge_extensions(): - """Test merging of sub-sampled extensions.""" - - assert merge_extensions( - [(BatchGrad, None), (BatchGrad, [0, 1]), (SqrtGGNExact, [0])] - ) == {BatchGrad: None, SqrtGGNExact: [0]} - - assert merge_extensions( - [(BatchGrad, [2, 5, 0, 0]), (BatchGrad, [0, 1]), (SqrtGGNExact, [1, 0])] - ) == { - BatchGrad: [0, 1, 2, 5], - SqrtGGNExact: [0, 1], - } - - assert merge_extensions( - [(BatchGrad, [2, 5, 0, 0]), (BatchGrad, [0, 1]), (BatchGrad, [1, 1, 0])] - ) == { - BatchGrad: [0, 1, 2, 5], - } diff --git a/vivit/optim/__init__.py b/vivit/optim/__init__.py index f5e8b83..c548c24 100644 --- a/vivit/optim/__init__.py +++ b/vivit/optim/__init__.py @@ -1,13 +1,7 @@ """Optimization methods using low-rank representations of the GGN/Fisher.""" -from vivit.optim.computations import BaseComputations -from vivit.optim.damped_newton import DampedNewton -from vivit.optim.damping import ConstantDamping -from vivit.optim.gram_computations import GramComputations +from vivit.optim.directional_derivatives import DirectionalDerivativesComputation __all__ = [ - "DampedNewton", - "ConstantDamping", - "BaseComputations", - "GramComputations", + "DirectionalDerivativesComputation", ] diff --git a/vivit/optim/base.py b/vivit/optim/base.py deleted file mode 100644 index 49e125d..0000000 --- a/vivit/optim/base.py +++ /dev/null @@ -1,29 +0,0 @@ -"""Base class for optimizers that use a closure with BackPACK extensions.""" - -from typing import Callable - -from torch import Tensor -from torch.optim import Optimizer - - -class BackpackOptimizer(Optimizer): - """Base class for optimizers that use a closure with BackPACK extensions. - - Note: - For better control of the backward pass, the closure has different - responsibilities in comparison to the official documentation - (https://pytorch.org/docs/stable/optim.html): It only performs a forward - pass and returns the loss. This optimizer class needs to take care of - clearing the gradients performing the backward pass. - """ - - def step(self, closure: Callable[[], Tensor]): - """Perform a singel optimization step (parameter update). - - Args: - closure: Function that evaluates the model and returns the loss. - - Raises: - NotImplementedError: Must be implemented by subclasses. - """ - raise NotImplementedError diff --git a/vivit/optim/computations.py b/vivit/optim/computations.py deleted file mode 100644 index 1414753..0000000 --- a/vivit/optim/computations.py +++ /dev/null @@ -1,487 +0,0 @@ -"""Handle damped Newton step computations after and during backpropagation.""" - -import math -from functools import partial - -from backpack.extensions import SqrtGGNExact - -from vivit.optim.damping import _DirectionalCoefficients -from vivit.optim.gram_computations import GramComputations -from vivit.utils.ggn import V_mat_prod -from vivit.utils.hooks import ParameterGroupsHook - - -class BaseComputations: - """Base class for assigning mini-batch samples in a mini-batch to computations. - - The algorithms rely on three fundamental steps, to which samples may be assigned: - - - Computing the Newton directions ``{e[d]}``. - - Computing the first-order derivatives ``{γ[n,d]}`` along the directions. - - Computing the second-order derivatives ``{λ[n,d]}`` along the directions. - - The three mini-batch subsets used for each task need not be disjoint. - """ - - def __init__( - self, - subsampling_directions=None, - subsampling_first=None, - subsampling_second=None, - extension_cls_directions=SqrtGGNExact, - extension_cls_second=SqrtGGNExact, - verbose=False, - ): - """Store indices of samples used for each task. - - Args: - subsampling_directions ([int] or None): Indices of samples used to compute - Newton directions. If ``None``, all samples in the batch will be used. - subsampling_first ([int] or None): Indices of samples used to compute first- - order directional derivatives along the Newton directions. If ``None``, - all samples in the batch will be used. - subsampling_second ([int] or None): Indices of samples used to compute - second-order directional derivatives along the Newton directions. If - ``None``, all samples in the batch will be used. - extension_cls_directions (backpack.backprop_extension.BackpropExtension): - BackPACK extension class used to compute descent directions. - extension_cls_second (backpack.backprop_extension.BackpropExtension): - BackPACK extension class used to compute second-order directional - derivatives. - verbose (bool, optional): Turn on verbose mode. Default: ``False``. - """ - self._gram_computation = GramComputations( - subsampling_directions=subsampling_directions, - subsampling_first=subsampling_first, - subsampling_second=subsampling_second, - extension_cls_directions=extension_cls_directions, - extension_cls_second=extension_cls_second, - verbose=verbose, - ) - self._verbose = verbose - - # filled via side effects during update step computation, keys are group ids - self._coefficients = {} - self._newton_step = {} - - def get_extensions(self, param_groups): - """Return the instantiated BackPACK extensions required in the backward pass. - - Args: - param_groups (list): Parameter group list from a ``torch.optim.Optimizer``. - - Returns: - [backpack.extensions.backprop_extension.BackpropExtension]: List of - extensions that can be handed into a ``with backpack(...)`` context. - """ - return self._gram_computation.get_extensions(param_groups) - - def get_extension_hook( - self, - param_groups, - coefficients: _DirectionalCoefficients, - savefield, - keep_gram_mat=True, - keep_gram_evals=True, - keep_gram_evecs=True, - keep_gammas=True, - keep_lambdas=True, - keep_batch_size=True, - keep_coefficients: bool = True, - keep_newton_step=True, - keep_backpack_buffers=True, - ): - """Return hook to be executed right after a BackPACK extension during backprop. - - Args: - param_groups (list): Parameter group list from a ``torch.optim.Optimizer``. - coefficients: Instance for computing Newton step coefficients from first- - and second-order directional derivatives. - savefield (str): Name of the attribute created in the parameters. - keep_gram_mat (bool, optional): Keep buffers for Gram matrix under group id - in ``self._gram_computation._gram_mat``. Default: ``True`` - keep_gram_evals (bool, optional): Keep buffers for filtered Gram matrix - eigenvalues under group id in ``self._gram_computation._gram_evals``. - Default: ``True`` - keep_gram_evecs (bool, optional): Keep buffers for filtered Gram matrix - eigenvectors under group id in ``self._gram_computation._gram_evecs``. - Default: ``True`` - keep_gammas (bool, optional): Keep buffers for first-order directional - derivatives under group id in ``self._gram_computation._gammas``. - Default: ``True`` - keep_lambdas (bool, optional): Keep buffers for second-order directional - derivatives under group id in ``self._gram_computation._lambdas``. - Default: ``True`` - keep_batch_size (bool, optional): Keep batch size for under group id - in ``self._gram_computation._lambdas``. Default: ``True`` - keep_coefficients: Keep Newton step coefficients under group id in - ``self._coefficients``. Default: ``True``. - keep_newton_step (bool, optional): Keep damped Newton step under group id - in ``self._newton_step``. Default: ``True``. - keep_backpack_buffers (bool, optional): Keep buffers from used BackPACK - extensions during backpropagation. Default: ``True``. - - Returns: - callable or None: Hook function that can be handed into a - ``with backpack(...)`` context. ``None`` signifies no action will be - performed. - """ - hook_store_batch_size = self._gram_computation._get_hook_store_batch_size( - param_groups - ) - - param_computation = self.get_param_computation() - group_hook = self.get_group_hook( - coefficients, - savefield, - keep_gram_mat=keep_gram_mat, - keep_gram_evals=keep_gram_evals, - keep_gram_evecs=keep_gram_evecs, - keep_gammas=keep_gammas, - keep_lambdas=keep_lambdas, - keep_batch_size=keep_batch_size, - keep_coefficients=keep_coefficients, - keep_newton_step=keep_newton_step, - keep_backpack_buffers=keep_backpack_buffers, - ) - accumulate = self.get_accumulate() - - hook = ParameterGroupsHook.from_functions( - param_groups, param_computation, group_hook, accumulate - ) - - def extension_hook(module): - """Extension hook executed right after BackPACK extensions during backprop. - - Chains together all the required computations. - - Args: - module (torch.nn.Module): Layer on which the hook is executed. - """ - if self._verbose: - print(f"Extension hook on module {id(module)} {module}") - hook_store_batch_size(module) - hook(module) - - if self._verbose: - print("ID map groups → params") - for group in param_groups: - print(f"{id(group)} → {[id(p) for p in group['params']]}") - - return extension_hook - - def get_param_computation(self): - """Set up the ``param_computation`` function of the ``ParameterGroupsHook``. - - Returns: - function: Function that can be bound to a ``ParameterGroupsHook`` instance. - Performs an action on the accumulated results over parameters for a - group. - """ - return self._gram_computation.get_param_computation(keep_backpack_buffers=True) - - def get_group_hook( - self, - coefficients: _DirectionalCoefficients, - savefield, - keep_gram_mat, - keep_gram_evals, - keep_gram_evecs, - keep_gammas, - keep_lambdas, - keep_batch_size, - keep_coefficients: bool, - keep_newton_step, - keep_backpack_buffers, - ): - """Set up the ``group_hook`` function of the ``ParameterGroupsHook``. - - Args: - coefficients: Instance for computing Newton step coefficients from first- - and second-order directional derivatives. - savefield (str): Name of the attribute created in the parameters. - keep_gram_mat (bool): Keep buffers for Gram matrix under group id in - ``self._gram_computation._gram_mat``. - keep_gram_evals (bool): Keep buffers for filtered Gram matrix - eigenvalues under group id in ``self._gram_computation._gram_evals``. - keep_gram_evecs (bool): Keep buffers for filtered Gram matrix - eigenvectors under group id in ``self._gram_computation._gram_evecs``. - keep_gammas (bool): Keep buffers for first-order directional - derivatives under group id in ``self._gram_computation._gammas``. - keep_lambdas (bool): Keep buffers for second-order directional - derivatives under group id in ``self._gram_computation._lambdas``. - keep_batch_size (bool): Keep batch size for under group id - in ``self._gram_computation._lambdas``. Default: ``True`` - keep_coefficients: Keep Newton step coefficients under group id in - ``self._coefficients``. - keep_newton_step (bool): Keep damped Newton step under group id - in ``self._newton_step``. - keep_backpack_buffers (bool): Keep buffers from used BackPACK - extensions during backpropagation. - - Returns: - function: Function that can be bound to a ``ParameterGroupsHook`` instance. - Performs an action on the accumulated results over parameters for a - group. - """ - group_hook_gram = self._gram_computation.get_group_hook( - keep_gram_mat=True, - keep_gram_evals=True, - keep_gram_evecs=True, - keep_gammas=True, - keep_lambdas=True, - keep_batch_size=True, - ) - group_hook_newton_step = partial( - self._group_hook_newton_step, coefficients=coefficients - ) - group_hook_load_to_params = partial( - self._group_hook_load_to_params, savefield=savefield - ) - group_hook_memory_cleanup = partial( - self._group_hook_memory_cleanup, - keep_gram_mat=keep_gram_mat, - keep_gram_evals=keep_gram_evals, - keep_gram_evecs=keep_gram_evecs, - keep_gammas=keep_gammas, - keep_lambdas=keep_lambdas, - keep_batch_size=keep_batch_size, - keep_coefficients=keep_coefficients, - keep_newton_step=keep_newton_step, - keep_backpack_buffers=keep_backpack_buffers, - ) - - def group_hook(self, accumulation, group): - """Compute Newton step, load to parameter, clean up. - - Args: - self (ParameterGroupsHook): Group hook to which this function will be - bound. - accumulation (dict): Accumulated dot products. - group (dict): Parameter group of a ``torch.optim.Optimizer``. - """ - group_hook_gram(self, accumulation, group) - group_hook_newton_step(accumulation, group) - group_hook_load_to_params(accumulation, group) - group_hook_memory_cleanup(accumulation, group) - - return group_hook - - def get_accumulate(self): - """Set up the ``accumulate`` function of the ``ParameterGroupsHook``. - - Returns: - function: Function that can be bound to a ``ParameterGroupsHook`` instance. - Accumulates the parameter computations. - """ - return self._gram_computation.get_accumulate() - - # group hooks - - def _group_hook_memory_cleanup( - self, - accumulation, - group, - keep_gram_mat, - keep_gram_evals, - keep_gram_evecs, - keep_gammas, - keep_lambdas, - keep_batch_size, - keep_coefficients: bool, - keep_newton_step, - keep_backpack_buffers, - ): - """Free cached information for an optimizer group. - - Modifies temporary buffers. - - Args: - accumulation (dict): Dictionary with accumulated information. - group (dict): Parameter group of a ``torch.optim.Optimizer``. - keep_gram_mat (bool): Keep buffers for Gram matrix under group id in - ``self._gram_computation._gram_mat``. - keep_gram_evals (bool): Keep buffers for filtered Gram matrix - eigenvalues under group id in ``self._gram_computation._gram_evals``. - keep_gram_evecs (bool): Keep buffers for filtered Gram matrix - eigenvectors under group id in ``self._gram_computation._gram_evecs``. - keep_gammas (bool): Keep buffers for first-order directional - derivatives under group id in ``self._gram_computation._gammas``. - keep_lambdas (bool): Keep buffers for second-order directional - derivatives under group id in ``self._gram_computation._lambdas``. - keep_batch_size (bool): Keep batch size for under group id - in ``self._gram_computation._lambdas``. Default: ``True`` - keep_coefficients: Keep Newton step coefficients under group id in - ``self._coefficients``. - keep_newton_step (bool): Keep damped Newton step under group id - in ``self._newton_step``. - keep_backpack_buffers (bool): Keep buffers from used BackPACK - extensions during backpropagation. - """ - self._gram_computation._group_hook_memory_cleanup( - accumulation, - group, - keep_gram_mat=keep_gram_mat, - keep_gram_evals=keep_gram_evals, - keep_gram_evecs=keep_gram_evecs, - keep_gammas=keep_gammas, - keep_lambdas=keep_lambdas, - keep_batch_size=keep_batch_size, - ) - - savefields = { - self._gram_computation._savefield_directions, - self._gram_computation._savefield_first, - self._gram_computation._savefield_second, - } - - if not keep_backpack_buffers: - for param in group["params"]: - for savefield in savefields: - - if self._verbose: - print(f"Param {id(param)}: Delete '{savefield}'") - - delattr(param, savefield) - - buffers = [] - - if not keep_newton_step: - buffers.append("_newton_step") - - if not keep_coefficients: - buffers.append("_coefficients") - - group_id = id(group) - for b in buffers: - - if self._verbose: - print(f"Group {group_id}: Delete '{b}'") - - getattr(self, b).pop(group_id) - - def _group_hook_newton_step( - self, accumulation, group, coefficients: _DirectionalCoefficients - ): - """Evaluate the damped Newton update. - - Sets the following entries under the id of ``group``: - - - In ``self._coefficients``: Newton step coefficients. - - In ``self._newton_step``: Damped Newton step. - - Args: - accumulation (dict): Dictionary with accumulated information. - group (dict): Parameter group of a ``torch.optim.Optimizer``. - coefficients: Instance for computing Newton step coefficients. - """ - group_id = id(group) - - gram_evals = self._gram_computation._gram_evals[group_id] - gram_evecs = self._gram_computation._gram_evecs[group_id] - gammas = self._gram_computation._gammas[group_id] - lambdas = self._gram_computation._lambdas[group_id] - N_dir = ( - self._gram_computation._batch_size[group_id] - if self._gram_computation._subsampling_directions is None - else len(self._gram_computation._subsampling_directions) - ) - C_dir = gram_evecs.shape[0] // N_dir - V_mp = self._get_V_mat_prod(group) - - newton_coefficients = coefficients.compute_coefficients(gammas, lambdas) - self._coefficients[group_id] = newton_coefficients - - if self._verbose: - print(f"Group {id(group)}: Store '_coefficients'") - """ - Don't expand directions in parameter space. Instead, use - - ``eₖ = V ẽₖ / √λₖ`` - - to perform the summation over ``k`` in the Gram space, - - ``∑ₖ cₖ eₖ = V [∑ₖ (cₖ / √λₖ) ẽₖ]``. - """ - gram_step = (newton_coefficients / gram_evals.sqrt() * gram_evecs).sum(1) - gram_step = gram_step.reshape(1, C_dir, N_dir) - newton_step = [V_g.squeeze(0) for V_g in V_mp(gram_step)] - - # compensate scale of V - N = self._gram_computation._batch_size[group_id] - newton_step = [math.sqrt(N / N_dir) * step for step in newton_step] - - self._newton_step[group_id] = newton_step - - if self._verbose: - print(f"Group {id(group)}: Store '_newton_step'") - - def _group_hook_load_to_params(self, accumulation, group, savefield): - """Copy the damped Newton step to the group parameters. - - Creates a ``savefield`` attribute in each parameter of ``group``. - - Args: - accumulation (dict): Dictionary with accumulated information. - group (dict): Parameter group of a ``torch.optim.Optimizer``. - savefield (str): Name of the attribute created in the parameters. - """ - group_id = id(group) - - params = group["params"] - newton_step = self._newton_step[group_id] - - for param, newton in zip(params, newton_step): - self._save_to_param(param, newton, savefield) - - def _get_V_mat_prod(self, group): - """Get multiplication with curvature matrix square root used by directions. - - Args: - group (dict): Parameter group of a ``torch.optim.Optimizer``. - - Returns: - function: Vectorized multiplication with curvature matrix square root ``V``. - """ - return partial( - V_mat_prod, - parameters=group["params"], - savefield=self._gram_computation._savefield_directions, - subsampling=self._gram_computation._access_directions, - ) - - def _load_newton_step_to_params(self, group, savefield): - """Copy the damped Newton step to the group parameters. - - Must be called after ``self._eval_newton``. - - Creates a ``savefield`` attribute in each parameter of ``group``. - - Args: - group (dict): Parameter group of a ``torch.optim.Optimizer``. - savefield (str): Name of the attribute created in the parameters. - """ - group_id = id(group) - - params = group["params"] - newton_step = self._newton_step[group_id] - - for param, newton in zip(params, newton_step): - self._save_to_param(param, newton, savefield) - - @staticmethod - def _save_to_param(param, value, savefield): - """Save ``value`` in ``param`` under ``savefield``. - - Args: - param (torch.nn.Parameter): Parameter to which ``value`` is attached. - value (any): Saved quantity. - savefield (str): Name of the attribute to save ``value`` in. - - Raises: - ValueError: If the attribute field is already occupied. - """ - if hasattr(param, savefield): - raise ValueError(f"Savefield {savefield} already exists.") - else: - setattr(param, savefield, value) diff --git a/vivit/optim/damped_newton.py b/vivit/optim/damped_newton.py deleted file mode 100644 index 60de979..0000000 --- a/vivit/optim/damped_newton.py +++ /dev/null @@ -1,210 +0,0 @@ -"""PyTorch optimizer with damped Newton updates.""" - -from typing import Callable, List - -from backpack import backpack -from backpack.extensions.backprop_extension import BackpropExtension -from torch import Tensor -from torch.optim import Optimizer - -from vivit.optim.computations import BaseComputations -from vivit.optim.damping import _DirectionalCoefficients - - -class DampedNewton(Optimizer): - """ - Newton optimizer damped via bootstrapped 1st- and 2nd-order directional derivatives. - - Attributes: - SAVEFIELD: Field under which the damped Newton update is stored in a parameter. - """ - - SAVEFIELD: str = "damped_newton_step" - - def __init__( - self, - parameters: List[Tensor], - coefficients: _DirectionalCoefficients, - computations: BaseComputations, - criterion: Callable[[Tensor], List[int]], - ): - """Initialize the optimizer, specifying the damping damping and sample split. - - Args: - parameters: List of parameters to be trained. - coefficients: Policy for computing Newton step coefficients from first- - and second- order directional derivatives. - computations: Assignment of mini-batch samples to the different - computational tasks (finding directions, computing first- and - second-order derivatives along them). - criterion: Maps eigenvalues to indices of eigenvalues that are - kept as directions. Assumes eigenvalues to be sorted in ascending order. - """ - defaults = {"criterion": criterion} - super().__init__(parameters, defaults=defaults) - - self._coefficients = coefficients - self._computations = computations - - def get_extensions(self) -> List[BackpropExtension]: - """Return the required extensions for BackPACK. - - They can directly be placed inside a ``with backpack(...)`` context. - - Returns: - List of extensions that can be handed into a ``with backpack(...)`` context. - """ - return self._computations.get_extensions(self.param_groups) - - def get_extension_hook( - self, - keep_gram_mat=False, - keep_gram_evals=False, - keep_gram_evecs=False, - keep_gammas=False, - keep_lambdas=False, - keep_batch_size=False, - keep_coefficients: bool = False, - keep_newton_step=False, - keep_backpack_buffers=False, - ): - """Return hook to be executed right after a BackPACK extension during backprop. - - Args: - keep_gram_mat (bool, optional): Keep buffers for Gram matrix under group id - in ``self._computations._gram_computation._gram_mat``. - Default: ``False`` - keep_gram_evals (bool, optional): Keep buffers for filtered Gram matrix - eigenvalues under group id in - ``self._computations._gram_computation._gram_evals``. Default: ``False`` - keep_gram_evecs (bool, optional): Keep buffers for filtered Gram matrix - eigenvectors under group id in - ``self._computations._gram_computation._gram_evecs``. Default: ``False`` - keep_gammas (bool, optional): Keep buffers for first-order directional - derivatives under group id in - ``self._computations._gram_computation._gammas``. Default: ``False`` - keep_lambdas (bool, optional): Keep buffers for second-order directional - derivatives under group id in - ``self._computations._gram_computation._lambdas``. Default: ``False`` - keep_batch_size (bool, optional): Keep batch size for under group id - in ``self._computations._gram_computation._lambdas``. Default: ``False`` - keep_coefficients: Keep Newton step coefficients under group id in - ``self._computations._coefficients``. Default: ``False``. - keep_newton_step (bool, optional): Keep damped Newton step under group id - in ``self._computations._newton_step``. Default: ``False``. - keep_backpack_buffers (bool, optional): Keep buffers from used BackPACK - extensions during backpropagation. Default: ``False``. - - Returns: - callable or None: Hook function that can be handed into a - ``with backpack(...)`` context. ``None`` signifies no action will be - performed. - """ - return self._computations.get_extension_hook( - self.param_groups, - self._coefficients, - self.SAVEFIELD, - keep_gram_mat=keep_gram_mat, - keep_gram_evals=keep_gram_evals, - keep_gram_evecs=keep_gram_evecs, - keep_gammas=keep_gammas, - keep_lambdas=keep_lambdas, - keep_batch_size=keep_batch_size, - keep_coefficients=keep_coefficients, - keep_newton_step=keep_newton_step, - keep_backpack_buffers=keep_backpack_buffers, - ) - - def step( - self, - closure: Callable[[], Tensor] = None, - lr: float = 1.0, - keep_gram_mat: bool = False, - keep_gram_evals: bool = False, - keep_gram_evecs: bool = False, - keep_gammas: bool = False, - keep_lambdas: bool = False, - keep_batch_size: bool = False, - keep_coefficients: bool = False, - keep_newton_step: bool = False, - keep_backpack_buffers: bool = False, - ): - """Apply damped Newton step to all parameters. - - Modifies the ``.data`` entry of each parameter. - - Args: - closure: Function to reevaluate the model and return the loss. This - function should only perform the forward pass, BUT NOT the additional - steps outlined in https://pytorch.org/docs/stable/optim.html. - lr: Learning rate. The Newton step is scaled by this value before - it is applied to the network parameters. The default value is ``1.0``. - keep_gram_mat: (only relevant if closure us passed) Keep buffers for Gram - matrix under group id in - ``self._computations._gram_computation._gram_mat``. Default: ``False`` - keep_gram_evals: (only relevant if closure us passed) Keep buffers for - filtered Gram matrix eigenvalues under group id in - ``self._computations._gram_computation._gram_evals``. Default: ``False`` - keep_gram_evecs: (only relevant if closure us passed) Keep buffers for - filtered Gram matrix eigenvectors under group id in - ``self._computations._gram_computation._gram_evecs``. Default: ``False`` - keep_gammas: (only relevant if closure us passed) Keep buffers for - first-order directional derivatives under group id in - ``self._computations._gram_computation._gammas``. Default: ``False`` - keep_lambdas: (only relevant if closure us passed) Keep buffers for - second-order directional derivatives under group id in - ``self._computations._gram_computation._lambdas``. Default: ``False`` - keep_batch_size: (only relevant if closure us passed) Keep batch size for - under group id in ``self._computations._gram_computation._lambdas``. - Default: ``False`` - keep_coefficients: Keep Newton step coefficients under group id in - ``self._computations._coefficients``. Default: ``False``. - keep_newton_step: (only relevant if closure us passed) Keep damped Newton - step under group id in ``self._computations._newton_step``. - Default: ``False``. - keep_backpack_buffers: (only relevant if closure us passed) Keep buffers - from used BackPACK extensions during backpropagation. Default: - ``False``. - """ - if closure is not None: - self.zero_grad() - self.zero_newton() - loss = closure() - extensions = self.get_extensions() - hook = self.get_extension_hook( - keep_gram_mat=keep_gram_mat, - keep_gram_evals=keep_gram_evals, - keep_gram_evecs=keep_gram_evecs, - keep_gammas=keep_gammas, - keep_lambdas=keep_lambdas, - keep_batch_size=keep_batch_size, - keep_coefficients=keep_coefficients, - keep_newton_step=keep_newton_step, - keep_backpack_buffers=keep_backpack_buffers, - ) - with backpack(*extensions, extension_hook=hook): - loss.backward() - - for group in self.param_groups: - self.step_group(group, lr) - - def step_group(self, group, lr=1.0): - """Apply damped Newton step to a parameter group. - - Modifies the ``.data`` entry of each group parameter. - - Args: - group (dict): Parameter group. Entry of a ``torch.optim.Optimizer``'s - ``param_groups`` list. - lr (float): Learning rate. The Newton step is scaled by this value before - it is applied to the network parameters. The default value is ``1.0``. - """ - for param in group["params"]: - param.data.add_(getattr(param, self.SAVEFIELD), alpha=lr) - - def zero_newton(self): - """Delete the parameter attributes used to store the Newton steps.""" - for group in self.param_groups: - for param in group["params"]: - if hasattr(param, self.SAVEFIELD): - delattr(param, self.SAVEFIELD) diff --git a/vivit/optim/damping.py b/vivit/optim/damping.py deleted file mode 100644 index ecd5c0b..0000000 --- a/vivit/optim/damping.py +++ /dev/null @@ -1,453 +0,0 @@ -"""Damping policies from first- and second-order directional derivatives.""" - -from typing import Dict, Tuple - -import torch -from torch import Tensor - - -class _DirectionalCoefficients: - """Base class defining the interface for computing Newton step coefficients.""" - - def compute_coefficients( - self, first_derivatives: Tensor, second_derivatives: Tensor - ) -> Tensor: - """Compute the Newton step coefficients. - - Let ``N₁`` and ``N₂`` denote the number of samples used for computing first- - and second-order derivatives respectively. Let ``D`` be the number of - directions. - - Args: - first_derivatives: 2d tensor of shape ``[N₁,D]`` with the - gradient projections ``γ[n, d]`` of sample ``n`` along direction ``d``. - second_derivatives: 2d tensor of shape ``[N₂, D]`` with the - curvature projections ``λ[n, d]`` of sample ``n`` along direction ``d``. - - Returns: # noqa: DAR202 - 1d tensor of shape ``[D]`` with coefficients ``c[d]`` along direction ``d``. - - Raises: - NotImplementedError: Must be implemented by a child class. - """ - raise NotImplementedError - - -class _Damping(_DirectionalCoefficients): - """Base class for policies to determine the damping parameter. - - To create a new damping policy, the following methods need to be implemented by - a child class: - - - ``__call__`` - - """ - - def __init__(self, save_history: bool = False): - """Initialize damping, enable saving of previously computed values. - - Args: - save_history: Whether to store the computed dampings. Default: ``False``. - Only use this option if you need access to the damping values (e.g. - for logging). - """ - self._save_history = save_history - self._history: Dict[Tuple[int, int], Tensor] = {} - - def compute_coefficients( - self, first_derivatives: Tensor, second_derivatives: Tensor - ) -> Tensor: - """Compute Newton step coefficients ``cₖ = - γₖ / (λₖ + δₖ)``. - - Let ``N₁`` and ``N₂`` denote the number of samples used for computing first- - and second-order derivatives respectively. Let ``D`` be the number of - directions. - - Args: - first_derivatives: 2d tensor of shape ``[N₁,D]`` with the - gradient projections ``γ[n, d]`` of sample ``n`` along direction ``d``. - second_derivatives: 2d tensor of shape ``[N₂, D]`` with the - curvature projections ``λ[n, d]`` of sample ``n`` along direction ``d``. - - Returns: - 1d tensor of shape ``[D]`` with coefficients ``c[d]`` along direction ``d``. - """ - batch_axis = 0 - gammas_mean = first_derivatives.mean(batch_axis) - lambdas_mean = second_derivatives.mean(batch_axis) - - deltas = self.__call__(first_derivatives, second_derivatives) - - return -gammas_mean / (lambdas_mean + deltas) - - def __call__(self, first_derivatives: Tensor, second_derivatives: Tensor) -> Tensor: - """Determine damping parameter for each direction. - - Let ``N₁`` and ``N₂`` denote the number of samples used for computing first- - and second-order derivatives respectively. Let ``D`` be the number of - directions. - - Args: - first_derivatives: 2d tensor of shape ``[N₁,D]`` with the - gradient projections ``γ[n, d]`` of sample ``n`` along direction ``d``. - second_derivatives: 2d tensor of shape ``[N₂, D]`` with the - curvature projections ``λ[n, d]`` of sample ``n`` along direction ``d``. - - Returns: - 1d tensor of shape ``[D]`` with dampings ``δ[d]`` along direction ``d``. - """ - damping = self.compute_damping(first_derivatives, second_derivatives) - - if self._save_history: - key = (id(first_derivatives), id(second_derivatives)) - self._history[key] = damping - - return damping - - def get_from_history( - self, first_derivatives: Tensor, second_derivatives: Tensor, pop: bool = False - ) -> Tensor: - """Load previously computed damping values from history. - - Args: - first_derivatives: First input used for damping in ``compute_damping``. - second_derivatives: Second input used for damping in ``compute_damping``. - pop: Whether to pop the returned value from the internal saved ones. - Default: ``False``. - - Returns: - Damping value from history. - """ - key = (id(first_derivatives), id(second_derivatives)) - - return self._history.pop(key) if pop else self._history[key] - - def compute_damping( - self, first_derivatives: Tensor, second_derivatives: Tensor - ) -> Tensor: - """Compute the damping for each direction. - - Let ``N₁`` and ``N₂`` denote the number of samples used for computing first- - and second-order derivatives respectively. Let ``D`` be the number of - directions. - - Args: - first_derivatives: 2d tensor of shape ``[N₁,D]`` with the gradient - projections ``γ[n, d]`` of sample ``n`` along direction ``d``. - second_derivatives: 2d tensor of shape ``[N₂, D]`` with the curvature - projections ``λ[n, d]`` of sample ``n`` along direction ``d``. - - Returns: # noqa: DAR202 - 1d tensor of shape ``[D]`` with dampings ``δ[d]`` along direction ``d``. - - Raises: - NotImplementedError: Must be implemented by a child class. - """ - raise NotImplementedError - - -class ConstantDamping(_Damping): - """Constant isotropic damping.""" - - def __init__(self, damping: float = 1.0, save_history: bool = False): - """Store damping constant. - - Args: - damping: Damping constant. Default value uses ``1.0``. - save_history: Whether to store the computed dampings. Default: ``False``. - Only use this option if you need access to the damping values (e.g. - for logging). - """ - super().__init__(save_history=save_history) - - self._damping = damping - - def compute_damping( - self, first_derivatives: Tensor, second_derivatives: Tensor - ) -> Tensor: - num_directions = first_derivatives.shape[1] - device = first_derivatives.device - - return self._damping * torch.ones(num_directions, device=device) - - -class BootstrapDamping(_Damping): - """Adaptive damping, uses Bootstrap to generate gain samples.""" - - DEFAULT_DAMPING_GRID = torch.logspace(-3, 2, 100) - - def __init__( - self, - damping_grid: Tensor = None, - num_resamples: int = 100, - percentile: float = 95.0, - save_history: bool = False, - ): - """Store ``damping_grid``, ``num_resamples`` and ``percentile``. - - Args: - damping_grid: The Bootstrap generates gain samples for all damping values - in ``damping_grid``. Default is a log-equidistant grid between - ``1e-3`` and ``1e2``. - num_resamples: Number of gain samples that are generated using the - Bootstrap. The default value is ``100``. - percentile: Policy for delta finds a curve (among the Bootstrap gain - samples), such that ``percentile`` percent of the gain samples lie - above it. The default value is ``95.0``. - save_history: Whether to store the computed dampings. Default: ``False``. - Only use this option if you need access to the damping values (e.g. - for logging). - """ - super().__init__(save_history=save_history) - - self._damping_grid = ( - damping_grid if damping_grid is not None else self.DEFAULT_DAMPING_GRID - ) - self._num_resamples = num_resamples - self._percentile = percentile - - def _resample(self, sample): - """Create resample of ``sample``. - - Args: - sample (torch.Tensor): 1d ``torch.Tensor`` - - Returns: - torch.Tensor: A 1d ``torch.Tensor`` whose size is the same as ``sample`` - and whose entries are sampled with replacement from ``sample``. - """ - - N = len(sample) - return sample[torch.randint(low=0, high=N, size=(N,))] - - def _delta_policy(self, gains): - """Compute damping based on gains generated by the Bootstrap. - - Args: - gains (torch.Tensor): 2d ``torch.Tensor`` of shape ``[num_resamples, - num_dampings]``, i.e. each row corresponds to one gain resample, - where the gain is evaluated for all dampings in ``damping_grid``. - - Returns: - float or float("inf"): The "optimal" damping. In case no reasonable - damping is found, it will return ``float("inf")``. - """ - - # Compute gain percentile - q = 1 - self._percentile / 100.0 - gain_perc = torch.quantile(gains, q, dim=0) - - # Filter for positive entries in gain_perc - ge_zero = gain_perc >= 0 - if torch.any(ge_zero): - damping_grid_filtered = self._damping_grid[ge_zero] - gain_perc_filtered = gain_perc[ge_zero] - max_idx = torch.argmax(gain_perc_filtered) - return damping_grid_filtered[max_idx] - else: - return float("inf") - - def compute_damping( - self, first_derivatives: Tensor, second_derivatives: Tensor - ) -> Tensor: - """Determine damping parameter for each direction. - - Let ``N₁`` and ``N₂`` denote the number of samples used for computing first- - and second-order derivatives respectively. Let ``D`` be the number of - directions. - - Args: - first_derivatives: 2d tensor of shape ``[N₁,D]`` with the - gradient projections ``γ[n, d]`` of sample ``n`` along direction ``d``. - second_derivatives: 2d tensor of shape ``[N₂, D]`` with the - curvature projections ``λ[n, d]`` of sample ``n`` along direction ``d``. - - Returns: - 1d tensor of shape ``[D]`` with dampings ``δ[d]`` along direction ``d``. - """ - D = first_derivatives.shape[1] - num_dampings = len(self._damping_grid) - device = first_derivatives.device - - self._damping_grid = self._damping_grid.to(device) - - # Vector for dampings for each direction - dampings = torch.zeros(D, device=device) - - for D_idx in range(D): - - # Extract first and second derivatives for current direction - first = first_derivatives[:, D_idx] - second = second_derivatives[:, D_idx] - - # Create gain samples for every delta in self._damping_grid - gains = torch.zeros(self._num_resamples, num_dampings).to(device) - - for resample_idx in range(self._num_resamples): - - # Resample gamma_hat and lambda_hat - gam_hat_re = torch.mean(self._resample(first)) - lam_hat_re = torch.mean(self._resample(second)) - - # Resample tau_hat - tau_hat_re = -torch.mean(self._resample(first)) / ( - torch.mean(self._resample(second)) + self._damping_grid - ) - - # Compute gain and store sample in gains - gain = -gam_hat_re * tau_hat_re - 0.5 * lam_hat_re * tau_hat_re**2 - gains[resample_idx, :] = gain - - # Compute damping based on gains - dampings[D_idx] = self._delta_policy(gains) - - return dampings - - -class BootstrapDamping2(_Damping): - """Adaptive damping, uses Bootstrap to generate gain samples. - - This version differs from ``BootstrapDamping`` with regard to two aspects: - - - First, we don't resample the Newton step ``tau_hat_re``, i.e. we assume this step - to be fixed and we only evaluate the corresponding gain for thsi step in different - (resampled) metrics. - - Second, when resampling gamma and lambda, we use the same resampling indices for - both vectors, because gamma and lambda may be correlated and we loose this - correlation, when we resample them independently. That means: We assume, that - the gammas and lambdas are evaluated on the same samples. - """ - - DEFAULT_DAMPING_GRID = torch.logspace(-3, 2, 100) - - def __init__( - self, - damping_grid: Tensor = None, - num_resamples: int = 100, - percentile: float = 95.0, - save_history: bool = False, - ): - """Store ``damping_grid``, ``num_resamples`` and ``percentile``. - - Args: - damping_grid: The Bootstrap generates gain samples - for all damping values in ``damping_grid``. Default is a log- - equidistant grid between ``1e-3`` and ``1e2``. - num_resamples: This is the number of gain samples that are - generated using the Bootstrap. The default value is ``100``. - percentile: The policy for delta finds a curve (among the - Bootstrap gain samples), such that ``percentile`` percent of the - gain samples lie above it. The default value is ``95.0``. - save_history: Whether to store the computed dampings. Default: ``False``. - Only use this option if you need access to the damping values (e.g. - for logging). - """ - super().__init__(save_history=save_history) - - self._damping_grid = ( - damping_grid if damping_grid is not None else self.DEFAULT_DAMPING_GRID - ) - self._num_resamples = num_resamples - self._percentile = percentile - - def _resample(self, sample): - """Create resample of ``sample``. - - Args: - sample (torch.Tensor): 1d ``torch.Tensor`` - - Returns: - torch.Tensor: A 1d ``torch.Tensor`` whose size is the same as ``sample`` - and whose entries are sampled with replacement from ``sample``. - """ - N = len(sample) - return sample[torch.randint(low=0, high=N, size=(N,))] - - def _delta_policy(self, gains): - """Compute damping based on gains generated by the Bootstrap. - - Args: - gains (torch.Tensor): 2d ``torch.Tensor`` of shape ``[num_resamples, - num_dampings]``, i.e. each row corresponds to one gain resample, - where the gain is evaluated for all dampings in ``damping_grid``. - - Returns: - float or float("inf"): The "optimal" damping. In case no reasonable - damping is found, it will return ``float("inf")``. - """ - # Compute gain percentile - q = 1 - self._percentile / 100.0 - gain_perc = torch.quantile(gains, q, dim=0) - - # Filter for positive entries in gain_perc - ge_zero = gain_perc >= 0 - if torch.any(ge_zero): - damping_grid_filtered = self._damping_grid[ge_zero] - gain_perc_filtered = gain_perc[ge_zero] - max_idx = torch.argmax(gain_perc_filtered) - return damping_grid_filtered[max_idx] - - # gain_median_filtered = torch.quantile(gains, 0.5, dim=0)[ge_zero] - # max_idx = torch.argmax(gain_median_filtered) - # return damping_grid_filtered[max_idx] - else: - return float("inf") - - def compute_damping( - self, first_derivatives: Tensor, second_derivatives: Tensor - ) -> Tensor: - """Determine damping parameter for each direction. - - Let ``N₁`` and ``N₂`` denote the number of samples used for computing first- - and second-order derivatives respectively. Let ``D`` be the number of - directions. - - Args: - first_derivatives: 2d tensor of shape ``[N₁,D]`` with the - gradient projections ``γ[n, d]`` of sample ``n`` along direction ``d``. - second_derivatives: 2d tensor of shape ``[N₂, D]`` with the - curvature projections ``λ[n, d]`` of sample ``n`` along direction ``d``. - - Returns: - 1d tensor of shape ``[D]`` with dampings ``δ[d]`` along direction ``d``. - """ - D = first_derivatives.shape[1] - num_dampings = len(self._damping_grid) - device = first_derivatives.device - - self._damping_grid = self._damping_grid.to(device) - - # Vector for dampings for each direction - dampings = torch.zeros(D, device=device) - - # Make sure that N_1 = N_2 - assert first_derivatives.shape[0] == second_derivatives.shape[0] - - for D_idx in range(D): - - # Extract first and second derivatives for current direction - first = first_derivatives[:, D_idx] - second = second_derivatives[:, D_idx] - - # Determine the step for this direction - step = -torch.mean(first) / (torch.mean(second) + self._damping_grid) - - # Create gain samples for every delta in self._damping_grid - gains = torch.zeros(self._num_resamples, num_dampings).to(device) - - for resample_idx in range(self._num_resamples): - - # Sample one index vector and evaluate both the gammas and lambdas - N = first.numel() # = second.numel() - rand_idx = torch.randint(low=0, high=N, size=(N,)) - gam_hat_re = torch.mean(first[rand_idx]) - lam_hat_re = torch.mean(second[rand_idx]) - - # Compute gain and store sample in gains - gain = -gam_hat_re * step - 0.5 * lam_hat_re * step**2 - gains[resample_idx, :] = gain - - # Compute damping based on gains - dampings[D_idx] = self._delta_policy(gains) - - return dampings diff --git a/vivit/optim/directional_derivatives.py b/vivit/optim/directional_derivatives.py new file mode 100644 index 0000000..21db334 --- /dev/null +++ b/vivit/optim/directional_derivatives.py @@ -0,0 +1,567 @@ +"""Manage computations in Gram space through extension hooks.""" + +import math +from typing import Callable, Dict, List, Optional + +import torch +from backpack.extensions import BatchGrad, SqrtGGNExact, SqrtGGNMC +from backpack.extensions.backprop_extension import BackpropExtension +from torch.nn import Module + +from vivit.optim.utils import get_sqrt_ggn_extension +from vivit.utils.checks import check_subsampling_unique +from vivit.utils.eig import stable_symeig +from vivit.utils.gram import partial_contract, reshape_as_square +from vivit.utils.hooks import ParameterGroupsHook + + +class DirectionalDerivativesComputation: + """Provide BackPACK extension and hook for 1ˢᵗ/2ⁿᵈ-order directional derivatives. + + The directions are given by the GGN eigenvectors. First-order directional + derivatives are denoted ``γ``, second-order directional derivatives as ``λ``. + """ + + def __init__( + self, + subsampling_grad: Optional[List[int]] = None, + subsampling_ggn: Optional[List[int]] = None, + mc_samples_ggn: Optional[int] = 0, + verbose: Optional[bool] = False, + ): + """Specify GGN and gradient approximations. Use no approximations by default. + + Args: + subsampling_grad: Indices of samples used for gradient sub-sampling. + ``None`` (equivalent to ``list(range(batch_size))``) uses all mini-batch + samples to compute directional gradients . Defaults to ``None`` (no + gradient sub-sampling). + subsampling_ggn: Indices of samples used for GGN curvature sub-sampling. + ``None`` (equivalent to ``list(range(batch_size))``) uses all mini-batch + samples to compute directions and directional curvatures. Defaults to + ``None`` (no curvature sub-sampling). + mc_samples_ggn: If ``0``, don't Monte-Carlo (MC) approximate the GGN + (using the same samples to compute the directions and directional + curvatures). Otherwise, specifies the number of MC samples used to + approximate the backpropagated loss Hessian. Default: ``0`` (no MC + approximation). + verbose: Turn on verbose mode. If enabled, this will print what's happening + during backpropagation to command line (consider it a debugging tool). + Defaults to ``False``. + """ + check_subsampling_unique(subsampling_grad) + check_subsampling_unique(subsampling_ggn) + + self._mc_samples_ggn = mc_samples_ggn + + if self._mc_samples_ggn != 0: + assert mc_samples_ggn == 1 + self._extension_cls_ggn = SqrtGGNMC + else: + self._extension_cls_ggn = SqrtGGNExact + + self._extension_cls_grad = BatchGrad + self._savefield_grad = self._extension_cls_grad().savefield + self._subsampling_grad = subsampling_grad + + self._savefield_ggn = self._extension_cls_ggn().savefield + self._subsampling_ggn = subsampling_ggn + + self._verbose = verbose + + # filled via side effects during update step computation, keys are group ids + self._gram_evals = {} + self._gram_evecs = {} + self._gram_mat = {} + self._gammas = {} + self._lambdas = {} + self._batch_size = {} + + def get_extensions(self) -> List[BackpropExtension]: + """Instantiate the BackPACK extensions to compute GGN directional derivatives. + + Returns: + BackPACK extensions, to compute directional 1ˢᵗ- and 2ⁿᵈ-order directional + derivatives along GGN eigenvectors, that should be extracted and passed to + the :py:class:`with backpack(...) ` context. + """ + return [ + self._extension_cls_grad(subsampling=self._subsampling_grad), + get_sqrt_ggn_extension( + subsampling=self._subsampling_ggn, mc_samples=self._mc_samples_ggn + ), + ] + + def get_extension_hook(self, param_groups: List[Dict]) -> Callable[[Module], None]: + """Instantiate BackPACK extension hook to compute GGN directional derivatives. + + Args: + param_groups: Parameter groups list as required by a + ``torch.optim.Optimizer``. Specifies the block structure: Each group + must specify the ``'params'`` key which contains a list of the + parameters that form a GGN block, and a ``'criterion'`` entry that + specifies a filter function to select eigenvalues as directions along + which to compute directional derivatives (details below). + + Examples for ``'params'``: + + - ``[{'params': list(p for p in model.parameters()}]`` uses the full + GGN (one block). + - ``[{'params': [p]} for p in model.parameters()]`` uses a per-parameter + block-diagonal GGN approximation. + + The function specified under ``'criterion'`` is a + ``Callable[[Tensor], List[int]]``. It receives the eigenvalues (in + ascending order) and returns the indices of eigenvalues whose + eigenvectors should be used as directions to evaluate directional + derivatives. Examples: + + - ``{'criterion': lambda evals: [evals.numel() - 1]}`` discards all + directions except for the leading eigenvector. + - ``{'criterion': lambda evals: list(range(evals.numel()))}`` computes + directional derivatives along all Gram matrix eigenvectors. + + Returns: + BackPACK extension hook, to compute directional derivatives, that should be + passed to the :py:class:`with backpack(...) ` context. + The hook computes GGN directional derivatives during backpropagation and + stores them internally (under ``self._gammas`` and ``self._lambdas``). + """ + hook_store_batch_size = self._get_hook_store_batch_size(param_groups) + + param_computation = self.get_param_computation() + group_hook = self.get_group_hook() + accumulate = self.get_accumulate() + + hook = ParameterGroupsHook.from_functions( + param_groups, param_computation, group_hook, accumulate + ) + + def extension_hook(module): + """Extension hook executed right after BackPACK extensions during backprop. + + Chains together all the required computations. + + Args: + module (torch.nn.Module): Layer on which the hook is executed. + """ + if self._verbose: + print(f"Extension hook on module {id(module)} {module}") + hook_store_batch_size(module) + hook(module) + + if self._verbose: + print("ID map groups → params") + for group in param_groups: + print(f"{id(group)} → {[id(p) for p in group['params']]}") + + return extension_hook + + def get_param_computation(self): + """Set up the ``param_computation`` function of the ``ParameterGroupsHook``. + + Returns: + function: Function that can be bound to a ``ParameterGroupsHook`` instance. + Performs an action on the accumulated results over parameters for a + group. + """ + param_computation_V_t_V = self._param_computation_V_t_V + param_computation_V_t_g_n = self._param_computation_V_t_g_n + param_computation_memory_cleanup = self._param_computation_memory_cleanup + + def param_computation(self, param): + """Compute dot products for a parameter used in directional derivatives. + + Args: + self (ParameterGroupsHook): Group hook to which this function will be + bound. + param (torch.Tensor): Parameter of a neural net. + + Returns: + dict: Dictionary with results of the different dot products. Has key + ``"V_t_g_n"``. + """ + result = { + "V_t_V": param_computation_V_t_V(param), + "V_t_g_n": param_computation_V_t_g_n(param), + } + + param_computation_memory_cleanup(param) + + return result + + return param_computation + + def get_group_hook(self): + """Set up the ``group_hook`` function of the ``ParameterGroupsHook``. + + Returns: + function: Function that can be bound to a ``ParameterGroupsHook`` instance. + Performs an action on the accumulated results over parameters for a + group. + """ + group_hook_directions = self._group_hook_directions + group_hook_filter_directions = self._group_hook_filter_directions + group_hook_gammas = self._group_hook_gammas + group_hook_lambdas = self._group_hook_lambdas + group_hook_memory_cleanup = self._group_hook_memory_cleanup + + def group_hook(self, accumulation, group): + """Compute Gram space directions. Evaluate directional derivatives. + + Args: + self (ParameterGroupsHook): Group hook to which this function will be + bound. + accumulation (dict): Accumulated dot products. + group (dict): Parameter group of a ``torch.optim.Optimizer``. + """ + group_hook_directions(accumulation, group) + group_hook_filter_directions(accumulation, group) + group_hook_gammas(accumulation, group) + group_hook_lambdas(accumulation, group) + group_hook_memory_cleanup(accumulation, group) + + return group_hook + + def get_accumulate(self): + """Set up the ``accumulate`` function of the ``ParameterGroupsHook``. + + Returns: + function: Function that can be bound to a ``ParameterGroupsHook`` instance. + Accumulates the parameter computations. + """ + verbose = self._verbose + + def accumulate(self, existing, update): + """Update existing results with computation result of a parameter. + + Args: + self (ParameterGroupsHook): Group hook to which this function will be + bound. + existing (dict): Dictionary containing the different accumulated scalar + products. Must have same keys as ``update``. + update (dict): Dictionary containing the different scalar products for + a parameter. + + Returns: + dict: Updated scalar products. + + Raises: + ValueError: If the two inputs don't have the same keys. + ValueError: If the two values associated to a key have different type. + NotImplementedError: If the rule to accumulate a data type is missing. + """ + same_keys = set(existing.keys()) == set(update.keys()) + if not same_keys: + raise ValueError("Cached and new results have different keys.") + + for key in existing.keys(): + current, new = existing[key], update[key] + + same_type = type(current) is type(new) + if not same_type: + raise ValueError(f"Value for key '{key}' have different types.") + + if isinstance(current, torch.Tensor): + current.add_(new) + elif current is None: + pass + else: + raise NotImplementedError(f"No rule for {type(current)}") + + existing[key] = current + + if verbose: + print(f"Accumulate group entry '{key}'") + + return existing + + return accumulate + + # parameter computations + + def _param_computation_V_t_V(self, param): + """Perform scalar products ``V_t_V`` for a parameter. + + Args: + param (torch.Tensor): Parameter of a neural net. + + Returns: + torch.Tensor: Scalar products ``V_t_V``. + """ + savefields = (self._savefield_ggn, self._savefield_ggn) + subsamplings = (self._subsampling_ggn, self._subsampling_ggn) + start_dims = (2, 2) # only applies to GGN and GGN-MC + + tensors = self._get_subsampled_tensors( + param, start_dims, savefields, subsamplings + ) + + if self._verbose: + print(f"Param {id(param)}: Compute 'V_t_V'") + + return partial_contract(*tensors, start_dims) + + def _param_computation_V_t_g_n(self, param): + """Perform scalar products ``V_t_g_n`` for a parameter. + + Args: + param (torch.Tensor): Parameter of a neural net. + + Returns: + torch.Tensor: Scalar products ``V_t_g_n``. + """ + savefields = (self._savefield_ggn, self._savefield_grad) + subsamplings = (self._subsampling_ggn, self._subsampling_grad) + start_dims = (2, 1) # only applies to (GGN or GGN-MC, BatchGrad) + + tensors = self._get_subsampled_tensors( + param, start_dims, savefields, subsamplings + ) + + if self._verbose: + print(f"Param {id(param)}: Compute 'V_t_g_n'") + + return partial_contract(*tensors, start_dims) + + @staticmethod + def _get_subsampled_tensors(param, start_dims, savefields, subsamplings): + """Fetch the scalar product inputs and apply sub-sampling if necessary. + + Args: + param (torch.Tensor): Parameter of a neural net. + savefields ([str, str]): List containing the attribute names under which + the processed tensors are stored inside a parameter. + start_dims ([int, int]): List holding the dimensions at which the dot + product contractions starts. + subsamplings([[int], [int]]): Sub-samplings that should be applied to the + processed tensors before the scalar product operation. The batch axis + is automatically identified as the last before the contracted + dimensions. An entry of ``None`` does not apply subsampling. Default: + ``(None, None)`` + + Returns: + [torch.Tensor]: List of sub-sampled inputs for the scalar product. + """ + tensors = [] + + for start_dim, savefield, subsampling in zip( + start_dims, savefields, subsamplings + ): + tensor = getattr(param, savefield) + + if subsampling is not None: + batch_axis = start_dim - 1 + select = torch.tensor( + subsampling, dtype=torch.int64, device=tensor.device + ) + tensor = tensor.index_select(batch_axis, select) + + tensors.append(tensor) + + return tensors + + def _param_computation_memory_cleanup(self, param): + """Free buffers in a parameter that are not required anymore. + + Args: + param (torch.Tensor): Parameter of a neural net. + """ + savefields = { + self._savefield_ggn, + self._savefield_grad, + self._savefield_ggn, + } + + for savefield in savefields: + delattr(param, savefield) + + if self._verbose: + print(f"Param {id(param)}: Delete '{savefield}'") + + # group hooks + + def _group_hook_directions(self, accumulation, group): + """Evaluate and store directions of quadratic model in the Gram space. + + Sets the following entries under the id of ``group``: + + - In ``self._gram_evals``: Eigenvalues, sorted in ascending order. + - In ``self._gram_evecs``: Normalized eigenvectors, stacked column-wise. + - In ``self._gram_mat``: The Gram matrix ``Vᵀ V``. + + Args: + accumulation (dict): Dictionary with accumulated scalar products. + group (dict): Parameter group of a ``torch.optim.Optimizer``. + """ + group_id = id(group) + gram_mat = accumulation["V_t_V"] + + # compensate subsampling scale + if self._subsampling_ggn is not None: + N_dir = len(self._subsampling_ggn) + N = self._batch_size[group_id] + gram_mat *= N / N_dir + + gram_evals, gram_evecs = stable_symeig( + reshape_as_square(gram_mat), eigenvectors=True + ) + + # save + self._gram_mat[group_id] = gram_mat + self._gram_evals[group_id] = gram_evals + self._gram_evecs[group_id] = gram_evecs + + if self._verbose: + print(f"Group {id(group)}: Store 'gram_mat', 'gram_evals', 'gram_evecs'") + + def _group_hook_filter_directions(self, accumulation, group): + """Filter Gram directions depending on their eigenvalues. + + Modifies the group entries in ``self._gram_evals`` and ``self._gram_evecs``. + + Args: + accumulation (dict): Dictionary with accumulated scalar products. + group (dict): Parameter group. + """ + group_id = id(group) + + evals = self._gram_evals[group_id] + evecs = self._gram_evecs[group_id] + + keep = group["criterion"](evals) + + self._gram_evals[group_id] = evals[keep] + self._gram_evecs[group_id] = evecs[:, keep] + + if self._verbose: + before, after = len(evals), len(keep) + print(f"Group {id(group)}: Filter directions ({before} → {after})") + + def _group_hook_gammas(self, accumulation, group): + """Evaluate and store first-order directional derivatives ``γ[n, d]``. + + Sets the following entries under the id of ``group``: + + - In ``self._gammas``: First-order directional derivatives. + + Args: + accumulation (dict): Dictionary with accumulated scalar products. + group (dict): Parameter group of a ``torch.optim.Optimizer``. + """ + group_id = id(group) + + # L = ¹/ₙ ∑ᵢ ℓᵢ, BackPACK's BatchGrad computes ¹/ₙ ∇ℓᵢ, we have to rescale + N = self._batch_size[group_id] + + V_t_g_n = N * accumulation["V_t_g_n"] + + # compensate subsampling scale + if self._subsampling_ggn is not None: + N_dir = len(self._subsampling_ggn) + N = self._batch_size[group_id] + V_t_g_n *= math.sqrt(N / N_dir) + + # NOTE Flipping the order (g_n_t_V) may be more efficient + V_t_g_n = V_t_g_n.flatten( + start_dim=0, end_dim=1 + ) # only applies to GGN and GGN-MC + + gammas = ( + torch.einsum("in,id->nd", V_t_g_n, self._gram_evecs[group_id]) + / self._gram_evals[group_id].sqrt() + ) + + self._gammas[group_id] = gammas + + if self._verbose: + print(f"Group {id(group)}: Store 'gammas'") + + def _group_hook_lambdas(self, accumulation, group): + """Evaluate and store second-order directional derivatives ``λ[n, d]``. + + Sets the following entries under the id of ``group``: + + - In ``self._lambdas``: Second-order directional derivatives. + + Args: + accumulation (dict): Dictionary with accumulated scalar products. + group (dict): Parameter group of a ``torch.optim.Optimizer``. + """ + group_id = id(group) + + gram_evals = self._gram_evals[group_id] + gram_evecs = self._gram_evecs[group_id] + gram_mat = self._gram_mat[group_id] + + C_dir, N_dir = gram_mat.shape[:2] + + V_n_T_V = gram_mat.reshape(C_dir, N_dir, C_dir * N_dir) + + if self._subsampling_ggn is not None: + V_n_T_V = V_n_T_V[:, self._subsampling_ggn, :] + + # compensate scale of V_n + V_n_T_V *= math.sqrt(N_dir) + + V_n_T_V_e_d = torch.einsum("cni,id->cnd", V_n_T_V, gram_evecs) + + lambdas = (V_n_T_V_e_d**2).sum(0) / gram_evals + + self._lambdas[group_id] = lambdas + + if self._verbose: + print(f"Group {id(group)}: Store 'lambdas'") + + def _group_hook_memory_cleanup(self, accumulation, group): + """Free up buffers which are not required anymore for a group. + + Modifies temporary buffers. + + Args: + accumulation (dict): Dictionary with accumulated scalar products. + group (dict): Parameter group of a ``torch.optim.Optimizer``. + """ + group_id = id(group) + buffers = ["_gram_mat", "_gram_evals", "_gram_evecs", "_batch_size"] + + for b in buffers: + + if self._verbose: + print(f"Group {group_id}: Delete '{b}'") + + getattr(self, b).pop(group_id) + + def _get_hook_store_batch_size(self, param_groups): + """Create extension hook that stores the batch size during backpropagation. + + Args: + param_groups (list): Parameter group list from a ``torch.optim.Optimizer``. + + Returns: + callable: Hook function to hand into a ``with backpack(...)`` context. + Stores the batch size under the ``self._batch_size`` dictionary for each + group. + """ + + def hook_store_batch_size(module): + """Store batch size internally. + + Modifies ``self._batch_size``. + + Args: + module (torch.nn.Module): The module on which the hook is executed. + """ + if self._batch_size == {}: + batch_axis = 0 + batch_size = module.input0.shape[batch_axis] + + for group in param_groups: + group_id = id(group) + + if self._verbose: + print(f"Group {group_id}: Store 'batch_size'") + + self._batch_size[group_id] = batch_size + + return hook_store_batch_size diff --git a/vivit/optim/gram_computations.py b/vivit/optim/gram_computations.py deleted file mode 100644 index d035644..0000000 --- a/vivit/optim/gram_computations.py +++ /dev/null @@ -1,796 +0,0 @@ -"""Manage computations in Gram space through extension hooks.""" - -import math -import warnings -from functools import partial - -import torch -from backpack.extensions import BatchGrad, SqrtGGNExact - -from vivit.utils.eig import stable_symeig -from vivit.utils.gram import partial_contract, reshape_as_square -from vivit.utils.hooks import ParameterGroupsHook -from vivit.utils.subsampling import is_subset, merge_extensions, sample_output_mapping - - -class GramComputations: - """Compute directions ``{(λₖ, ẽₖ)}``, slopes ``γ[n,k]`` & curvatures ``λ[n,k]``. - - Different samples may be assigned to the three steps. The computation happens - during backpropagation via extension hooks and allows the used buffers to be - discarded immediately afterwards. - - - ``{(λₖ, ẽₖ)}``: Directions in Gram space with associated eigenvalues. - - ``γ[n,k]``: 1st-order directional derivative along ``eₖ`` (implied by ``ẽₖ``). - - ``λ[n,k]``: 2nd-order directional derivative along ``eₖ`` (implied by ``ẽₖ``). - """ - - def __init__( - self, - subsampling_directions=None, - subsampling_first=None, - subsampling_second=None, - extension_cls_directions=SqrtGGNExact, - extension_cls_second=SqrtGGNExact, - compute_gammas=True, - compute_lambdas=True, - verbose=False, - ): - """Store indices of samples used for each task. - - Args: - subsampling_directions ([int] or None): Indices of samples used to compute - Newton directions. If ``None``, all samples in the batch will be used. - subsampling_first ([int] or None): Indices of samples used to compute first- - order directional derivatives along the Newton directions. If ``None``, - all samples in the batch will be used. - subsampling_second ([int] or None): Indices of samples used to compute - second-order directional derivatives along the Newton directions. If - ``None``, all samples in the batch will be used. - extension_cls_directions (backpack.backprop_extension.BackpropExtension): - BackPACK extension class used to compute descent directions. - extension_cls_second (backpack.backprop_extension.BackpropExtension): - BackPACK extension class used to compute second-order directional - derivatives. - compute_gammas (bool, optional): Whether to compute first-order directional - derivatives. Default: ``True`` - compute_lambdas (bool, optional): Whether to compute second-order - directional derivatives. Default: ``True`` - verbose (bool, optional): Turn on verbose mode. Default: ``False``. - """ - self._extension_cls_first = BatchGrad - self._savefield_first = self._extension_cls_first().savefield - self._subsampling_first = subsampling_first - - self._extension_cls_second = extension_cls_second - self._savefield_second = extension_cls_second().savefield - self._subsampling_second = subsampling_second - - self._extension_cls_directions = extension_cls_directions - self._savefield_directions = self._extension_cls_directions().savefield - self._subsampling_directions = subsampling_directions - - # different tasks may use different samples of the same extension - self._merged_extensions = merge_extensions( - [ - (self._extension_cls_first, self._subsampling_first), - (self._extension_cls_second, self._subsampling_second), - (self._extension_cls_directions, self._subsampling_directions), - ] - ) - - # how to access samples from the computed quantities - merged_subsampling_first = self._merged_extensions[self._extension_cls_first] - self._access_first = sample_output_mapping( - self._subsampling_first, merged_subsampling_first - ) - - merged_subsampling_second = self._merged_extensions[self._extension_cls_second] - self._access_second = sample_output_mapping( - self._subsampling_second, merged_subsampling_second - ) - - merged_subsampling_directions = self._merged_extensions[ - self._extension_cls_directions - ] - self._access_directions = sample_output_mapping( - self._subsampling_directions, merged_subsampling_directions - ) - - self._verbose = verbose - - self._compute_gammas = compute_gammas - self._compute_lambdas = compute_lambdas - - # safe guards if directional derivatives are not computed - if not self._compute_gammas: - assert subsampling_first is None - if not self._compute_lambdas: - assert extension_cls_second == extension_cls_directions - assert subsampling_second == subsampling_directions - - # filled via side effects during update step computation, keys are group ids - self._gram_evals = {} - self._gram_evecs = {} - self._gram_mat = {} - self._gammas = {} - self._lambdas = {} - self._batch_size = {} - - def get_extensions(self, param_groups): - """Return the instantiated BackPACK extensions required in the backward pass. - - Args: - param_groups (list): Parameter group list from a ``torch.optim.Optimizer``. - - Returns: - [backpack.extensions.backprop_extension.BackpropExtension]: List of - extensions that can be handed into a ``with backpack(...)`` context. - """ - extensions = [ - ext_cls(subsampling=subsampling) - for ext_cls, subsampling in self._merged_extensions.items() - ] - - if not self._compute_gammas: - extensions = [ - ext - for ext in extensions - if not isinstance(ext, self._extension_cls_first) - ] - - return extensions - - def get_extension_hook( - self, - param_groups, - keep_gram_mat=True, - keep_gram_evals=True, - keep_gram_evecs=True, - keep_gammas=True, - keep_lambdas=True, - keep_batch_size=True, - keep_backpack_buffers=True, - ): - """Return hook to be executed right after a BackPACK extension during backprop. - - Args: - param_groups (list): Parameter group list from a ``torch.optim.Optimizer``. - keep_gram_mat (bool, optional): Keep buffers for Gram matrix under group id - in ``self._gram_mat``. Default: ``True`` - keep_gram_evals (bool, optional): Keep buffers for filtered Gram matrix - eigenvalues under group id in ``self._gram_evals``. Default: ``True`` - keep_gram_evecs (bool, optional): Keep buffers for filtered Gram matrix - eigenvectors under group id in ``self._gram_evecs``. Default: ``True`` - keep_gammas (bool, optional): Keep buffers for first-order directional - derivatives under group id in ``self._gammas``. Default: ``True`` - keep_lambdas (bool, optional): Keep buffers for second-order directional - derivatives under group id in ``self._lambdas``. Default: ``True`` - keep_batch_size (bool, optional): Keep batch size for under group id - in ``self._lambdas``. Default: ``True`` - keep_backpack_buffers (bool, optional): Keep buffers from used BackPACK - extensions during backpropagation. Default: ``True``. - - Returns: - ParameterGroupsHook: Hook that can be handed into a ``with backpack(...)``. - """ - hook_store_batch_size = self._get_hook_store_batch_size(param_groups) - - param_computation = self.get_param_computation(keep_backpack_buffers) - group_hook = self.get_group_hook( - keep_gram_mat=keep_gram_mat, - keep_gram_evals=keep_gram_evals, - keep_gram_evecs=keep_gram_evecs, - keep_gammas=keep_gammas, - keep_lambdas=keep_lambdas, - keep_batch_size=keep_batch_size, - ) - accumulate = self.get_accumulate() - - hook = ParameterGroupsHook.from_functions( - param_groups, param_computation, group_hook, accumulate - ) - - def extension_hook(module): - """Extension hook executed right after BackPACK extensions during backprop. - - Chains together all the required computations. - - Args: - module (torch.nn.Module): Layer on which the hook is executed. - """ - if self._verbose: - print(f"Extension hook on module {id(module)} {module}") - hook_store_batch_size(module) - hook(module) - - if self._verbose: - print("ID map groups → params") - for group in param_groups: - print(f"{id(group)} → {[id(p) for p in group['params']]}") - - return extension_hook - - def get_param_computation(self, keep_backpack_buffers): - """Set up the ``param_computation`` function of the ``ParameterGroupsHook``. - - Args: - keep_backpack_buffers (bool): Keep buffers from used BackPACK extensions - during backpropagation. - - Returns: - function: Function that can be bound to a ``ParameterGroupsHook`` instance. - Performs an action on the accumulated results over parameters for a - group. - """ - param_computation_V_t_V = self._param_computation_V_t_V - param_computation_V_t_g_n = self._param_computation_V_t_g_n - param_computation_V_n_t_V = self._param_computation_V_n_t_V - param_computation_memory_cleanup = partial( - self._param_computation_memory_cleanup, - keep_backpack_buffers=keep_backpack_buffers, - ) - - compute_gammas = self._compute_gammas - compute_lambdas = self._compute_lambdas - - def param_computation(self, param): - """Compute dot products for a parameter used in directional derivatives. - - Args: - self (ParameterGroupsHook): Group hook to which this function will be - bound. - param (torch.Tensor): Parameter of a neural net. - - Returns: - dict: Dictionary with results of the different dot products. Has key - ``"V_t_g_n"``. - """ - result = {} - - result["V_t_V"] = param_computation_V_t_V(param) - - if compute_gammas: - result["V_t_g_n"] = param_computation_V_t_g_n(param) - if compute_lambdas: - result["V_n_t_V"] = param_computation_V_n_t_V(param) - - param_computation_memory_cleanup(param) - - return result - - return param_computation - - def get_group_hook( - self, - keep_gram_mat, - keep_gram_evals, - keep_gram_evecs, - keep_gammas, - keep_lambdas, - keep_batch_size, - ): - """Set up the ``group_hook`` function of the ``ParameterGroupsHook``. - - Args: - keep_gram_mat (bool): Keep buffers for Gram matrix under group id - in ``self._gram_mat``. - keep_gram_evals (bool): Keep buffers for filtered Gram matrix - eigenvalues under group id in ``self._gram_evals``. - keep_gram_evecs (bool): Keep buffers for filtered Gram matrix - eigenvectors under group id in ``self._gram_evecs``. - keep_gammas (bool): Keep buffers for first-order directional - derivatives under group id in ``self._gammas``. - keep_lambdas (bool): Keep buffers for second-order directional - derivatives under group id in ``self._lambdas``. - keep_batch_size (bool): Keep batch size for under group id - in ``self._lambdas``. - - Returns: - function: Function that can be bound to a ``ParameterGroupsHook`` instance. - Performs an action on the accumulated results over parameters for a - group. - """ - group_hook_directions = self._group_hook_directions - group_hook_filter_directions = self._group_hook_filter_directions - group_hook_gammas = self._group_hook_gammas - group_hook_lambdas = self._group_hook_lambdas - group_hook_memory_cleanup = partial( - self._group_hook_memory_cleanup, - keep_gram_mat=keep_gram_mat, - keep_gram_evals=keep_gram_evals, - keep_gram_evecs=keep_gram_evecs, - keep_gammas=keep_gammas, - keep_lambdas=keep_lambdas, - keep_batch_size=keep_batch_size, - ) - - compute_gammas = self._compute_gammas - compute_lambdas = self._compute_lambdas - - def group_hook(self, accumulation, group): - """Compute Gram space directions. Evaluate directional derivatives. - - Args: - self (ParameterGroupsHook): Group hook to which this function will be - bound. - accumulation (dict): Accumulated dot products. - group (dict): Parameter group of a ``torch.optim.Optimizer``. - """ - group_hook_directions(accumulation, group) - group_hook_filter_directions(accumulation, group) - if compute_gammas: - group_hook_gammas(accumulation, group) - if compute_lambdas: - group_hook_lambdas(accumulation, group) - group_hook_memory_cleanup(accumulation, group) - - return group_hook - - def get_accumulate(self): - """Set up the ``accumulate`` function of the ``ParameterGroupsHook``. - - Returns: - function: Function that can be bound to a ``ParameterGroupsHook`` instance. - Accumulates the parameter computations. - """ - verbose = self._verbose - - def accumulate(self, existing, update): - """Update existing results with computation result of a parameter. - - Args: - self (ParameterGroupsHook): Group hook to which this function will be - bound. - existing (dict): Dictionary containing the different accumulated scalar - products. Must have same keys as ``update``. - update (dict): Dictionary containing the different scalar products for - a parameter. - - Returns: - dict: Updated scalar products. - - Raises: - ValueError: If the two inputs don't have the same keys. - ValueError: If the two values associated to a key have different type. - NotImplementedError: If the rule to accumulate a data type is missing. - """ - same_keys = set(existing.keys()) == set(update.keys()) - if not same_keys: - raise ValueError("Cached and new results have different keys.") - - for key in existing.keys(): - current, new = existing[key], update[key] - - same_type = type(current) is type(new) - if not same_type: - raise ValueError(f"Value for key '{key}' have different types.") - - if isinstance(current, torch.Tensor): - current.add_(new) - elif current is None: - pass - else: - raise NotImplementedError(f"No rule for {type(current)}") - - existing[key] = current - - if verbose: - print(f"Accumulate group entry '{key}'") - - return existing - - return accumulate - - # parameter computations - - def _param_computation_V_t_V(self, param): - """Perform scalar products ``V_t_V`` for a parameter. - - Args: - param (torch.Tensor): Parameter of a neural net. - - Returns: - torch.Tensor: Scalar products ``V_t_V``. - """ - savefields = (self._savefield_directions, self._savefield_directions) - subsamplings = (self._access_directions, self._access_directions) - start_dims = (2, 2) # only applies to GGN and GGN-MC - - tensors = self._get_subsampled_tensors( - param, start_dims, savefields, subsamplings - ) - - if self._verbose: - print(f"Param {id(param)}: Compute 'V_t_V'") - - return partial_contract(*tensors, start_dims) - - def _param_computation_V_t_g_n(self, param): - """Perform scalar products ``V_t_g_n`` for a parameter. - - Args: - param (torch.Tensor): Parameter of a neural net. - - Returns: - torch.Tensor: Scalar products ``V_t_g_n``. - """ - savefields = (self._savefield_directions, self._savefield_first) - subsamplings = (self._access_directions, self._access_first) - start_dims = (2, 1) # only applies to (GGN or GGN-MC, BatchGrad) - - tensors = self._get_subsampled_tensors( - param, start_dims, savefields, subsamplings - ) - - if self._verbose: - print(f"Param {id(param)}: Compute 'V_t_g_n'") - - return partial_contract(*tensors, start_dims) - - def _param_computation_V_n_t_V(self, param): - """Perform scalar products ``V_t_g_n`` if not fully contained in ``V_t_V``. - - Args: - param (torch.Tensor): Parameter of a neural net. - - Returns: - None or torch.Tensor: ``None`` if all scalar products are already computed - through ``V_t_V``. Else returns the scalar products. - """ - # assume same extensions for directions and derivatives - self._different_curvatures_not_supported() - - if self._verbose: - print(f"Param {id(param)}: Compute 'V_n_t_V'") - - # ``V_n_t_V`` already computed through ``V_t_V`` - if is_subset(self._subsampling_second, self._subsampling_directions): - return None - else: - # TODO Recycle scalar products that are available from the Gram matrix - # and only compute the missing ones - self._warn_inefficient_subsamplings() - - # re-compute everything, easier but less efficient - savefields = (self._savefield_second, self._savefield_directions) - subsamplings = (self._access_second, self._access_directions) - start_dims = (2, 2) # only applies to (GGN or GGN-MC) - - tensors = self._get_subsampled_tensors( - param, start_dims, savefields, subsamplings - ) - - return partial_contract(*tensors, start_dims) - - @staticmethod - def _get_subsampled_tensors(param, start_dims, savefields, subsamplings): - """Fetch the scalar product inputs and apply sub-sampling if necessary. - - Args: - param (torch.Tensor): Parameter of a neural net. - savefields ([str, str]): List containing the attribute names under which - the processed tensors are stored inside a parameter. - start_dims ([int, int]): List holding the dimensions at which the dot - product contractions starts. - subsamplings([[int], [int]]): Sub-samplings that should be applied to the - processed tensors before the scalar product operation. The batch axis - is automatically identified as the last before the contracted - dimensions. An entry of ``None`` does not apply subsampling. Default: - ``(None, None)`` - - Returns: - [torch.Tensor]: List of sub-sampled inputs for the scalar product. - """ - tensors = [] - - for start_dim, savefield, subsampling in zip( - start_dims, savefields, subsamplings - ): - tensor = getattr(param, savefield) - - if subsampling is not None: - batch_axis = start_dim - 1 - select = torch.tensor( - subsampling, dtype=torch.int64, device=tensor.device - ) - tensor = tensor.index_select(batch_axis, select) - - tensors.append(tensor) - - return tensors - - def _param_computation_memory_cleanup(self, param, keep_backpack_buffers): - """Free buffers in a parameter that are not required anymore. - - Args: - param (torch.Tensor): Parameter of a neural net. - keep_backpack_buffers (bool): Keep buffers from used BackPACK - extensions during backpropagation. - """ - if keep_backpack_buffers: - savefields = [] - else: - savefields = { - self._savefield_directions, - self._savefield_first, - self._savefield_second, - } - - if not self._compute_gammas: - savefields.remove(self._savefield_first) - - for savefield in savefields: - delattr(param, savefield) - - if self._verbose: - print(f"Param {id(param)}: Delete '{savefield}'") - - # group hooks - - def _group_hook_directions(self, accumulation, group): - """Evaluate and store directions of quadratic model in the Gram space. - - Sets the following entries under the id of ``group``: - - - In ``self._gram_evals``: Eigenvalues, sorted in ascending order. - - In ``self._gram_evecs``: Normalized eigenvectors, stacked column-wise. - - In ``self._gram_mat``: The Gram matrix ``Vᵀ V``. - - Args: - accumulation (dict): Dictionary with accumulated scalar products. - group (dict): Parameter group of a ``torch.optim.Optimizer``. - """ - group_id = id(group) - gram_mat = accumulation["V_t_V"] - - # compensate subsampling scale - if self._subsampling_directions is not None: - N_dir = len(self._subsampling_directions) - N = self._batch_size[group_id] - gram_mat *= N / N_dir - - gram_evals, gram_evecs = stable_symeig( - reshape_as_square(gram_mat), eigenvectors=True - ) - - # save - self._gram_mat[group_id] = gram_mat - self._gram_evals[group_id] = gram_evals - self._gram_evecs[group_id] = gram_evecs - - if self._verbose: - print(f"Group {id(group)}: Store 'gram_mat', 'gram_evals', 'gram_evecs'") - - def _group_hook_filter_directions(self, accumulation, group): - """Filter Gram directions depending on their eigenvalues. - - Modifies the group entries in ``self._gram_evals`` and ``self._gram_evecs``. - - Args: - accumulation (dict): Dictionary with accumulated scalar products. - group (dict): Parameter group. - """ - group_id = id(group) - - evals = self._gram_evals[group_id] - evecs = self._gram_evecs[group_id] - - keep = group["criterion"](evals) - - self._gram_evals[group_id] = evals[keep] - self._gram_evecs[group_id] = evecs[:, keep] - - if self._verbose: - before, after = len(evals), len(keep) - print(f"Group {id(group)}: Filter directions ({before} → {after})") - - def _group_hook_gammas(self, accumulation, group): - """Evaluate and store first-order directional derivatives ``γ[n, d]``. - - Sets the following entries under the id of ``group``: - - - In ``self._gammas``: First-order directional derivatives. - - Args: - accumulation (dict): Dictionary with accumulated scalar products. - group (dict): Parameter group of a ``torch.optim.Optimizer``. - """ - group_id = id(group) - - # L = ¹/ₙ ∑ᵢ ℓᵢ, BackPACK's BatchGrad computes ¹/ₙ ∇ℓᵢ, we have to rescale - N = self._batch_size[group_id] - - V_t_g_n = N * accumulation["V_t_g_n"] - - # compensate subsampling scale - if self._subsampling_directions is not None: - N_dir = len(self._subsampling_directions) - N = self._batch_size[group_id] - V_t_g_n *= math.sqrt(N / N_dir) - - # NOTE Flipping the order (g_n_t_V) may be more efficient - V_t_g_n = V_t_g_n.flatten( - start_dim=0, end_dim=1 - ) # only applies to GGN and GGN-MC - - gammas = ( - torch.einsum("in,id->nd", V_t_g_n, self._gram_evecs[group_id]) - / self._gram_evals[group_id].sqrt() - ) - - self._gammas[group_id] = gammas - - if self._verbose: - print(f"Group {id(group)}: Store 'gammas'") - - def _group_hook_lambdas(self, accumulation, group): - """Evaluate and store second-order directional derivatives ``λ[n, d]``. - - Sets the following entries under the id of ``group``: - - - In ``self._lambdas``: Second-order directional derivatives. - - Args: - accumulation (dict): Dictionary with accumulated scalar products. - group (dict): Parameter group of a ``torch.optim.Optimizer``. - """ - # assume same extensions for directions and derivatives - self._different_curvatures_not_supported() - - group_id = id(group) - - gram_evals = self._gram_evals[group_id] - gram_evecs = self._gram_evecs[group_id] - gram_mat = self._gram_mat[group_id] - - C_dir, N_dir = gram_mat.shape[:2] - batch_size = self._batch_size[group_id] - - # all info in Gram matrix, just slice the relevant info - if is_subset(self._subsampling_second, self._subsampling_directions): - V_n_T_V = gram_mat.reshape(C_dir, N_dir, C_dir * N_dir) - - idx = sample_output_mapping( - self._subsampling_second, self._subsampling_directions - ) - if idx is not None: - V_n_T_V = V_n_T_V[:, idx, :] - - # compensate scale of V_n - V_n_T_V *= math.sqrt(N_dir) - - else: - # TODO Recycle scalar products that are available from the Gram matrix - # and only compute the missing ones - self._warn_inefficient_subsamplings() - - # re-compute everything, easier but less efficient - V_n_T_V = accumulation["V_n_t_V"] - - C_second, N_second = V_n_T_V.shape[:2] - V_n_T_V = V_n_T_V.reshape(C_second, N_second, C_dir * N_dir) - - # compensate scale of V_n - V_n_T_V *= batch_size / math.sqrt(N_dir) - - V_n_T_V_e_d = torch.einsum("cni,id->cnd", V_n_T_V, gram_evecs) - - lambdas = (V_n_T_V_e_d**2).sum(0) / gram_evals - - self._lambdas[group_id] = lambdas - - if self._verbose: - print(f"Group {id(group)}: Store 'lambdas'") - - def _group_hook_memory_cleanup( - self, - accumulation, - group, - keep_gram_mat, - keep_gram_evals, - keep_gram_evecs, - keep_gammas, - keep_lambdas, - keep_batch_size, - ): - """Free up buffers which are not required anymore for a group. - - Modifies temporary buffers. - - Args: - accumulation (dict): Dictionary with accumulated scalar products. - group (dict): Parameter group of a ``torch.optim.Optimizer``. - keep_gram_mat (bool): Keep buffers for Gram matrix under group id - in ``self._gram_mat``. - keep_gram_evals (bool): Keep buffers for filtered Gram matrix - eigenvalues under group id in ``self._gram_evals``. - keep_gram_evecs (bool): Keep buffers for filtered Gram matrix - eigenvectors under group id in ``self._gram_evecs``. - keep_gammas (bool): Keep buffers for first-order directional - derivatives under group id in ``self._gammas``. - keep_lambdas (bool): Keep buffers for second-order directional - derivatives under group id in ``self._lambdas``. - keep_batch_size (bool): Keep batch size for under group id - in ``self._lambdas``. - """ - buffers = [] - - if not keep_gram_mat: - buffers.append("_gram_mat") - if not keep_gram_evals: - buffers.append("_gram_evals") - if not keep_gram_evecs: - buffers.append("_gram_evecs") - if not keep_gammas and self._compute_gammas: - buffers.append("_gammas") - if not keep_lambdas and self._compute_lambdas: - buffers.append("_lambdas") - if not keep_batch_size: - buffers.append("_batch_size") - - group_id = id(group) - for b in buffers: - - if self._verbose: - print(f"Group {group_id}: Delete '{b}'") - - getattr(self, b).pop(group_id) - - def _get_hook_store_batch_size(self, param_groups): - """Create extension hook that stores the batch size during backpropagation. - - Args: - param_groups (list): Parameter group list from a ``torch.optim.Optimizer``. - - Returns: - callable: Hook function to hand into a ``with backpack(...)`` context. - Stores the batch size under the ``self._batch_size`` dictionary for each - group. - """ - - def hook_store_batch_size(module): - """Store batch size internally. - - Modifies ``self._batch_size``. - - Args: - module (torch.nn.Module): The module on which the hook is executed. - """ - if self._batch_size == {}: - batch_axis = 0 - batch_size = module.input0.shape[batch_axis] - - for group in param_groups: - group_id = id(group) - - if self._verbose: - print(f"Group {group_id}: Store 'batch_size'") - - self._batch_size[group_id] = batch_size - - return hook_store_batch_size - - def _different_curvatures_not_supported(self): - """Raise exception if curvatures for directions and derivatives deviate. - - Raises: - NotImplementedError: If different extensions/curvature matrices are used - for directions and second-order directional derivatives, respectively. - """ - if self._extension_cls_directions != self._extension_cls_second: - raise NotImplementedError( - "Different extensions for (directions, second) not supported." - ) - - def _warn_inefficient_subsamplings(self): - """Issue a warning if samples for ``λ[n,k]`` are not used in the Gram matrix. - - This requires more pairwise scalar products be evaluated and makes the - computation less efficient. - """ - warnings.warn( - "If subsampling_second is not a subset of subsampling_directions," - + " all required dot products will be re-evaluated. This is not" - + " the most efficient, but less complex implementation." - ) diff --git a/vivit/optim/utils.py b/vivit/optim/utils.py new file mode 100644 index 0000000..1b9c286 --- /dev/null +++ b/vivit/optim/utils.py @@ -0,0 +1,25 @@ +"""Utility functions for ``vivit.optim``.""" + +from typing import List, Union + +from backpack.extensions import SqrtGGNExact, SqrtGGNMC + + +def get_sqrt_ggn_extension( + subsampling: Union[None, List[int]], mc_samples: int +) -> Union[SqrtGGNExact, SqrtGGNMC]: + """Instantiate ``SqrtGGN{Exact, MC} extension. + + Args: + subsampling: Indices of active samples. + mc_samples: Number of MC-samples to approximate the loss Hessian. ``0`` + uses the exact loss Hessian. + + Returns: + Instantiated SqrtGGN extension. + """ + return ( + SqrtGGNExact(subsampling=subsampling) + if mc_samples == 0 + else SqrtGGNMC(subsampling=subsampling, mc_samples=mc_samples) + ) diff --git a/vivit/utils/subsampling.py b/vivit/utils/subsampling.py deleted file mode 100644 index 675b1ac..0000000 --- a/vivit/utils/subsampling.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Utility functions for subsampling.""" - - -def is_subset(subsampling, reference): - """Return whether indices specified by ``subsampling`` are subset of the reference. - - Args: - subsampling ([int] or None): Sample indices - reference ([int] or None): Reference set. - - Returns: - bool: Whether all indices are contained in the reference set. - """ - if reference is None: - return True - elif subsampling is None and reference is not None: - return False - else: - return set(subsampling).issubset(set(reference)) - - -def sample_output_mapping(idx_samples, idx_all): - """Return access indices for sub-sampled BackPACK quantities. - - Args: - idx_samples ([int]): Mini-batch sample indices of samples to be accessed. - idx_all ([int] or None): Sub-sampling indices used in the BackPACK extension - whose savefield is being accessed. ``None`` signifies the entire batch - was used. - - Example: - Let's say we want to compute individual gradients for samples 0, 2, 3 from a - mini-batch with ``N = 5`` samples. Those samples are described by the indices - - ``samples = [0, 1, 2, 3, 4]`` - - Calling ``BatchGrad`` with ``subsampling = [0, 2, 3]``, will result in - - ``grad_batch = [∇f₀, ∇f₂, ∇f₃]`` - - To access the gradient for sample 3, we need a mapping: - - ``mapping = [2]`` - - Then, ``[∇f₃] = grad_batch[mapping]``. - - Returns: - [int] or None: Index mapping for samples to output index. ``None`` if the - mapping is the identity. - - Raises: - ValueError: If one of the requested samples is not contained in all samples. - """ - if not is_subset(idx_samples, idx_all): - raise ValueError(f"Requested samples {idx_samples} must be subset of {idx_all}") - - if idx_all is None: - mapping = idx_samples - else: - mapping = [idx_all.index(sample) for sample in idx_samples] - - return mapping - - -def merge_subsamplings(subsampling, other): - """Merge indices of sub-samplings, removing duplicates and sorting indices. - - Args: - subsampling ([int] or None): Sub-sampling indices for use in a BackPACK - extension as ``subsampling`` argument. - other ([int] or None): Sub-sampling indices for use in a BackPACK - extension as ``subsampling`` argument. - - Returns: - [int]: Indices corresponding to the merged sub-samplings. - """ - if subsampling is None or other is None: - merged = None - else: - merged = sorted(set(subsampling).union(set(other))) - - return merged - - -def merge_multiple_subsamplings(*subsamplings): - """Merge a sequence of sub-samplings, removing duplicates and sorting indices. - - Args: - subsamplings ([[int] or None]): Sub-sampling sequence. - - Returns: - [int]: Indices corresponding to the merged sub-samplings. - - Raises: - ValueError: If no arguments are handed in - """ - if len(subsamplings) == 0: - raise ValueError("Expecting one or more inputs. Got {subsamplings}.") - - subsampling = [] - - for other in subsamplings: - subsampling = merge_subsamplings(subsampling, other) - - return subsampling - - -def merge_extensions(extension_subsampling_list): - """Combine subsamplings of same extensions. - - Args: - extension_subsampling_list ([tuple]): List of extension-subsampling - pairs to be merged. - - Returns: - dict: Keys are extension classes, values are subsamplings. - """ - unique = {extension for (extension, _) in extension_subsampling_list} - - merged_subsamplings = {} - - for extension in unique: - subsamplings = [ - sub for (ext, sub) in extension_subsampling_list if ext == extension - ] - merged_subsamplings[extension] = merge_multiple_subsamplings(*subsamplings) - - return merged_subsamplings