Skip to content

Commit

Permalink
[ADD] Application of RNN/LSTM param_jac_t for a subset of samples (#…
Browse files Browse the repository at this point in the history
…197)

Prepares the `core` functionality to support #12.

* [ADD] Extend `weight_jac_t` by subsampling argument

* [ADD] Sub-sampling for batch normalization, adapt interface

* [REF] Separate fixtures for filtering weights and matrix generation

* [REF] Shorten instantiation fixture name

* [TEST] Add BN setting to `weight_jac_t` tests

* [DOC] Add batch norm test setting to fully-documented

* [ADD] Support subsampling in `bias_jac_t`

* [ADD] Support sub-sampling in RNN/LSTM `param_weight_jac_t`

* [FIX] flake8

* [FMT] Squeeze some lines

* [FIX] Typo in exception message

* [DOC] Correct shapes

Co-authored-by: Felix Dangel <fdangel@tuebingen.mpg.de>
  • Loading branch information
f-dangel and Felix Dangel committed Jul 6, 2021
1 parent 68870ed commit cbee344
Show file tree
Hide file tree
Showing 10 changed files with 313 additions and 144 deletions.
60 changes: 44 additions & 16 deletions backpack/core/derivatives/basederivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def bias_ih_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
"""Apply transposed Jacobian of the output w.r.t. bias_ih_l0 to a matrix.
Expand All @@ -458,16 +459,21 @@ def bias_ih_l0_jac_t_mat_prod(
g_inp: input gradients
g_out: output gradients
mat: Matrix the transposed Jacobian will be applied to.
Must have shape [V, T, N, H].
Must have shape [V, T, N, H]; but if used with sub-sampling, the batch
dimension is replaced by ``len(subsampling)``.
sum_batch: Whether to sum over the batch dimension on the fly.
subsampling: Indices of samples along the output's batch dimension that
should be considered. Defaults to ``None`` (use all samples).
Returns:
Jacobian-matrix product.
Has shape [V, T, N, H] if `sum_batch == False`.
Has shape [V, T, H] if `sum_batch == True`.
Has shape [V, N, *module.bias_ih_l0.shape] if ``sum_batch == False``; but if
used with sub-sampling, the batch dimension is replaced by
``len(subsampling)``. Has shape [V, *module.bias_ih_l0.shape] if
``sum_batch == True``.
"""
return self._bias_ih_l0_jac_t_mat_prod(
module, g_inp, g_out, mat, sum_batch=sum_batch
module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling
)

def _bias_ih_l0_jac_t_mat_prod(
Expand All @@ -477,6 +483,7 @@ def _bias_ih_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
raise NotImplementedError

Expand All @@ -489,6 +496,7 @@ def bias_hh_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
"""Apply transposed Jacobian of the output w.r.t. bias_hh_l0 to a matrix.
Expand All @@ -497,16 +505,21 @@ def bias_hh_l0_jac_t_mat_prod(
g_inp: input gradients
g_out: output gradients
mat: Matrix the transposed Jacobian will be applied to.
Must have shape [V, T, N, H].
Must have shape [V, T, N, H]; but if used with sub-sampling, the batch
dimension is replaced by ``len(subsampling)``.
sum_batch: Whether to sum over the batch dimension on the fly.
subsampling: Indices of samples along the output's batch dimension that
should be considered. Defaults to ``None`` (use all samples).
Returns:
Jacobian-matrix product.
Has shape [V, T, N, H] if `sum_batch == False`.
Has shape [V, T, H] if `sum_batch == True`.
Has shape [V, N, *module.bias_hh_l0.shape] if ``sum_batch == False``; but if
used with sub-sampling, the batch dimension is replaced by
``len(subsampling)``. Has shape [V, *module.bias_hh_l0.shape] if
``sum_batch == True``.
"""
return self._bias_hh_l0_jac_t_mat_prod(
module, g_inp, g_out, mat, sum_batch=sum_batch
module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling
)

def _bias_hh_l0_jac_t_mat_prod(
Expand All @@ -516,6 +529,7 @@ def _bias_hh_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
raise NotImplementedError

Expand All @@ -528,6 +542,7 @@ def weight_ih_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
"""Apply transposed Jacobian of the output w.r.t. weight_ih_l0 to a matrix.
Expand All @@ -536,16 +551,21 @@ def weight_ih_l0_jac_t_mat_prod(
g_inp: input gradients
g_out: output gradients
mat: Matrix the transposed Jacobian will be applied to.
Must have shape [V, T, N, H].
Must have shape [V, T, N, H]; but if used with sub-sampling, the batch
dimension is replaced by ``len(subsampling)``.
sum_batch: Whether to sum over the batch dimension on the fly.
subsampling: Indices of samples along the output's batch dimension that
should be considered. Defaults to ``None`` (use all samples).
Returns:
Jacobian-matrix product.
Has shape [V, T, N, H, I] if `sum_batch == False`.
Has shape [V, T, H, I] if `sum_batch == True`.
Has shape [V, N, *module.weight_ih_l0.shape] if ``sum_batch == False``; but
if used with sub-sampling, the batch dimension is replaced by
``len(subsampling)``. Has shape [V, *module.weight_ih_l0.shape] if
``sum_batch == True``.
"""
return self._weight_ih_l0_jac_t_mat_prod(
module, g_inp, g_out, mat, sum_batch=sum_batch
module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling
)

def _weight_ih_l0_jac_t_mat_prod(
Expand All @@ -555,6 +575,7 @@ def _weight_ih_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
raise NotImplementedError

Expand All @@ -567,6 +588,7 @@ def weight_hh_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
"""Apply transposed Jacobian of the output w.r.t. weight_hh_l0 to a matrix.
Expand All @@ -575,16 +597,21 @@ def weight_hh_l0_jac_t_mat_prod(
g_inp: input gradients
g_out: output gradients
mat: Matrix the transposed Jacobian will be applied to.
Must have shape [V, T, N, H].
Must have shape [V, T, N, H]; but if used with sub-sampling, the batch
dimension is replaced by ``len(subsampling)``.
sum_batch: Whether to sum over the batch dimension on the fly.
subsampling: Indices of samples along the output's batch dimension that
should be considered. Defaults to ``None`` (use all samples).
Returns:
Jacobian-matrix product.
Has shape [V, T, N, H, I] if `sum_batch == False`.
Has shape [V, T, H, I] if `sum_batch == True`.
Has shape [V, N, *module.weight_hh_l0.shape] if ``sum_batch == False``; but
if used with sub-sampling, the batch dimension is replaced by
``len(subsampling)``. Has shape [V, *module.weight_hh_l0.shape] if
``sum_batch == True``.
"""
return self._weight_hh_l0_jac_t_mat_prod(
module, g_inp, g_out, mat, sum_batch=sum_batch
module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling
)

def _weight_hh_l0_jac_t_mat_prod(
Expand All @@ -594,6 +621,7 @@ def _weight_hh_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
raise NotImplementedError

Expand Down
55 changes: 40 additions & 15 deletions backpack/core/derivatives/lstm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Partial derivatives for nn.LSTM."""
from typing import Tuple
from typing import List, Tuple

from torch import Tensor, cat, einsum, sigmoid, tanh, zeros
from torch.nn import LSTM

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils import TORCH_VERSION, VERSION_1_8_0
from backpack.utils.subsampling import get_batch_axis, subsample


class LSTMDerivatives(BaseParameterDerivatives):
Expand Down Expand Up @@ -54,7 +55,7 @@ def _check_parameters(module: LSTM) -> None:

@staticmethod
def _forward_pass(
module: LSTM, mat: Tensor
module: LSTM, mat: Tensor, subsampling: List[int] = None
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""This performs an additional forward pass and returns the hidden variables.
Expand All @@ -65,7 +66,8 @@ def _forward_pass(
Args:
module: module
mat: matrix, used to extract device and shapes
mat: matrix, used to extract device and shapes.
subsampling: Indices of active samples. Defaults to ``None`` (all samples).
Returns:
ifgo, c, c_tanh, h
Expand All @@ -83,16 +85,19 @@ def _forward_pass(
c: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype)
c_tanh: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype)
h: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype)

N_axis = get_batch_axis(module)
input0 = subsample(module.input0, dim=N_axis, subsampling=subsampling)
output = subsample(module.output, dim=N_axis, subsampling=subsampling)

for t in range(T):
ifgo[t] = (
einsum("hi,ni->nh", module.weight_ih_l0, module.input0[t])
einsum("hi,ni->nh", module.weight_ih_l0, input0[t])
+ module.bias_ih_l0
+ module.bias_hh_l0
)
if t != 0:
ifgo[t] += einsum(
"hg,ng->nh", module.weight_hh_l0, module.output[t - 1]
)
ifgo[t] += einsum("hg,ng->nh", module.weight_hh_l0, output[t - 1])
ifgo[t, :, H0:H1] = sigmoid(ifgo[t, :, H0:H1])
ifgo[t, :, H1:H2] = sigmoid(ifgo[t, :, H1:H2])
ifgo[t, :, H2:H3] = tanh(ifgo[t, :, H2:H3])
Expand All @@ -106,7 +111,9 @@ def _forward_pass(
return ifgo, c, c_tanh, h

@classmethod
def _ifgo_jac_t_mat_prod(cls, module: LSTM, mat: Tensor) -> Tensor:
def _ifgo_jac_t_mat_prod(
cls, module: LSTM, mat: Tensor, subsampling: List[int] = None
) -> Tensor:
V: int = mat.shape[0]
T: int = mat.shape[1]
N: int = mat.shape[2]
Expand All @@ -117,7 +124,7 @@ def _ifgo_jac_t_mat_prod(cls, module: LSTM, mat: Tensor) -> Tensor:
H3: int = 3 * H
H4: int = 4 * H

ifgo, c, c_tanh, _ = cls._forward_pass(module, mat)
ifgo, c, c_tanh, _ = cls._forward_pass(module, mat, subsampling=subsampling)

# backward pass
H_prod_t: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype)
Expand Down Expand Up @@ -288,10 +295,13 @@ def _bias_ih_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
self._check_parameters(module)

IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat)
IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(
module, mat, subsampling=subsampling
)

return einsum(f"vtnh->v{'' if sum_batch else 'n'}h", IFGO_prod)

Expand All @@ -302,9 +312,10 @@ def _bias_hh_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
return self._bias_ih_l0_jac_t_mat_prod(
module, g_inp, g_out, mat, sum_batch=sum_batch
module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling
)

def _weight_ih_l0_jac_t_mat_prod(
Expand All @@ -314,13 +325,20 @@ def _weight_ih_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
self._check_parameters(module)

IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat)
IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(
module, mat, subsampling=subsampling
)

return einsum(
f"vtnh,tni->v{'' if sum_batch else 'n'}hi", IFGO_prod, module.input0
f"vtnh,tni->v{'' if sum_batch else 'n'}hi",
IFGO_prod,
subsample(
module.input0, dim=get_batch_axis(module), subsampling=subsampling
),
)

def _weight_hh_l0_jac_t_mat_prod(
Expand All @@ -330,21 +348,28 @@ def _weight_hh_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
self._check_parameters(module)

N: int = mat.shape[2]
H: int = module.hidden_size

IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat)
IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(
module, mat, subsampling=subsampling
)

return einsum(
f"vtnh,tng->v{'' if sum_batch else 'n'}hg",
IFGO_prod,
cat(
[
zeros(1, N, H, device=mat.device, dtype=mat.dtype),
module.output[0:-1],
subsample(
module.output,
dim=get_batch_axis(module),
subsampling=subsampling,
)[0:-1],
],
dim=0,
),
Expand Down
Loading

0 comments on commit cbee344

Please sign in to comment.