Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Application of RNN/LSTM param_jac_t for a subset of samples #197

Merged
merged 13 commits into from
Jul 6, 2021
Merged
60 changes: 44 additions & 16 deletions backpack/core/derivatives/basederivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def bias_ih_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
"""Apply transposed Jacobian of the output w.r.t. bias_ih_l0 to a matrix.

Expand All @@ -458,16 +459,21 @@ def bias_ih_l0_jac_t_mat_prod(
g_inp: input gradients
g_out: output gradients
mat: Matrix the transposed Jacobian will be applied to.
Must have shape [V, T, N, H].
Must have shape [V, T, N, H]; but if used with sub-sampling, the batch
dimension is replaced by ``len(subsampling)``.
sum_batch: Whether to sum over the batch dimension on the fly.
subsampling: Indices of samples along the output's batch dimension that
should be considered. Defaults to ``None`` (use all samples).

Returns:
Jacobian-matrix product.
Has shape [V, T, N, H] if `sum_batch == False`.
Has shape [V, T, H] if `sum_batch == True`.
Has shape [V, N, *module.bias_ih_l0.shape] if ``sum_batch == False``; but if
used with sub-sampling, the batch dimension is replaced by
``len(subsampling)``. Has shape [V, *module.bias_ih_l0.shape] if
``sum_batch == True``.
"""
return self._bias_ih_l0_jac_t_mat_prod(
module, g_inp, g_out, mat, sum_batch=sum_batch
module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling
)

def _bias_ih_l0_jac_t_mat_prod(
Expand All @@ -477,6 +483,7 @@ def _bias_ih_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
raise NotImplementedError

Expand All @@ -489,6 +496,7 @@ def bias_hh_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
"""Apply transposed Jacobian of the output w.r.t. bias_hh_l0 to a matrix.

Expand All @@ -497,16 +505,21 @@ def bias_hh_l0_jac_t_mat_prod(
g_inp: input gradients
g_out: output gradients
mat: Matrix the transposed Jacobian will be applied to.
Must have shape [V, T, N, H].
Must have shape [V, T, N, H]; but if used with sub-sampling, the batch
dimension is replaced by ``len(subsampling)``.
sum_batch: Whether to sum over the batch dimension on the fly.
subsampling: Indices of samples along the output's batch dimension that
should be considered. Defaults to ``None`` (use all samples).

Returns:
Jacobian-matrix product.
Has shape [V, T, N, H] if `sum_batch == False`.
Has shape [V, T, H] if `sum_batch == True`.
Has shape [V, N, *module.bias_hh_l0.shape] if ``sum_batch == False``; but if
used with sub-sampling, the batch dimension is replaced by
``len(subsampling)``. Has shape [V, *module.bias_hh_l0.shape] if
``sum_batch == True``.
"""
return self._bias_hh_l0_jac_t_mat_prod(
module, g_inp, g_out, mat, sum_batch=sum_batch
module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling
)

def _bias_hh_l0_jac_t_mat_prod(
Expand All @@ -516,6 +529,7 @@ def _bias_hh_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
raise NotImplementedError

Expand All @@ -528,6 +542,7 @@ def weight_ih_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
"""Apply transposed Jacobian of the output w.r.t. weight_ih_l0 to a matrix.

Expand All @@ -536,16 +551,21 @@ def weight_ih_l0_jac_t_mat_prod(
g_inp: input gradients
g_out: output gradients
mat: Matrix the transposed Jacobian will be applied to.
Must have shape [V, T, N, H].
Must have shape [V, T, N, H]; but if used with sub-sampling, the batch
dimension is replaced by ``len(subsampling)``.
sum_batch: Whether to sum over the batch dimension on the fly.
subsampling: Indices of samples along the output's batch dimension that
should be considered. Defaults to ``None`` (use all samples).

Returns:
Jacobian-matrix product.
Has shape [V, T, N, H, I] if `sum_batch == False`.
Has shape [V, T, H, I] if `sum_batch == True`.
Has shape [V, N, *module.weight_ih_l0.shape] if ``sum_batch == False``; but
if used with sub-sampling, the batch dimension is replaced by
``len(subsampling)``. Has shape [V, *module.weight_ih_l0.shape] if
``sum_batch == True``.
"""
return self._weight_ih_l0_jac_t_mat_prod(
module, g_inp, g_out, mat, sum_batch=sum_batch
module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling
)

def _weight_ih_l0_jac_t_mat_prod(
Expand All @@ -555,6 +575,7 @@ def _weight_ih_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
raise NotImplementedError

Expand All @@ -567,6 +588,7 @@ def weight_hh_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
"""Apply transposed Jacobian of the output w.r.t. weight_hh_l0 to a matrix.

Expand All @@ -575,16 +597,21 @@ def weight_hh_l0_jac_t_mat_prod(
g_inp: input gradients
g_out: output gradients
mat: Matrix the transposed Jacobian will be applied to.
Must have shape [V, T, N, H].
Must have shape [V, T, N, H]; but if used with sub-sampling, the batch
dimension is replaced by ``len(subsampling)``.
sum_batch: Whether to sum over the batch dimension on the fly.
subsampling: Indices of samples along the output's batch dimension that
should be considered. Defaults to ``None`` (use all samples).

Returns:
Jacobian-matrix product.
Has shape [V, T, N, H, I] if `sum_batch == False`.
Has shape [V, T, H, I] if `sum_batch == True`.
Has shape [V, N, *module.weight_hh_l0.shape] if ``sum_batch == False``; but
if used with sub-sampling, the batch dimension is replaced by
``len(subsampling)``. Has shape [V, *module.weight_hh_l0.shape] if
``sum_batch == True``.
"""
return self._weight_hh_l0_jac_t_mat_prod(
module, g_inp, g_out, mat, sum_batch=sum_batch
module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling
)

def _weight_hh_l0_jac_t_mat_prod(
Expand All @@ -594,6 +621,7 @@ def _weight_hh_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
raise NotImplementedError

Expand Down
55 changes: 40 additions & 15 deletions backpack/core/derivatives/lstm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Partial derivatives for nn.LSTM."""
from typing import Tuple
from typing import List, Tuple

from torch import Tensor, cat, einsum, sigmoid, tanh, zeros
from torch.nn import LSTM

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils import TORCH_VERSION, VERSION_1_8_0
from backpack.utils.subsampling import get_batch_axis, subsample


class LSTMDerivatives(BaseParameterDerivatives):
Expand Down Expand Up @@ -54,7 +55,7 @@ def _check_parameters(module: LSTM) -> None:

@staticmethod
def _forward_pass(
module: LSTM, mat: Tensor
module: LSTM, mat: Tensor, subsampling: List[int] = None
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""This performs an additional forward pass and returns the hidden variables.

Expand All @@ -65,7 +66,8 @@ def _forward_pass(

Args:
module: module
mat: matrix, used to extract device and shapes
mat: matrix, used to extract device and shapes.
subsampling: Indices of active samples. Defaults to ``None`` (all samples).

Returns:
ifgo, c, c_tanh, h
Expand All @@ -83,16 +85,19 @@ def _forward_pass(
c: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype)
c_tanh: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype)
h: Tensor = zeros(T, N, H, device=mat.device, dtype=mat.dtype)

N_axis = get_batch_axis(module)
input0 = subsample(module.input0, dim=N_axis, subsampling=subsampling)
output = subsample(module.output, dim=N_axis, subsampling=subsampling)

for t in range(T):
ifgo[t] = (
einsum("hi,ni->nh", module.weight_ih_l0, module.input0[t])
einsum("hi,ni->nh", module.weight_ih_l0, input0[t])
+ module.bias_ih_l0
+ module.bias_hh_l0
)
if t != 0:
ifgo[t] += einsum(
"hg,ng->nh", module.weight_hh_l0, module.output[t - 1]
)
ifgo[t] += einsum("hg,ng->nh", module.weight_hh_l0, output[t - 1])
ifgo[t, :, H0:H1] = sigmoid(ifgo[t, :, H0:H1])
ifgo[t, :, H1:H2] = sigmoid(ifgo[t, :, H1:H2])
ifgo[t, :, H2:H3] = tanh(ifgo[t, :, H2:H3])
Expand All @@ -106,7 +111,9 @@ def _forward_pass(
return ifgo, c, c_tanh, h

@classmethod
def _ifgo_jac_t_mat_prod(cls, module: LSTM, mat: Tensor) -> Tensor:
def _ifgo_jac_t_mat_prod(
cls, module: LSTM, mat: Tensor, subsampling: List[int] = None
) -> Tensor:
V: int = mat.shape[0]
T: int = mat.shape[1]
N: int = mat.shape[2]
Expand All @@ -117,7 +124,7 @@ def _ifgo_jac_t_mat_prod(cls, module: LSTM, mat: Tensor) -> Tensor:
H3: int = 3 * H
H4: int = 4 * H

ifgo, c, c_tanh, _ = cls._forward_pass(module, mat)
ifgo, c, c_tanh, _ = cls._forward_pass(module, mat, subsampling=subsampling)

# backward pass
H_prod_t: Tensor = zeros(V, N, H, device=mat.device, dtype=mat.dtype)
Expand Down Expand Up @@ -288,10 +295,13 @@ def _bias_ih_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
self._check_parameters(module)

IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat)
IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(
module, mat, subsampling=subsampling
)

return einsum(f"vtnh->v{'' if sum_batch else 'n'}h", IFGO_prod)

Expand All @@ -302,9 +312,10 @@ def _bias_hh_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
return self._bias_ih_l0_jac_t_mat_prod(
module, g_inp, g_out, mat, sum_batch=sum_batch
module, g_inp, g_out, mat, sum_batch=sum_batch, subsampling=subsampling
)

def _weight_ih_l0_jac_t_mat_prod(
Expand All @@ -314,13 +325,20 @@ def _weight_ih_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
self._check_parameters(module)

IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat)
IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(
module, mat, subsampling=subsampling
)

return einsum(
f"vtnh,tni->v{'' if sum_batch else 'n'}hi", IFGO_prod, module.input0
f"vtnh,tni->v{'' if sum_batch else 'n'}hi",
IFGO_prod,
subsample(
module.input0, dim=get_batch_axis(module), subsampling=subsampling
),
)

def _weight_hh_l0_jac_t_mat_prod(
Expand All @@ -330,21 +348,28 @@ def _weight_hh_l0_jac_t_mat_prod(
g_out: Tuple[Tensor],
mat: Tensor,
sum_batch: bool = True,
subsampling: List[int] = None,
) -> Tensor:
self._check_parameters(module)

N: int = mat.shape[2]
H: int = module.hidden_size

IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(module, mat)
IFGO_prod: Tensor = self._ifgo_jac_t_mat_prod(
module, mat, subsampling=subsampling
)

return einsum(
f"vtnh,tng->v{'' if sum_batch else 'n'}hg",
IFGO_prod,
cat(
[
zeros(1, N, H, device=mat.device, dtype=mat.dtype),
module.output[0:-1],
subsample(
module.output,
dim=get_batch_axis(module),
subsampling=subsampling,
)[0:-1],
],
dim=0,
),
Expand Down