diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py index 524a72fd..35ca27d3 100644 --- a/backpack/core/derivatives/basederivatives.py +++ b/backpack/core/derivatives/basederivatives.py @@ -450,6 +450,7 @@ def bias_ih_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. bias_ih_l0 to a matrix. @@ -458,16 +459,21 @@ def bias_ih_l0_jac_t_mat_prod( g_inp: input gradients g_out: output gradients mat: Matrix the transposed Jacobian will be applied to. - Must have shape [V, T, N, H]. + Must have shape [V, T, N, H]; but if used with sub-sampling, the batch + dimension is replaced by ``len(subsampling)``. sum_batch: Whether to sum over the batch dimension on the fly. + subsampling: Indices of samples along the output's batch dimension that + should be considered. Defaults to ``None`` (use all samples). Returns: Jacobian-matrix product. - Has shape [V, T, N, H] if `sum_batch == False`. - Has shape [V, T, H] if `sum_batch == True`. + Has shape [V, N, *module.bias_ih_l0.shape] if ``sum_batch == False``; but if + used with sub-sampling, the batch dimension is replaced by + ``len(subsampling)``. Has shape [V, *module.bias_ih_l0.shape] if + ``sum_batch == True``. """ return self._bias_ih_l0_jac_t_mat_prod( - module, g_inp, g_out, mat, sum_batch=sum_batch + module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling ) def _bias_ih_l0_jac_t_mat_prod( @@ -477,6 +483,7 @@ def _bias_ih_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: raise NotImplementedError @@ -489,6 +496,7 @@ def bias_hh_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. bias_hh_l0 to a matrix. @@ -497,16 +505,21 @@ def bias_hh_l0_jac_t_mat_prod( g_inp: input gradients g_out: output gradients mat: Matrix the transposed Jacobian will be applied to. - Must have shape [V, T, N, H]. + Must have shape [V, T, N, H]; but if used with sub-sampling, the batch + dimension is replaced by ``len(subsampling)``. sum_batch: Whether to sum over the batch dimension on the fly. + subsampling: Indices of samples along the output's batch dimension that + should be considered. Defaults to ``None`` (use all samples). Returns: Jacobian-matrix product. - Has shape [V, T, N, H] if `sum_batch == False`. - Has shape [V, T, H] if `sum_batch == True`. + Has shape [V, N, *module.bias_hh_l0.shape] if ``sum_batch == False``; but if + used with sub-sampling, the batch dimension is replaced by + ``len(subsampling)``. Has shape [V, *module.bias_hh_l0.shape] if + ``sum_batch == True``. """ return self._bias_hh_l0_jac_t_mat_prod( - module, g_inp, g_out, mat, sum_batch=sum_batch + module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling ) def _bias_hh_l0_jac_t_mat_prod( @@ -516,6 +529,7 @@ def _bias_hh_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: raise NotImplementedError @@ -528,6 +542,7 @@ def weight_ih_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. weight_ih_l0 to a matrix. @@ -536,16 +551,21 @@ def weight_ih_l0_jac_t_mat_prod( g_inp: input gradients g_out: output gradients mat: Matrix the transposed Jacobian will be applied to. - Must have shape [V, T, N, H]. + Must have shape [V, T, N, H]; but if used with sub-sampling, the batch + dimension is replaced by ``len(subsampling)``. sum_batch: Whether to sum over the batch dimension on the fly. + subsampling: Indices of samples along the output's batch dimension that + should be considered. Defaults to ``None`` (use all samples). Returns: Jacobian-matrix product. - Has shape [V, T, N, H, I] if `sum_batch == False`. - Has shape [V, T, H, I] if `sum_batch == True`. + Has shape [V, N, *module.weight_ih_l0.shape] if ``sum_batch == False``; but + if used with sub-sampling, the batch dimension is replaced by + ``len(subsampling)``. Has shape [V, *module.weight_ih_l0.shape] if + ``sum_batch == True``. """ return self._weight_ih_l0_jac_t_mat_prod( - module, g_inp, g_out, mat, sum_batch=sum_batch + module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling ) def _weight_ih_l0_jac_t_mat_prod( @@ -555,6 +575,7 @@ def _weight_ih_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: raise NotImplementedError @@ -567,6 +588,7 @@ def weight_hh_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. weight_hh_l0 to a matrix. @@ -575,16 +597,21 @@ def weight_hh_l0_jac_t_mat_prod( g_inp: input gradients g_out: output gradients mat: Matrix the transposed Jacobian will be applied to. - Must have shape [V, T, N, H]. + Must have shape [V, T, N, H]; but if used with sub-sampling, the batch + dimension is replaced by ``len(subsampling)``. sum_batch: Whether to sum over the batch dimension on the fly. + subsampling: Indices of samples along the output's batch dimension that + should be considered. Defaults to ``None`` (use all samples). Returns: Jacobian-matrix product. - Has shape [V, T, N, H, I] if `sum_batch == False`. - Has shape [V, T, H, I] if `sum_batch == True`. + Has shape [V, N, *module.weight_hh_l0.shape] if ``sum_batch == False``; but + if used with sub-sampling, the batch dimension is replaced by + ``len(subsampling)``. Has shape [V, *module.weight_hh_l0.shape] if + ``sum_batch == True``. """ return self._weight_hh_l0_jac_t_mat_prod( - module, g_inp, g_out, mat, sum_batch=sum_batch + module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling ) def _weight_hh_l0_jac_t_mat_prod( @@ -594,6 +621,7 @@ def _weight_hh_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: raise NotImplementedError diff --git a/backpack/core/derivatives/lstm.py b/backpack/core/derivatives/lstm.py index 1360103e..8eb88386 100644 --- a/backpack/core/derivatives/lstm.py +++ b/backpack/core/derivatives/lstm.py @@ -1,11 +1,12 @@ """Partial derivatives for nn.LSTM.""" -from typing import Tuple +from typing import List, Tuple from torch import Tensor, cat, einsum, sigmoid, tanh, zeros from torch.nn import LSTM from backpack.core.derivatives.basederivatives import BaseParameterDerivatives from backpack.utils import TORCH_VERSION, VERSION_1_8_0 +from backpack.utils.subsampling import get_batch_axis, subsample class LSTMDerivatives(BaseParameterDerivatives): @@ -54,7 +55,7 @@ def _check_parameters(module: LSTM) -> None: @staticmethod def _forward_pass( - module: LSTM, mat: Tensor + module: LSTM, mat: Tensor, subsampling: List[int] = None ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """This performs an additional forward pass and returns the hidden variables. @@ -65,7 +66,8 @@ def _forward_pass( Args: module: module - mat: matrix, used to extract device and shapes + mat: matrix, used to extract device and shapes. + subsampling: Indices of active samples. Defaults to ``None`` (all samples). Returns: ifgo, c, c_tanh, h @@ -83,16 +85,19 @@ def _forward_pass( c: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype) c_tanh: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype) h: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype) + + N_axis = get_batch_axis(module) + input0 = subsample(module.input0, dim=N_axis, subsampling=subsampling) + output = subsample(module.output, dim=N_axis, subsampling=subsampling) + for t in range(T): ifgo[t] = ( - einsum("hi,ni->nh", module.weight_ih_l0, module.input0[t]) + einsum("hi,ni->nh", module.weight_ih_l0, input0[t]) + module.bias_ih_l0 + module.bias_hh_l0 ) if t != 0: - ifgo[t] += einsum( - "hg,ng->nh", module.weight_hh_l0, module.output[t - 1] - ) + ifgo[t] += einsum("hg,ng->nh", module.weight_hh_l0, output[t - 1]) ifgo[t, :, H0:H1] = sigmoid(ifgo[t, :, H0:H1]) ifgo[t, :, H1:H2] = sigmoid(ifgo[t, :, H1:H2]) ifgo[t, :, H2:H3] = tanh(ifgo[t, :, H2:H3]) @@ -106,7 +111,9 @@ def _forward_pass( return ifgo, c, c_tanh, h @classmethod - def _ifgo_jac_t_mat_prod(cls, module: LSTM, mat: Tensor) -> Tensor: + def _ifgo_jac_t_mat_prod( + cls, module: LSTM, mat: Tensor, subsampling: List[int] = None + ) -> Tensor: V: int = mat.shape[0] T: int = mat.shape[1] N: int = mat.shape[2] @@ -117,7 +124,7 @@ def _ifgo_jac_t_mat_prod(cls, module: LSTM, mat: Tensor) -> Tensor: H3: int = 3 * H H4: int = 4 * H - ifgo, c, c_tanh, _ = cls._forward_pass(module, mat) + ifgo, c, c_tanh, _ = cls._forward_pass(module, mat, subsampling=subsampling) # backward pass H_prod_t: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype) @@ -288,10 +295,13 @@ def _bias_ih_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: self._check_parameters(module) - IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat) + IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod( + module, mat, subsampling=subsampling + ) return einsum(f"vtnh->v{'' if sum_batch else 'n'}h", IFGO_prod) @@ -302,9 +312,10 @@ def _bias_hh_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: return self._bias_ih_l0_jac_t_mat_prod( - module, g_inp, g_out, mat, sum_batch=sum_batch + module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling ) def _weight_ih_l0_jac_t_mat_prod( @@ -314,13 +325,20 @@ def _weight_ih_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: self._check_parameters(module) - IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat) + IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod( + module, mat, subsampling=subsampling + ) return einsum( - f"vtnh,tni->v{'' if sum_batch else 'n'}hi", IFGO_prod, module.input0 + f"vtnh,tni->v{'' if sum_batch else 'n'}hi", + IFGO_prod, + subsample( + module.input0, dim=get_batch_axis(module), subsampling=subsampling + ), ) def _weight_hh_l0_jac_t_mat_prod( @@ -330,13 +348,16 @@ def _weight_hh_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: self._check_parameters(module) N: int = mat.shape[2] H: int = module.hidden_size - IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat) + IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod( + module, mat, subsampling=subsampling + ) return einsum( f"vtnh,tng->v{'' if sum_batch else 'n'}hg", @@ -344,7 +365,11 @@ def _weight_hh_l0_jac_t_mat_prod( cat( [ zeros(1, N, H, device=mat.device, dtype=mat.dtype), - module.output[0:-1], + subsample( + module.output, + dim=get_batch_axis(module), + subsampling=subsampling, + )[0:-1], ], dim=0, ), diff --git a/backpack/core/derivatives/rnn.py b/backpack/core/derivatives/rnn.py index 434b987b..2277c038 100644 --- a/backpack/core/derivatives/rnn.py +++ b/backpack/core/derivatives/rnn.py @@ -6,6 +6,7 @@ from torch.nn import RNN from backpack.core.derivatives.basederivatives import BaseParameterDerivatives +from backpack.utils.subsampling import get_batch_axis, subsample class RNNDerivatives(BaseParameterDerivatives): @@ -134,6 +135,7 @@ def _bias_ih_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. bias_ih_l0. @@ -143,6 +145,7 @@ def _bias_ih_l0_jac_t_mat_prod( g_out: output gradient mat: matrix to multiply sum_batch: Whether to sum along batch axis. Defaults to True. + subsampling: Indices of active samples. Defaults to ``None`` (all samples). Returns: product @@ -152,9 +155,13 @@ def _bias_ih_l0_jac_t_mat_prod( dim: List[int] = [1, 2] else: dim: int = 1 - return self._a_jac_t_mat_prod(module.output, module.weight_hh_l0, mat).sum( - dim=dim - ) + return self._a_jac_t_mat_prod( + subsample( + module.output, dim=get_batch_axis(module), subsampling=subsampling + ), + module.weight_hh_l0, + mat, + ).sum(dim=dim) def _bias_hh_l0_jac_t_mat_prod( self, @@ -163,6 +170,7 @@ def _bias_hh_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. bias_hh_l0. @@ -172,12 +180,13 @@ def _bias_hh_l0_jac_t_mat_prod( g_out: output gradient mat: matrix to multiply sum_batch: Whether to sum along batch axis. Defaults to True. + subsampling: Indices of active samples. Defaults to ``None`` (all samples). Returns: product """ return self._bias_ih_l0_jac_t_mat_prod( - module, g_inp, g_out, mat, sum_batch=sum_batch + module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling ) def _weight_ih_l0_jac_t_mat_prod( @@ -187,6 +196,7 @@ def _weight_ih_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. weight_ih_l0. @@ -196,15 +206,21 @@ def _weight_ih_l0_jac_t_mat_prod( g_out: output gradient mat: matrix to multiply sum_batch: Whether to sum along batch axis. Defaults to True. + subsampling: Indices of active samples. Defaults to ``None`` (all samples). Returns: product """ self._check_parameters(module) + N_axis = get_batch_axis(module) return einsum( "vtnh,tnj->" + ("vhj" if sum_batch else "vnhj"), - self._a_jac_t_mat_prod(module.output, module.weight_hh_l0, mat), - module.input0, + self._a_jac_t_mat_prod( + subsample(module.output, dim=N_axis, subsampling=subsampling), + module.weight_hh_l0, + mat, + ), + subsample(module.input0, dim=N_axis, subsampling=subsampling), ) def _weight_hh_l0_jac_t_mat_prod( @@ -214,6 +230,7 @@ def _weight_hh_l0_jac_t_mat_prod( g_out: Tuple[Tensor], mat: Tensor, sum_batch: bool = True, + subsampling: List[int] = None, ) -> Tensor: """Apply transposed Jacobian of the output w.r.t. weight_hh_l0. @@ -223,21 +240,21 @@ def _weight_hh_l0_jac_t_mat_prod( g_out: output gradient mat: matrix to multiply sum_batch: Whether to sum along batch axis. Defaults to True. + subsampling: Indices of active samples. Defaults to ``None`` (all samples). Returns: product """ self._check_parameters(module) - N: int = mat.shape[2] + N_axis = get_batch_axis(module) + N: int = mat.shape[N_axis + 1] H: int = mat.shape[3] + output = subsample(module.output, dim=N_axis, subsampling=subsampling) return einsum( "vtnh,tnk->" + ("vhk" if sum_batch else "vnhk"), - self._a_jac_t_mat_prod(module.output, module.weight_hh_l0, mat), + self._a_jac_t_mat_prod(output, module.weight_hh_l0, mat), cat( - [ - zeros(1, N, H, device=mat.device, dtype=mat.dtype), - module.output[0:-1], - ], + [zeros(1, N, H, device=mat.device, dtype=mat.dtype), output[0:-1]], dim=0, ), ) diff --git a/backpack/core/derivatives/shape_check.py b/backpack/core/derivatives/shape_check.py index a9ba070c..c5bdc6e2 100644 --- a/backpack/core/derivatives/shape_check.py +++ b/backpack/core/derivatives/shape_check.py @@ -8,7 +8,7 @@ from torch import Tensor from torch.nn import Module -from backpack.utils.subsampling import subsample +from backpack.utils.subsampling import get_batch_axis, subsample ############################################################################### @@ -62,7 +62,9 @@ def _check_same_V_dim(mat1, mat2): def _check_like(mat, module, name, diff=1, *args, **kwargs): if name == "output" and "subsampling" in kwargs.keys(): - compare = subsample(module.output, subsampling=kwargs["subsampling"]) + compare = subsample( + module.output, dim=get_batch_axis(module), subsampling=kwargs["subsampling"] + ) else: compare = getattr(module, name) diff --git a/backpack/utils/subsampling.py b/backpack/utils/subsampling.py index f429d580..68a02c34 100644 --- a/backpack/utils/subsampling.py +++ b/backpack/utils/subsampling.py @@ -2,6 +2,7 @@ from typing import List from torch import Tensor +from torch.nn import LSTM, RNN, Module def subsample(tensor: Tensor, dim: int = 0, subsampling: List[int] = None) -> Tensor: @@ -17,12 +18,32 @@ def subsample(tensor: Tensor, dim: int = 0, subsampling: List[int] = None) -> Te Tensor of same rank that is sub-sampled along the dimension. Raises: - NotImplementedError: If dimension differs from ``0``. + NotImplementedError: If dimension differs from ``0`` or ``1``. """ if subsampling is None: return tensor else: if dim == 0: return tensor[subsampling] + elif dim == 1: + return tensor[:, subsampling] else: - raise NotImplementedError(f"Only supports dim = 0. Got {dim}.") + raise NotImplementedError(f"Only supports dim = 0,1. Got {dim}.") + + +def get_batch_axis(module: Module) -> int: + """Return the batch axis assumed by the module. + + Args: + module: A module. + + Returns: + Batch axis + """ + if isinstance(module, (RNN, LSTM)): + if module.batch_first: + return 0 + else: + return 1 + else: + return 0 diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index d4475f2a..3478c6e6 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -23,9 +23,10 @@ import pytest import torch from pytest import fixture, skip -from torch import Size, Tensor +from torch import Tensor from backpack.core.derivatives.convnd import weight_jac_t_save_memory +from backpack.utils.subsampling import get_batch_axis PROBLEMS = make_test_problems(SETTINGS) IDS = [problem.make_id() for problem in PROBLEMS] @@ -124,112 +125,144 @@ def test_jac_t_mat_prod(problem: DerivativesTestProblem, request, V: int = 3) -> IDS_WITH_WEIGHTS.append(problem_id) +@pytest.mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @pytest.mark.parametrize( "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] ) @pytest.mark.parametrize( "problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS ) -def test_bias_ih_l0_jac_t_mat_prod(problem, sum_batch, V=3): +def test_bias_ih_l0_jac_t_mat_prod( + problem: DerivativesTestProblem, + sum_batch: bool, + subsampling: Union[List[int], None], + V: int = 3, +) -> None: """Test the transposed Jacobian-matrix product w.r.t. to bias_ih_l0. Args: - problem (DerivativesProblem): Problem for derivative test. - sum_batch (bool): Sum results over the batch dimension. - V (int): Number of vectorized transposed Jacobian-vector products. + problem: Problem for derivative test. + sum_batch: Sum results over the batch dimension. + subsampling: Indices of active samples. + V: Number of vectorized transposed Jacobian-vector products. """ problem.set_up() - mat = torch.rand(V, *problem.output_shape).to(problem.device) + _skip_if_subsampling_conflict(problem, subsampling) + mat = rand_mat_like_output(V, problem, subsampling=subsampling).to(problem.device) autograd_res = AutogradDerivatives(problem).bias_ih_l0_jac_t_mat_prod( - mat, sum_batch + mat, sum_batch, subsampling=subsampling ) backpack_res = BackpackDerivatives(problem).bias_ih_l0_jac_t_mat_prod( - mat, sum_batch + mat, sum_batch, subsampling=subsampling ) check_sizes_and_values(autograd_res, backpack_res) problem.tear_down() +@pytest.mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @pytest.mark.parametrize( "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] ) @pytest.mark.parametrize( "problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS ) -def test_bias_hh_l0_jac_t_mat_prod(problem, sum_batch, V=3): +def test_bias_hh_l0_jac_t_mat_prod( + problem: DerivativesTestProblem, + sum_batch: bool, + subsampling: Union[List[int], None], + V: int = 3, +) -> None: """Test the transposed Jacobian-matrix product w.r.t. to bias_hh_l0. Args: - problem (DerivativesProblem): Problem for derivative test. - sum_batch (bool): Sum results over the batch dimension. - V (int): Number of vectorized transposed Jacobian-vector products. + problem: Problem for derivative test. + sum_batch: Sum results over the batch dimension. + subsampling: Indices of active samples. + V: Number of vectorized transposed Jacobian-vector products. """ problem.set_up() - mat = torch.rand(V, *problem.output_shape).to(problem.device) + _skip_if_subsampling_conflict(problem, subsampling) + mat = rand_mat_like_output(V, problem, subsampling=subsampling).to(problem.device) autograd_res = AutogradDerivatives(problem).bias_hh_l0_jac_t_mat_prod( - mat, sum_batch + mat, sum_batch, subsampling=subsampling ) backpack_res = BackpackDerivatives(problem).bias_hh_l0_jac_t_mat_prod( - mat, sum_batch + mat, sum_batch, subsampling=subsampling ) check_sizes_and_values(autograd_res, backpack_res) problem.tear_down() +@pytest.mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @pytest.mark.parametrize( "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] ) @pytest.mark.parametrize( "problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS ) -def test_weight_ih_l0_jac_t_mat_prod(problem, sum_batch, V=3): +def test_weight_ih_l0_jac_t_mat_prod( + problem: DerivativesTestProblem, + sum_batch: bool, + subsampling: Union[List[int], None], + V: int = 3, +) -> None: """Test the transposed Jacobian-matrix product w.r.t. to weight_ih_l0. Args: - problem (DerivativesProblem): Problem for derivative test. - sum_batch (bool): Sum results over the batch dimension. - V (int): Number of vectorized transposed Jacobian-vector products. + problem: Problem for derivative test. + sum_batch: Sum results over the batch dimension. + subsampling: Indices of active samples. + V: Number of vectorized transposed Jacobian-vector products. """ problem.set_up() - mat = torch.rand(V, *problem.output_shape).to(problem.device) + _skip_if_subsampling_conflict(problem, subsampling) + mat = rand_mat_like_output(V, problem, subsampling=subsampling).to(problem.device) autograd_res = AutogradDerivatives(problem).weight_ih_l0_jac_t_mat_prod( - mat, sum_batch + mat, sum_batch, subsampling=subsampling ) backpack_res = BackpackDerivatives(problem).weight_ih_l0_jac_t_mat_prod( - mat, sum_batch + mat, sum_batch, subsampling=subsampling ) check_sizes_and_values(autograd_res, backpack_res) problem.tear_down() +@pytest.mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @pytest.mark.parametrize( "sum_batch", [True, False], ids=["sum_batch=True", "sum_batch=False"] ) @pytest.mark.parametrize( "problem", RNN_PROBLEMS + LSTM_PROBLEMS, ids=RNN_IDS + LSTM_IDS ) -def test_weight_hh_l0_jac_t_mat_prod(problem, sum_batch, V=4): +def test_weight_hh_l0_jac_t_mat_prod( + problem: DerivativesTestProblem, + sum_batch: bool, + subsampling: Union[List[int], None], + V: int = 3, +) -> None: """Test the transposed Jacobian-matrix product w.r.t. to weight_hh_l0. Args: - problem (DerivativesProblem): Problem for derivative test. - sum_batch (bool): Sum results over the batch dimension. - V (int): Number of vectorized transposed Jacobian-vector products. + problem: Problem for derivative test. + sum_batch: Sum results over the batch dimension. + subsampling: Indices of active samples. + V: Number of vectorized transposed Jacobian-vector products. """ problem.set_up() - mat = torch.rand(V, *problem.output_shape).to(problem.device) + _skip_if_subsampling_conflict(problem, subsampling) + mat = rand_mat_like_output(V, problem, subsampling=subsampling).to(problem.device) autograd_res = AutogradDerivatives(problem).weight_hh_l0_jac_t_mat_prod( - mat, sum_batch + mat, sum_batch, subsampling=subsampling ) backpack_res = BackpackDerivatives(problem).weight_hh_l0_jac_t_mat_prod( - mat, sum_batch + mat, sum_batch, subsampling=subsampling ) check_sizes_and_values(autograd_res, backpack_res) @@ -271,7 +304,7 @@ def test_weight_jac_t_mat_prod( def rand_mat_like_output( - V: int, output_shape: Size, subsampling: List[int] = None + V: int, problem: DerivativesTestProblem, subsampling: List[int] = None ) -> Tensor: """Generate random matrix whose columns are shaped like the layer output. @@ -280,16 +313,16 @@ def rand_mat_like_output( Args: V: Number of rows. - output_shape: Shape of the module output. + problem: Test case. subsampling: Indices of samples used by sub-sampling. Returns: Random matrix with (subsampled) output shape. """ - subsample_shape = list(output_shape) + subsample_shape = list(problem.output_shape) if subsampling is not None: - N_axis = 0 + N_axis = get_batch_axis(problem.module) subsample_shape[N_axis] = len(subsampling) return torch.rand(V, *subsample_shape) @@ -523,11 +556,8 @@ def problem_weight(problem: DerivativesTestProblem) -> DerivativesTestProblem: Yields: Instantiated cases that have a weight parameter. """ - has_weight = getattr(problem.module, "weight", None) is not None - if has_weight: - yield problem - else: - skip("Test case has no weight parameter.") + _skip_if_no_param(problem, "weight") + yield problem @fixture(params=SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @@ -547,21 +577,44 @@ def problem_weight_jac_t_mat( problem with weight, subsampling, matrix for weight_jac_t """ subsampling: Union[None, List[int]] = request.param - N = problem_weight.input_shape[0] - enough_samples = subsampling is None or N >= max(subsampling) - - if not enough_samples: - skip(f"Not enough samples: sub-sampling {subsampling}, batch_size {N}") + _skip_if_subsampling_conflict(problem_weight, subsampling) V = 3 - mat = rand_mat_like_output( - V, problem_weight.output_shape, subsampling=subsampling - ).to(problem_weight.device) + mat = rand_mat_like_output(V, problem_weight, subsampling=subsampling).to( + problem_weight.device + ) yield (problem_weight, subsampling, mat) del mat +def _skip_if_subsampling_conflict( + problem: DerivativesTestProblem, subsampling: Union[List[int], None] +) -> None: + """Skip if some samples in subsampling are not contained in input. + + Args: + problem: Test case. + subsampling: Indices of active samples. + """ + N = problem.input_shape[get_batch_axis(problem.module)] + enough_samples = subsampling is None or N >= max(subsampling) + if not enough_samples: + skip("Not enough samples.") + + +def _skip_if_no_param(problem: DerivativesTestProblem, param_str: str) -> None: + """Skip if test case does not contain the parameter. + + Args: + problem: Test case. + param_str: Parameter name. + """ + has_param = getattr(problem.module, param_str, None) is not None + if not has_param: + skip(f"Test case has no {param_str} parameter.") + + @fixture def problem_bias(problem: DerivativesTestProblem) -> DerivativesTestProblem: """Filter out cases that don't have a bias parameter. @@ -572,11 +625,8 @@ def problem_bias(problem: DerivativesTestProblem) -> DerivativesTestProblem: Yields: Instantiated cases that have a bias parameter. """ - has_bias = getattr(problem.module, "bias", None) is not None - if has_bias: - yield problem - else: - skip("Test case has no bias parameter.") + _skip_if_no_param(problem, "bias") + yield problem @fixture(params=SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @@ -596,16 +646,12 @@ def problem_bias_jac_t_mat( problem with bias, subsampling, matrix for bias_jac_t """ subsampling: Union[None, List[int]] = request.param - N = problem_bias.input_shape[0] - enough_samples = subsampling is None or N >= max(subsampling) - - if not enough_samples: - skip(f"Not enough samples: sub-sampling {subsampling}, batch_size {N}") + _skip_if_subsampling_conflict(problem_bias, subsampling) V = 3 - mat = rand_mat_like_output( - V, problem_bias.output_shape, subsampling=subsampling - ).to(problem_bias.device) + mat = rand_mat_like_output(V, problem_bias, subsampling=subsampling).to( + problem_bias.device + ) yield (problem_bias, subsampling, mat) del mat diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index d17f8a1e..68060e99 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -52,17 +52,29 @@ def bias_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 "bias", mat, sum_batch, subsampling=subsampling ) - def bias_ih_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 - return self.param_jac_t_mat_prod("bias_ih_l0", mat, sum_batch, axis_batch=1) + def bias_ih_l0_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 + return self.param_jac_t_mat_prod( + "bias_ih_l0", mat, sum_batch, axis_batch=1, subsampling=subsampling + ) - def bias_hh_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 - return self.param_jac_t_mat_prod("bias_ih_l0", mat, sum_batch, axis_batch=1) + def bias_hh_l0_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 + return self.param_jac_t_mat_prod( + "bias_ih_l0", mat, sum_batch, axis_batch=1, subsampling=subsampling + ) - def weight_ih_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 - return self.param_jac_t_mat_prod("weight_ih_l0", mat, sum_batch, axis_batch=1) + def weight_ih_l0_jac_t_mat_prod( + self, mat, sum_batch, subsampling=None + ): # noqa: D102 + return self.param_jac_t_mat_prod( + "weight_ih_l0", mat, sum_batch, axis_batch=1, subsampling=subsampling + ) - def weight_hh_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 - return self.param_jac_t_mat_prod("weight_hh_l0", mat, sum_batch, axis_batch=1) + def weight_hh_l0_jac_t_mat_prod( + self, mat, sum_batch, subsampling=None + ): # noqa: D102 + return self.param_jac_t_mat_prod( + "weight_hh_l0", mat, sum_batch, axis_batch=1, subsampling=subsampling + ) def param_jac_t_vec_prod( self, diff --git a/test/core/derivatives/implementation/backpack.py b/test/core/derivatives/implementation/backpack.py index 2662c684..cced2fc9 100644 --- a/test/core/derivatives/implementation/backpack.py +++ b/test/core/derivatives/implementation/backpack.py @@ -70,28 +70,52 @@ def bias_jac_mat_prod(self, mat): # noqa: D102 self.problem.module, None, None, mat ) - def bias_ih_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 + def bias_ih_l0_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 self.store_forward_io() return self.problem.derivative.bias_ih_l0_jac_t_mat_prod( - self.problem.module, None, None, mat, sum_batch=sum_batch + self.problem.module, + None, + None, + mat, + sum_batch=sum_batch, + subsampling=subsampling, ) - def bias_hh_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 + def bias_hh_l0_jac_t_mat_prod(self, mat, sum_batch, subsampling=None): # noqa: D102 self.store_forward_io() return self.problem.derivative.bias_hh_l0_jac_t_mat_prod( - self.problem.module, None, None, mat, sum_batch=sum_batch + self.problem.module, + None, + None, + mat, + sum_batch=sum_batch, + subsampling=subsampling, ) - def weight_ih_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 + def weight_ih_l0_jac_t_mat_prod( + self, mat, sum_batch, subsampling=None + ): # noqa: D102 self.store_forward_io() return self.problem.derivative.weight_ih_l0_jac_t_mat_prod( - self.problem.module, None, None, mat, sum_batch=sum_batch + self.problem.module, + None, + None, + mat, + sum_batch=sum_batch, + subsampling=subsampling, ) - def weight_hh_l0_jac_t_mat_prod(self, mat, sum_batch): # noqa: D102 + def weight_hh_l0_jac_t_mat_prod( + self, mat, sum_batch, subsampling=None + ): # noqa: D102 self.store_forward_io() return self.problem.derivative.weight_hh_l0_jac_t_mat_prod( - self.problem.module, None, None, mat, sum_batch=sum_batch + self.problem.module, + None, + None, + mat, + sum_batch=sum_batch, + subsampling=subsampling, ) def ea_jac_t_mat_jac_prod(self, mat): # noqa: D102 diff --git a/test/core/derivatives/implementation/base.py b/test/core/derivatives/implementation/base.py index b6eedb42..f6411535 100644 --- a/test/core/derivatives/implementation/base.py +++ b/test/core/derivatives/implementation/base.py @@ -99,12 +99,15 @@ def bias_jac_mat_prod(self, mat: Tensor) -> Tensor: raise NotImplementedError @abstractmethod - def bias_ih_l0_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: + def bias_ih_l0_jac_t_mat_prod( + self, mat: Tensor, sum_batch: bool, subsampling: List[int] = None + ) -> Tensor: """Product of jacobian and matrix. Args: mat: matrix sum_batch: whether to sum along batch axis + subsampling: Active samples in the output. Default: ``None`` (all). Returns: product @@ -112,12 +115,15 @@ def bias_ih_l0_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: raise NotImplementedError @abstractmethod - def bias_hh_l0_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: + def bias_hh_l0_jac_t_mat_prod( + self, mat: Tensor, sum_batch: bool, subsampling: List[int] = None + ) -> Tensor: """Product of jacobian and matrix. Args: mat: matrix sum_batch: whether to sum along batch axis + subsampling: Active samples in the output. Default: ``None`` (all). Returns: product @@ -125,12 +131,15 @@ def bias_hh_l0_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: raise NotImplementedError @abstractmethod - def weight_ih_l0_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: + def weight_ih_l0_jac_t_mat_prod( + self, mat: Tensor, sum_batch: bool, subsampling: List[int] = None + ) -> Tensor: """Product of jacobian and matrix. Args: mat: matrix sum_batch: whether to sum along batch axis + subsampling: Active samples in the output. Default: ``None`` (all). Returns: product @@ -138,12 +147,15 @@ def weight_ih_l0_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: raise NotImplementedError @abstractmethod - def weight_hh_l0_jac_t_mat_prod(self, mat: Tensor, sum_batch: bool) -> Tensor: + def weight_hh_l0_jac_t_mat_prod( + self, mat: Tensor, sum_batch: bool, subsampling: List[int] = None + ) -> Tensor: """Product of jacobian and matrix. Args: mat: matrix sum_batch: whether to sum along batch axis + subsampling: Active samples in the output. Default: ``None`` (all). Returns: product diff --git a/test/extensions/secondorder/diag_ggn/diagggn_settings.py b/test/extensions/secondorder/diag_ggn/diagggn_settings.py deleted file mode 100644 index d0d80667..00000000 --- a/test/extensions/secondorder/diag_ggn/diagggn_settings.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Test cases for BackPACK extensions for the GGN diagonal. - -Includes -- ``DiagGGNExact`` -- ``DiagGGNMC`` -- ``BatchDiagGGNExact`` -- ``BatchDiagGGNMC`` - -Shared settings are taken from `test.extensions.secondorder.secondorder_settings`. -Additional local cases can be defined here through ``LOCAL_SETTINGS``. -""" - -from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS - -SHARED_SETTINGS = SECONDORDER_SETTINGS -LOCAL_SETTINGS = [] - -DiagGGN_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS