Skip to content

Commit

Permalink
[core] Support subsampling in jac_t_mat_prod
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Mar 17, 2021
1 parent 7283969 commit ba2ff39
Show file tree
Hide file tree
Showing 24 changed files with 291 additions and 125 deletions.
29 changes: 17 additions & 12 deletions backpack/core/derivatives/avgpoolnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)

from backpack.core.derivatives.basederivatives import BaseDerivatives
from backpack.core.derivatives.subsampling import subsample_input


class AvgPoolNDDerivatives(BaseDerivatives):
Expand Down Expand Up @@ -116,16 +117,18 @@ def __check_jmp_out_as_pool(self, mat, jmp_as_pool, module):
N, C_out, D_out, H_out, W_out = module.output.shape
assert jmp_as_pool.shape == (V * N * C_out, 1, D_out, H_out, W_out)

def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
def _jac_t_mat_prod(self, module, g_inp, g_out, mat, subsampling=None):
self.check_exotic_parameters(module)

mat_as_pool = self.__make_single_channel(mat, module)
jmp_as_pool = self.__apply_jacobian_t_of(module, mat_as_pool)
self.__check_jmp_in_as_pool(mat, jmp_as_pool, module)
jmp_as_pool = self.__apply_jacobian_t_of(
module, mat_as_pool, subsampling=subsampling
)
self.__check_jmp_in_as_pool(mat, jmp_as_pool, module, subsampling=subsampling)

return self.reshape_like_input(jmp_as_pool, module)
return self.reshape_like_input(jmp_as_pool, module, subsampling=subsampling)

def __apply_jacobian_t_of(self, module, mat):
def __apply_jacobian_t_of(self, module, mat, subsampling=None):
C_for_conv_t = 1

convnd_t = self.convt(
Expand All @@ -142,26 +145,28 @@ def __apply_jacobian_t_of(self, module, mat):
convnd_t.weight.data = avg_kernel

V_N_C_in = mat.size(0)
input = subsample_input(module, subsampling=subsampling)
if self.N == 1:
_, _, L_in = module.input0.size()
_, _, L_in = input.size()
output_size = (V_N_C_in, C_for_conv_t, L_in)
elif self.N == 2:
_, _, H_in, W_in = module.input0.size()
_, _, H_in, W_in = input.size()
output_size = (V_N_C_in, C_for_conv_t, H_in, W_in)
elif self.N == 3:
_, _, D_in, H_in, W_in = module.input0.size()
_, _, D_in, H_in, W_in = input.size()
output_size = (V_N_C_in, C_for_conv_t, D_in, H_in, W_in)

return convnd_t(mat, output_size=output_size)

def __check_jmp_in_as_pool(self, mat, jmp_as_pool, module):
def __check_jmp_in_as_pool(self, mat, jmp_as_pool, module, subsampling=None):
V = mat.size(0)
input = subsample_input(module, subsampling=subsampling)
if self.N == 1:
N, C_in, L_in = module.input0.size()
N, C_in, L_in = input.size()
assert jmp_as_pool.shape == (V * N * C_in, 1, L_in)
elif self.N == 2:
N, C_in, H_in, W_in = module.input0.size()
N, C_in, H_in, W_in = input.size()
assert jmp_as_pool.shape == (V * N * C_in, 1, H_in, W_in)
elif self.N == 3:
N, C_in, D_in, H_in, W_in = module.input0.size()
N, C_in, D_in, H_in, W_in = input.size()
assert jmp_as_pool.shape == (V * N * C_in, 1, D_in, H_in, W_in)
46 changes: 37 additions & 9 deletions backpack/core/derivatives/basederivatives.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
"""Base classes for more flexible Jacobians and second-order information."""
import warnings

import torch

from backpack.core.derivatives import shape_check
from backpack.core.derivatives.subsampling import (
subsampled_input_shape,
subsampled_output_shape,
)


class BaseDerivatives:
Expand Down Expand Up @@ -71,7 +77,7 @@ def _jac_mat_prod(self, module, g_inp, g_out, mat):

@shape_check.jac_t_mat_prod_accept_vectors
@shape_check.jac_t_mat_prod_check_shapes
def jac_t_mat_prod(self, module, g_inp, g_out, mat):
def jac_t_mat_prod(self, module, g_inp, g_out, mat, subsampling=None):
"""Apply transposed input-ouput Jacobian of module output to a matrix.
Implicit application of Jᵀ:
Expand All @@ -83,16 +89,18 @@ def jac_t_mat_prod(self, module, g_inp, g_out, mat):
mat: torch.Tensor
Matrix the transposed Jacobian will be applied to.
Must have shape [V, N, C_out, H_out, ...].
subsampling: list(int)
Indices of samples to be considered. If ``None``, use all samples.
Returns:
--------
result: torch.Tensor
Transposed Jacobian-matrix product.
Has shape [V, N, C_in, H_in, ...].
"""
return self._jac_t_mat_prod(module, g_inp, g_out, mat)
return self._jac_t_mat_prod(module, g_inp, g_out, mat, subsampling=subsampling)

def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
def _jac_t_mat_prod(self, module, g_inp, g_out, mat, subsampling=None):
"""Internal implementation of transposed Jacobian."""
raise NotImplementedError

Expand Down Expand Up @@ -163,19 +171,39 @@ def _residual_mat_prod(self, module, g_inp, g_out, mat):
def _reshape_like(mat, like):
"""Reshape as like with trailing and additional 0th dimension.
If like is [N, C, H, ...], returns shape [-1, N, C, H, ...]
Args:
mat (torch.Tensor): Tensor that will be reshaped.
like (torch.Tensor, torch.Size, tuple, list): Tensor or shape
that will be used as trailing dimensions in the output.
Returns:
torch.Tensor: Reshape of ``mat`` with trailing dimensions identical
to ``like``. For example, if ``like`` is ``[N, C, H, ...]``, and
``mat`` has shape ``[V * N * C * H * ...]``, the result has shape
``[V, N, C, H, ...]``.
"""
V = -1
shape = (V, *like.shape)

if isinstance(like, torch.Tensor):
shape = (V, *like.shape)
else:
shape = (V, *like)

return mat.reshape(shape)

@classmethod
def reshape_like_input(cls, mat, module):
return cls._reshape_like(mat, module.input0)
def reshape_like_input(cls, mat, module, subsampling=None):
"""Reshape matrix like module input."""
input_shape = subsampled_input_shape(module, subsampling=subsampling)

return cls._reshape_like(mat, input_shape)

@classmethod
def reshape_like_output(cls, mat, module):
return cls._reshape_like(mat, module.output)
def reshape_like_output(cls, mat, module, subsampling=None):
"""Reshape matrix like module output."""
output_shape = subsampled_output_shape(module, subsampling=subsampling)

return cls._reshape_like(mat, output_shape)


class BaseParameterDerivatives(BaseDerivatives):
Expand Down
17 changes: 12 additions & 5 deletions backpack/core/derivatives/batchnorm1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def hessian_is_diagonal(self):
def _jac_mat_prod(self, module, g_inp, g_out, mat):
return self._jac_t_mat_prod(module, g_inp, g_out, mat)

def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
def _jac_t_mat_prod(self, module, g_inp, g_out, mat, subsampling=None):
"""
Note:
-----
Expand All @@ -29,7 +29,14 @@ def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
-----------
https://kevinzakka.github.io/2016/09/14/batch_normalization/
https://chrisyeh96.github.io/2017/08/28/deriving-batchnorm-backprop.html
Raises:
-------
NotImplementedError
If subsampling is enabled.
"""
self._no_subsampling(subsampling)
assert module.affine is True

N = module.input0.size(0)
Expand Down Expand Up @@ -99,7 +106,7 @@ def _weight_jac_t_mat_prod(
self, module, g_inp, g_out, mat, sum_batch, subsampling=None
):
self._maybe_warn_no_batch_summation(sum_batch)
self._no_subsampling_support(subsampling)
self._no_subsampling(subsampling)

x_hat, _ = self.get_normalized_input_and_var(module)
equation = "vni,ni->v{}i".format("" if sum_batch is True else "n")
Expand All @@ -114,7 +121,7 @@ def _bias_jac_t_mat_prod(
self, module, g_inp, g_out, mat, sum_batch=True, subsampling=None
):
self._maybe_warn_no_batch_summation(sum_batch)
self._no_subsampling_support(subsampling)
self._no_subsampling(subsampling)

if not sum_batch:
return mat
Expand All @@ -136,8 +143,8 @@ def _maybe_warn_no_batch_summation(sum_batch):
)

@staticmethod
def _no_subsampling_support(subsampling):
"""Subsampling is not supported.
def _no_subsampling(subsampling):
"""Raise exception if subsampling is enabled.
Args:
subsampling ([int] or None): Indices of samples to be considered. ``None``
Expand Down
4 changes: 2 additions & 2 deletions backpack/core/derivatives/conv_transposend.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,10 @@ def __jac(self, module, mat):
)
return jac_t_mat

def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
def _jac_t_mat_prod(self, module, g_inp, g_out, mat, subsampling=None):
mat_as_conv = rearrange(mat, "v n c ... -> (v n) c ...")
jmp_as_conv = self.__jac_t(module, mat_as_conv)
return self.reshape_like_input(jmp_as_conv, module)
return self.reshape_like_input(jmp_as_conv, module, subsampling=subsampling)

def __jac_t(self, module, mat):
jac_t = self.conv_func(
Expand Down
7 changes: 3 additions & 4 deletions backpack/core/derivatives/convnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,13 @@ def _jac_mat_prod(self, module, g_inp, g_out, mat):
)
return self.reshape_like_output(jmp_as_conv, module)

def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
def _jac_t_mat_prod(self, module, g_inp, g_out, mat, subsampling=None):
mat_as_conv = rearrange(mat, "v n c ... -> (v n) c ...")
jmp_as_conv = self.__jac_t(module, mat_as_conv)
return self.reshape_like_input(jmp_as_conv, module)
return self.reshape_like_input(jmp_as_conv, module, subsampling=subsampling)

def __jac_t(self, module, mat):
input_size = list(module.input0.size())
input_size[0] = mat.size(0)
input_size = [mat.size(0)] + list(module.input0.size())[1:]

grad_padding = _grad_input_padding(
grad_output=mat,
Expand Down
11 changes: 7 additions & 4 deletions backpack/core/derivatives/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,19 @@ class ElementwiseDerivatives(BaseDerivatives):
- If the activation is piece-wise linear: `hessian_is_zero`, else `d2f`.
"""

def df(self, module, g_inp, g_out):
def df(self, module, g_inp, g_out, subsampling=None):
"""Elementwise first derivative.
Args:
module (torch.nn.Module): PyTorch activation function module.
g_inp ([torch.Tensor]): Gradients of the module w.r.t. its inputs.
g_out ([torch.Tensor]): Gradients of the module w.r.t. its outputs.
subsampling list(int): Indices of samples to be considered. If ``None``,
use all samples.
Returns:
(torch.Tensor): Tensor containing the derivatives `f'(input[i]) ∀ i`.
(torch.Tensor): Tensor containing the derivatives ``f'(input[i]) ∀ i``
(for the entire, or a subset of the, mini-batch).
"""

raise NotImplementedError("First derivatives not implemented")
Expand Down Expand Up @@ -73,10 +76,10 @@ def hessian_is_diagonal(self):
"""
return True

def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
def _jac_t_mat_prod(self, module, g_inp, g_out, mat, subsampling=None):
self._no_inplace(module)

df_elementwise = self.df(module, g_inp, g_out)
df_elementwise = self.df(module, g_inp, g_out, subsampling=subsampling)
return einsum("...,v...->v...", (df_elementwise, mat))

def _jac_mat_prod(self, module, g_inp, g_out, mat):
Expand Down
11 changes: 8 additions & 3 deletions backpack/core/derivatives/elu.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from torch import exp, gt

from backpack.core.derivatives.elementwise import ElementwiseDerivatives
from backpack.core.derivatives.subsampling import subsample_input


class ELUDerivatives(ElementwiseDerivatives):
def hessian_is_zero(self):
"""`ELU''(x) ≠ 0`."""
return False

def df(self, module, g_inp, g_out):
def df(self, module, g_inp, g_out, subsampling=None):
"""First ELU derivative: `ELU'(x) = alpha * e^x if x < 0 else 1`. """
df_ELU = gt(module.input0, 0).float()
df_ELU[df_ELU == 0] = module.alpha * exp(module.input0[df_ELU == 0])
input = subsample_input(module, subsampling=subsampling)

df_ELU = gt(input, 0).float()
idx_zero = df_ELU == 0
df_ELU[idx_zero] = module.alpha * exp(input[idx_zero])

return df_ELU

def d2f(self, module, g_inp, g_out):
Expand Down
4 changes: 2 additions & 2 deletions backpack/core/derivatives/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def hessian_is_zero(self):
def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
return mat

def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
return self.reshape_like_input(mat, module)
def _jac_t_mat_prod(self, module, g_inp, g_out, mat, subsampling=None):
return self.reshape_like_input(mat, module, subsampling=subsampling)

def _jac_mat_prod(self, module, g_inp, g_out, mat):
return self.reshape_like_output(mat, module)
Expand Down
8 changes: 6 additions & 2 deletions backpack/core/derivatives/leakyrelu.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from torch import gt

from backpack.core.derivatives.elementwise import ElementwiseDerivatives
from backpack.core.derivatives.subsampling import subsample_input


class LeakyReLUDerivatives(ElementwiseDerivatives):
def hessian_is_zero(self):
"""`LeakyReLU''(x) = 0`."""
return True

def df(self, module, g_inp, g_out):
def df(self, module, g_inp, g_out, subsampling=None):
"""First LeakyReLU derivative:
`LeakyReLU'(x) = negative_slope if x < 0 else 1`."""
df_leakyrelu = gt(module.input0, 0).float()
input = subsample_input(module, subsampling=subsampling)

df_leakyrelu = gt(input, 0).float()
df_leakyrelu[df_leakyrelu == 0] = module.negative_slope

return df_leakyrelu
2 changes: 1 addition & 1 deletion backpack/core/derivatives/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class LinearDerivatives(BaseParameterDerivatives):
def hessian_is_zero(self):
return True

def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
def _jac_t_mat_prod(self, module, g_inp, g_out, mat, subsampling=None):
"""Apply transposed Jacobian of the output w.r.t. the input."""
d_input = module.weight.data
return einsum("oi,vno->vni", (d_input, mat))
Expand Down
7 changes: 5 additions & 2 deletions backpack/core/derivatives/logsigmoid.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from torch import exp

from backpack.core.derivatives.elementwise import ElementwiseDerivatives
from backpack.core.derivatives.subsampling import subsample_input


class LogSigmoidDerivatives(ElementwiseDerivatives):
def hessian_is_zero(self):
"""`logsigmoid''(x) ≠ 0`."""
return False

def df(self, module, g_inp, g_out):
def df(self, module, g_inp, g_out, subsampling=None):
"""First Logsigmoid derivative: `logsigmoid'(x) = 1 / (e^x + 1) `."""
return 1 / (exp(module.input0) + 1)
input = subsample_input(module, subsampling=subsampling)

return 1 / (exp(input) + 1)

def d2f(self, module, g_inp, g_out):
"""Second Logsigmoid derivative: `logsigmoid''(x) = - e^x / (e^x + 1)^2`."""
Expand Down

0 comments on commit ba2ff39

Please sign in to comment.