Skip to content

Commit

Permalink
Merge b25c72c into 36c6bb3
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jul 22, 2021
2 parents 36c6bb3 + b25c72c commit fff6b44
Show file tree
Hide file tree
Showing 29 changed files with 596 additions and 329 deletions.
73 changes: 24 additions & 49 deletions backpack/core/derivatives/avgpoolnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,22 @@
Average pooling can be expressed as convolution over grouped channels with a constant
kernel.
"""
from typing import Any, Tuple
from typing import Any, List, Tuple

import torch.nn
from einops import rearrange
from torch.nn import (
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
Module,
)
from torch import Tensor, ones_like
from torch.nn import Module

from backpack.core.derivatives.basederivatives import BaseDerivatives
from backpack.utils.conv import get_conv_module
from backpack.utils.conv_transpose import get_conv_transpose_module


class AvgPoolNDDerivatives(BaseDerivatives):
def __init__(self, N):
def __init__(self, N: int):
self.conv = get_conv_module(N)
self.convt = get_conv_transpose_module(N)
self.N = N
if self.N == 1:
self.conv = Conv1d
self.convt = ConvTranspose1d
elif self.N == 2:
self.conv = Conv2d
self.convt = ConvTranspose2d
elif self.N == 3:
self.conv = Conv3d
self.convt = ConvTranspose3d

def check_parameters(self, module: Module) -> None:
assert module.count_include_pad, (
Expand Down Expand Up @@ -101,31 +88,31 @@ def __apply_jacobian_of(self, module, mat):
).to(module.input0.device)

convnd.weight.requires_grad = False
avg_kernel = torch.ones_like(convnd.weight) / convnd.weight.numel()
avg_kernel = ones_like(convnd.weight) / convnd.weight.numel()
convnd.weight.data = avg_kernel

return convnd(mat)

def __check_jmp_out_as_pool(self, mat, jmp_as_pool, module):
V = mat.size(0)
if self.N == 1:
N, C_out, L_out = module.output.shape
assert jmp_as_pool.shape == (V * N * C_out, 1, L_out)
elif self.N == 2:
N, C_out, H_out, W_out = module.output.shape
assert jmp_as_pool.shape == (V * N * C_out, 1, H_out, W_out)
elif self.N == 3:
N, C_out, D_out, H_out, W_out = module.output.shape
assert jmp_as_pool.shape == (V * N * C_out, 1, D_out, H_out, W_out)

def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
V = mat.shape[0]
N, C_out = module.output.shape[:2]

assert jmp_as_pool.shape == (V * N * C_out, 1) + module.output.shape[2:]

def _jac_t_mat_prod(
self,
module: Module,
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
mat: Tensor,
subsampling: List[int] = None,
) -> Tensor:
self.check_parameters(module)

mat_as_pool = self.__make_single_channel(mat, module)
jmp_as_pool = self.__apply_jacobian_t_of(module, mat_as_pool)
self.__check_jmp_in_as_pool(mat, jmp_as_pool, module)

return self.reshape_like_input(jmp_as_pool, module)
return self.reshape_like_input(jmp_as_pool, module, subsampling=subsampling)

def __apply_jacobian_t_of(self, module, mat):
stride, kernel_size, padding = self.get_avg_pool_parameters(module)
Expand All @@ -141,22 +128,10 @@ def __apply_jacobian_t_of(self, module, mat):
).to(module.input0.device)

convnd_t.weight.requires_grad = False
avg_kernel = torch.ones_like(convnd_t.weight) / convnd_t.weight.numel()
avg_kernel = ones_like(convnd_t.weight) / convnd_t.weight.numel()
convnd_t.weight.data = avg_kernel

V_N_C_in = mat.size(0)
output_size = (V_N_C_in, C_for_conv_t) + tuple(module.input0.shape[2:])

return convnd_t(mat, output_size=output_size)

def __check_jmp_in_as_pool(self, mat, jmp_as_pool, module):
V = mat.size(0)
if self.N == 1:
N, C_in, L_in = module.input0.size()
assert jmp_as_pool.shape == (V * N * C_in, 1, L_in)
elif self.N == 2:
N, C_in, H_in, W_in = module.input0.size()
assert jmp_as_pool.shape == (V * N * C_in, 1, H_in, W_in)
elif self.N == 3:
N, C_in, D_in, H_in, W_in = module.input0.size()
assert jmp_as_pool.shape == (V * N * C_in, 1, D_in, H_in, W_in)
48 changes: 34 additions & 14 deletions backpack/core/derivatives/basederivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.nn import Module

from backpack.core.derivatives import shape_check
from backpack.utils.subsampling import get_batch_axis


class BaseDerivatives(ABC):
Expand Down Expand Up @@ -80,7 +81,12 @@ def _jac_mat_prod(
@shape_check.jac_t_mat_prod_accept_vectors
@shape_check.jac_t_mat_prod_check_shapes
def jac_t_mat_prod(
self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor
self,
module: Module,
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
mat: Tensor,
subsampling: List[int] = None,
) -> Tensor:
"""Apply transposed input-ouput Jacobian of module output to a matrix.
Expand All @@ -93,16 +99,25 @@ def jac_t_mat_prod(
g_inp: input gradients
g_out: output gradients
mat: Matrix the transposed Jacobian will be applied to.
Must have shape [V, N, C_out, H_out, ...].
Must have shape ``[V, *module.output.shape]``; but if used with
sub-sampling, the batch dimension is replaced by ``len(subsampling)``.
subsampling: Indices of samples along the output's batch dimension that
should be considered. Defaults to ``None`` (use all samples).
Returns:
Transposed Jacobian-matrix product.
Has shape [V, N, C_in, H_in, ...].
Has shape ``[V, *module.input0.shape]``; but if used with sub-sampling,
the batch dimension is replaced by ``len(subsampling)``.
"""
return self._jac_t_mat_prod(module, g_inp, g_out, mat)
return self._jac_t_mat_prod(module, g_inp, g_out, mat, subsampling=subsampling)

def _jac_t_mat_prod(
self, module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor], mat: Tensor
self,
module: Module,
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
mat: Tensor,
subsampling: List[int] = None,
) -> Tensor:
raise NotImplementedError

Expand Down Expand Up @@ -244,34 +259,39 @@ def _residual_mat_prod(
raise NotImplementedError

@staticmethod
def _reshape_like(mat: Tensor, like: Tensor) -> Tensor:
def _reshape_like(mat: Tensor, shape: Tuple[int]) -> Tensor:
"""Reshape as like with trailing and additional 0th dimension.
If like is [N, C, H, ...], returns shape [-1, N, C, H, ...]
Args:
mat: matrix to reshape
like: matrix with target shape
mat: Matrix to reshape.
shape: Trailing target shape.
Returns:
reshaped matrix
"""
V = -1
shape = (V, *like.shape)
return mat.reshape(shape)
return mat.reshape(-1, *shape)

@classmethod
def reshape_like_input(cls, mat: Tensor, module: Module) -> Tensor:
def reshape_like_input(
cls, mat: Tensor, module: Module, subsampling: List[int] = None
) -> Tensor:
"""Reshapes matrix according to input.
Args:
mat: matrix to reshape
module: module which input shape is used
subsampling: Indices of active samples. ``None`` means use all samples.
Returns:
reshaped matrix
"""
return cls._reshape_like(mat, module.input0)
shape = list(module.input0.shape)
if subsampling is not None:
shape[get_batch_axis(module)] = len(subsampling)

return cls._reshape_like(mat, shape)

@classmethod
def reshape_like_output(cls, mat: Tensor, module: Module) -> Tensor:
Expand All @@ -284,7 +304,7 @@ def reshape_like_output(cls, mat: Tensor, module: Module) -> Tensor:
Returns:
reshaped matrix
"""
return cls._reshape_like(mat, module.output)
return cls._reshape_like(mat, module.output.shape)


class BaseParameterDerivatives(BaseDerivatives, ABC):
Expand Down
7 changes: 7 additions & 0 deletions backpack/core/derivatives/batchnorm_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,17 @@ def _jac_t_mat_prod(
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
mat: Tensor,
subsampling: List[int] = None,
) -> Tensor:
self._check_parameters(module)
N: int = self._get_n_axis(module)
if module.training:

if subsampling is not None:
raise NotImplementedError(
"BatchNorm VJP sub-sampling is not defined in train mode."
)

denominator: int = self._get_denominator(module)
x_hat, var = self._get_normalized_input_and_var(module)
ivar = 1.0 / (var + module.eps).sqrt()
Expand Down
57 changes: 23 additions & 34 deletions backpack/core/derivatives/conv_transposend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,48 +4,29 @@
from einops import rearrange
from numpy import prod
from torch import Tensor, einsum
from torch.nn import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d
from torch.nn.functional import (
conv1d,
conv2d,
conv3d,
conv_transpose1d,
conv_transpose2d,
conv_transpose3d,
)
from torch.nn import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, Module
from torch.nn.grad import _grad_input_padding

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils.conv_transpose import unfold_by_conv_transpose
from backpack.utils.conv import get_conv_function
from backpack.utils.conv_transpose import (
get_conv_transpose_function,
unfold_by_conv_transpose,
)
from backpack.utils.subsampling import subsample


class ConvTransposeNDDerivatives(BaseParameterDerivatives):
"""Base class for partial derivatives of transpose convolution."""

def __init__(self, N):
"""Store convolution dimension and operations.
def __init__(self, N: int):
"""Store transpose convolution dimension and operations.
Args:
N (int): Convolution dimension. Must be ``1``, ``2``, or ``3``.
Raises:
ValueError: If convolution dimension is unsupported.
N: Transpose convolution dimension.
"""
if N == 1:
self.module = ConvTranspose1d
self.conv_func = conv1d
self.conv_transpose_func = conv_transpose1d
elif N == 2:
self.module = ConvTranspose2d
self.conv_func = conv2d
self.conv_transpose_func = conv_transpose2d
elif N == 3:
self.module = ConvTranspose3d
self.conv_func = conv3d
self.conv_transpose_func = conv_transpose3d
else:
raise ValueError(f"ConvTranspose{N}d not supported.")
self.conv_func = get_conv_function(N)
self.conv_transpose_func = get_conv_transpose_function(N)
self.conv_dims = N

def hessian_is_zero(self, module):
Expand Down Expand Up @@ -150,7 +131,7 @@ def __jac(self, module, mat):
dilation=module.dilation,
)

jac_t_mat = conv_transpose1d(
jac_t_mat = self.conv_transpose_func(
input=mat,
weight=module.weight,
bias=None,
Expand All @@ -160,14 +141,22 @@ def __jac(self, module, mat):
groups=module.groups,
dilation=module.dilation,
)

return jac_t_mat

def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
def _jac_t_mat_prod(
self,
module: Module,
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
mat: Tensor,
subsampling: List[int] = None,
) -> Tensor:
mat_as_conv = rearrange(mat, "v n c ... -> (v n) c ...")
jmp_as_conv = self.__jac_t(module, mat_as_conv)
return self.reshape_like_input(jmp_as_conv, module)
return self.reshape_like_input(jmp_as_conv, module, subsampling=subsampling)

def __jac_t(self, module, mat):
def __jac_t(self, module: Module, mat: Tensor) -> Tensor:
jac_t = self.conv_func(
mat,
module.weight,
Expand Down

0 comments on commit fff6b44

Please sign in to comment.